Browse Source

Task registry is no longer global, but needs testing before merge, and hopefully cleaning up

Conflicts:

	celery/app/base.py
	celery/bin/celeryctl.py
Ask Solem 13 years ago
parent
commit
22ceb946b5

+ 1 - 2
celery/app/__init__.py

@@ -15,7 +15,6 @@ from __future__ import absolute_import
 import os
 import threading
 
-from .. import registry
 from ..utils import cached_property, instantiate
 
 from . import annotations
@@ -182,7 +181,7 @@ class App(base.BaseApp):
                         "run": staticmethod(fun),
                         "__doc__": fun.__doc__,
                         "__module__": fun.__module__}, **options))()
-                return registry.tasks[T.name]  # global instance.
+                return self._tasks[T.name]  # global instance.
 
             return _create_task_cls
 

+ 7 - 3
celery/app/base.py

@@ -89,12 +89,13 @@ class BaseApp(object):
     loader_cls = "celery.loaders.app:AppLoader"
     log_cls = "celery.log:Logging"
     control_cls = "celery.task.control:Control"
+    registry_cls = "celery.app.registry:TaskRegistry"
 
     _pool = None
 
     def __init__(self, main=None, loader=None, backend=None,
             amqp=None, events=None, log=None, control=None,
-            set_as_current=True, accept_magic_kwargs=False, **kwargs):
+            set_as_current=True, accept_magic_kwargs=False, tasks=None, **kwargs):
         self.main = main
         self.amqp_cls = amqp or self.amqp_cls
         self.backend_cls = backend or self.backend_cls
@@ -105,6 +106,8 @@ class BaseApp(object):
         self.set_as_current = set_as_current
         self.accept_magic_kwargs = accept_magic_kwargs
         self.clock = LamportClock()
+        self.registry_cls = self.registry_cls if tasks is None else tasks
+        self._tasks = instantiate(self.registry_cls)
 
         self.on_init()
 
@@ -388,5 +391,6 @@ class BaseApp(object):
 
     @cached_property
     def tasks(self):
-        from ..registry import tasks
-        return tasks
+        from .task.builtins import load_builtins
+        load_builtins(self)
+        return self._tasks

+ 57 - 0
celery/app/registry.py

@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.app.registry
+    ~~~~~~~~~~~~~~~~~~~
+
+    Registry of available tasks.
+
+    :copyright: (c) 2009 - 2012 by Ask Solem.
+    :license: BSD, see LICENSE for more details.
+
+"""
+from __future__ import absolute_import
+
+import inspect
+
+from .. import current_app
+from ..exceptions import NotRegistered
+
+
+class TaskRegistry(dict):
+    NotRegistered = NotRegistered
+
+    def register(self, task):
+        """Register a task in the task registry.
+
+        The task will be automatically instantiated if not already an
+        instance.
+
+        """
+        self[task.name] = inspect.isclass(task) and task() or task
+
+    def unregister(self, name):
+        """Unregister task by name.
+
+        :param name: name of the task to unregister, or a
+            :class:`celery.task.base.Task` with a valid `name` attribute.
+
+        :raises celery.exceptions.NotRegistered: if the task has not
+            been registered.
+
+        """
+        try:
+            # Might be a task class
+            name = name.name
+        except AttributeError:
+            pass
+        self.pop(name)
+
+    def pop(self, key, *args):
+        try:
+            return dict.pop(self, key, *args)
+        except KeyError:
+            raise self.NotRegistered(key)
+
+
+def _unpickle_task(name):
+    return current_app.tasks[name]

+ 7 - 6
celery/app/task/__init__.py

@@ -15,16 +15,18 @@ from __future__ import absolute_import
 import sys
 import threading
 
+from ... import current_app
 from ... import states
 from ...datastructures import ExceptionInfo
 from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...execute.trace import eager_trace_task
-from ...registry import tasks, _unpickle_task
 from ...result import EagerResult
 from ...utils import (fun_takes_kwargs, instantiate,
                       mattrgetter, uuid, maybe_reraise)
 from ...utils.mail import ErrorMail
 
+from ..registry import _unpickle_task
+
 extract_exec_options = mattrgetter("queue", "routing_key",
                                    "exchange", "immediate",
                                    "mandatory", "priority",
@@ -76,6 +78,7 @@ class TaskType(type):
 
     def __new__(cls, name, bases, attrs):
         new = super(TaskType, cls).__new__
+        app = attrs.get("app") or current_app
         task_module = attrs.get("__module__") or "__main__"
 
         if "__call__" in attrs:
@@ -118,6 +121,7 @@ class TaskType(type):
         # we may or may not be the first time the task tries to register
         # with the framework.  There should only be one class for each task
         # name, so we always return the registered version.
+        tasks = app._tasks
         task_name = attrs["name"]
         if task_name not in tasks:
             task_cls = new(cls, name, bases, attrs)
@@ -127,7 +131,7 @@ class TaskType(type):
         task = tasks[task_name].__class__
 
         # decorate with annotations from config.
-        task.app.annotate_task(task)
+        app.annotate_task(task)
         return task
 
     def __repr__(cls):
@@ -270,9 +274,6 @@ class BaseTask(object):
     #: Default task expiry time.
     expires = None
 
-    #: The type of task *(no longer used)*.
-    type = "regular"
-
     #: Execution strategy used, or the qualified name of one.
     Strategy = "celery.worker.strategy:default"
 
@@ -590,7 +591,7 @@ class BaseTask(object):
                                 options.pop("throw", None))
 
         # Make sure we get the task instance, not class.
-        task = tasks[self.name]
+        task = self.app._tasks[self.name]
 
         request = {"id": task_id,
                    "retries": retries,

+ 44 - 0
celery/app/task/builtins.py

@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+_builtins = []
+
+def builtin_task(constructor):
+    _builtins.append(constructor)
+    return constructor
+
+
+@builtin_task
+def add_backend_cleanup_task(app):
+
+    @app.task(name="celery.backend_cleanup")
+    def backend_cleanup():
+        app.backend.cleanup()
+
+    return backend_cleanup
+
+
+@builtin_task
+def add_unlock_chord_task(app):
+
+    @app.task(name="celery.chord_unlock", max_retries=None)
+    def unlock_chord(setid, callback, interval=1, propagate=False,
+            max_retries=None, result=None):
+        from ...result import AsyncResult, TaskSetResult
+        from ...task.sets import subtask
+
+        result = TaskSetResult(setid, map(AsyncResult, result))
+        if result.ready():
+            j = result.join_native if result.supports_native_join else result.join
+            subtask(callback).delay(j(propagate=propagate))
+        else:
+            unlock_chord.retry(countdown=interval, max_retries=max_retries)
+
+    return unlock_chord
+
+
+def load_builtins(app):
+    for constructor in _builtins:
+        constructor(app)
+
+

+ 1 - 2
celery/apps/worker.py

@@ -174,8 +174,7 @@ class Worker(configurated):
         self.loader.init_worker()
 
     def tasklist(self, include_builtins=True):
-        from ..registry import tasks
-        tasklist = tasks.keys()
+        tasklist = self.app.tasks.keys()
         if not include_builtins:
             tasklist = filter(lambda s: not s.startswith("celery."),
                               tasklist)

+ 2 - 3
celery/backends/base.py

@@ -207,10 +207,9 @@ class BaseBackend(object):
         pass
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
-        from ..registry import tasks
         kwargs["result"] = [r.task_id for r in result]
-        tasks["celery.chord_unlock"].apply_async((setid, body, ), kwargs,
-                                                 countdown=1)
+        self.app.tasks["celery.chord_unlock"].apply_async((setid, body, ),
+                                                          kwargs, countdown=1)
 
     def __reduce__(self, args=(), kwargs={}):
         return (unpickle_backend, (self.__class__, args, kwargs))

+ 1 - 2
celery/beat.py

@@ -27,7 +27,6 @@ from kombu.utils import reprcall
 
 from . import __version__
 from . import platforms
-from . import registry
 from . import signals
 from . import current_app
 from .app import app_or_default
@@ -215,7 +214,7 @@ class Scheduler(object):
         # so we have that done if an exception is raised (doesn't schedule
         # forever.)
         entry = self.reserve(entry)
-        task = registry.tasks.get(entry.task)
+        task = self.app.tasks.get(entry.task)
 
         try:
             if task:

+ 2 - 4
celery/bin/celeryctl.py

@@ -229,12 +229,11 @@ class result(Command):
     )
 
     def run(self, task_id, *args, **kwargs):
-        from .. import registry
         result_cls = self.app.AsyncResult
         task = kwargs.get("task")
 
         if task:
-            result_cls = registry.tasks[task].AsyncResult
+            result_cls = self.app.tasks[task].AsyncResult
         result = result_cls(task_id)
         self.out(self.prettify(result.get())[1])
 result = command(result)
@@ -378,7 +377,6 @@ class shell(Command):
     def run(self, force_ipython=False, force_bpython=False,
             force_python=False, without_tasks=False, eventlet=False,
             gevent=False, **kwargs):
-        from .. import registry
         if eventlet:
             import_module("celery.concurrency.eventlet")
         if gevent:
@@ -388,7 +386,7 @@ class shell(Command):
 
         if not without_tasks:
             self.locals.update(dict((task.__name__, task)
-                                for task in registry.tasks.itervalues()))
+                                for task in self.app.tasks.itervalues()))
 
         if force_python:
             return self.invoke_fallback_shell()

+ 1 - 2
celery/execute/trace.py

@@ -30,7 +30,6 @@ from .. import current_app
 from .. import states, signals
 from ..datastructures import ExceptionInfo
 from ..exceptions import RetryTaskError
-from ..registry import tasks
 from ..utils.serialization import get_pickleable_exception
 
 send_prerun = signals.task_prerun.send
@@ -107,7 +106,7 @@ class TraceInfo(object):
 
 def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
         Info=TraceInfo, eager=False, propagate=False):
-    task = task or tasks[name]
+    task = task or current_app.tasks[name]
     loader = loader or current_app.loader
     backend = task.backend
     ignore_result = task.ignore_result

+ 3 - 70
celery/registry.py

@@ -1,73 +1,6 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.registry
-    ~~~~~~~~~~~~~~~
-
-    Registry of available tasks.
-
-    :copyright: (c) 2009 - 2012 by Ask Solem.
-    :license: BSD, see LICENSE for more details.
-
-"""
 from __future__ import absolute_import
 
-import inspect
-
-from .exceptions import NotRegistered
-
-
-class TaskRegistry(dict):
-    NotRegistered = NotRegistered
-
-    def regular(self):
-        """Get all regular task types."""
-        return self.filter_types("regular")
-
-    def periodic(self):
-        """Get all periodic task types."""
-        return self.filter_types("periodic")
-
-    def register(self, task):
-        """Register a task in the task registry.
-
-        The task will be automatically instantiated if not already an
-        instance.
-
-        """
-        self[task.name] = inspect.isclass(task) and task() or task
-
-    def unregister(self, name):
-        """Unregister task by name.
-
-        :param name: name of the task to unregister, or a
-            :class:`celery.task.base.Task` with a valid `name` attribute.
-
-        :raises celery.exceptions.NotRegistered: if the task has not
-            been registered.
-
-        """
-        try:
-            # Might be a task class
-            name = name.name
-        except AttributeError:
-            pass
-        self.pop(name)
-
-    def filter_types(self, type):
-        """Return all tasks of a specific type."""
-        return dict((name, task) for name, task in self.iteritems()
-                                    if task.type == type)
-
-    def pop(self, key, *args):
-        try:
-            return dict.pop(self, key, *args)
-        except KeyError:
-            raise self.NotRegistered(key)
-
-
-#: Global task registry.
-tasks = TaskRegistry()
-
+from . import current_app
+from .local import Proxy
 
-def _unpickle_task(name):
-    return tasks[name]
+tasks = Proxy(lambda: current_app.tasks)

+ 1 - 1
celery/result.py

@@ -20,8 +20,8 @@ from itertools import imap
 from . import current_app
 from . import states
 from .app import app_or_default
+from .app.registry import _unpickle_task
 from .exceptions import TimeoutError
-from .registry import _unpickle_task
 from .utils.compat import OrderedDict
 
 

+ 1 - 4
celery/task/__init__.py

@@ -89,7 +89,4 @@ def periodic_task(*args, **options):
     """
     return task(**dict({"base": PeriodicTask}, **options))
 
-
-@task(name="celery.backend_cleanup")
-def backend_cleanup():
-    backend_cleanup.backend.cleanup()
+backend_cleanup = Proxy(lambda: current_app.tasks["celery.backend_cleanup"])

+ 0 - 1
celery/task/base.py

@@ -78,7 +78,6 @@ class PeriodicTask(Task):
     """
     abstract = True
     ignore_result = True
-    type = "periodic"
     relative = False
     options = None
 

+ 1 - 13
celery/task/chords.py

@@ -12,21 +12,9 @@
 from __future__ import absolute_import
 
 from .. import current_app
-from ..result import AsyncResult, TaskSetResult
 from ..utils import uuid
 
-from .sets import TaskSet, subtask
-
-
-@current_app.task(name="celery.chord_unlock", max_retries=None)
-def _unlock_chord(setid, callback, interval=1, propagate=False,
-        max_retries=None, result=None):
-    result = TaskSetResult(setid, map(AsyncResult, result))
-    if result.ready():
-        j = result.join_native if result.supports_native_join else result.join
-        subtask(callback).delay(j(propagate=propagate))
-    else:
-        _unlock_chord.retry(countdown=interval, max_retries=max_retries)
+from .sets import TaskSet
 
 
 class Chord(current_app.Task):

+ 2 - 2
celery/task/sets.py

@@ -12,7 +12,7 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-from .. import registry
+from .. import current_app
 from ..app import app_or_default
 from ..datastructures import AttributeDict
 from ..utils import cached_property, reprcall, uuid
@@ -93,7 +93,7 @@ class subtask(AttributeDict):
 
     @cached_property
     def type(self):
-        return registry.tasks[self.task]
+        return current_app.tasks[self.task]
 
 
 def maybe_subtask(t):

+ 1 - 2
celery/tests/test_app/test_beat.py

@@ -6,7 +6,6 @@ from datetime import datetime, timedelta
 from nose import SkipTest
 
 from celery import beat
-from celery import registry
 from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.task.base import Task
@@ -166,7 +165,7 @@ class test_Scheduler(Case):
             def apply_async(cls, *args, **kwargs):
                 through_task[0] = True
 
-        assert MockTask.name in registry.tasks
+        assert MockTask.name in MockTask.app.tasks
 
         scheduler = mScheduler()
         scheduler.apply_async(scheduler.Entry(task=MockTask.name))

+ 4 - 4
celery/tests/test_backends/test_base.py

@@ -7,6 +7,7 @@ import types
 from mock import Mock
 from nose import SkipTest
 
+from celery import current_app
 from celery.result import AsyncResult
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
@@ -98,14 +99,13 @@ class test_BaseBackend_interface(Case):
             b.forget("SOMExx-N0nex1stant-IDxx-")
 
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
-        from celery.registry import tasks
-        p, tasks[unlock] = tasks.get(unlock), Mock()
+        p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
         try:
             b.on_chord_apply("dakj221", "sdokqweok",
                              result=map(AsyncResult, [1, 2, 3]))
-            self.assertTrue(tasks[unlock].apply_async.call_count)
+            self.assertTrue(current_app.tasks[unlock].apply_async.call_count)
         finally:
-            tasks[unlock] = p
+            current_app.tasks[unlock] = p
 
 
 class test_exception_pickle(Case):

+ 3 - 3
celery/tests/test_backends/test_cache.py

@@ -8,10 +8,10 @@ from contextlib import contextmanager
 
 from mock import Mock, patch
 
+from celery import current_app
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
-from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.task import subtask
 from celery.utils import uuid
@@ -70,7 +70,7 @@ class test_CacheBackend(Case):
         task = Mock()
         task.name = "foobarbaz"
         try:
-            tasks["foobarbaz"] = task
+            current_app.tasks["foobarbaz"] = task
             task.request.chord = subtask(task)
             task.request.taskset = "setid"
 
@@ -85,7 +85,7 @@ class test_CacheBackend(Case):
             deps.delete.assert_called_with()
 
         finally:
-            tasks.pop("foobarbaz")
+            current_app.tasks.pop("foobarbaz")
 
     def test_mget(self):
         self.tb.set("foo", 1)

+ 2 - 3
celery/tests/test_backends/test_redis_unit.py

@@ -7,7 +7,6 @@ from mock import Mock, patch
 from celery import current_app
 from celery import states
 from celery.result import AsyncResult
-from celery.registry import tasks
 from celery.task import subtask
 from celery.utils import cached_property, uuid
 from celery.utils.timeutils import timedelta_seconds
@@ -129,7 +128,7 @@ class test_RedisBackend(Case):
         task = Mock()
         task.name = "foobarbaz"
         try:
-            tasks["foobarbaz"] = task
+            current_app.tasks["foobarbaz"] = task
             task.request.chord = subtask(task)
             task.request.taskset = "setid"
 
@@ -143,7 +142,7 @@ class test_RedisBackend(Case):
 
             self.assertTrue(b.client.expire.call_count)
         finally:
-            tasks.pop("foobarbaz")
+            current_app.tasks.pop("foobarbaz")
 
     def test_process_cleanup(self):
         self.Backend().process_cleanup()

+ 1 - 1
celery/tests/test_slow/test_buckets.py

@@ -7,7 +7,7 @@ import time
 from functools import partial
 from itertools import chain, izip
 
-from celery.registry import TaskRegistry
+from celery.app.registry import TaskRegistry
 from celery.task.base import Task
 from celery.utils import timeutils
 from celery.utils import uuid

+ 49 - 32
celery/tests/test_task/test_chord.py

@@ -1,11 +1,14 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 from mock import patch
+from contextlib import contexmanager
 
 from celery import current_app
+from celery import result
 from celery.result import AsyncResult
 from celery.task import chords
-from celery.task import TaskSet
+from celery.task import task, TaskSet
 from celery.tests.utils import AppCase, Mock
 
 passthru = lambda x: x
@@ -35,47 +38,61 @@ class TSR(chords.TaskSetResult):
         return self.value
 
 
-class test_unlock_chord_task(AppCase):
+@contextmanager
+def patch_unlock_retry():
+    unlock = current_app.tasks["celery.chord_unlock"]
+    retry = Mock()
+    prev, unlock.retry = unlock.retry, retry
+    yield unlock, retry
+    unlock.retry = prev
 
-    @patch("celery.task.chords._unlock_chord.retry")
-    def test_unlock_ready(self, retry):
-        callback.apply_async = Mock()
 
-        pts, chords.TaskSetResult = chords.TaskSetResult, TSR
-        subtask, chords.subtask = chords.subtask, passthru
-        try:
-            chords._unlock_chord("setid", callback.subtask(),
-                    result=map(AsyncResult, [1, 2, 3]))
-        finally:
-            chords.subtask = subtask
-            chords.TaskSetResult = pts
-        callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
-        # did not retry
-        self.assertFalse(retry.call_count)
-
-    @patch("celery.task.chords.TaskSetResult")
-    @patch("celery.task.chords._unlock_chord.retry")
-    def test_when_not_ready(self, retry, TaskSetResult):
-        callback.apply_async = Mock()
+class test_unlock_chord_task(AppCase):
+
+    @patch("celery.result.TaskSetResult")
+    def test_unlock_ready(self, TaskSetResult):
+        tasks = current_app.tasks
 
         class NeverReady(TSR):
             is_ready = False
 
-        pts, chords.TaskSetResult = chords.TaskSetResult, NeverReady
+        @task
+        def callback(*args, **kwargs):
+            pass
+
+        pts, result.TaskSetResult  = result.TaskSetResult, NeverReady
+        callback.apply_async = Mock()
         try:
-            chords._unlock_chord("setid", callback.subtask, interval=10,
-                                max_retries=30,
-                                result=map(AsyncResult, [1, 2, 3]))
-            self.assertFalse(callback.apply_async.call_count)
-            # did retry
-            chords._unlock_chord.retry.assert_called_with(countdown=10,
-                                                          max_retries=30)
+            with patch_unlock_retry() as (unlock, retry):
+                result = Mock(attrs=dict(ready=lambda: True,
+                                        join=lambda **kw: [2, 4, 8, 6]))
+                TaskSetResult.restore = lambda setid: result
+                subtask, chords.subtask = chords.subtask, passthru
+                try:
+                    unlock("setid", callback,
+                           result=map(AsyncResult, [1, 2, 3]))
+                finally:
+                    chords.subtask = subtask
+                callback.apply_async.assert_called_with(([2, 4, 8, 6], ), {})
+                result.delete.assert_called_with()
+                # did not retry
+                self.assertFalse(retry.call_count)
         finally:
-            chords.TaskSetResult = pts
+            result.TaskSetResult = pts
+
+    @patch("celery.result.TaskSetResult")
+    def test_when_not_ready(self, TaskSetResult):
+        with patch_unlock_retry() as (unlock, retry):
+            callback = Mock()
+            result = Mock(attrs=dict(ready=lambda: False))
+            TaskSetResult.restore = lambda setid: result
+            unlock("setid", callback, interval=10, max_retries=30,)
+            self.assertFalse(callback.delay.call_count)
+            # did retry
+            unlock.retry.assert_called_with(countdown=10, max_retries=30)
 
     def test_is_in_registry(self):
-        from celery.registry import tasks
-        self.assertIn("celery.chord_unlock", tasks)
+        self.assertIn("celery.chord_unlock", current_app.tasks)
 
 
 class test_chord(AppCase):

+ 2 - 10
celery/tests/test_task/test_registry.py

@@ -1,7 +1,7 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-from celery import registry
+from celery.app.registry import TaskRegistry
 from celery.task import Task, PeriodicTask
 from celery.tests.utils import Case
 
@@ -36,7 +36,7 @@ class TestTaskRegistry(Case):
         self.assertIn(task_name, r)
 
     def test_task_registry(self):
-        r = registry.TaskRegistry()
+        r = TaskRegistry()
         self.assertIsInstance(r, dict,
                 "TaskRegistry is mapping")
 
@@ -53,14 +53,6 @@ class TestTaskRegistry(Case):
         self.assertIsInstance(tasks.get(TestPeriodicTask.name),
                                    TestPeriodicTask)
 
-        regular = r.regular()
-        self.assertIn(TestTask.name, regular)
-        self.assertNotIn(TestPeriodicTask.name, regular)
-
-        periodic = r.periodic()
-        self.assertNotIn(TestTask.name, periodic)
-        self.assertIn(TestPeriodicTask.name, periodic)
-
         self.assertIsInstance(r[TestTask.name], TestTask)
         self.assertIsInstance(r[TestPeriodicTask.name],
                                    TestPeriodicTask)

+ 2 - 3
celery/tests/test_worker/test_worker_control.py

@@ -12,7 +12,6 @@ from mock import Mock, patch
 from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.task import task
-from celery.registry import tasks
 from celery.utils import uuid
 from celery.utils.timer2 import Timer
 from celery.worker import WorkController as _WC
@@ -275,9 +274,9 @@ class test_ControlPanel(Case):
                 self.ready_queue = self.ReadyQueue()
 
         consumer = Consumer()
-        panel = self.create_panel(consumer=consumer)
+        panel = self.create_panel(app=current_app, consumer=consumer)
 
-        task = tasks[mytask.name]
+        task = current_app.tasks[mytask.name]
         old_rate_limit = task.rate_limit
         try:
             panel.handle("rate_limit", arguments=dict(task_name=task.name,

+ 3 - 2
celery/tests/test_worker/test_worker_job.py

@@ -14,6 +14,7 @@ from kombu.transport.base import Message
 from mock import Mock
 from nose import SkipTest
 
+from celery import current_app
 from celery import states
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
@@ -22,7 +23,6 @@ from celery.exceptions import (RetryTaskError,
                                WorkerLostError, InvalidTaskError)
 from celery.execute.trace import eager_trace_task, TraceInfo
 from celery.log import setup_logger
-from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.task import task as task_dec
 from celery.task.base import Task
@@ -39,7 +39,8 @@ some_kwargs_scratchpad = {}
 
 
 def jail(task_id, name, args, kwargs):
-    return eager_trace_task(tasks[name], task_id, args, kwargs, eager=False)[0]
+    return eager_trace_task(current_app.tasks[name],
+                            task_id, args, kwargs, eager=False)[0]
 
 
 def on_ack(*args, **kwargs):

+ 1 - 2
celery/worker/__init__.py

@@ -25,7 +25,6 @@ from kombu.utils.finalize import Finalize
 
 from .. import abstract
 from .. import concurrency as _concurrency
-from .. import registry
 from ..app import app_or_default
 from ..app.abstract import configurated, from_config
 from ..exceptions import SystemTerminate
@@ -131,7 +130,7 @@ class Queues(abstract.Component):
                 # just send task directly to pool, skip the mediator.
                 w.ready_queue.put = w.process_task
         else:
-            w.ready_queue = TaskBucket(task_registry=registry.tasks)
+            w.ready_queue = TaskBucket(task_registry=self.app.tasks)
 
 
 class Timers(abstract.Component):

+ 1 - 2
celery/worker/consumer.py

@@ -85,7 +85,6 @@ from ..abstract import StartStopComponent
 from ..app import app_or_default
 from ..datastructures import AttributeDict
 from ..exceptions import InvalidTaskError
-from ..registry import tasks
 from ..utils import noop
 from ..utils import timer2
 from ..utils.encoding import safe_repr
@@ -317,7 +316,7 @@ class Consumer(object):
 
     def update_strategies(self):
         S = self.strategies
-        for task in tasks.itervalues():
+        for task in self.app.tasks.itervalues():
             S[task.name] = task.start_strategy(self.app, self)
 
     def start(self):

+ 4 - 4
celery/worker/control.py

@@ -14,7 +14,6 @@ from __future__ import absolute_import
 from datetime import datetime
 
 from ..platforms import signals as _signals
-from ..registry import tasks
 from ..utils import timeutils
 from ..utils.compat import UserDict
 from ..utils.encoding import safe_repr
@@ -97,7 +96,7 @@ def rate_limit(panel, task_name, rate_limit, **kwargs):
         return {"error": "Invalid rate limit string: %s" % exc}
 
     try:
-        tasks[task_name].rate_limit = rate_limit
+        panel.app.tasks[task_name].rate_limit = rate_limit
     except KeyError:
         panel.logger.error("Rate limit attempt for unknown task %s",
                            task_name, exc_info=True)
@@ -122,7 +121,7 @@ def rate_limit(panel, task_name, rate_limit, **kwargs):
 @Panel.register
 def time_limit(panel, task_name=None, hard=None, soft=None, **kwargs):
     try:
-        task = tasks[task_name]
+        task = panel.app.tasks[task_name]
     except KeyError:
         panel.logger.error("Change time limit attempt for unknown task %s",
                            task_name, exc_info=True)
@@ -195,6 +194,7 @@ def dump_revoked(panel, **kwargs):
 
 @Panel.register
 def dump_tasks(panel, **kwargs):
+    tasks = panel.app.tasks
 
     def _extract_info(task):
         fields = dict((field, str(getattr(task, field, None)))
@@ -206,7 +206,7 @@ def dump_tasks(panel, **kwargs):
         return "%s [%s]" % (task.name, " ".join(info))
 
     info = map(_extract_info, (tasks[task]
-                                        for task in sorted(tasks.keys())))
+                                    for task in sorted(tasks.keys())))
     panel.logger.debug("* Dump of currently registered tasks:\n%s",
                        "\n".join(info))
 

+ 4 - 4
celery/worker/job.py

@@ -19,9 +19,9 @@ import sys
 
 from datetime import datetime
 
+from .. import current_app
 from .. import exceptions
 from ..datastructures import ExceptionInfo
-from ..registry import tasks
 from ..app import app_or_default
 from ..execute.trace import build_tracer, trace_task, report_internal_error
 from ..platforms import set_mp_process_title as setps
@@ -47,7 +47,7 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
         >>> trace_task(name, *args, **kwargs)[0]
 
     """
-    task = tasks[name]
+    task = current_app.tasks[name]
     try:
         hostname = opts.get("hostname")
         setps("celeryd", name, hostname, rate_limit=True)
@@ -114,7 +114,7 @@ class Request(object):
         self.logger = logger or self.app.log.get_default_logger()
         self.eventer = eventer
         self.connection_errors = connection_errors or ()
-        self.task = task or tasks[name]
+        self.task = task or self.app.tasks[name]
         self.acknowledged = self._already_revoked = False
         self.time_start = self.worker_pid = self._terminate_on_ack = None
         self._tzlocal = None
@@ -392,7 +392,7 @@ class Request(object):
                                         "name": self.name,
                                         "hostname": self.hostname}})
 
-        task_obj = tasks.get(self.name, object)
+        task_obj = self.app.tasks.get(self.name, object)
         task_obj.send_error_email(context, exc_info.exception)
 
     def acknowledge(self):