Selaa lähdekoodia

Introducing Chords (taskset callbacks)

Ask Solem 14 vuotta sitten
vanhempi
commit
332ac0fdf7

+ 19 - 0
Changelog

@@ -64,6 +64,25 @@ News
     A limit of :const:`None` or 0 means no limit, and connections will be
     established and closed every time.
 
+* Introducing Chords (taskset callbacks).
+
+    A chord is a task that only executes after all of the tasks in a taskset
+    has finished executing.  It's a fancy term for "taskset callbacks"
+    adopted from
+    `Cω  <http://research.microsoft.com/en-us/um/cambridge/projects/comega/>`_).
+
+    It works with all result backends, but the best implementation is
+    currently provided by the Redis result backend.
+
+    Here's an example chord::
+
+        >>> chord(add.subtask((i, i))
+        ...         for i in xrange(100))(tsum.subtask()).get()
+        9900
+
+    Please read the :ref:`Chord section in the user guide <chord>`, if you
+    want to know more.
+
 .. _version-2.2.5:
 
 2.2.5

+ 4 - 2
celery/app/amqp.py

@@ -181,7 +181,7 @@ class TaskPublisher(messaging.Publisher):
             countdown=None, eta=None, task_id=None, taskset_id=None,
             expires=None, exchange=None, exchange_type=None,
             event_dispatcher=None, retry=None, retry_policy=None,
-            queue=None, now=None, retries=0, **kwargs):
+            queue=None, now=None, retries=0, chord=None, **kwargs):
         """Send task message."""
 
         connection = self.connection
@@ -221,7 +221,9 @@ class TaskPublisher(messaging.Publisher):
                 "kwargs": task_kwargs or {},
                 "retries": retries or 0,
                 "eta": eta,
-                "expires": expires}
+                "expires": expires,
+                "chord": chord}
+
         if taskset_id:
             body["taskset"] = taskset_id
 

+ 12 - 2
celery/backends/base.py

@@ -143,6 +143,16 @@ class BaseBackend(object):
         raise NotImplementedError(
                 "reload_taskset_result is not supported by this backend.")
 
+    def on_chord_part_return(self, task):
+        pass
+
+    def on_chord_apply(self, setid, body):
+        from celery.registry import tasks
+        tasks["celery.chord_unlock"].apply_async((setid, body, ), countdown=1)
+
+    def __reduce__(self):
+        return (self.__class__, ())
+
 
 class BaseDictBackend(BaseBackend):
 
@@ -244,8 +254,8 @@ class KeyValueStoreBackend(BaseDictBackend):
         return result
 
     def _save_taskset(self, taskset_id, result):
-        meta = {"result": result}
-        self.set(self.get_key_for_taskset(taskset_id), pickle.dumps(meta))
+        self.set(self.get_key_for_taskset(taskset_id),
+                 pickle.dumps({"result": result}))
         return result
 
     def _get_task_meta_for(self, task_id):

+ 13 - 0
celery/backends/pyredis.py

@@ -4,6 +4,8 @@ from kombu.utils import cached_property
 
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
+from celery.result import TaskSetResult
+from celery.task.sets import subtask
 from celery.utils import timeutils
 
 try:
@@ -79,6 +81,17 @@ class RedisBackend(KeyValueStoreBackend):
     def process_cleanup(self):
         self.close()
 
+    def on_chord_apply(self, setid, body):
+        pass
+
+    def on_chord_part_return(self, task, keyprefix="chord-unlock-%s"):
+        setid = task.request.taskset
+        key = keyprefix % setid
+        deps = TaskSetResult.restore(setid, backend=task.backend)
+        if self.client.incr(key) >= deps.total:
+            subtask(task.request.chord).delay(deps.join())
+        self.client.expire(key, 86400)
+
     @cached_property
     def client(self):
         return self.redis.Redis(host=self.redis_host,

+ 4 - 0
celery/result.py

@@ -468,6 +468,7 @@ class TaskSetResult(ResultSet):
         if backend is None:
             backend = self.app.backend
         backend.save_taskset(self.taskset_id, self)
+        return self
 
     @classmethod
     def restore(self, taskset_id, backend=None):
@@ -480,6 +481,9 @@ class TaskSetResult(ResultSet):
         """Depreacted.   Use ``iter(self.results)`` instead."""
         return iter(self.results)
 
+    def __reduce__(self):
+        return (self.__class__, (self.taskset_id, self.results))
+
 
 class EagerResult(BaseAsyncResult):
     """Result that we know has already been executed."""

+ 1 - 0
celery/task/__init__.py

@@ -4,6 +4,7 @@ import warnings
 from celery.app import app_or_default
 from celery.task.base import Task, PeriodicTask
 from celery.task.sets import TaskSet, subtask
+from celery.task.chord import chord
 from celery.task.control import discard_all
 
 __all__ = ["Task", "TaskSet", "PeriodicTask", "subtask", "discard_all"]

+ 3 - 1
celery/task/base.py

@@ -48,6 +48,7 @@ class Context(threading.local):
     is_eager = False
     delivery_info = None
     taskset = None
+    chord = None
 
     def update(self, d, **kwargs):
         self.__dict__.update(d, **kwargs)
@@ -647,7 +648,8 @@ class BaseTask(object):
         The return value of this handler is ignored.
 
         """
-        pass
+        if self.request.chord:
+            self.backend.on_chord_part_return(self)
 
     def on_failure(self, exc, task_id, args, kwargs, einfo=None):
         """Error handler.

+ 41 - 0
celery/task/chord.py

@@ -0,0 +1,41 @@
+from kombu.utils import gen_unique_id
+
+from celery import current_app
+from celery.result import TaskSetResult
+from celery.task.sets import TaskSet, subtask
+
+@current_app.task(name="celery.chord_unlock", max_retries=None)
+def _unlock_chord(setid, callback, interval=1, max_retries=None):
+    result = TaskSetResult.restore(setid)
+    if result.ready():
+        return subtask(callback).delay(result.join())
+    _unlock_chord.retry(countdown=interval, max_retries=max_retries)
+
+
+class Chord(current_app.Task):
+    name = "celery.chord"
+
+    def run(self, set, body):
+        if not isinstance(set, TaskSet):
+            set = TaskSet(set)
+        r = []
+        setid = gen_unique_id()
+        for task in set.tasks:
+            uuid = gen_unique_id()
+            task.options.update(task_id=uuid, chord=body)
+            r.append(current_app.AsyncResult(uuid))
+        ts = current_app.TaskSetResult(setid, r).save()
+        self.backend.on_chord_apply(setid, body)
+        return set.apply_async(taskset_id=setid)
+
+
+class chord(object):
+    Chord = Chord
+
+    def __init__(self, tasks):
+        self.tasks = tasks
+
+    def __call__(self, body):
+        uuid = body.options.setdefault("task_id", gen_unique_id())
+        self.Chord.apply_async((list(self.tasks), body))
+        return body.type.app.AsyncResult(uuid)

+ 8 - 2
celery/worker/job.py

@@ -196,6 +196,9 @@ class TaskRequest(object):
     #: When the task expires.
     expires = None
 
+    #: Body of a chord depending on this task.
+    chord = None
+
     #: Callback called when the task should be acknowledged.
     on_ack = None
 
@@ -246,7 +249,7 @@ class TaskRequest(object):
             on_ack=noop, retries=0, delivery_info=None, hostname=None,
             email_subject=None, email_body=None, logger=None,
             eventer=None, eta=None, expires=None, app=None,
-            taskset_id=None, **opts):
+            taskset_id=None, chord=None, **opts):
         self.app = app_or_default(app)
         self.task_name = task_name
         self.task_id = task_id
@@ -256,6 +259,7 @@ class TaskRequest(object):
         self.kwargs = kwargs
         self.eta = eta
         self.expires = expires
+        self.chord = chord
         self.on_ack = on_ack
         self.delivery_info = delivery_info or {}
         self.hostname = hostname or socket.gethostname()
@@ -290,6 +294,7 @@ class TaskRequest(object):
                    taskset_id=body.get("taskset", None),
                    args=body["args"],
                    kwargs=kwdict(kwargs),
+                   chord=body.get("chord"),
                    retries=body.get("retries", 0),
                    eta=maybe_iso8601(body.get("eta")),
                    expires=maybe_iso8601(body.get("expires")),
@@ -304,7 +309,8 @@ class TaskRequest(object):
                 "taskset": self.taskset_id,
                 "retries": self.retries,
                 "is_eager": False,
-                "delivery_info": self.delivery_info}
+                "delivery_info": self.delivery_info,
+                "chord": self.chord}
 
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.

+ 87 - 62
docs/userguide/tasksets.rst

@@ -159,81 +159,106 @@ It supports the following operations:
     and return a list with them ordered by the order of which they
     were called.
 
+.. _chords:
 
-Task set callbacks
-------------------
+Chords
+======
 
 
-Simple, but may take a long time before your callback is called:
+A chord is a task that only executes after all of the tasks in a taskset has
+finished executing.
 
 
+Let's calculate the sum of the expression
+:math:`1 + 1 + 2 + 2 + 3 + 3 ... n + n` up to a hundred digits.
+
+First we need two tasks, :func:`add` and :func:`tsum` (:func:`sum` is
+already a standard function):
+
 .. code-block:: python
 
-    from celery import current_app
-    from celery.task import subtask
+    from celery.task import task
+
+    @task
+    def add(x, y):
+        return x + y
+
+    @task
+    def tsum(numbers):
+        return sum(numbers)
+
+
+Now we can use a chord to calculate each addition step in parallel, and then
+get the sum of the resulting numbers::
+
+    >>> from celery.task import chord
+    >>> from tasks import add, tsum
+
+    >>> chord(add.subtask((i, i))
+    ...     for i in xrange(100))(tsum.subtask()).get()
+    9900
+
+
+This is obviously a very contrived example, the overhead of messaging and
+synchronization makes this a lot slower than its Python counterpart::
 
-    def join_taskset(setid, subtasks, callback, interval=15, max_retries=None):
-        result = TaskSetResult(setid, subtasks)
-        if result.ready():
-            return subtask(callback).delay(result.join())
-        join_taskset.retry(countdown=interval, max_retries=max_retries)
+    sum(i + i for i in xrange(100))
 
+The synchronization step is costly, so you should avoid using chords as much
+as possible. Still, the chord is a powerful primitive to have in your toolbox
+as synchronization is a required step for many parallel algorithms.
 
+Let's break the chord expression down::
 
-Using Redis and atomic counters:
+    >>> callback = tsum.subtask()
+    >>> header = [add.subtask((i, i)) for i in xrange(100])
+    >>> result = chord(header)(callback)
+    >>> result.get()
+    9900
 
+Remember, the callback can only be executed after all of the tasks in the
+header has returned.  Each step in the header is executed as a task, in
+parallel, possibly on different nodes.  The callback is then applied with
+the return value of each task in the header.  The task id returned by
+:meth:`chord` is the id of the callback, so you can wait for it to complete
+and get the final return value (but remember to :ref:`never have a task wait
+for other tasks <task-synchronous-subtasks>`)
+
+.. _chord-important-notes:
+
+Important Notes
+---------------
+
+By default the synchronization step is implemented by having a recurring task
+poll the completion of the taskset every second, applying the subtask when
+ready.
+
+Example implementation:
 
 .. code-block:: python
 
-    from celery import current_app
-    from celery.task import Task, TaskSet
-    from celery.result import TaskSetResult
-    from celery.utils import gen_unique_id, cached_property
-    from redis import Redis
-    from time import sleep
-
-    class supports_taskset_callback(Task):
-        abstract = True
-        accept_magic_kwargs = False
-
-        def after_return(self, \*args, \*\*kwargs):
-            if self.request.taskset:
-                callback = self.request.kwargs.get("callback")
-                if callback:
-                    setid = self.request.taskset
-                    # task set must be saved in advance, so the task doesn't
-                    # try to restore it before that happens.  This is why we
-                    # use the `apply_presaved_taskset` below.
-                    result = TaskSetResult.restore(setid)
-                    current = self.redis.incr("taskset-" + setid)
-                    if current >= result.total:
-                        r = subtask(callback).delay(result.join())
-
-        @cached_property
-        def redis(self):
-            return Redis(host="localhost", port=6379)
-
-    @task(base=supports_taskset_callback)
-    def add(x, y, \*\*kwargs):
-        return x + y
+    def unlock_chord(taskset, callback, interval=1, max_retries=None):
+        if taskset.ready():
+            return subtask(callback).delay(taskset.join())
+        unlock_chord.retry(countdown=interval, max_retries=max_retries)
 
-    @task
-    def sum_of(numbers):
-        print("TASKSET READY: %r" % (sum(numbers), ))
-
-    def apply_presaved_taskset(tasks):
-        r = []
-        setid = gen_unique_id()
-        for task in tasks:
-            uuid = gen_unique_id()
-            task.options["task_id"] = uuid
-            r.append((task, current_app.AsyncResult(uuid)))
-        ts = current_app.TaskSetResult(setid, [task[1] for task in r])
-        ts.save()
-        return TaskSet(task[0] for task in r).apply_async(taskset_id=setid)
-
-
-    # sum of 100 add tasks
-    result = apply_presaved_taskset(
-                add.subtask((i, i), {"callback": sum_of.subtask()})
-                    for i in xrange(100))
+
+This is used by all result backends except Redis, which increments a
+counter after each task in the header, then applying the callback when the
+counter exceeds the number of tasks in the set.
+
+The Redis approach is a much better solution, but not easily implemented
+in other backends (suggestions welcome!)
+
+
+.. note::
+
+    If you are using chords with the Redis result backend and also overriding
+    the :meth:`Task.after_return` method, you need to make sure to call the
+    super method or else the chord callback will not be applied.
+
+    .. code-block:: python
+
+        def after_return(self, *args, **kwargs):
+            do_something()
+            super(MyTask, self).after_return(*args, **kwargs)