Browse Source

Tasks can now have callbacks and errbacks, and dependencies are recorded

- The task message format have been updated with two new extension keys

    Both keys can be empty/undefined or a list of subtasks.

    - ``callbacks``

        Applied if the task exits successfully, with the result
        of the task as an argument.

    - ``errbacks``

        Applied if an error occurred while executing the task,
        with the uuid of the task as an argument.  Since it may not be possible
        to serialize the exception instance, it passes the uuid of the task
        instead.  The uuid can then be used to retrieve the exception and
        traceback of the task from the result backend.

- ``link`` and ``link_error`` keyword arguments has been added
  to ``apply_async``.

    The value passed can be either a subtask or a list of
    subtasks:

    .. code-block:: python

        add.apply_async((2, 2), link=mul.subtask())
        add.apply_async((2, 2), link=[mul.subtask(), echo.subtask()])

    Example error callback:

    .. code-block:: python

        @task
        def error_handler(uuid):
            result = AsyncResult(uuid)
            exc = result.get(propagate=False)
            print("Task %r raised exception: %r\n%r" % (
                exc, result.traceback))

        >>> add.apply_async((2, 2), link_error=error_handler)

- We now track what subtasks a task sends, and some result backends
  supports retrieving this information.

    - task.request.children

        Contains the result instances of the subtasks
        the currently executing task has applied.

    - AsyncResult.children

        Returns the tasks dependencies, as a list of
        ``AsyncResult``/``ResultSet`` instances.

    - AsyncResult.iterdeps

        Recursively iterates over the tasks dependencies,
        yielding `(parent, node)` tuples.

        Raises IncompleteStream if any of the dependencies
        has not returned yet.

    - AsyncResult.graph

        A ``DependencyGraph`` of the tasks dependencies.
        This can also be used to convert to dot format:

        .. code-block:: python

            with open("graph.dot") as fh:
                result.graph.to_dot(fh)

        which can than be used to produce an image::

            $ dot -Tpng graph.dot -o graph.png
Ask Solem 13 years ago
parent
commit
041cb3c20d

+ 2 - 16
celery/app/__init__.py

@@ -13,24 +13,14 @@
 from __future__ import absolute_import
 
 import os
-import threading
 
 from ..local import PromiseProxy
 from ..utils import cached_property, instantiate
 
 from . import annotations
 from . import base
-
-
-class _TLS(threading.local):
-    #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
-    #: sets this, so it will always contain the last instantiated app,
-    #: and is the default app returned by :func:`app_or_default`.
-    current_app = None
-
-    #: The currently executing task.
-    current_task = None
-_tls = _TLS()
+from .state import _tls
+from .state import current_task  # noqa
 
 
 class AppPickler(object):
@@ -249,10 +239,6 @@ def current_app():
     return getattr(_tls, "current_app", None) or default_app
 
 
-def current_task():
-    return getattr(_tls, "current_task", None)
-
-
 def _app_or_default(app=None):
     """Returns the app provided or the default app if none.
 

+ 5 - 2
celery/app/amqp.py

@@ -183,7 +183,8 @@ class TaskPublisher(messaging.Publisher):
             countdown=None, eta=None, task_id=None, taskset_id=None,
             expires=None, exchange=None, exchange_type=None,
             event_dispatcher=None, retry=None, retry_policy=None,
-            queue=None, now=None, retries=0, chord=None, **kwargs):
+            queue=None, now=None, retries=0, chord=None, callbacks=None,
+            errbacks=None, **kwargs):
         """Send task message."""
 
         connection = self.connection
@@ -224,7 +225,9 @@ class TaskPublisher(messaging.Publisher):
                 "retries": retries or 0,
                 "eta": eta,
                 "expires": expires,
-                "utc": self.utc}
+                "utc": self.utc,
+                "callbacks": callbacks,
+                "errbacks": errbacks}
         if taskset_id:
             body["taskset"] = taskset_id
         if chord:

+ 18 - 0
celery/app/state.py

@@ -0,0 +1,18 @@
+from __future__ import absolute_import
+
+import threading
+
+
+class _TLS(threading.local):
+    #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
+    #: sets this, so it will always contain the last instantiated app,
+    #: and is the default app returned by :func:`app_or_default`.
+    current_app = None
+
+    #: The currently executing task.
+    current_task = None
+_tls = _TLS()
+
+
+def current_task():
+    return getattr(_tls, "current_task", None)

+ 24 - 2
celery/app/task/__init__.py

@@ -22,9 +22,11 @@ from ...exceptions import MaxRetriesExceededError, RetryTaskError
 from ...result import EagerResult
 from ...utils import (fun_takes_kwargs, instantiate,
                       mattrgetter, uuid, maybe_reraise)
+from ...utils.functional import maybe_list
 from ...utils.mail import ErrorMail
 from ...utils.compat import fun_of_method
 
+from ..state import current_task
 from ..registry import _unpickle_task
 
 #: extracts options related to publishing a message from a dict.
@@ -57,6 +59,9 @@ class Context(threading.local):
     taskset = None
     chord = None
     called_directly = True
+    callbacks = None
+    errbacks = None
+    _children = None   # see property
 
     def update(self, d, **kwargs):
         self.__dict__.update(d, **kwargs)
@@ -73,6 +78,13 @@ class Context(threading.local):
     def __repr__(self):
         return "<Context: %r>" % (vars(self, ))
 
+    @property
+    def children(self):
+        # children must be an empy list for every thread
+        if self._children is None:
+            self._children = []
+        return self._children
+
 
 class TaskType(type):
     """Meta class for tasks.
@@ -442,7 +454,7 @@ class BaseTask(object):
 
     def apply_async(self, args=None, kwargs=None,
             task_id=None, publisher=None, connection=None,
-            router=None, queues=None, **options):
+            router=None, queues=None, link=None, link_error=None, **options):
         """Apply tasks asynchronously by sending a message.
 
         :keyword args: The positional arguments to pass on to the
@@ -519,6 +531,10 @@ class BaseTask(object):
                               :func:`kombu.compression.register`. Defaults to
                               the :setting:`CELERY_MESSAGE_COMPRESSION`
                               setting.
+        :keyword link: A single, or a list of subtasks to apply if the
+                       task exits successfully.
+        :keyword link_error: A single, or a list of subtasks to apply
+                      if an error occurs while executing the task.
 
         .. note::
             If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
@@ -544,12 +560,18 @@ class BaseTask(object):
             task_id = publish.delay_task(self.name, args, kwargs,
                                          task_id=task_id,
                                          event_dispatcher=evd,
+                                         callbacks=maybe_list(link),
+                                         errbacks=maybe_list(link_error),
                                          **options)
         finally:
             if not publisher:
                 publish.release()
 
-        return self.AsyncResult(task_id)
+        result = self.AsyncResult(task_id)
+        parent = current_task()
+        if parent:
+            parent.request.children.append(result)
+        return result
 
     def retry(self, args=None, kwargs=None, exc=None, throw=True,
             eta=None, countdown=None, max_retries=None, **options):

+ 15 - 6
celery/backends/base.py

@@ -10,6 +10,7 @@ from datetime import timedelta
 from kombu import serialization
 
 from .. import states
+from ..app import current_task
 from ..datastructures import LRUCache
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..utils import timeutils
@@ -174,6 +175,10 @@ class BaseBackend(object):
         raise NotImplementedError(
                 "get_result is not supported by this backend.")
 
+    def get_children(self, task_id):
+        raise NotImplementedError(
+                "get_children is not supported by this backend.")
+
     def get_traceback(self, task_id):
         """Get the traceback for a failed task."""
         raise NotImplementedError(
@@ -211,14 +216,10 @@ class BaseBackend(object):
         self.app.tasks["celery.chord_unlock"].apply_async((setid, body, ),
                                                           kwargs, countdown=1)
 
-    def _serializable_child(self):
-        if isinstance(node, ResultSet):
-            return (node, )
-
     def current_task_children(self):
         current = current_task()
         if current:
-            return
+            return [r.serializable() for r in current.request.children]
 
     def __reduce__(self, args=(), kwargs={}):
         return (unpickle_backend, (self.__class__, args, kwargs))
@@ -260,6 +261,13 @@ class BaseDictBackend(BaseBackend):
         else:
             return meta["result"]
 
+    def get_children(self, task_id):
+        """Get the list of subtasks sent by a task."""
+        try:
+            return self.get_task_meta(task_id)["children"]
+        except KeyError:
+            pass
+
     def get_task_meta(self, task_id, cache=True):
         if cache:
             try:
@@ -386,7 +394,8 @@ class KeyValueStoreBackend(BaseDictBackend):
         self.delete(self.get_key_for_task(task_id))
 
     def _store_result(self, task_id, result, status, traceback=None):
-        meta = {"status": status, "result": result, "traceback": traceback}
+        meta = {"status": status, "result": result, "traceback": traceback,
+                "children": self.current_task_children()}
         self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 

+ 1 - 1
celery/backends/redis.py

@@ -44,7 +44,7 @@ class RedisBackend(KeyValueStoreBackend):
         if self.redis is None:
             raise ImproperlyConfigured(
                     "You need to install the redis library in order to use "
-                  + "Redis result store backend.")
+                  + "the Redis result store backend.")
 
         # For compatibility with the old REDIS_* configuration keys.
         def _get(key):

+ 13 - 6
celery/datastructures.py

@@ -52,7 +52,7 @@ class DependencyGraph(object):
 
     def add_arc(self, obj):
         """Add an object to the graph."""
-        self.adjacent[obj] = []
+        self.adjacent.setdefault(obj, [])
 
     def add_edge(self, A, B):
         """Add an edge from object ``A`` to object ``B``
@@ -84,7 +84,10 @@ class DependencyGraph(object):
 
     def valency_of(self, obj):
         """Returns the velency (degree) of a vertex in the graph."""
-        l = [len(self[obj])]
+        try:
+            l = [len(self[obj])]
+        except KeyError:
+            return 0
         for node in self[obj]:
             l.append(self.valency_of(node))
         return sum(l)
@@ -183,6 +186,9 @@ class DependencyGraph(object):
     def __len__(self):
         return len(self.adjacent)
 
+    def __contains__(self, obj):
+        return obj in self.adjacent
+
     def _iterate_items(self):
         return self.adjacent.iteritems()
     items = iteritems = _iterate_items
@@ -192,10 +198,11 @@ class DependencyGraph(object):
 
     def repr_node(self, obj, level=1):
         output = ["%s(%s)" % (obj, self.valency_of(obj))]
-        for other in self[obj]:
-            d = "%s(%s)" % (other, self.valency_of(other))
-            output.append('     ' * level + d)
-            output.extend(self.repr_node(other, level + 1).split('\n')[1:])
+        if obj in self:
+            for other in self[obj]:
+                d = "%s(%s)" % (other, self.valency_of(other))
+                output.append('     ' * level + d)
+                output.extend(self.repr_node(other, level + 1).split('\n')[1:])
         return '\n'.join(output)
 
 

+ 4 - 0
celery/exceptions.py

@@ -98,3 +98,7 @@ class CPendingDeprecationWarning(PendingDeprecationWarning):
 
 class CDeprecationWarning(DeprecationWarning):
     pass
+
+
+class IncompleteStream(Exception):
+    """Found the end of a stream of data, but the data is not yet complete."""

+ 11 - 0
celery/execute/trace.py

@@ -159,6 +159,9 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     clear_request = task_request.clear
     on_chord_part_return = backend.on_chord_part_return
 
+    from celery.task import sets
+    subtask = sets.subtask
+
     def trace_task(uuid, args, kwargs, request=None):
         R = I = None
         try:
@@ -188,6 +191,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     I = Info(FAILURE, exc, sys.exc_info())
                     state, retval, einfo = I.state, I.retval, I.exc_info
                     R = I.handle_error_state(task, eager=eager)
+                    [subtask(errback).apply_async((uuid, ))
+                        for errback in task_request.errbacks or []]
                 except BaseException, exc:
                     raise
                 except:  # pragma: no cover
@@ -198,8 +203,14 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     I = Info(FAILURE, None, sys.exc_info())
                     state, retval, einfo = I.state, I.retval, I.exc_info
                     R = I.handle_error_state(task, eager=eager)
+                    [subtask(errback).apply_async((uuid, ))
+                        for errback in task_request.errbacks or []]
                 else:
                     task_on_success(retval, uuid, args, kwargs)
+                    # callback tasks must be applied before the result is
+                    # stored, so that result.children is populated.
+                    [subtask(callback).apply_async((retval, ))
+                        for callback in task_request.callbacks or []]
                     if publish_result:
                         store_result(uuid, retval, SUCCESS)
 

+ 1 - 0
celery/platforms.py

@@ -11,6 +11,7 @@
 
 """
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import errno
 import os

+ 59 - 20
celery/result.py

@@ -22,7 +22,9 @@ from . import current_app
 from . import states
 from .app import app_or_default
 from .app.registry import _unpickle_task
-from .exceptions import TimeoutError
+from .datastructures import DependencyGraph
+from .exceptions import IncompleteStream, TimeoutError
+from .utils import cached_property
 from .utils.compat import OrderedDict
 
 
@@ -30,6 +32,13 @@ def _unpickle_result(task_id, task_name):
     return _unpickle_task(task_name).AsyncResult(task_id)
 
 
+def from_serializable(r):
+    id, nodes = r
+    if nodes:
+        return TaskSetResult(id, map(AsyncResult(nodes)))
+    return AsyncResult(id)
+
+
 class AsyncResult(object):
     """Query task state.
 
@@ -54,7 +63,7 @@ class AsyncResult(object):
         self.task_name = task_name
 
     def serializable(self):
-        return self.id, []
+        return self.id, None
 
     def forget(self):
         """Forget about (and possibly remove the result of) this task."""
@@ -98,7 +107,7 @@ class AsyncResult(object):
                                               interval=interval)
     wait = get  # deprecated alias to :meth:`get`.
 
-    def collect(self, timeout=None, propagate=True):
+    def collect(self, intermediate=False, **kwargs):
         """Iterator, like :meth:`get` will wait for the task to complete,
         but will also follow :class:`AsyncResult` and :class:`ResultSet`
         returned by the task, yielding for each result in the tree.
@@ -128,18 +137,20 @@ class AsyncResult(object):
             [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
 
         """
-        stack = deque([self])
-        native_join = self.supports_native_join
+        for _, R in self.iterdeps():
+            yield R, R.get(**kwargs)
+
+    def iterdeps(self, intermediate=False):
+        stack = deque([(None, self)])
 
         while stack:
-            res = stack.popleft()
-            if isinstance(res, ResultSet):
-                j = res.join_native if native_join else res.join
-                stack.extend(j(timeout=timeout, propagate=propagate))
-            elif isinstance(res, AsyncResult):
-                stack.append(res.get(timeout=timeout, propagate=propagate))
+            parent, node = stack.popleft()
+            yield parent, node
+            if node.ready():
+                stack.extend((node, child) for child in node.children or [])
             else:
-                yield res
+                if not intermediate:
+                    raise IncompleteStream()
 
     def ready(self):
         """Returns :const:`True` if the task has been executed.
@@ -184,10 +195,28 @@ class AsyncResult(object):
             return (self.__class__, (self.id, self.backend,
                                      None, self.app))
 
+    def build_graph(self, intermediate=False):
+        graph = DependencyGraph()
+        for parent, node in self.iterdeps(intermediate=intermediate):
+            if parent:
+                graph.add_arc(parent)
+                graph.add_edge(parent, node)
+        return graph
+
+    @cached_property
+    def graph(self):
+        return self.build_graph()
+
     @property
     def supports_native_join(self):
         return self.backend.supports_native_join
 
+    @property
+    def children(self):
+        children = self.backend.get_children(self.id)
+        if children:
+            return map(from_serializable, children)
+
     @property
     def result(self):
         """When the task has been executed, this contains the return value.
@@ -389,6 +418,17 @@ class ResultSet(object):
             if timeout and elapsed >= timeout:
                 raise TimeoutError("The operation timed out")
 
+    def get(self, timeout=None, propagate=True, interval=0.5):
+        """See :meth:`join`
+
+        This is here for API compatibility with :class:`AsyncResult`,
+        in addition it uses :meth:`join_native` if available for the
+        current result backend.
+
+        """
+        return (self.join_native if self.supports_native_join else self.join)(
+                    timeout=timeout, propagate=propagate, interval=interval)
+
     def join(self, timeout=None, propagate=True, interval=0.5):
         """Gathers the results of all tasks as a list in order.
 
@@ -496,19 +536,19 @@ class TaskSetResult(ResultSet):
     It enables inspection of the tasks state and return values as
     a single entity.
 
-    :param taskset_id: The id of the taskset.
+    :param id: The id of the taskset.
     :param results: List of result instances.
 
     """
 
     #: The UUID of the taskset.
-    taskset_id = None
+    id = None
 
     #: List/iterator of results in the taskset
     results = None
 
-    def __init__(self, taskset_id, results=None, **kwargs):
-        self.taskset_id = taskset_id
+    def __init__(self, id, results=None, **kwargs):
+        self.id = id
 
         # XXX previously the "results" arg was named "subtasks".
         if "subtasks" in kwargs:
@@ -524,19 +564,18 @@ class TaskSetResult(ResultSet):
             >>> result = TaskSetResult.restore(taskset_id)
 
         """
-        return (backend or self.app.backend).save_taskset(self.taskset_id,
-                                                          self)
+        return (backend or self.app.backend).save_taskset(self.id, self)
 
     def delete(self, backend=None):
         """Remove this result if it was previously saved."""
-        (backend or self.app.backend).delete_taskset(self.taskset_id)
+        (backend or self.app.backend).delete_taskset(self.id)
 
     def itersubtasks(self):
         """Depreacted.   Use ``iter(self.results)`` instead."""
         return iter(self.results)
 
     def __reduce__(self):
-        return (self.__class__, (self.taskset_id, self.results))
+        return (self.__class__, (self.id, self.results))
 
     def serializable(self):
         return self.id, [r.serializable() for r in self.results]

+ 6 - 2
celery/task/sets.py

@@ -13,7 +13,7 @@ from __future__ import absolute_import
 from __future__ import with_statement
 
 from .. import current_app
-from ..app import app_or_default
+from ..app import app_or_default, current_task
 from ..datastructures import AttributeDict
 from ..utils import cached_property, reprcall, uuid
 from ..utils.compat import UserList
@@ -142,7 +142,11 @@ class TaskSet(UserList):
                 if not publisher:  # created by us.
                     pub.close()
 
-            return app.TaskSetResult(setid, results)
+            result = app.TaskSetResult(setid, results)
+            parent = current_task()
+            if parent:
+                parent.request.children.append(result)
+            return result
 
     def _async_results(self, taskset_id, publisher):
         return [task.apply_async(taskset_id=taskset_id, publisher=publisher)

+ 5 - 3
celery/tests/test_task/test_context.py

@@ -42,7 +42,8 @@ class TestTaskContext(Case):
     def test_default_context(self):
         # A bit of a tautological test, since it uses the same
         # initializer as the default_context constructor.
-        self.assertDictEqual(get_context_as_dict(Context()), default_context)
+        defaults = dict(default_context, children=[])
+        self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
     def test_default_context_threaded(self):
         ctx = Context()
@@ -124,8 +125,9 @@ class TestTaskContext(Case):
         ctx = Context()
         ctx.update(changes)
         ctx.clear()
-        self.assertDictEqual(get_context_as_dict(ctx), default_context)
-        self.assertDictEqual(get_context_as_dict(Context()), default_context)
+        defaults = dict(default_context, children=[])
+        self.assertDictEqual(get_context_as_dict(ctx), defaults)
+        self.assertDictEqual(get_context_as_dict(Context()), defaults)
 
     def test_cleared_context_threaded(self):
         changes_a = dict(id="a", args=["some", 1], wibble="wobble")

+ 3 - 1
celery/utils/functional.py

@@ -27,7 +27,9 @@ KEYWORD_MARK = object()
 
 
 def maybe_list(l):
-    if isinstance(l, Sequence):
+    if l is None:
+        return l
+    elif isinstance(l, Sequence):
         return l
     return [l]
 

+ 1 - 0
celery/worker/job.py

@@ -65,6 +65,7 @@ class Request(object):
     """A request for task execution."""
     __slots__ = ("app", "name", "id", "args", "kwargs",
                  "on_ack", "delivery_info", "hostname",
+                 "callbacks", "errbacks",
                  "logger", "eventer", "connection_errors",
                  "task", "eta", "expires",
                  "_does_debug", "_does_info", "request_dict",

+ 35 - 18
docs/internals/protocol.rst

@@ -11,41 +11,42 @@ Message format
 ==============
 
 * task
-    `string`
+    :`string`:
 
     Name of the task. **required**
 
 * id
-    `string`
+    :`string`:
 
     Unique id of the task (UUID). **required**
 
 * args
-    `list`
+    :`list`:
 
     List of arguments. Will be an empty list if not provided.
 
 * kwargs
-    `dictionary`
+    :`dictionary`:
 
     Dictionary of keyword arguments. Will be an empty dictionary if not
     provided.
 
 * retries
-    `int`
+    :`int`:
 
     Current number of times this task has been retried.
     Defaults to `0` if not specified.
 
 * eta
-    `string` (ISO 8601)
+    :`string` (ISO 8601):
 
     Estimated time of arrival. This is the date and time in ISO 8601
     format. If not provided the message is not scheduled, but will be
     executed asap.
 
 * expires
-    `string` (ISO 8601)
+    :`string` (ISO 8601):
+
     .. versionadded:: 2.0.2
 
     Expiration date. This is the date and time in ISO 8601 format.
@@ -64,24 +65,40 @@ to process it.
 
 
 * taskset
-  `string`
+    :`string`:
 
-  The taskset this task is part of.
+    The taskset this task is part of (if any).
 
 * chord
-  `object`
-  .. versionadded:: 2.3
+    :`subtask`:
+
+    .. versionadded:: 2.3
 
-  Signifies that this task is one of the header parts of a chord.  The value
-  of this key is the body of the cord that should be executed when all of
-  the tasks in the header has returned.
+    Signifies that this task is one of the header parts of a chord.  The value
+    of this key is the body of the cord that should be executed when all of
+    the tasks in the header has returned.
 
 * utc
-  `bool`
-  .. versionadded:: 2.5
+    :`bool`:
+
+    .. versionadded:: 2.5
+
+    If true time uses the UTC timezone, if not the current local timezone
+    should be used.
+
+* callbacks
+    :`<list>subtask`:
+
+    .. versionadded:: 2.6
+
+    A list of subtasks to apply if the task exited successfully.
+
+* errbacks
+    :`<list>subtask`:
+
+    .. versionadded:: 2.6
 
-  If true time uses the UTC timezone, if not the current local timezone
-  should be used.
+    A list of subtasks to apply if an error occurs while executing the task.
 
 Example message
 ===============