فهرست منبع

Adds shared task decorator (from celery import shared_task), for library authors

Ask Solem 12 سال پیش
والد
کامیت
359375ee72
7فایلهای تغییر یافته به همراه132 افزوده شده و 47 حذف شده
  1. 6 6
      celery/__init__.py
  2. 23 0
      celery/_state.py
  3. 56 0
      celery/app/__init__.py
  4. 13 9
      celery/app/base.py
  5. 3 2
      celery/app/builtins.py
  6. 4 30
      celery/app/task.py
  7. 27 0
      celery/utils/__init__.py

+ 6 - 6
celery/__init__.py

@@ -21,12 +21,12 @@ from .__compat__ import recreate_module
 
 old_module, new_module = recreate_module(__name__,  # pragma: no cover
     by_module={
-        'celery.app':       ['Celery', 'bugreport'],
-        'celery.app.task':  ['Task'],
-        'celery._state':    ['current_app', 'current_task'],
-        'celery.canvas':    ['chain', 'chord', 'chunks',
-                             'group', 'subtask', 'xmap', 'xstarmap'],
-        'celery.utils':     ['uuid'],
+        'celery.app':          ['Celery', 'bugreport', 'shared_task'],
+        'celery.app.task':     ['Task'],
+        'celery._state':       ['current_app', 'current_task'],
+        'celery.canvas':       ['chain', 'chord', 'chunks',
+                                'group', 'subtask', 'xmap', 'xstarmap'],
+        'celery.utils':        ['uuid'],
     },
     direct={'task': 'celery.task'},
     __package__='celery', __file__=__file__,

+ 23 - 0
celery/_state.py

@@ -12,6 +12,7 @@
 from __future__ import absolute_import
 
 import threading
+import weakref
 
 from celery.local import Proxy
 from celery.utils.threads import LocalStack
@@ -62,3 +63,25 @@ def get_current_worker_task():
 
 current_app = Proxy(get_current_app)
 current_task = Proxy(get_current_task)
+
+#: WeakSet does not seem to work properly,
+#: it doesn't recognize when objects go out of scope.
+_apps = set()
+
+
+def _register_app(app):
+    _apps.add(weakref.ref(app))
+
+
+def _get_active_apps():
+    dirty = []
+    try:
+        for appref in _apps:
+            app = appref()
+            if app is None:
+                dirty.append(appref)
+            else:
+                yield app
+    finally:
+        while dirty:
+            _apps.discard(dirty.pop())

+ 56 - 0
celery/app/__init__.py

@@ -7,6 +7,7 @@
 
 """
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import os
 
@@ -16,8 +17,11 @@ from celery._state import (  # noqa
         set_default_app,
         get_current_app as current_app,
         get_current_task as current_task,
+        _get_active_apps,
 )
+from celery.utils import gen_task_name
 
+from .builtins import shared_task as _shared_task
 from .base import Celery, AppPickler  # noqa
 
 #: Proxy always returning the app set as default.
@@ -80,3 +84,55 @@ else:
     disable_trace()
 
 App = Celery  # XXX Compat
+
+
+def shared_task(*args, **kwargs):
+    """Task decorator that creates shared tasks,
+    and returns a proxy that always returns the task from the current apps
+    task registry.
+
+    This can be used by library authors to create tasks that will work
+    for any app environment.
+
+    Example:
+
+        >>> from celery import Celery, shared_task
+        >>> @shared_task
+        ... def add(x, y):
+        ...     return x + y
+
+        >>> app1 = Celery(broker='amqp://')
+        >>> add.app is app1
+        True
+
+        >>> app2 = Celery(broker='redis://')
+        >>> add.app is app2
+
+    """
+
+    def create_shared_task(**options):
+
+        def __inner(fun):
+            name = options.get('name')
+            # Set as shared task so that unfinalized apps,
+            # and future apps will load the task.
+            _shared_task(lambda app: app._task_from_fun(fun, **options))
+
+            # Force all finalized apps to take this task as well.
+            for app in _get_active_apps():
+                if app.finalized:
+                    with app._finalize_mutex:
+                        app._task_from_fun(fun, **options)
+
+            # Returns a proxy that always gets the task from the current
+            # apps task registry.
+            def task_by_cons():
+                app = current_app()
+                return app.tasks[name or gen_task_name(app,
+                            fun.__name__, fun.__module__)]
+            return Proxy(task_by_cons)
+        return __inner
+
+    if len(args) == 1 and callable(args[0]):
+        return create_shared_task(**kwargs)(args[0])
+    return create_shared_task(**kwargs)

+ 13 - 9
celery/app/base.py

@@ -15,6 +15,7 @@ from collections import deque
 from contextlib import contextmanager
 from copy import deepcopy
 from functools import wraps
+from threading import Lock
 
 from billiard.util import register_after_fork
 from kombu.clocks import LamportClock
@@ -24,7 +25,7 @@ from celery import platforms
 from celery.exceptions import AlwaysEagerIgnored
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
-from celery._state import _task_stack, _tls, get_current_app
+from celery._state import _task_stack, _tls, get_current_app, _register_app
 from celery.utils.functional import first
 from celery.utils.imports import instantiate, symbol_by_name
 
@@ -73,6 +74,7 @@ class Celery(object):
         self.accept_magic_kwargs = accept_magic_kwargs
 
         self.finalized = False
+        self._finalize_mutex = Lock()
         self._pending = deque()
         self._tasks = tasks
         if not isinstance(self._tasks, TaskRegistry):
@@ -89,6 +91,7 @@ class Celery(object):
         if self.set_as_current:
             self.set_current()
         self.on_init()
+        _register_app(self)
 
     def set_current(self):
         _tls.current_app = self
@@ -148,16 +151,17 @@ class Celery(object):
         return task
 
     def finalize(self):
-        if not self.finalized:
-            self.finalized = True
-            load_shared_tasks(self)
+        with self._finalize_mutex:
+            if not self.finalized:
+                self.finalized = True
+                load_shared_tasks(self)
 
-            pending = self._pending
-            while pending:
-                maybe_evaluate(pending.pop())
+                pending = self._pending
+                while pending:
+                    maybe_evaluate(pending.pop())
 
-            for task in self._tasks.itervalues():
-                task.bind(self)
+                for task in self._tasks.itervalues():
+                    task.bind(self)
 
     def config_from_object(self, obj, silent=False):
         del(self.conf)

+ 3 - 2
celery/app/builtins.py

@@ -89,6 +89,7 @@ def add_map_task(app):
     def xmap(task, it):
         task = subtask(task).type
         return list(map(task, it))
+    return xmap
 
 
 @shared_task
@@ -99,6 +100,7 @@ def add_starmap_task(app):
     def xstarmap(task, it):
         task = subtask(task).type
         return list(starmap(task, it))
+    return xstarmap
 
 
 @shared_task
@@ -108,6 +110,7 @@ def add_chunk_task(app):
     @app.task(name='celery.chunks')
     def chunks(task, it, n):
         return _chunks.apply_chunks(task, it, n)
+    return chunks
 
 
 @shared_task
@@ -204,7 +207,6 @@ def add_chain_task(app):
                 res = task.apply((prev.get(), ) if prev else ())
                 res.parent, prev = prev, res
             return res
-
     return Chain
 
 
@@ -270,5 +272,4 @@ def add_chord_task(app):
                                            **options)
             return maybe_subtask(body).apply(
                         args=(res.get(propagate=propagate).get().join(), ))
-
     return Chord

+ 4 - 30
celery/app/task.py

@@ -9,9 +9,6 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import os
-import sys
-
 from celery import current_app
 from celery import states
 from celery.__compat__ import class_property
@@ -19,7 +16,7 @@ from celery._state import get_current_worker_task, _task_stack
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.result import EagerResult
-from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
+from celery.utils import gen_task_name, fun_takes_kwargs, uuid, maybe_reraise
 from celery.utils.functional import mattrgetter, maybe_list
 from celery.utils.imports import instantiate
 from celery.utils.mail import ErrorMail
@@ -34,12 +31,6 @@ extract_exec_options = mattrgetter(
     'serializer', 'delivery_mode', 'compression',
 )
 
-#: Billiard sets this when execv is enabled.
-#: We use it to find out the name of the original ``__main__``
-#: module, so that we can properly rewrite the name of the
-#: task to be that of ``App.main``.
-MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE') or None
-
 
 class Context(object):
     # Default context
@@ -112,15 +103,9 @@ class TaskType(type):
         app = attrs['_app'] = _app1 or _app2 or current_app
 
         # - Automatically generate missing/empty name.
-        autoname = False
-        if not attrs.get('name'):
-            try:
-                module_name = sys.modules[task_module].__name__
-            except KeyError:  # pragma: no cover
-                # Fix for manage.py shell_plus (Issue #366).
-                module_name = task_module
-            attrs['name'] = '.'.join(filter(None, [module_name, name]))
-            autoname = True
+        task_name = attrs.get('name')
+        if not task_name:
+            attrs['name'] = task_name = gen_task_name(app, name, task_module)
 
         # - Create and register class.
         # Because of the way import happens (recursively)
@@ -128,17 +113,6 @@ class TaskType(type):
         # with the framework.  There should only be one class for each task
         # name, so we always return the registered version.
         tasks = app._tasks
-
-        # - If the task module is used as the __main__ script
-        # - we need to rewrite the module part of the task name
-        # - to match App.main.
-        if MP_MAIN_FILE and sys.modules[task_module].__file__ == MP_MAIN_FILE:
-            # - see comment about :envvar:`MP_MAIN_FILE` above.
-            task_module = '__main__'
-        if autoname and task_module == '__main__' and app.main:
-            attrs['name'] = '.'.join([app.main, name])
-
-        task_name = attrs['name']
         if task_name not in tasks:
             tasks.register(new(cls, name, bases, attrs))
         instance = tasks[task_name]

+ 27 - 0
celery/utils/__init__.py

@@ -10,6 +10,7 @@ from __future__ import absolute_import
 from __future__ import with_statement
 
 import operator
+import os
 import sys
 import threading
 import traceback
@@ -35,6 +36,12 @@ DEPRECATION_FMT = """
     version %(removal)s. %(alternative)s
 """
 
+#: Billiard sets this when execv is enabled.
+#: We use it to find out the name of the original ``__main__``
+#: module, so that we can properly rewrite the name of the
+#: task to be that of ``App.main``.
+MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE') or None
+
 
 def warn_deprecated(description=None, deprecation=None, removal=None,
         alternative=None):
@@ -173,6 +180,26 @@ def strtobool(term, table={'false': False, 'no': False, '0': False,
             raise TypeError('Cannot coerce %r to type bool' % (term, ))
     return term
 
+
+def gen_task_name(app, name, module_name):
+    try:
+        module = sys.modules[module_name]
+    except KeyError:
+        # Fix for manage.py shell_plus (Issue #366)
+        module = None
+
+    if module is not None:
+        module_name = module.__name__
+        # - If the task module is used as the __main__ script
+        # - we need to rewrite the module part of the task name
+        # - to match App.main.
+        if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE:
+            # - see comment about :envvar:`MP_MAIN_FILE` above.
+            module_name = '__main__'
+    if module_name == '__main__' and app.main:
+        return '.'.join([app.main, name])
+    return '.'.join(filter(None, [module_name, name]))
+
 # ------------------------------------------------------------------------ #
 # > XXX Compat
 from .log import LOG_LEVELS     # noqa