Ver Fonte

current_task and Task.request are now LocalStack's. Maybe fixes #521

Ask Solem há 13 anos atrás
pai
commit
e2b052c2d8

+ 1 - 1
celery/__compat__.py

@@ -89,7 +89,7 @@ class class_property(object):
 
     def __init__(self, fget=None, fset=None):
         assert fget and isinstance(fget, classmethod)
-        assert fset and isinstance(fset, classmethod)
+        assert isinstance(fset, classmethod) if fset else True
         self.__get = fget
         self.__set = fset
 

+ 4 - 5
celery/app/state.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 import threading
 
-from celery.local import Proxy
+from celery.local import Proxy, LocalStack
 
 default_app = None
 
@@ -12,11 +12,10 @@ class _TLS(threading.local):
     #: sets this, so it will always contain the last instantiated app,
     #: and is the default app returned by :func:`app_or_default`.
     current_app = None
-
-    #: The currently executing task.
-    current_task = None
 _tls = _TLS()
 
+_task_stack = LocalStack()
+
 
 def set_default_app(app):
     global default_app
@@ -28,7 +27,7 @@ def get_current_app():
 
 
 def get_current_task():
-    return getattr(_tls, "current_task", None)
+    return _task_stack.top
 
 
 current_app = Proxy(get_current_app)

+ 20 - 7
celery/app/task.py

@@ -14,7 +14,6 @@ from __future__ import absolute_import
 
 import logging
 import sys
-import threading
 
 from kombu import Exchange
 from kombu.utils import cached_property
@@ -24,6 +23,7 @@ from celery import states
 from celery.__compat__ import class_property
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
+from celery.local import LocalStack
 from celery.result import EagerResult
 from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
 from celery.utils.functional import mattrgetter, maybe_list
@@ -43,7 +43,7 @@ extract_exec_options = mattrgetter("queue", "routing_key",
                                    "compression", "expires")
 
 
-class Context(threading.local):
+class Context(object):
     # Default context
     logfile = None
     loglevel = None
@@ -61,8 +61,11 @@ class Context(threading.local):
     errbacks = None
     _children = None   # see property
 
-    def update(self, d, **kwargs):
-        self.__dict__.update(d, **kwargs)
+    def __init__(self, *args, **kwargs):
+        self.update(*args, **kwargs)
+
+    def update(self, *args, **kwargs):
+        self.__dict__.update(*args, **kwargs)
 
     def clear(self):
         self.__dict__.clear()
@@ -172,9 +175,6 @@ class BaseTask(object):
     #: Deprecated and scheduled for removal in v3.0.
     accept_magic_kwargs = False
 
-    #: Request context (set when task is applied).
-    request = Context()
-
     #: Destination queue.  The queue needs to exist
     #: in :setting:`CELERY_QUEUES`.  The `routing_key`, `exchange` and
     #: `exchange_type` attributes will be ignored if this is set.
@@ -324,6 +324,9 @@ class BaseTask(object):
         if not was_bound:
             self.annotate()
 
+        self.request_stack = LocalStack()
+        self.request_stack.push(Context())
+
         # PeriodicTask uses this to add itself to the PeriodicTask schedule.
         self.on_bound(app)
 
@@ -845,6 +848,12 @@ class BaseTask(object):
         """
         request.execute_using_pool(pool, loglevel, logfile)
 
+    def push_request(self, *args, **kwargs):
+        self.request_stack.push(Context(*args, **kwargs))
+
+    def pop_request(self):
+        self.request_stack.pop()
+
     def __repr__(self):
         """`repr(task)`"""
         return "<@task: %s>" % (self.name, )
@@ -853,6 +862,10 @@ class BaseTask(object):
     def logger(self):
         return self.get_logger()
 
+    @property
+    def request(self):
+        return self.request_stack.top
+
     @property
     def __name__(self):
         return self.__class__.__name__

+ 2 - 2
celery/contrib/batches.py

@@ -79,14 +79,14 @@ def consume_queue(queue):
 
 
 def apply_batches_task(task, args, loglevel, logfile):
-    task.request.update({"loglevel": loglevel, "logfile": logfile})
+    task.push_request(loglevel=loglevel, logfile=logfile)
     try:
         result = task(*args)
     except Exception, exc:
         result = None
         task.logger.error("Error: %r", exc, exc_info=True)
     finally:
-        task.request.clear()
+        task.pop_request()
     return result
 
 

+ 3 - 2
celery/events/__init__.py

@@ -91,6 +91,8 @@ class EventDispatcher(object):
         self.on_disabled = set()
 
         self.enabled = enabled
+        if not connection and channel:
+            self.connection = channel.connection.client
         if self.enabled:
             self.enable()
 
@@ -151,8 +153,7 @@ class EventDispatcher(object):
     def close(self):
         """Close the event dispatcher."""
         self.mutex.locked() and self.mutex.release()
-        if self.publisher is not None:
-            self.publisher = None
+        self.publisher = None
 
 
 class EventReceiver(object):

+ 212 - 0
celery/local.py

@@ -7,12 +7,25 @@
     needs to be loaded as soon as possible, and that
     shall not load any third party modules.
 
+    Parts of this module is Copyright by Werkzeug Team.
+
     :copyright: (c) 2009 - 2012 by Ask Solem.
     :license: BSD, see LICENSE for more details.
 
 """
 from __future__ import absolute_import
 
+# since each thread has its own greenlet we can just use those as identifiers
+# for the context.  If greenlets are not available we fall back to the
+# current thread ident.
+try:
+    from greenlet import getcurrent as get_ident
+except ImportError:  # pragma: no cover
+    try:
+        from thread import get_ident  # noqa
+    except ImportError:  # pragma: no cover
+        from dummy_thread import get_ident  # noqa
+
 
 def try_import(module, default=None):
     """Try to import and return module, or return
@@ -201,3 +214,202 @@ def maybe_evaluate(obj):
         return obj.__maybe_evaluate__()
     except AttributeError:
         return obj
+
+
+def release_local(local):
+    """Releases the contents of the local for the current context.
+    This makes it possible to use locals without a manager.
+
+    Example::
+
+        >>> loc = Local()
+        >>> loc.foo = 42
+        >>> release_local(loc)
+        >>> hasattr(loc, 'foo')
+        False
+
+    With this function one can release :class:`Local` objects as well
+    as :class:`StackLocal` objects.  However it is not possible to
+    release data held by proxies that way, one always has to retain
+    a reference to the underlying local object in order to be able
+    to release it.
+
+    .. versionadded:: 0.6.1
+    """
+    local.__release_local__()
+
+
+class Local(object):
+    __slots__ = ('__storage__', '__ident_func__')
+
+    def __init__(self):
+        object.__setattr__(self, '__storage__', {})
+        object.__setattr__(self, '__ident_func__', get_ident)
+
+    def __iter__(self):
+        return iter(self.__storage__.items())
+
+    def __call__(self, proxy):
+        """Create a proxy for a name."""
+        return Proxy(self, proxy)
+
+    def __release_local__(self):
+        self.__storage__.pop(self.__ident_func__(), None)
+
+    def __getattr__(self, name):
+        try:
+            return self.__storage__[self.__ident_func__()][name]
+        except KeyError:
+            raise AttributeError(name)
+
+    def __setattr__(self, name, value):
+        ident = self.__ident_func__()
+        storage = self.__storage__
+        try:
+            storage[ident][name] = value
+        except KeyError:
+            storage[ident] = {name: value}
+
+    def __delattr__(self, name):
+        try:
+            del self.__storage__[self.__ident_func__()][name]
+        except KeyError:
+            raise AttributeError(name)
+
+
+class LocalStack(object):
+    """This class works similar to a :class:`Local` but keeps a stack
+    of objects instead.  This is best explained with an example::
+
+        >>> ls = LocalStack()
+        >>> ls.push(42)
+        >>> ls.top
+        42
+        >>> ls.push(23)
+        >>> ls.top
+        23
+        >>> ls.pop()
+        23
+        >>> ls.top
+        42
+
+    They can be force released by using a :class:`LocalManager` or with
+    the :func:`release_local` function but the correct way is to pop the
+    item from the stack after using.  When the stack is empty it will
+    no longer be bound to the current context (and as such released).
+
+    By calling the stack without arguments it returns a proxy that resolves to
+    the topmost item on the stack.
+
+    """
+
+    def __init__(self):
+        self._local = Local()
+
+    def __release_local__(self):
+        self._local.__release_local__()
+
+    def _get__ident_func__(self):
+        return self._local.__ident_func__
+
+    def _set__ident_func__(self, value):
+        object.__setattr__(self._local, '__ident_func__', value)
+    __ident_func__ = property(_get__ident_func__, _set__ident_func__)
+    del _get__ident_func__, _set__ident_func__
+
+    def __call__(self):
+        def _lookup():
+            rv = self.top
+            if rv is None:
+                raise RuntimeError('object unbound')
+            return rv
+        return Proxy(_lookup)
+
+    def push(self, obj):
+        """Pushes a new item to the stack"""
+        rv = getattr(self._local, 'stack', None)
+        if rv is None:
+            self._local.stack = rv = []
+        rv.append(obj)
+        return rv
+
+    def pop(self):
+        """Removes the topmost item from the stack, will return the
+        old value or `None` if the stack was already empty.
+        """
+        stack = getattr(self._local, 'stack', None)
+        if stack is None:
+            return None
+        elif len(stack) == 1:
+            release_local(self._local)
+            return stack[-1]
+        else:
+            return stack.pop()
+
+    @property
+    def top(self):
+        """The topmost item on the stack.  If the stack is empty,
+        `None` is returned.
+        """
+        try:
+            return self._local.stack[-1]
+        except (AttributeError, IndexError):
+            return None
+
+
+class LocalManager(object):
+    """Local objects cannot manage themselves. For that you need a local
+    manager.  You can pass a local manager multiple locals or add them later
+    by appending them to `manager.locals`.  Everytime the manager cleans up
+    it, will clean up all the data left in the locals for this context.
+
+    The `ident_func` parameter can be added to override the default ident
+    function for the wrapped locals.
+
+    .. versionchanged:: 0.6.1
+       Instead of a manager the :func:`release_local` function can be used
+       as well.
+
+    .. versionchanged:: 0.7
+       `ident_func` was added.
+    """
+
+    def __init__(self, locals=None, ident_func=None):
+        if locals is None:
+            self.locals = []
+        elif isinstance(locals, Local):
+            self.locals = [locals]
+        else:
+            self.locals = list(locals)
+        if ident_func is not None:
+            self.ident_func = ident_func
+            for local in self.locals:
+                object.__setattr__(local, '__ident_func__', ident_func)
+        else:
+            self.ident_func = get_ident
+
+    def get_ident(self):
+        """Return the context identifier the local objects use internally for
+        this context.  You cannot override this method to change the behavior
+        but use it to link other context local objects (such as SQLAlchemy's
+        scoped sessions) to the Werkzeug locals.
+
+        .. versionchanged:: 0.7
+           You can pass a different ident function to the local manager that
+           will then be propagated to all the locals passed to the
+           constructor.
+        """
+        return self.ident_func()
+
+    def cleanup(self):
+        """Manually clean up the data in the locals for this context.  Call
+        this at the end of the request or use `make_middleware()`.
+        """
+        for local in self.locals:
+            release_local(local)
+
+    def __repr__(self):
+        return '<%s storages: %d>' % (
+            self.__class__.__name__,
+            len(self.locals)
+        )

+ 8 - 2
celery/task/base.py

@@ -14,14 +14,15 @@
 from __future__ import absolute_import
 
 from celery import current_app
-from celery.__compat__ import reclassmethod
+from celery.__compat__ import class_property, reclassmethod
 from celery.app.task import Context, TaskType, BaseTask  # noqa
 from celery.schedules import maybe_schedule
 
 #: list of methods that must be classmethods in the old API.
 _COMPAT_CLASSMETHODS = (
     "get_logger", "establish_connection", "get_publisher", "get_consumer",
-    "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask")
+    "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask",
+    "push_request", "pop_request")
 
 
 class Task(BaseTask):
@@ -36,6 +37,11 @@ class Task(BaseTask):
     for name in _COMPAT_CLASSMETHODS:
         locals()[name] = reclassmethod(getattr(BaseTask, name))
 
+    @classmethod
+    def _get_request(self):
+        return self.request_stack.top
+    request = class_property(_get_request)
+
 
 class PeriodicTask(Task):
     """A periodic task is a task that adds itself to the

+ 11 - 10
celery/task/trace.py

@@ -28,8 +28,8 @@ from kombu.utils import kwdict
 
 from celery import current_app
 from celery import states, signals
-from celery.app.state import _tls
-from celery.app.task import BaseTask
+from celery.app.state import _task_stack
+from celery.app.task import BaseTask, Context
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import RetryTaskError
 from celery.utils.serialization import get_pickleable_exception
@@ -146,15 +146,15 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
 
     task_on_success = task.on_success
     task_after_return = task.after_return
-    task_request = task.request
 
     store_result = backend.store_result
     backend_cleanup = backend.process_cleanup
 
     pid = os.getpid()
 
-    update_request = task_request.update
-    clear_request = task_request.clear
+    request_stack = task.request_stack
+    push_request = request_stack.push
+    pop_request = request_stack.pop
     on_chord_part_return = backend.on_chord_part_return
 
     from celery import canvas
@@ -164,9 +164,10 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         R = I = None
         kwargs = kwdict(kwargs)
         try:
-            _tls.current_task = task
-            update_request(request or {}, args=args,
-                           called_directly=False, kwargs=kwargs)
+            _task_stack.push(task)
+            task_request = Context(request or {}, args=args,
+                                   called_directly=False, kwargs=kwargs)
+            push_request(task_request)
             try:
                 # -*- PRE -*-
                 send_prerun(sender=task, task_id=uuid, task=task,
@@ -220,8 +221,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 send_postrun(sender=task, task_id=uuid, task=task,
                             args=args, kwargs=kwargs, retval=retval)
             finally:
-                _tls.current_task = None
-                clear_request()
+                _task_stack.pop()
+                pop_request()
                 if not eager:
                     try:
                         backend_cleanup()

+ 3 - 3
celery/tests/app/test_builtins.py

@@ -4,7 +4,7 @@ from mock import Mock
 
 from celery import current_app as app, group, task, chord
 from celery.app import builtins
-from celery.app.state import _tls
+from celery.app.state import _task_stack
 from celery.tests.utils import Case
 
 
@@ -61,13 +61,13 @@ class test_group(Case):
         x.apply_async()
 
     def test_apply_async_with_parent(self):
-        _tls.current_task = add
+        _task_stack.push(add)
         try:
             x = group([add.s(4, 4), add.s(8, 8)])
             x.apply_async()
             self.assertTrue(add.request.children)
         finally:
-            _tls.current_task = None
+            _task_stack.pop()
 
 
 class test_chain(Case):

+ 4 - 4
celery/tests/app/test_log.py

@@ -230,12 +230,12 @@ class test_task_logger(test_default_logger):
             pass
         test_task.logger.handlers = []
         self.task = test_task
-        from celery.app.state import _tls
-        _tls.current_task = test_task
+        from celery.app.state import _task_stack
+        _task_stack.push(test_task)
 
     def tearDown(self):
-        from celery.app.state import _tls
-        _tls.current_task = None
+        from celery.app.state import _task_stack
+        _task_stack.pop()
 
     def setup_logger(self, *args, **kwargs):
         return log.setup_task_loggers(*args, **kwargs)

+ 7 - 4
celery/tests/events/test_events.py

@@ -132,10 +132,11 @@ class test_EventReceiver(AppCase):
         def my_handler(event):
             got_event[0] = True
 
-        r = events.EventReceiver(object(),
+        connection = Mock()
+        connection.transport_cls = "memory"
+        r = events.EventReceiver(connection,
                                  handlers={"world-war": my_handler},
-                                 node_id="celery.tests",
-                                 )
+                                 node_id="celery.tests")
         r._receive(message, object())
         self.assertTrue(got_event[0])
 
@@ -148,7 +149,9 @@ class test_EventReceiver(AppCase):
         def my_handler(event):
             got_event[0] = True
 
-        r = events.EventReceiver(object(), node_id="celery.tests")
+        connection = Mock()
+        connection.transport_cls = "memory"
+        r = events.EventReceiver(connection, node_id="celery.tests")
         events.EventReceiver.handlers["*"] = my_handler
         try:
             r._receive(message, object())

+ 0 - 94
celery/tests/tasks/test_context.py

@@ -1,8 +1,6 @@
 # -*- coding: utf-8 -*-"
 from __future__ import absolute_import
 
-import threading
-
 from celery.task.base import Context
 from celery.tests.utils import Case
 
@@ -22,21 +20,6 @@ def get_context_as_dict(ctx, getter=getattr):
 default_context = get_context_as_dict(Context())
 
 
-# Manipulate the a context in a separate thread
-class ContextManipulator(threading.Thread):
-    def __init__(self, ctx, *args):
-        super(ContextManipulator, self).__init__()
-        self.daemon = True
-        self.ctx = ctx
-        self.args = args
-        self.result = None
-
-    def run(self):
-        for func, args in self.args:
-            func(self.ctx, *args)
-        self.result = get_context_as_dict(self.ctx)
-
-
 class test_Context(Case):
 
     def test_default_context(self):
@@ -45,14 +28,6 @@ class test_Context(Case):
         defaults = dict(default_context, children=[])
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
-    def test_default_context_threaded(self):
-        ctx = Context()
-        worker = ContextManipulator(ctx)
-        worker.start()
-        worker.join()
-        self.assertDictEqual(worker.result, default_context)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_updated_context(self):
         expected = dict(default_context)
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
@@ -62,26 +37,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
 
-    def test_updated_contex_threadedt(self):
-        expected_a = dict(default_context)
-        changes_a = dict(id="a", args=["some", 1], wibble="wobble")
-        expected_a.update(changes_a)
-        expected_b = dict(default_context)
-        changes_b = dict(id="b", args=["other", 2], weasel="woozle")
-        expected_b.update(changes_b)
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx, (Context.update, [changes_a]))
-        worker_b = ContextManipulator(ctx, (Context.update, [changes_b]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, expected_a)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_modified_context(self):
         expected = dict(default_context)
         ctx = Context()
@@ -92,34 +47,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
 
-    def test_modified_contex_threadedt(self):
-        expected_a = dict(default_context)
-        expected_a["id"] = "a"
-        expected_a["args"] = ["some", 1]
-        expected_a["wibble"] = "wobble"
-        expected_b = dict(default_context)
-        expected_b["id"] = "b"
-        expected_b["args"] = ["other", 2]
-        expected_b["weasel"] = "woozle"
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx,
-                                      (setattr, ["id", "a"]),
-                                      (setattr, ["args", ["some", 1]]),
-                                      (setattr, ["wibble", "wobble"]))
-        worker_b = ContextManipulator(ctx,
-                                      (setattr, ["id", "b"]),
-                                      (setattr, ["args", ["other", 2]]),
-                                      (setattr, ["weasel", "woozle"]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, expected_a)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_cleared_context(self):
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
         ctx = Context()
@@ -129,27 +56,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), defaults)
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
-    def test_cleared_context_threaded(self):
-        changes_a = dict(id="a", args=["some", 1], wibble="wobble")
-        expected_b = dict(default_context)
-        changes_b = dict(id="b", args=["other", 2], weasel="woozle")
-        expected_b.update(changes_b)
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx,
-                                      (Context.update, [changes_a]),
-                                      (Context.clear, []))
-        worker_b = ContextManipulator(ctx,
-                                      (Context.update, [changes_b]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, default_context)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_context_get(self):
         expected = dict(default_context)
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")

+ 4 - 4
celery/tests/tasks/test_sets.py

@@ -148,13 +148,13 @@ class test_TaskSet(Case):
         @current_app.task
         def xyz():
             pass
-        from celery.app.state import _tls
-        _tls.current_task = xyz
+        from celery.app.state import _task_stack
+        _task_stack.push(xyz)
         try:
             ts.apply_async(publisher=Publisher())
         finally:
-            _tls.current_task = None
-            xyz.request.clear()
+            _task_stack.pop()
+            xyz.pop_request()
 
     def test_apply(self):
 

+ 10 - 20
celery/tests/worker/test_request.py

@@ -67,8 +67,9 @@ class test_mro_lookup(Case):
 
 
 def jail(task_id, name, args, kwargs):
+    request = {"id": task_id}
     return eager_trace_task(current_app.tasks[name],
-                            task_id, args, kwargs, eager=False)[0]
+            task_id, args, kwargs, request=request, eager=False)[0]
 
 
 def on_ack(*args, **kwargs):
@@ -196,26 +197,16 @@ class test_trace_task(Case):
             mytask.ignore_result = False
 
     def test_execute_jail_failure(self):
-        u = uuid()
-        mytask_raising.request.update({"id": u})
-        try:
-            ret = jail(u, mytask_raising.name,
-                    [4], {})
-            self.assertIsInstance(ret, ExceptionInfo)
-            self.assertTupleEqual(ret.exception.args, (4, ))
-        finally:
-            mytask_raising.request.clear()
+        ret = jail(uuid(), mytask_raising.name,
+                   [4], {})
+        self.assertIsInstance(ret, ExceptionInfo)
+        self.assertTupleEqual(ret.exception.args, (4, ))
 
     def test_execute_ignore_result(self):
         task_id = uuid()
-        MyTaskIgnoreResult.request.update({"id": task_id})
-        try:
-            ret = jail(task_id, MyTaskIgnoreResult.name,
-                       [4], {})
-            self.assertEqual(ret, 256)
-            self.assertFalse(AsyncResult(task_id).ready())
-        finally:
-            MyTaskIgnoreResult.request.clear()
+        ret = jail(task_id, MyTaskIgnoreResult.name, [4], {})
+        self.assertEqual(ret, 256)
+        self.assertFalse(AsyncResult(task_id).ready())
 
 
 class MockEventDispatcher(object):
@@ -557,10 +548,9 @@ class test_TaskRequest(Case):
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
 
-        @task_dec
+        @task_dec(request=None)
         def raising():
             raise KeyError("baz")
-        raising.request = None
 
         with self.assertWarnsRegex(RuntimeWarning,
                 r'Exception raised outside'):