浏览代码

current_task and Task.request are now LocalStack's. Maybe fixes #521

Ask Solem 13 年之前
父节点
当前提交
e2b052c2d8

+ 1 - 1
celery/__compat__.py

@@ -89,7 +89,7 @@ class class_property(object):
 
 
     def __init__(self, fget=None, fset=None):
     def __init__(self, fget=None, fset=None):
         assert fget and isinstance(fget, classmethod)
         assert fget and isinstance(fget, classmethod)
-        assert fset and isinstance(fset, classmethod)
+        assert isinstance(fset, classmethod) if fset else True
         self.__get = fget
         self.__get = fget
         self.__set = fset
         self.__set = fset
 
 

+ 4 - 5
celery/app/state.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 
 import threading
 import threading
 
 
-from celery.local import Proxy
+from celery.local import Proxy, LocalStack
 
 
 default_app = None
 default_app = None
 
 
@@ -12,11 +12,10 @@ class _TLS(threading.local):
     #: sets this, so it will always contain the last instantiated app,
     #: sets this, so it will always contain the last instantiated app,
     #: and is the default app returned by :func:`app_or_default`.
     #: and is the default app returned by :func:`app_or_default`.
     current_app = None
     current_app = None
-
-    #: The currently executing task.
-    current_task = None
 _tls = _TLS()
 _tls = _TLS()
 
 
+_task_stack = LocalStack()
+
 
 
 def set_default_app(app):
 def set_default_app(app):
     global default_app
     global default_app
@@ -28,7 +27,7 @@ def get_current_app():
 
 
 
 
 def get_current_task():
 def get_current_task():
-    return getattr(_tls, "current_task", None)
+    return _task_stack.top
 
 
 
 
 current_app = Proxy(get_current_app)
 current_app = Proxy(get_current_app)

+ 20 - 7
celery/app/task.py

@@ -14,7 +14,6 @@ from __future__ import absolute_import
 
 
 import logging
 import logging
 import sys
 import sys
-import threading
 
 
 from kombu import Exchange
 from kombu import Exchange
 from kombu.utils import cached_property
 from kombu.utils import cached_property
@@ -24,6 +23,7 @@ from celery import states
 from celery.__compat__ import class_property
 from celery.__compat__ import class_property
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
+from celery.local import LocalStack
 from celery.result import EagerResult
 from celery.result import EagerResult
 from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
 from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.functional import mattrgetter, maybe_list
@@ -43,7 +43,7 @@ extract_exec_options = mattrgetter("queue", "routing_key",
                                    "compression", "expires")
                                    "compression", "expires")
 
 
 
 
-class Context(threading.local):
+class Context(object):
     # Default context
     # Default context
     logfile = None
     logfile = None
     loglevel = None
     loglevel = None
@@ -61,8 +61,11 @@ class Context(threading.local):
     errbacks = None
     errbacks = None
     _children = None   # see property
     _children = None   # see property
 
 
-    def update(self, d, **kwargs):
-        self.__dict__.update(d, **kwargs)
+    def __init__(self, *args, **kwargs):
+        self.update(*args, **kwargs)
+
+    def update(self, *args, **kwargs):
+        self.__dict__.update(*args, **kwargs)
 
 
     def clear(self):
     def clear(self):
         self.__dict__.clear()
         self.__dict__.clear()
@@ -172,9 +175,6 @@ class BaseTask(object):
     #: Deprecated and scheduled for removal in v3.0.
     #: Deprecated and scheduled for removal in v3.0.
     accept_magic_kwargs = False
     accept_magic_kwargs = False
 
 
-    #: Request context (set when task is applied).
-    request = Context()
-
     #: Destination queue.  The queue needs to exist
     #: Destination queue.  The queue needs to exist
     #: in :setting:`CELERY_QUEUES`.  The `routing_key`, `exchange` and
     #: in :setting:`CELERY_QUEUES`.  The `routing_key`, `exchange` and
     #: `exchange_type` attributes will be ignored if this is set.
     #: `exchange_type` attributes will be ignored if this is set.
@@ -324,6 +324,9 @@ class BaseTask(object):
         if not was_bound:
         if not was_bound:
             self.annotate()
             self.annotate()
 
 
+        self.request_stack = LocalStack()
+        self.request_stack.push(Context())
+
         # PeriodicTask uses this to add itself to the PeriodicTask schedule.
         # PeriodicTask uses this to add itself to the PeriodicTask schedule.
         self.on_bound(app)
         self.on_bound(app)
 
 
@@ -845,6 +848,12 @@ class BaseTask(object):
         """
         """
         request.execute_using_pool(pool, loglevel, logfile)
         request.execute_using_pool(pool, loglevel, logfile)
 
 
+    def push_request(self, *args, **kwargs):
+        self.request_stack.push(Context(*args, **kwargs))
+
+    def pop_request(self):
+        self.request_stack.pop()
+
     def __repr__(self):
     def __repr__(self):
         """`repr(task)`"""
         """`repr(task)`"""
         return "<@task: %s>" % (self.name, )
         return "<@task: %s>" % (self.name, )
@@ -853,6 +862,10 @@ class BaseTask(object):
     def logger(self):
     def logger(self):
         return self.get_logger()
         return self.get_logger()
 
 
+    @property
+    def request(self):
+        return self.request_stack.top
+
     @property
     @property
     def __name__(self):
     def __name__(self):
         return self.__class__.__name__
         return self.__class__.__name__

+ 2 - 2
celery/contrib/batches.py

@@ -79,14 +79,14 @@ def consume_queue(queue):
 
 
 
 
 def apply_batches_task(task, args, loglevel, logfile):
 def apply_batches_task(task, args, loglevel, logfile):
-    task.request.update({"loglevel": loglevel, "logfile": logfile})
+    task.push_request(loglevel=loglevel, logfile=logfile)
     try:
     try:
         result = task(*args)
         result = task(*args)
     except Exception, exc:
     except Exception, exc:
         result = None
         result = None
         task.logger.error("Error: %r", exc, exc_info=True)
         task.logger.error("Error: %r", exc, exc_info=True)
     finally:
     finally:
-        task.request.clear()
+        task.pop_request()
     return result
     return result
 
 
 
 

+ 3 - 2
celery/events/__init__.py

@@ -91,6 +91,8 @@ class EventDispatcher(object):
         self.on_disabled = set()
         self.on_disabled = set()
 
 
         self.enabled = enabled
         self.enabled = enabled
+        if not connection and channel:
+            self.connection = channel.connection.client
         if self.enabled:
         if self.enabled:
             self.enable()
             self.enable()
 
 
@@ -151,8 +153,7 @@ class EventDispatcher(object):
     def close(self):
     def close(self):
         """Close the event dispatcher."""
         """Close the event dispatcher."""
         self.mutex.locked() and self.mutex.release()
         self.mutex.locked() and self.mutex.release()
-        if self.publisher is not None:
-            self.publisher = None
+        self.publisher = None
 
 
 
 
 class EventReceiver(object):
 class EventReceiver(object):

+ 212 - 0
celery/local.py

@@ -7,12 +7,25 @@
     needs to be loaded as soon as possible, and that
     needs to be loaded as soon as possible, and that
     shall not load any third party modules.
     shall not load any third party modules.
 
 
+    Parts of this module is Copyright by Werkzeug Team.
+
     :copyright: (c) 2009 - 2012 by Ask Solem.
     :copyright: (c) 2009 - 2012 by Ask Solem.
     :license: BSD, see LICENSE for more details.
     :license: BSD, see LICENSE for more details.
 
 
 """
 """
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+# since each thread has its own greenlet we can just use those as identifiers
+# for the context.  If greenlets are not available we fall back to the
+# current thread ident.
+try:
+    from greenlet import getcurrent as get_ident
+except ImportError:  # pragma: no cover
+    try:
+        from thread import get_ident  # noqa
+    except ImportError:  # pragma: no cover
+        from dummy_thread import get_ident  # noqa
+
 
 
 def try_import(module, default=None):
 def try_import(module, default=None):
     """Try to import and return module, or return
     """Try to import and return module, or return
@@ -201,3 +214,202 @@ def maybe_evaluate(obj):
         return obj.__maybe_evaluate__()
         return obj.__maybe_evaluate__()
     except AttributeError:
     except AttributeError:
         return obj
         return obj
+
+
+def release_local(local):
+    """Releases the contents of the local for the current context.
+    This makes it possible to use locals without a manager.
+
+    Example::
+
+        >>> loc = Local()
+        >>> loc.foo = 42
+        >>> release_local(loc)
+        >>> hasattr(loc, 'foo')
+        False
+
+    With this function one can release :class:`Local` objects as well
+    as :class:`StackLocal` objects.  However it is not possible to
+    release data held by proxies that way, one always has to retain
+    a reference to the underlying local object in order to be able
+    to release it.
+
+    .. versionadded:: 0.6.1
+    """
+    local.__release_local__()
+
+
+class Local(object):
+    __slots__ = ('__storage__', '__ident_func__')
+
+    def __init__(self):
+        object.__setattr__(self, '__storage__', {})
+        object.__setattr__(self, '__ident_func__', get_ident)
+
+    def __iter__(self):
+        return iter(self.__storage__.items())
+
+    def __call__(self, proxy):
+        """Create a proxy for a name."""
+        return Proxy(self, proxy)
+
+    def __release_local__(self):
+        self.__storage__.pop(self.__ident_func__(), None)
+
+    def __getattr__(self, name):
+        try:
+            return self.__storage__[self.__ident_func__()][name]
+        except KeyError:
+            raise AttributeError(name)
+
+    def __setattr__(self, name, value):
+        ident = self.__ident_func__()
+        storage = self.__storage__
+        try:
+            storage[ident][name] = value
+        except KeyError:
+            storage[ident] = {name: value}
+
+    def __delattr__(self, name):
+        try:
+            del self.__storage__[self.__ident_func__()][name]
+        except KeyError:
+            raise AttributeError(name)
+
+
+class LocalStack(object):
+    """This class works similar to a :class:`Local` but keeps a stack
+    of objects instead.  This is best explained with an example::
+
+        >>> ls = LocalStack()
+        >>> ls.push(42)
+        >>> ls.top
+        42
+        >>> ls.push(23)
+        >>> ls.top
+        23
+        >>> ls.pop()
+        23
+        >>> ls.top
+        42
+
+    They can be force released by using a :class:`LocalManager` or with
+    the :func:`release_local` function but the correct way is to pop the
+    item from the stack after using.  When the stack is empty it will
+    no longer be bound to the current context (and as such released).
+
+    By calling the stack without arguments it returns a proxy that resolves to
+    the topmost item on the stack.
+
+    """
+
+    def __init__(self):
+        self._local = Local()
+
+    def __release_local__(self):
+        self._local.__release_local__()
+
+    def _get__ident_func__(self):
+        return self._local.__ident_func__
+
+    def _set__ident_func__(self, value):
+        object.__setattr__(self._local, '__ident_func__', value)
+    __ident_func__ = property(_get__ident_func__, _set__ident_func__)
+    del _get__ident_func__, _set__ident_func__
+
+    def __call__(self):
+        def _lookup():
+            rv = self.top
+            if rv is None:
+                raise RuntimeError('object unbound')
+            return rv
+        return Proxy(_lookup)
+
+    def push(self, obj):
+        """Pushes a new item to the stack"""
+        rv = getattr(self._local, 'stack', None)
+        if rv is None:
+            self._local.stack = rv = []
+        rv.append(obj)
+        return rv
+
+    def pop(self):
+        """Removes the topmost item from the stack, will return the
+        old value or `None` if the stack was already empty.
+        """
+        stack = getattr(self._local, 'stack', None)
+        if stack is None:
+            return None
+        elif len(stack) == 1:
+            release_local(self._local)
+            return stack[-1]
+        else:
+            return stack.pop()
+
+    @property
+    def top(self):
+        """The topmost item on the stack.  If the stack is empty,
+        `None` is returned.
+        """
+        try:
+            return self._local.stack[-1]
+        except (AttributeError, IndexError):
+            return None
+
+
+class LocalManager(object):
+    """Local objects cannot manage themselves. For that you need a local
+    manager.  You can pass a local manager multiple locals or add them later
+    by appending them to `manager.locals`.  Everytime the manager cleans up
+    it, will clean up all the data left in the locals for this context.
+
+    The `ident_func` parameter can be added to override the default ident
+    function for the wrapped locals.
+
+    .. versionchanged:: 0.6.1
+       Instead of a manager the :func:`release_local` function can be used
+       as well.
+
+    .. versionchanged:: 0.7
+       `ident_func` was added.
+    """
+
+    def __init__(self, locals=None, ident_func=None):
+        if locals is None:
+            self.locals = []
+        elif isinstance(locals, Local):
+            self.locals = [locals]
+        else:
+            self.locals = list(locals)
+        if ident_func is not None:
+            self.ident_func = ident_func
+            for local in self.locals:
+                object.__setattr__(local, '__ident_func__', ident_func)
+        else:
+            self.ident_func = get_ident
+
+    def get_ident(self):
+        """Return the context identifier the local objects use internally for
+        this context.  You cannot override this method to change the behavior
+        but use it to link other context local objects (such as SQLAlchemy's
+        scoped sessions) to the Werkzeug locals.
+
+        .. versionchanged:: 0.7
+           You can pass a different ident function to the local manager that
+           will then be propagated to all the locals passed to the
+           constructor.
+        """
+        return self.ident_func()
+
+    def cleanup(self):
+        """Manually clean up the data in the locals for this context.  Call
+        this at the end of the request or use `make_middleware()`.
+        """
+        for local in self.locals:
+            release_local(local)
+
+    def __repr__(self):
+        return '<%s storages: %d>' % (
+            self.__class__.__name__,
+            len(self.locals)
+        )

+ 8 - 2
celery/task/base.py

@@ -14,14 +14,15 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from celery import current_app
 from celery import current_app
-from celery.__compat__ import reclassmethod
+from celery.__compat__ import class_property, reclassmethod
 from celery.app.task import Context, TaskType, BaseTask  # noqa
 from celery.app.task import Context, TaskType, BaseTask  # noqa
 from celery.schedules import maybe_schedule
 from celery.schedules import maybe_schedule
 
 
 #: list of methods that must be classmethods in the old API.
 #: list of methods that must be classmethods in the old API.
 _COMPAT_CLASSMETHODS = (
 _COMPAT_CLASSMETHODS = (
     "get_logger", "establish_connection", "get_publisher", "get_consumer",
     "get_logger", "establish_connection", "get_publisher", "get_consumer",
-    "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask")
+    "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask",
+    "push_request", "pop_request")
 
 
 
 
 class Task(BaseTask):
 class Task(BaseTask):
@@ -36,6 +37,11 @@ class Task(BaseTask):
     for name in _COMPAT_CLASSMETHODS:
     for name in _COMPAT_CLASSMETHODS:
         locals()[name] = reclassmethod(getattr(BaseTask, name))
         locals()[name] = reclassmethod(getattr(BaseTask, name))
 
 
+    @classmethod
+    def _get_request(self):
+        return self.request_stack.top
+    request = class_property(_get_request)
+
 
 
 class PeriodicTask(Task):
 class PeriodicTask(Task):
     """A periodic task is a task that adds itself to the
     """A periodic task is a task that adds itself to the

+ 11 - 10
celery/task/trace.py

@@ -28,8 +28,8 @@ from kombu.utils import kwdict
 
 
 from celery import current_app
 from celery import current_app
 from celery import states, signals
 from celery import states, signals
-from celery.app.state import _tls
-from celery.app.task import BaseTask
+from celery.app.state import _task_stack
+from celery.app.task import BaseTask, Context
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import RetryTaskError
 from celery.exceptions import RetryTaskError
 from celery.utils.serialization import get_pickleable_exception
 from celery.utils.serialization import get_pickleable_exception
@@ -146,15 +146,15 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
 
 
     task_on_success = task.on_success
     task_on_success = task.on_success
     task_after_return = task.after_return
     task_after_return = task.after_return
-    task_request = task.request
 
 
     store_result = backend.store_result
     store_result = backend.store_result
     backend_cleanup = backend.process_cleanup
     backend_cleanup = backend.process_cleanup
 
 
     pid = os.getpid()
     pid = os.getpid()
 
 
-    update_request = task_request.update
-    clear_request = task_request.clear
+    request_stack = task.request_stack
+    push_request = request_stack.push
+    pop_request = request_stack.pop
     on_chord_part_return = backend.on_chord_part_return
     on_chord_part_return = backend.on_chord_part_return
 
 
     from celery import canvas
     from celery import canvas
@@ -164,9 +164,10 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         R = I = None
         R = I = None
         kwargs = kwdict(kwargs)
         kwargs = kwdict(kwargs)
         try:
         try:
-            _tls.current_task = task
-            update_request(request or {}, args=args,
-                           called_directly=False, kwargs=kwargs)
+            _task_stack.push(task)
+            task_request = Context(request or {}, args=args,
+                                   called_directly=False, kwargs=kwargs)
+            push_request(task_request)
             try:
             try:
                 # -*- PRE -*-
                 # -*- PRE -*-
                 send_prerun(sender=task, task_id=uuid, task=task,
                 send_prerun(sender=task, task_id=uuid, task=task,
@@ -220,8 +221,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 send_postrun(sender=task, task_id=uuid, task=task,
                 send_postrun(sender=task, task_id=uuid, task=task,
                             args=args, kwargs=kwargs, retval=retval)
                             args=args, kwargs=kwargs, retval=retval)
             finally:
             finally:
-                _tls.current_task = None
-                clear_request()
+                _task_stack.pop()
+                pop_request()
                 if not eager:
                 if not eager:
                     try:
                     try:
                         backend_cleanup()
                         backend_cleanup()

+ 3 - 3
celery/tests/app/test_builtins.py

@@ -4,7 +4,7 @@ from mock import Mock
 
 
 from celery import current_app as app, group, task, chord
 from celery import current_app as app, group, task, chord
 from celery.app import builtins
 from celery.app import builtins
-from celery.app.state import _tls
+from celery.app.state import _task_stack
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
 
 
@@ -61,13 +61,13 @@ class test_group(Case):
         x.apply_async()
         x.apply_async()
 
 
     def test_apply_async_with_parent(self):
     def test_apply_async_with_parent(self):
-        _tls.current_task = add
+        _task_stack.push(add)
         try:
         try:
             x = group([add.s(4, 4), add.s(8, 8)])
             x = group([add.s(4, 4), add.s(8, 8)])
             x.apply_async()
             x.apply_async()
             self.assertTrue(add.request.children)
             self.assertTrue(add.request.children)
         finally:
         finally:
-            _tls.current_task = None
+            _task_stack.pop()
 
 
 
 
 class test_chain(Case):
 class test_chain(Case):

+ 4 - 4
celery/tests/app/test_log.py

@@ -230,12 +230,12 @@ class test_task_logger(test_default_logger):
             pass
             pass
         test_task.logger.handlers = []
         test_task.logger.handlers = []
         self.task = test_task
         self.task = test_task
-        from celery.app.state import _tls
-        _tls.current_task = test_task
+        from celery.app.state import _task_stack
+        _task_stack.push(test_task)
 
 
     def tearDown(self):
     def tearDown(self):
-        from celery.app.state import _tls
-        _tls.current_task = None
+        from celery.app.state import _task_stack
+        _task_stack.pop()
 
 
     def setup_logger(self, *args, **kwargs):
     def setup_logger(self, *args, **kwargs):
         return log.setup_task_loggers(*args, **kwargs)
         return log.setup_task_loggers(*args, **kwargs)

+ 7 - 4
celery/tests/events/test_events.py

@@ -132,10 +132,11 @@ class test_EventReceiver(AppCase):
         def my_handler(event):
         def my_handler(event):
             got_event[0] = True
             got_event[0] = True
 
 
-        r = events.EventReceiver(object(),
+        connection = Mock()
+        connection.transport_cls = "memory"
+        r = events.EventReceiver(connection,
                                  handlers={"world-war": my_handler},
                                  handlers={"world-war": my_handler},
-                                 node_id="celery.tests",
-                                 )
+                                 node_id="celery.tests")
         r._receive(message, object())
         r._receive(message, object())
         self.assertTrue(got_event[0])
         self.assertTrue(got_event[0])
 
 
@@ -148,7 +149,9 @@ class test_EventReceiver(AppCase):
         def my_handler(event):
         def my_handler(event):
             got_event[0] = True
             got_event[0] = True
 
 
-        r = events.EventReceiver(object(), node_id="celery.tests")
+        connection = Mock()
+        connection.transport_cls = "memory"
+        r = events.EventReceiver(connection, node_id="celery.tests")
         events.EventReceiver.handlers["*"] = my_handler
         events.EventReceiver.handlers["*"] = my_handler
         try:
         try:
             r._receive(message, object())
             r._receive(message, object())

+ 0 - 94
celery/tests/tasks/test_context.py

@@ -1,8 +1,6 @@
 # -*- coding: utf-8 -*-"
 # -*- coding: utf-8 -*-"
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-import threading
-
 from celery.task.base import Context
 from celery.task.base import Context
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
@@ -22,21 +20,6 @@ def get_context_as_dict(ctx, getter=getattr):
 default_context = get_context_as_dict(Context())
 default_context = get_context_as_dict(Context())
 
 
 
 
-# Manipulate the a context in a separate thread
-class ContextManipulator(threading.Thread):
-    def __init__(self, ctx, *args):
-        super(ContextManipulator, self).__init__()
-        self.daemon = True
-        self.ctx = ctx
-        self.args = args
-        self.result = None
-
-    def run(self):
-        for func, args in self.args:
-            func(self.ctx, *args)
-        self.result = get_context_as_dict(self.ctx)
-
-
 class test_Context(Case):
 class test_Context(Case):
 
 
     def test_default_context(self):
     def test_default_context(self):
@@ -45,14 +28,6 @@ class test_Context(Case):
         defaults = dict(default_context, children=[])
         defaults = dict(default_context, children=[])
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
 
-    def test_default_context_threaded(self):
-        ctx = Context()
-        worker = ContextManipulator(ctx)
-        worker.start()
-        worker.join()
-        self.assertDictEqual(worker.result, default_context)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_updated_context(self):
     def test_updated_context(self):
         expected = dict(default_context)
         expected = dict(default_context)
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
@@ -62,26 +37,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
 
 
-    def test_updated_contex_threadedt(self):
-        expected_a = dict(default_context)
-        changes_a = dict(id="a", args=["some", 1], wibble="wobble")
-        expected_a.update(changes_a)
-        expected_b = dict(default_context)
-        changes_b = dict(id="b", args=["other", 2], weasel="woozle")
-        expected_b.update(changes_b)
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx, (Context.update, [changes_a]))
-        worker_b = ContextManipulator(ctx, (Context.update, [changes_b]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, expected_a)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_modified_context(self):
     def test_modified_context(self):
         expected = dict(default_context)
         expected = dict(default_context)
         ctx = Context()
         ctx = Context()
@@ -92,34 +47,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(ctx), expected)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
         self.assertDictEqual(get_context_as_dict(Context()), default_context)
 
 
-    def test_modified_contex_threadedt(self):
-        expected_a = dict(default_context)
-        expected_a["id"] = "a"
-        expected_a["args"] = ["some", 1]
-        expected_a["wibble"] = "wobble"
-        expected_b = dict(default_context)
-        expected_b["id"] = "b"
-        expected_b["args"] = ["other", 2]
-        expected_b["weasel"] = "woozle"
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx,
-                                      (setattr, ["id", "a"]),
-                                      (setattr, ["args", ["some", 1]]),
-                                      (setattr, ["wibble", "wobble"]))
-        worker_b = ContextManipulator(ctx,
-                                      (setattr, ["id", "b"]),
-                                      (setattr, ["args", ["other", 2]]),
-                                      (setattr, ["weasel", "woozle"]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, expected_a)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_cleared_context(self):
     def test_cleared_context(self):
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
         ctx = Context()
         ctx = Context()
@@ -129,27 +56,6 @@ class test_Context(Case):
         self.assertDictEqual(get_context_as_dict(ctx), defaults)
         self.assertDictEqual(get_context_as_dict(ctx), defaults)
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
         self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
 
-    def test_cleared_context_threaded(self):
-        changes_a = dict(id="a", args=["some", 1], wibble="wobble")
-        expected_b = dict(default_context)
-        changes_b = dict(id="b", args=["other", 2], weasel="woozle")
-        expected_b.update(changes_b)
-        ctx = Context()
-
-        worker_a = ContextManipulator(ctx,
-                                      (Context.update, [changes_a]),
-                                      (Context.clear, []))
-        worker_b = ContextManipulator(ctx,
-                                      (Context.update, [changes_b]))
-        worker_a.start()
-        worker_b.start()
-        worker_a.join()
-        worker_b.join()
-
-        self.assertDictEqual(worker_a.result, default_context)
-        self.assertDictEqual(worker_b.result, expected_b)
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
     def test_context_get(self):
     def test_context_get(self):
         expected = dict(default_context)
         expected = dict(default_context)
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")
         changes = dict(id="unique id", args=["some", 1], wibble="wobble")

+ 4 - 4
celery/tests/tasks/test_sets.py

@@ -148,13 +148,13 @@ class test_TaskSet(Case):
         @current_app.task
         @current_app.task
         def xyz():
         def xyz():
             pass
             pass
-        from celery.app.state import _tls
-        _tls.current_task = xyz
+        from celery.app.state import _task_stack
+        _task_stack.push(xyz)
         try:
         try:
             ts.apply_async(publisher=Publisher())
             ts.apply_async(publisher=Publisher())
         finally:
         finally:
-            _tls.current_task = None
-            xyz.request.clear()
+            _task_stack.pop()
+            xyz.pop_request()
 
 
     def test_apply(self):
     def test_apply(self):
 
 

+ 10 - 20
celery/tests/worker/test_request.py

@@ -67,8 +67,9 @@ class test_mro_lookup(Case):
 
 
 
 
 def jail(task_id, name, args, kwargs):
 def jail(task_id, name, args, kwargs):
+    request = {"id": task_id}
     return eager_trace_task(current_app.tasks[name],
     return eager_trace_task(current_app.tasks[name],
-                            task_id, args, kwargs, eager=False)[0]
+            task_id, args, kwargs, request=request, eager=False)[0]
 
 
 
 
 def on_ack(*args, **kwargs):
 def on_ack(*args, **kwargs):
@@ -196,26 +197,16 @@ class test_trace_task(Case):
             mytask.ignore_result = False
             mytask.ignore_result = False
 
 
     def test_execute_jail_failure(self):
     def test_execute_jail_failure(self):
-        u = uuid()
-        mytask_raising.request.update({"id": u})
-        try:
-            ret = jail(u, mytask_raising.name,
-                    [4], {})
-            self.assertIsInstance(ret, ExceptionInfo)
-            self.assertTupleEqual(ret.exception.args, (4, ))
-        finally:
-            mytask_raising.request.clear()
+        ret = jail(uuid(), mytask_raising.name,
+                   [4], {})
+        self.assertIsInstance(ret, ExceptionInfo)
+        self.assertTupleEqual(ret.exception.args, (4, ))
 
 
     def test_execute_ignore_result(self):
     def test_execute_ignore_result(self):
         task_id = uuid()
         task_id = uuid()
-        MyTaskIgnoreResult.request.update({"id": task_id})
-        try:
-            ret = jail(task_id, MyTaskIgnoreResult.name,
-                       [4], {})
-            self.assertEqual(ret, 256)
-            self.assertFalse(AsyncResult(task_id).ready())
-        finally:
-            MyTaskIgnoreResult.request.clear()
+        ret = jail(task_id, MyTaskIgnoreResult.name, [4], {})
+        self.assertEqual(ret, 256)
+        self.assertFalse(AsyncResult(task_id).ready())
 
 
 
 
 class MockEventDispatcher(object):
 class MockEventDispatcher(object):
@@ -557,10 +548,9 @@ class test_TaskRequest(Case):
         def _error_exec(self, *args, **kwargs):
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
             raise KeyError("baz")
 
 
-        @task_dec
+        @task_dec(request=None)
         def raising():
         def raising():
             raise KeyError("baz")
             raise KeyError("baz")
-        raising.request = None
 
 
         with self.assertWarnsRegex(RuntimeWarning,
         with self.assertWarnsRegex(RuntimeWarning,
                 r'Exception raised outside'):
                 r'Exception raised outside'):