浏览代码

Task registry is no longer global, but needs testing before merge, and hopefully cleaning up

Conflicts:

	celery/app/base.py
	celery/bin/celeryctl.py
Ask Solem 13 年之前
父节点
当前提交
22ceb946b5

+ 1 - 2
celery/app/__init__.py

@@ -15,7 +15,6 @@ from __future__ import absolute_import
 import os
 import os
 import threading
 import threading
 
 
-from .. import registry
 from ..utils import cached_property, instantiate
 from ..utils import cached_property, instantiate
 
 
 from . import annotations
 from . import annotations
@@ -182,7 +181,7 @@ class App(base.BaseApp):
                         "run": staticmethod(fun),
                         "run": staticmethod(fun),
                         "__doc__": fun.__doc__,
                         "__doc__": fun.__doc__,
                         "__module__": fun.__module__}, **options))()
                         "__module__": fun.__module__}, **options))()
-                return registry.tasks[T.name]  # global instance.
+                return self._tasks[T.name]  # global instance.
 
 
             return _create_task_cls
             return _create_task_cls
 
 

+ 7 - 3
celery/app/base.py

@@ -89,12 +89,13 @@ class BaseApp(object):
     loader_cls = "celery.loaders.app:AppLoader"
     loader_cls = "celery.loaders.app:AppLoader"
     log_cls = "celery.log:Logging"
     log_cls = "celery.log:Logging"
     control_cls = "celery.task.control:Control"
     control_cls = "celery.task.control:Control"
+    registry_cls = "celery.app.registry:TaskRegistry"
 
 
     _pool = None
     _pool = None
 
 
     def __init__(self, main=None, loader=None, backend=None,
     def __init__(self, main=None, loader=None, backend=None,
             amqp=None, events=None, log=None, control=None,
             amqp=None, events=None, log=None, control=None,
-            set_as_current=True, accept_magic_kwargs=False, **kwargs):
+            set_as_current=True, accept_magic_kwargs=False, tasks=None, **kwargs):
         self.main = main
         self.main = main
         self.amqp_cls = amqp or self.amqp_cls
         self.amqp_cls = amqp or self.amqp_cls
         self.backend_cls = backend or self.backend_cls
         self.backend_cls = backend or self.backend_cls
@@ -105,6 +106,8 @@ class BaseApp(object):
         self.set_as_current = set_as_current
         self.set_as_current = set_as_current
         self.accept_magic_kwargs = accept_magic_kwargs
         self.accept_magic_kwargs = accept_magic_kwargs
         self.clock = LamportClock()
         self.clock = LamportClock()
+        self.registry_cls = self.registry_cls if tasks is None else tasks
+        self._tasks = instantiate(self.registry_cls)
 
 
         self.on_init()
         self.on_init()
 
 
@@ -388,5 +391,6 @@ class BaseApp(object):
 
 
     @cached_property
     @cached_property
     def tasks(self):
     def tasks(self):
-        from ..registry import tasks
-        return tasks
+        from .task.builtins import load_builtins
+        load_builtins(self)
+        return self._tasks

+ 57 - 0
celery/app/registry.py

@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.app.registry
+    ~~~~~~~~~~~~~~~~~~~
+
+    Registry of available tasks.
+
+    :copyright: (c) 2009 - 2012 by Ask Solem.
+    :license: BSD, see LICENSE for more details.
+
+"""
+from __future__ import absolute_import
+
+import inspect
+
+from .. import current_app
+from ..exceptions import NotRegistered
+
+
+class TaskRegistry(dict):
+    NotRegistered = NotRegistered
+
+    def register(self, task):
+        """Register a task in the task registry.
+
+        The task will be automatically instantiated if not already an
+        instance.
+
+        """
+        self[task.name] = inspect.isclass(task) and task() or task
+
+    def unregister(self, name):
+        """Unregister task by name.
+
+        :param name: name of the task to unregister, or a
+            :class:`celery.task.base.Task` with a valid `name` attribute.
+
+        :raises celery.exceptions.NotRegistered: if the task has not
+            been registered.
+
+        """
+        try:
+            # Might be a task class
+            name = name.name
+        except AttributeError:
+            pass
+        self.pop(name)
+
+    def pop(self, key, *args):
+        try:
+            return dict.pop(self, key, *args)
+        except KeyError:
+            raise self.NotRegistered(key)
+
+
+def _unpickle_task(name):
+    return current_app.tasks[name]

+ 7 - 6
celery/app/task/__init__.py

@@ -15,16 +15,18 @@ from __future__ import absolute_import
 import sys
 import sys
 import threading
 import threading
 
 
+from ... import current_app
 from ... import states
 from ... import states
 from ...datastructures import ExceptionInfo
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import eager_trace_task
 from ...execute.trace import eager_trace_task
-from ...registry import tasks, _unpickle_task
 from ...result import EagerResult
 from ...result import EagerResult
 from ...utils import (fun_takes_kwargs, instantiate,
 from ...utils import (fun_takes_kwargs, instantiate,
                       mattrgetter, uuid, maybe_reraise)
                       mattrgetter, uuid, maybe_reraise)
 from ...utils.mail import ErrorMail
 from ...utils.mail import ErrorMail
 
 
+from ..registry import _unpickle_task
+
 extract_exec_options = mattrgetter("queue", "routing_key",
 extract_exec_options = mattrgetter("queue", "routing_key",
                                    "exchange", "immediate",
                                    "exchange", "immediate",
                                    "mandatory", "priority",
                                    "mandatory", "priority",
@@ -76,6 +78,7 @@ class TaskType(type):
 
 
     def __new__(cls, name, bases, attrs):
     def __new__(cls, name, bases, attrs):
         new = super(TaskType, cls).__new__
         new = super(TaskType, cls).__new__
+        app = attrs.get("app") or current_app
         task_module = attrs.get("__module__") or "__main__"
         task_module = attrs.get("__module__") or "__main__"
 
 
         if "__call__" in attrs:
         if "__call__" in attrs:
@@ -118,6 +121,7 @@ class TaskType(type):
         # we may or may not be the first time the task tries to register
         # we may or may not be the first time the task tries to register
         # with the framework.  There should only be one class for each task
         # with the framework.  There should only be one class for each task
         # name, so we always return the registered version.
         # name, so we always return the registered version.
+        tasks = app._tasks
         task_name = attrs["name"]
         task_name = attrs["name"]
         if task_name not in tasks:
         if task_name not in tasks:
             task_cls = new(cls, name, bases, attrs)
             task_cls = new(cls, name, bases, attrs)
@@ -127,7 +131,7 @@ class TaskType(type):
         task = tasks[task_name].__class__
         task = tasks[task_name].__class__
 
 
         # decorate with annotations from config.
         # decorate with annotations from config.
-        task.app.annotate_task(task)
+        app.annotate_task(task)
         return task
         return task
 
 
     def __repr__(cls):
     def __repr__(cls):
@@ -270,9 +274,6 @@ class BaseTask(object):
     #: Default task expiry time.
     #: Default task expiry time.
     expires = None
     expires = None
 
 
-    #: The type of task *(no longer used)*.
-    type = "regular"
-
     #: Execution strategy used, or the qualified name of one.
     #: Execution strategy used, or the qualified name of one.
     Strategy = "celery.worker.strategy:default"
     Strategy = "celery.worker.strategy:default"
 
 
@@ -590,7 +591,7 @@ class BaseTask(object):
                                 options.pop("throw", None))
                                 options.pop("throw", None))
 
 
         # Make sure we get the task instance, not class.
         # Make sure we get the task instance, not class.
-        task = tasks[self.name]
+        task = self.app._tasks[self.name]
 
 
         request = {"id": task_id,
         request = {"id": task_id,
                    "retries": retries,
                    "retries": retries,

+ 44 - 0
celery/app/task/builtins.py

@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+_builtins = []
+
+def builtin_task(constructor):
+    _builtins.append(constructor)
+    return constructor
+
+
+@builtin_task
+def add_backend_cleanup_task(app):
+
+    @app.task(name="celery.backend_cleanup")
+    def backend_cleanup():
+        app.backend.cleanup()
+
+    return backend_cleanup
+
+
+@builtin_task
+def add_unlock_chord_task(app):
+
+    @app.task(name="celery.chord_unlock", max_retries=None)
+    def unlock_chord(setid, callback, interval=1, propagate=False,
+            max_retries=None, result=None):
+        from ...result import AsyncResult, TaskSetResult
+        from ...task.sets import subtask
+
+        result = TaskSetResult(setid, map(AsyncResult, result))
+        if result.ready():
+            j = result.join_native if result.supports_native_join else result.join
+            subtask(callback).delay(j(propagate=propagate))
+        else:
+            unlock_chord.retry(countdown=interval, max_retries=max_retries)
+
+    return unlock_chord
+
+
+def load_builtins(app):
+    for constructor in _builtins:
+        constructor(app)
+
+

+ 1 - 2
celery/apps/worker.py

@@ -174,8 +174,7 @@ class Worker(configurated):
         self.loader.init_worker()
         self.loader.init_worker()
 
 
     def tasklist(self, include_builtins=True):
     def tasklist(self, include_builtins=True):
-        from ..registry import tasks
-        tasklist = tasks.keys()
+        tasklist = self.app.tasks.keys()
         if not include_builtins:
         if not include_builtins:
             tasklist = filter(lambda s: not s.startswith("celery."),
             tasklist = filter(lambda s: not s.startswith("celery."),
                               tasklist)
                               tasklist)

+ 2 - 3
celery/backends/base.py

@@ -207,10 +207,9 @@ class BaseBackend(object):
         pass
         pass
 
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
     def on_chord_apply(self, setid, body, result=None, **kwargs):
-        from ..registry import tasks
         kwargs["result"] = [r.task_id for r in result]
         kwargs["result"] = [r.task_id for r in result]
-        tasks["celery.chord_unlock"].apply_async((setid, body, ), kwargs,
-                                                 countdown=1)
+        self.app.tasks["celery.chord_unlock"].apply_async((setid, body, ),
+                                                          kwargs, countdown=1)
 
 
     def __reduce__(self, args=(), kwargs={}):
     def __reduce__(self, args=(), kwargs={}):
         return (unpickle_backend, (self.__class__, args, kwargs))
         return (unpickle_backend, (self.__class__, args, kwargs))

+ 1 - 2
celery/beat.py

@@ -27,7 +27,6 @@ from kombu.utils import reprcall
 
 
 from . import __version__
 from . import __version__
 from . import platforms
 from . import platforms
-from . import registry
 from . import signals
 from . import signals
 from . import current_app
 from . import current_app
 from .app import app_or_default
 from .app import app_or_default
@@ -215,7 +214,7 @@ class Scheduler(object):
         # so we have that done if an exception is raised (doesn't schedule
         # so we have that done if an exception is raised (doesn't schedule
         # forever.)
         # forever.)
         entry = self.reserve(entry)
         entry = self.reserve(entry)
-        task = registry.tasks.get(entry.task)
+        task = self.app.tasks.get(entry.task)
 
 
         try:
         try:
             if task:
             if task:

+ 2 - 4
celery/bin/celeryctl.py

@@ -229,12 +229,11 @@ class result(Command):
     )
     )
 
 
     def run(self, task_id, *args, **kwargs):
     def run(self, task_id, *args, **kwargs):
-        from .. import registry
         result_cls = self.app.AsyncResult
         result_cls = self.app.AsyncResult
         task = kwargs.get("task")
         task = kwargs.get("task")
 
 
         if task:
         if task:
-            result_cls = registry.tasks[task].AsyncResult
+            result_cls = self.app.tasks[task].AsyncResult
         result = result_cls(task_id)
         result = result_cls(task_id)
         self.out(self.prettify(result.get())[1])
         self.out(self.prettify(result.get())[1])
 result = command(result)
 result = command(result)
@@ -378,7 +377,6 @@ class shell(Command):
     def run(self, force_ipython=False, force_bpython=False,
     def run(self, force_ipython=False, force_bpython=False,
             force_python=False, without_tasks=False, eventlet=False,
             force_python=False, without_tasks=False, eventlet=False,
             gevent=False, **kwargs):
             gevent=False, **kwargs):
-        from .. import registry
         if eventlet:
         if eventlet:
             import_module("celery.concurrency.eventlet")
             import_module("celery.concurrency.eventlet")
         if gevent:
         if gevent:
@@ -388,7 +386,7 @@ class shell(Command):
 
 
         if not without_tasks:
         if not without_tasks:
             self.locals.update(dict((task.__name__, task)
             self.locals.update(dict((task.__name__, task)
-                                for task in registry.tasks.itervalues()))
+                                for task in self.app.tasks.itervalues()))
 
 
         if force_python:
         if force_python:
             return self.invoke_fallback_shell()
             return self.invoke_fallback_shell()

+ 1 - 2
celery/execute/trace.py

@@ -30,7 +30,6 @@ from .. import current_app
 from .. import states, signals
 from .. import states, signals
 from ..datastructures import ExceptionInfo
 from ..datastructures import ExceptionInfo
 from ..exceptions import RetryTaskError
 from ..exceptions import RetryTaskError
-from ..registry import tasks
 from ..utils.serialization import get_pickleable_exception
 from ..utils.serialization import get_pickleable_exception
 
 
 send_prerun = signals.task_prerun.send
 send_prerun = signals.task_prerun.send
@@ -107,7 +106,7 @@ class TraceInfo(object):
 
 
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         Info=TraceInfo, eager=False, propagate=False):
         Info=TraceInfo, eager=False, propagate=False):
-    task = task or tasks[name]
+    task = task or current_app.tasks[name]
     loader = loader or current_app.loader
     loader = loader or current_app.loader
     backend = task.backend
     backend = task.backend
     ignore_result = task.ignore_result
     ignore_result = task.ignore_result

+ 3 - 70
celery/registry.py

@@ -1,73 +1,6 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.registry
-    ~~~~~~~~~~~~~~~
-
-    Registry of available tasks.
-
-    :copyright: (c) 2009 - 2012 by Ask Solem.
-    :license: BSD, see LICENSE for more details.
-
-"""
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
-import inspect
-
-from .exceptions import NotRegistered
-
-
-class TaskRegistry(dict):
-    NotRegistered = NotRegistered
-
-    def regular(self):
-        """Get all regular task types."""
-        return self.filter_types("regular")
-
-    def periodic(self):
-        """Get all periodic task types."""
-        return self.filter_types("periodic")
-
-    def register(self, task):
-        """Register a task in the task registry.
-
-        The task will be automatically instantiated if not already an
-        instance.
-
-        """
-        self[task.name] = inspect.isclass(task) and task() or task
-
-    def unregister(self, name):
-        """Unregister task by name.
-
-        :param name: name of the task to unregister, or a
-            :class:`celery.task.base.Task` with a valid `name` attribute.
-
-        :raises celery.exceptions.NotRegistered: if the task has not
-            been registered.
-
-        """
-        try:
-            # Might be a task class
-            name = name.name
-        except AttributeError:
-            pass
-        self.pop(name)
-
-    def filter_types(self, type):
-        """Return all tasks of a specific type."""
-        return dict((name, task) for name, task in self.iteritems()
-                                    if task.type == type)
-
-    def pop(self, key, *args):
-        try:
-            return dict.pop(self, key, *args)
-        except KeyError:
-            raise self.NotRegistered(key)
-
-
-#: Global task registry.
-tasks = TaskRegistry()
-
+from . import current_app
+from .local import Proxy
 
 
-def _unpickle_task(name):
-    return tasks[name]
+tasks = Proxy(lambda: current_app.tasks)

+ 1 - 1
celery/result.py

@@ -20,8 +20,8 @@ from itertools import imap
 from . import current_app
 from . import current_app
 from . import states
 from . import states
 from .app import app_or_default
 from .app import app_or_default
+from .app.registry import _unpickle_task
 from .exceptions import TimeoutError
 from .exceptions import TimeoutError
-from .registry import _unpickle_task
 from .utils.compat import OrderedDict
 from .utils.compat import OrderedDict
 
 
 
 

+ 1 - 4
celery/task/__init__.py

@@ -89,7 +89,4 @@ def periodic_task(*args, **options):
     """
     """
     return task(**dict({"base": PeriodicTask}, **options))
     return task(**dict({"base": PeriodicTask}, **options))
 
 
-
-@task(name="celery.backend_cleanup")
-def backend_cleanup():
-    backend_cleanup.backend.cleanup()
+backend_cleanup = Proxy(lambda: current_app.tasks["celery.backend_cleanup"])

+ 0 - 1
celery/task/base.py

@@ -78,7 +78,6 @@ class PeriodicTask(Task):
     """
     """
     abstract = True
     abstract = True
     ignore_result = True
     ignore_result = True
-    type = "periodic"
     relative = False
     relative = False
     options = None
     options = None
 
 

+ 1 - 13
celery/task/chords.py

@@ -12,21 +12,9 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from .. import current_app
 from .. import current_app
-from ..result import AsyncResult, TaskSetResult
 from ..utils import uuid
 from ..utils import uuid
 
 
-from .sets import TaskSet, subtask
-
-
-@current_app.task(name="celery.chord_unlock", max_retries=None)
-def _unlock_chord(setid, callback, interval=1, propagate=False,
-        max_retries=None, result=None):
-    result = TaskSetResult(setid, map(AsyncResult, result))
-    if result.ready():
-        j = result.join_native if result.supports_native_join else result.join
-        subtask(callback).delay(j(propagate=propagate))
-    else:
-        _unlock_chord.retry(countdown=interval, max_retries=max_retries)
+from .sets import TaskSet
 
 
 
 
 class Chord(current_app.Task):
 class Chord(current_app.Task):

+ 2 - 2
celery/task/sets.py

@@ -12,7 +12,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
-from .. import registry
+from .. import current_app
 from ..app import app_or_default
 from ..app import app_or_default
 from ..datastructures import AttributeDict
 from ..datastructures import AttributeDict
 from ..utils import cached_property, reprcall, uuid
 from ..utils import cached_property, reprcall, uuid
@@ -93,7 +93,7 @@ class subtask(AttributeDict):
 
 
     @cached_property
     @cached_property
     def type(self):
     def type(self):
-        return registry.tasks[self.task]
+        return current_app.tasks[self.task]
 
 
 
 
 def maybe_subtask(t):
 def maybe_subtask(t):

+ 1 - 2
celery/tests/test_app/test_beat.py

@@ -6,7 +6,6 @@ from datetime import datetime, timedelta
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery import beat
 from celery import beat
-from celery import registry
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.task.base import Task
 from celery.task.base import Task
@@ -166,7 +165,7 @@ class test_Scheduler(Case):
             def apply_async(cls, *args, **kwargs):
             def apply_async(cls, *args, **kwargs):
                 through_task[0] = True
                 through_task[0] = True
 
 
-        assert MockTask.name in registry.tasks
+        assert MockTask.name in MockTask.app.tasks
 
 
         scheduler = mScheduler()
         scheduler = mScheduler()
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))

+ 4 - 4
celery/tests/test_backends/test_base.py

@@ -7,6 +7,7 @@ import types
 from mock import Mock
 from mock import Mock
 from nose import SkipTest
 from nose import SkipTest
 
 
+from celery import current_app
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.utils import serialization
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import subclass_exception
@@ -98,14 +99,13 @@ class test_BaseBackend_interface(Case):
             b.forget("SOMExx-N0nex1stant-IDxx-")
             b.forget("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
-        from celery.registry import tasks
-        p, tasks[unlock] = tasks.get(unlock), Mock()
+        p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
         try:
         try:
             b.on_chord_apply("dakj221", "sdokqweok",
             b.on_chord_apply("dakj221", "sdokqweok",
                              result=map(AsyncResult, [1, 2, 3]))
                              result=map(AsyncResult, [1, 2, 3]))
-            self.assertTrue(tasks[unlock].apply_async.call_count)
+            self.assertTrue(current_app.tasks[unlock].apply_async.call_count)
         finally:
         finally:
-            tasks[unlock] = p
+            current_app.tasks[unlock] = p
 
 
 
 
 class test_exception_pickle(Case):
 class test_exception_pickle(Case):

+ 3 - 3
celery/tests/test_backends/test_cache.py

@@ -8,10 +8,10 @@ from contextlib import contextmanager
 
 
 from mock import Mock, patch
 from mock import Mock, patch
 
 
+from celery import current_app
 from celery import states
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.task import subtask
 from celery.task import subtask
 from celery.utils import uuid
 from celery.utils import uuid
@@ -70,7 +70,7 @@ class test_CacheBackend(Case):
         task = Mock()
         task = Mock()
         task.name = "foobarbaz"
         task.name = "foobarbaz"
         try:
         try:
-            tasks["foobarbaz"] = task
+            current_app.tasks["foobarbaz"] = task
             task.request.chord = subtask(task)
             task.request.chord = subtask(task)
             task.request.taskset = "setid"
             task.request.taskset = "setid"
 
 
@@ -85,7 +85,7 @@ class test_CacheBackend(Case):
             deps.delete.assert_called_with()
             deps.delete.assert_called_with()
 
 
         finally:
         finally:
-            tasks.pop("foobarbaz")
+            current_app.tasks.pop("foobarbaz")
 
 
     def test_mget(self):
     def test_mget(self):
         self.tb.set("foo", 1)
         self.tb.set("foo", 1)

+ 2 - 3
celery/tests/test_backends/test_redis_unit.py

@@ -7,7 +7,6 @@ from mock import Mock, patch
 from celery import current_app
 from celery import current_app
 from celery import states
 from celery import states
 from celery.result import AsyncResult
 from celery.result import AsyncResult
-from celery.registry import tasks
 from celery.task import subtask
 from celery.task import subtask
 from celery.utils import cached_property, uuid
 from celery.utils import cached_property, uuid
 from celery.utils.timeutils import timedelta_seconds
 from celery.utils.timeutils import timedelta_seconds
@@ -129,7 +128,7 @@ class test_RedisBackend(Case):
         task = Mock()
         task = Mock()
         task.name = "foobarbaz"
         task.name = "foobarbaz"
         try:
         try:
-            tasks["foobarbaz"] = task
+            current_app.tasks["foobarbaz"] = task
             task.request.chord = subtask(task)
             task.request.chord = subtask(task)
             task.request.taskset = "setid"
             task.request.taskset = "setid"
 
 
@@ -143,7 +142,7 @@ class test_RedisBackend(Case):
 
 
             self.assertTrue(b.client.expire.call_count)
             self.assertTrue(b.client.expire.call_count)
         finally:
         finally:
-            tasks.pop("foobarbaz")
+            current_app.tasks.pop("foobarbaz")
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         self.Backend().process_cleanup()
         self.Backend().process_cleanup()

+ 1 - 1
celery/tests/test_slow/test_buckets.py

@@ -7,7 +7,7 @@ import time
 from functools import partial
 from functools import partial
 from itertools import chain, izip
 from itertools import chain, izip
 
 
-from celery.registry import TaskRegistry
+from celery.app.registry import TaskRegistry
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import timeutils
 from celery.utils import timeutils
 from celery.utils import uuid
 from celery.utils import uuid

+ 49 - 32
celery/tests/test_task/test_chord.py

@@ -1,11 +1,14 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from mock import patch
 from mock import patch
+from contextlib import contexmanager
 
 
 from celery import current_app
 from celery import current_app
+from celery import result
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.task import chords
 from celery.task import chords
-from celery.task import TaskSet
+from celery.task import task, TaskSet
 from celery.tests.utils import AppCase, Mock
 from celery.tests.utils import AppCase, Mock
 
 
 passthru = lambda x: x
 passthru = lambda x: x
@@ -35,47 +38,61 @@ class TSR(chords.TaskSetResult):
         return self.value
         return self.value
 
 
 
 
-class test_unlock_chord_task(AppCase):
+@contextmanager
+def patch_unlock_retry():
+    unlock = current_app.tasks["celery.chord_unlock"]
+    retry = Mock()
+    prev, unlock.retry = unlock.retry, retry
+    yield unlock, retry
+    unlock.retry = prev
 
 
-    @patch("celery.task.chords._unlock_chord.retry")
-    def test_unlock_ready(self, retry):
-        callback.apply_async = Mock()
 
 
-        pts, chords.TaskSetResult = chords.TaskSetResult, TSR
-        subtask, chords.subtask = chords.subtask, passthru
-        try:
-            chords._unlock_chord("setid", callback.subtask(),
-                    result=map(AsyncResult, [1, 2, 3]))
-        finally:
-            chords.subtask = subtask
-            chords.TaskSetResult = pts
-        callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
-        # did not retry
-        self.assertFalse(retry.call_count)
-
-    @patch("celery.task.chords.TaskSetResult")
-    @patch("celery.task.chords._unlock_chord.retry")
-    def test_when_not_ready(self, retry, TaskSetResult):
-        callback.apply_async = Mock()
+class test_unlock_chord_task(AppCase):
+
+    @patch("celery.result.TaskSetResult")
+    def test_unlock_ready(self, TaskSetResult):
+        tasks = current_app.tasks
 
 
         class NeverReady(TSR):
         class NeverReady(TSR):
             is_ready = False
             is_ready = False
 
 
-        pts, chords.TaskSetResult = chords.TaskSetResult, NeverReady
+        @task
+        def callback(*args, **kwargs):
+            pass
+
+        pts, result.TaskSetResult  = result.TaskSetResult, NeverReady
+        callback.apply_async = Mock()
         try:
         try:
-            chords._unlock_chord("setid", callback.subtask, interval=10,
-                                max_retries=30,
-                                result=map(AsyncResult, [1, 2, 3]))
-            self.assertFalse(callback.apply_async.call_count)
-            # did retry
-            chords._unlock_chord.retry.assert_called_with(countdown=10,
-                                                          max_retries=30)
+            with patch_unlock_retry() as (unlock, retry):
+                result = Mock(attrs=dict(ready=lambda: True,
+                                        join=lambda **kw: [2, 4, 8, 6]))
+                TaskSetResult.restore = lambda setid: result
+                subtask, chords.subtask = chords.subtask, passthru
+                try:
+                    unlock("setid", callback,
+                           result=map(AsyncResult, [1, 2, 3]))
+                finally:
+                    chords.subtask = subtask
+                callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
+                result.delete.assert_called_with()
+                # did not retry
+                self.assertFalse(retry.call_count)
         finally:
         finally:
-            chords.TaskSetResult = pts
+            result.TaskSetResult = pts
+
+    @patch("celery.result.TaskSetResult")
+    def test_when_not_ready(self, TaskSetResult):
+        with patch_unlock_retry() as (unlock, retry):
+            callback = Mock()
+            result = Mock(attrs=dict(ready=lambda: False))
+            TaskSetResult.restore = lambda setid: result
+            unlock("setid", callback, interval=10, max_retries=30,)
+            self.assertFalse(callback.delay.call_count)
+            # did retry
+            unlock.retry.assert_called_with(countdown=10, max_retries=30)
 
 
     def test_is_in_registry(self):
     def test_is_in_registry(self):
-        from celery.registry import tasks
-        self.assertIn("celery.chord_unlock", tasks)
+        self.assertIn("celery.chord_unlock", current_app.tasks)
 
 
 
 
 class test_chord(AppCase):
 class test_chord(AppCase):

+ 2 - 10
celery/tests/test_task/test_registry.py

@@ -1,7 +1,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
-from celery import registry
+from celery.app.registry import TaskRegistry
 from celery.task import Task, PeriodicTask
 from celery.task import Task, PeriodicTask
 from celery.tests.utils import Case
 from celery.tests.utils import Case
 
 
@@ -36,7 +36,7 @@ class TestTaskRegistry(Case):
         self.assertIn(task_name, r)
         self.assertIn(task_name, r)
 
 
     def test_task_registry(self):
     def test_task_registry(self):
-        r = registry.TaskRegistry()
+        r = TaskRegistry()
         self.assertIsInstance(r, dict,
         self.assertIsInstance(r, dict,
                 "TaskRegistry is mapping")
                 "TaskRegistry is mapping")
 
 
@@ -53,14 +53,6 @@ class TestTaskRegistry(Case):
         self.assertIsInstance(tasks.get(TestPeriodicTask.name),
         self.assertIsInstance(tasks.get(TestPeriodicTask.name),
                                    TestPeriodicTask)
                                    TestPeriodicTask)
 
 
-        regular = r.regular()
-        self.assertIn(TestTask.name, regular)
-        self.assertNotIn(TestPeriodicTask.name, regular)
-
-        periodic = r.periodic()
-        self.assertNotIn(TestTask.name, periodic)
-        self.assertIn(TestPeriodicTask.name, periodic)
-
         self.assertIsInstance(r[TestTask.name], TestTask)
         self.assertIsInstance(r[TestTask.name], TestTask)
         self.assertIsInstance(r[TestPeriodicTask.name],
         self.assertIsInstance(r[TestPeriodicTask.name],
                                    TestPeriodicTask)
                                    TestPeriodicTask)

+ 2 - 3
celery/tests/test_worker/test_worker_control.py

@@ -12,7 +12,6 @@ from mock import Mock, patch
 from celery import current_app
 from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.task import task
-from celery.registry import tasks
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 from celery.worker import WorkController as _WC
 from celery.worker import WorkController as _WC
@@ -275,9 +274,9 @@ class test_ControlPanel(Case):
                 self.ready_queue = self.ReadyQueue()
                 self.ready_queue = self.ReadyQueue()
 
 
         consumer = Consumer()
         consumer = Consumer()
-        panel = self.create_panel(consumer=consumer)
+        panel = self.create_panel(app=current_app, consumer=consumer)
 
 
-        task = tasks[mytask.name]
+        task = current_app.tasks[mytask.name]
         old_rate_limit = task.rate_limit
         old_rate_limit = task.rate_limit
         try:
         try:
             panel.handle("rate_limit", arguments=dict(task_name=task.name,
             panel.handle("rate_limit", arguments=dict(task_name=task.name,

+ 3 - 2
celery/tests/test_worker/test_worker_job.py

@@ -14,6 +14,7 @@ from kombu.transport.base import Message
 from mock import Mock
 from mock import Mock
 from nose import SkipTest
 from nose import SkipTest
 
 
+from celery import current_app
 from celery import states
 from celery import states
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
@@ -22,7 +23,6 @@ from celery.exceptions import (RetryTaskError,
                                WorkerLostError, InvalidTaskError)
                                WorkerLostError, InvalidTaskError)
 from celery.execute.trace import eager_trace_task, TraceInfo
 from celery.execute.trace import eager_trace_task, TraceInfo
 from celery.log import setup_logger
 from celery.log import setup_logger
-from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 from celery.task import task as task_dec
 from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.task.base import Task
@@ -39,7 +39,8 @@ some_kwargs_scratchpad = {}
 
 
 
 
 def jail(task_id, name, args, kwargs):
 def jail(task_id, name, args, kwargs):
-    return eager_trace_task(tasks[name], task_id, args, kwargs, eager=False)[0]
+    return eager_trace_task(current_app.tasks[name],
+                            task_id, args, kwargs, eager=False)[0]
 
 
 
 
 def on_ack(*args, **kwargs):
 def on_ack(*args, **kwargs):

+ 1 - 2
celery/worker/__init__.py

@@ -25,7 +25,6 @@ from kombu.utils.finalize import Finalize
 
 
 from .. import abstract
 from .. import abstract
 from .. import concurrency as _concurrency
 from .. import concurrency as _concurrency
-from .. import registry
 from ..app import app_or_default
 from ..app import app_or_default
 from ..app.abstract import configurated, from_config
 from ..app.abstract import configurated, from_config
 from ..exceptions import SystemTerminate
 from ..exceptions import SystemTerminate
@@ -131,7 +130,7 @@ class Queues(abstract.Component):
                 # just send task directly to pool, skip the mediator.
                 # just send task directly to pool, skip the mediator.
                 w.ready_queue.put = w.process_task
                 w.ready_queue.put = w.process_task
         else:
         else:
-            w.ready_queue = TaskBucket(task_registry=registry.tasks)
+            w.ready_queue = TaskBucket(task_registry=self.app.tasks)
 
 
 
 
 class Timers(abstract.Component):
 class Timers(abstract.Component):

+ 1 - 2
celery/worker/consumer.py

@@ -85,7 +85,6 @@ from ..abstract import StartStopComponent
 from ..app import app_or_default
 from ..app import app_or_default
 from ..datastructures import AttributeDict
 from ..datastructures import AttributeDict
 from ..exceptions import InvalidTaskError
 from ..exceptions import InvalidTaskError
-from ..registry import tasks
 from ..utils import noop
 from ..utils import noop
 from ..utils import timer2
 from ..utils import timer2
 from ..utils.encoding import safe_repr
 from ..utils.encoding import safe_repr
@@ -317,7 +316,7 @@ class Consumer(object):
 
 
     def update_strategies(self):
     def update_strategies(self):
         S = self.strategies
         S = self.strategies
-        for task in tasks.itervalues():
+        for task in self.app.tasks.itervalues():
             S[task.name] = task.start_strategy(self.app, self)
             S[task.name] = task.start_strategy(self.app, self)
 
 
     def start(self):
     def start(self):

+ 4 - 4
celery/worker/control.py

@@ -14,7 +14,6 @@ from __future__ import absolute_import
 from datetime import datetime
 from datetime import datetime
 
 
 from ..platforms import signals as _signals
 from ..platforms import signals as _signals
-from ..registry import tasks
 from ..utils import timeutils
 from ..utils import timeutils
 from ..utils.compat import UserDict
 from ..utils.compat import UserDict
 from ..utils.encoding import safe_repr
 from ..utils.encoding import safe_repr
@@ -97,7 +96,7 @@ def rate_limit(panel, task_name, rate_limit, **kwargs):
         return {"error": "Invalid rate limit string: %s" % exc}
         return {"error": "Invalid rate limit string: %s" % exc}
 
 
     try:
     try:
-        tasks[task_name].rate_limit = rate_limit
+        panel.app.tasks[task_name].rate_limit = rate_limit
     except KeyError:
     except KeyError:
         panel.logger.error("Rate limit attempt for unknown task %s",
         panel.logger.error("Rate limit attempt for unknown task %s",
                            task_name, exc_info=True)
                            task_name, exc_info=True)
@@ -122,7 +121,7 @@ def rate_limit(panel, task_name, rate_limit, **kwargs):
 @Panel.register
 @Panel.register
 def time_limit(panel, task_name=None, hard=None, soft=None, **kwargs):
 def time_limit(panel, task_name=None, hard=None, soft=None, **kwargs):
     try:
     try:
-        task = tasks[task_name]
+        task = panel.app.tasks[task_name]
     except KeyError:
     except KeyError:
         panel.logger.error("Change time limit attempt for unknown task %s",
         panel.logger.error("Change time limit attempt for unknown task %s",
                            task_name, exc_info=True)
                            task_name, exc_info=True)
@@ -195,6 +194,7 @@ def dump_revoked(panel, **kwargs):
 
 
 @Panel.register
 @Panel.register
 def dump_tasks(panel, **kwargs):
 def dump_tasks(panel, **kwargs):
+    tasks = panel.app.tasks
 
 
     def _extract_info(task):
     def _extract_info(task):
         fields = dict((field, str(getattr(task, field, None)))
         fields = dict((field, str(getattr(task, field, None)))
@@ -206,7 +206,7 @@ def dump_tasks(panel, **kwargs):
         return "%s [%s]" % (task.name, " ".join(info))
         return "%s [%s]" % (task.name, " ".join(info))
 
 
     info = map(_extract_info, (tasks[task]
     info = map(_extract_info, (tasks[task]
-                                        for task in sorted(tasks.keys())))
+                                    for task in sorted(tasks.keys())))
     panel.logger.debug("* Dump of currently registered tasks:\n%s",
     panel.logger.debug("* Dump of currently registered tasks:\n%s",
                        "\n".join(info))
                        "\n".join(info))
 
 

+ 4 - 4
celery/worker/job.py

@@ -19,9 +19,9 @@ import sys
 
 
 from datetime import datetime
 from datetime import datetime
 
 
+from .. import current_app
 from .. import exceptions
 from .. import exceptions
 from ..datastructures import ExceptionInfo
 from ..datastructures import ExceptionInfo
-from ..registry import tasks
 from ..app import app_or_default
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..platforms import set_mp_process_title as setps
 from ..platforms import set_mp_process_title as setps
@@ -47,7 +47,7 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
         >>> trace_task(name, *args, **kwargs)[0]
         >>> trace_task(name, *args, **kwargs)[0]
 
 
     """
     """
-    task = tasks[name]
+    task = current_app.tasks[name]
     try:
     try:
         hostname = opts.get("hostname")
         hostname = opts.get("hostname")
         setps("celeryd", name, hostname, rate_limit=True)
         setps("celeryd", name, hostname, rate_limit=True)
@@ -114,7 +114,7 @@ class Request(object):
         self.logger = logger or self.app.log.get_default_logger()
         self.logger = logger or self.app.log.get_default_logger()
         self.eventer = eventer
         self.eventer = eventer
         self.connection_errors = connection_errors or ()
         self.connection_errors = connection_errors or ()
-        self.task = task or tasks[name]
+        self.task = task or self.app.tasks[name]
         self.acknowledged = self._already_revoked = False
         self.acknowledged = self._already_revoked = False
         self.time_start = self.worker_pid = self._terminate_on_ack = None
         self.time_start = self.worker_pid = self._terminate_on_ack = None
         self._tzlocal = None
         self._tzlocal = None
@@ -392,7 +392,7 @@ class Request(object):
                                         "name": self.name,
                                         "name": self.name,
                                         "hostname": self.hostname}})
                                         "hostname": self.hostname}})
 
 
-        task_obj = tasks.get(self.name, object)
+        task_obj = self.app.tasks.get(self.name, object)
         task_obj.send_error_email(context, exc_info.exception)
         task_obj.send_error_email(context, exc_info.exception)
 
 
     def acknowledge(self):
     def acknowledge(self):