浏览代码

.task.sets.subtask can now be used to pass around tasks as callbacks.

subtask is now a dict subclass, meaning it can safely be serialized
using JSON and other serializers that doesn't support complex objects.

Also if the first argument to subtask() is another subtask or a dict,
it will ignore the rest of the argument and use the values from the dict
instead, this means the task handling the callback can use subtask() like a
type to convert a subtask that has been serialized into a dict back into a
subtask::

    @task()
    def add(x, y, callback=None):
        result = x + y
        if callback:
            subtask(callback).apply_async(result)

subtask.apply_async now takes varargs that will be prepended to the stored
arguments, this is added for easy callback invocation. The new signature
is subtask.apply_async(*argmerge, **execopts)::

if callback is::

    callback = subtask("tasks.add", args=(10, ))

and callback is applied like::

    subtask(callback).apply_async(result)

The final arguments will be::

    (result, 10)

In addition to these changes the Task now has a new classmethod
to easily create subtasks:

    Task.subtask(args, kwargs, options)

which means we can launch the add task above with a callback like this::

    add.delay(16, 16, callback=add.subtask(args=(10, ))

this will result in (16 + 16) + 10
(allthough that is not the result stored by the first task id, rather
the second)
Ask Solem 15 年之前
父节点
当前提交
2810ab8b14
共有 4 个文件被更改,包括 95 次插入35 次删除
  1. 15 4
      celery/datastructures.py
  2. 8 1
      celery/task/base.py
  3. 51 17
      celery/task/sets.py
  4. 21 13
      docs/userguide/tasks.rst

+ 15 - 4
celery/datastructures.py

@@ -1,17 +1,28 @@
 from __future__ import generators
 from __future__ import generators
-"""
 
 
-Custom Datastructures
-
-"""
 import time
 import time
 import traceback
 import traceback
+
 from UserList import UserList
 from UserList import UserList
 from Queue import Queue, Empty as QueueEmpty
 from Queue import Queue, Empty as QueueEmpty
 
 
 from celery.utils.compat import OrderedDict
 from celery.utils.compat import OrderedDict
 
 
 
 
+class AttributeDict(dict):
+    """Dict subclass with attribute access."""
+
+    def __getattr__(self, key):
+        try:
+            return self[key]
+        except KeyError:
+            raise AttributeError("'%s' object has no attribute '%s'" % (
+                    self.__class__.__name__, key))
+
+    def __setattr__(self, key, value):
+        self[key] = value
+
+
 class PositionQueue(UserList):
 class PositionQueue(UserList):
     """A positional queue of a specific length, with slots that are either
     """A positional queue of a specific length, with slots that are either
     filled or unfilled. When all of the positions are filled, the queue
     filled or unfilled. When all of the positions are filled, the queue

+ 8 - 1
celery/task/base.py

@@ -13,7 +13,7 @@ from celery.messaging import establish_connection as _establish_connection
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 
 
 from celery.task.schedules import schedule
 from celery.task.schedules import schedule
-from celery.task.sets import TaskSet
+from celery.task.sets import TaskSet, subtask
 
 
 
 
 class TaskType(type):
 class TaskType(type):
@@ -500,6 +500,13 @@ class Task(object):
             kind = "%s(Task)" % self.__class__.__name__
             kind = "%s(Task)" % self.__class__.__name__
         return "<%s: %s (%s)>" % (kind, self.name, self.type)
         return "<%s: %s (%s)>" % (kind, self.name, self.type)
 
 
+    @classmethod
+    def subtask(cls, *args, **kwargs):
+        """Returns a :class:`~celery.task.sets.subtask` object for
+        this task that wraps arguments and execution options
+        for a single task invocation."""
+        return subtask(cls, *args, **kwargs)
+
 
 
 class PeriodicTask(Task):
 class PeriodicTask(Task):
     """A periodic task is a task that behaves like a :manpage:`cron` job.
     """A periodic task is a task that behaves like a :manpage:`cron` job.

+ 51 - 17
celery/task/sets.py

@@ -1,39 +1,72 @@
 from UserList import UserList
 from UserList import UserList
 
 
 from celery import conf
 from celery import conf
+from celery import registry
+from celery.datastructures import AttributeDict
 from celery.messaging import establish_connection, with_connection
 from celery.messaging import establish_connection, with_connection
 from celery.messaging import TaskPublisher
 from celery.messaging import TaskPublisher
 from celery.result import TaskSetResult
 from celery.result import TaskSetResult
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 
 
 
 
-class subtask(object):
-    """A subtask part of a :class:`TaskSet`.
+class subtask(AttributeDict):
+    """Class that wraps the arguments and execution options
+    for a single task invocation.
 
 
-    :param task: The task class.
+    Used as the parts in a :class:`TaskSet` or to safely
+    pass tasks around as callbacks.
+
+    :param task: Either a task class/instance, or the name of a task.
     :keyword args: Positional arguments to apply.
     :keyword args: Positional arguments to apply.
     :keyword kwargs: Keyword arguments to apply.
     :keyword kwargs: Keyword arguments to apply.
     :keyword options: Additional options to
     :keyword options: Additional options to
       :func:`celery.execute.apply_async`.
       :func:`celery.execute.apply_async`.
 
 
+    Note that if the first argument is a :class:`dict`, the other
+    arguments will be ignored and the values in the dict will be used
+    instead.
+
+        >>> s = subtask("tasks.add", args=(2, 2))
+        >>> subtask(s)
+        {"task": "tasks.add", args=(2, 2), kwargs={}, options={}}
+
     """
     """
 
 
-    def __init__(self, task, args=None, kwargs=None, options=None):
-        self.task = task
-        self.args = args or ()
-        self.kwargs = kwargs or {}
-        self.options = options or {}
+    def __init__(self, task=None, args=None, kwargs=None, options=None,
+            **extra):
+        init = super(subtask, self).__init__
+
+        if isinstance(task, dict):
+            # Use the values from a dict.
+            return init(task)
 
 
-    def apply(self, taskset_id):
+        # Also supports using task class/instance instead of string name.
+        try:
+            task_name = task.name
+        except AttributeError:
+            task_name = task
+
+        init(task=task_name, args=tuple(args or ()), kwargs=kwargs or (),
+             options=options or ())
+
+    def apply(self, *argmerge, **execopts):
         """Apply this task locally."""
         """Apply this task locally."""
-        return self.task.apply(self.args, self.kwargs,
-                               taskset_id=taskset_id, **self.options)
+        # For callbacks: extra args are prepended to the stored args.
+        args = tuple(argmerge) + tuple(self.args)
+        return self.get_type().apply(args, self.kwargs,
+                                     **dict(self.options, **execopts))
 
 
-    def apply_async(self, taskset_id, publisher):
+    def apply_async(self, *argmerge, **execopts):
         """Apply this task asynchronously."""
         """Apply this task asynchronously."""
-        return self.task.apply_async(self.args, self.kwargs,
-                                     taskset_id=taskset_id,
-                                     publisher=publisher, **self.options)
+        # For callbacks: extra args are prepended to the stored args.
+        args = tuple(argmerge) + tuple(self.args)
+        return self.get_type().apply_async(args, self.kwargs,
+                                           **dict(self.options, **execopts))
+
+    def get_type(self):
+        # For JSON serialization, the task class is lazily loaded,
+        # and not stored in the dict itself.
+        return registry.tasks[self.task]
 
 
 
 
 class TaskSet(UserList):
 class TaskSet(UserList):
@@ -115,7 +148,8 @@ class TaskSet(UserList):
                                     connect_timeout=connect_timeout)
                                     connect_timeout=connect_timeout)
         publisher = TaskPublisher(connection=conn)
         publisher = TaskPublisher(connection=conn)
         try:
         try:
-            results = [task.apply_async(taskset_id, publisher)
+            results = [task.apply_async(taskset_id=taskset_id,
+                                        publisher=publisher)
                             for task in self.tasks]
                             for task in self.tasks]
         finally:
         finally:
             publisher.close()
             publisher.close()
@@ -128,7 +162,7 @@ class TaskSet(UserList):
         taskset_id = gen_unique_id()
         taskset_id = gen_unique_id()
 
 
         # This will be filled with EagerResults.
         # This will be filled with EagerResults.
-        return TaskSetResult(taskset_id, [task.apply(taskset_id)
+        return TaskSetResult(taskset_id, [task.apply(taskset_id=taskset_id)
                                             for task in self.tasks])
                                             for task in self.tasks])
 
 
     @property
     @property

+ 21 - 13
docs/userguide/tasks.rst

@@ -107,8 +107,8 @@ the worker log:
 .. code-block:: python
 .. code-block:: python
 
 
     class AddTask(Task):
     class AddTask(Task):
-        def run(self, x, y, **kwargs):
-            logger = self.get_logger(**kwargs)
+        def run(self, x, y, \*\*kwargs):
+            logger = self.get_logger(\*\*kwargs)
             logger.info("Adding %s + %s" % (x, y))
             logger.info("Adding %s + %s" % (x, y))
             return x + y
             return x + y
 
 
@@ -117,8 +117,8 @@ or using the decorator syntax:
 .. code-block:: python
 .. code-block:: python
 
 
     @task()
     @task()
-    def add(x, y, **kwargs):
-        logger = add.get_logger(**kwargs)
+    def add(x, y, \*\*kwargs):
+        logger = add.get_logger(\*\*kwargs)
         logger.info("Adding %s + %s" % (x, y))
         logger.info("Adding %s + %s" % (x, y))
         return x + y
         return x + y
 
 
@@ -136,7 +136,7 @@ It will do the right thing, and respect the
 .. code-block:: python
 .. code-block:: python
 
 
     @task()
     @task()
-    def send_twitter_status(oauth, tweet, **kwargs):
+    def send_twitter_status(oauth, tweet, \*\*kwargs):
         try:
         try:
             twitter = Twitter(oauth)
             twitter = Twitter(oauth)
             twitter.update_status(tweet)
             twitter.update_status(tweet)
@@ -173,7 +173,7 @@ You can also provide the ``countdown`` argument to
     class MyTask(Task):
     class MyTask(Task):
         default_retry_delay = 30 * 60 # retry in 30 minutes
         default_retry_delay = 30 * 60 # retry in 30 minutes
 
 
-        def run(self, x, y, **kwargs):
+        def run(self, x, y, \*\*kwargs):
             try:
             try:
                 ...
                 ...
             except Exception, exc:
             except Exception, exc:
@@ -380,8 +380,8 @@ blog/tasks.py
 
 
 
 
     @task
     @task
-    def spam_filter(comment_id, remote_addr=None, **kwargs):
-            logger = spam_filter.get_logger(**kwargs)
+    def spam_filter(comment_id, remote_addr=None, \*\*kwargs):
+            logger = spam_filter.get_logger(\*\*kwargs)
             logger.info("Running spam filter for comment %s" % comment_id)
             logger.info("Running spam filter for comment %s" % comment_id)
 
 
             comment = Comment.objects.get(pk=comment_id)
             comment = Comment.objects.get(pk=comment_id)
@@ -530,26 +530,34 @@ Good:
     @task(ignore_result=True)
     @task(ignore_result=True)
     def update_page_info(url):
     def update_page_info(url):
         # fetch_page -> parse_page -> store_page
         # fetch_page -> parse_page -> store_page
-        fetch_page.delay(url, callback=callback,
-                         callback_args=(store_page_info.delay, ))
+        fetch_page.delay(url, callback=subtask(parse_page,
+                                    callback=subtask(store_page_info)))
 
 
     @task(ignore_result=True)
     @task(ignore_result=True)
-    def fetch_page(url, callback=None, callback_args=()):
+    def fetch_page(url, callback=None):
         page = myparser.parse_document(page)
         page = myparser.parse_document(page)
         if callback:
         if callback:
-            callback(page, \*callback_args)
+            # The callback may have been serialized with JSON,
+            # so best practice is to convert the subtask dict back
+            # into a subtask object.
+            subtask(callback).apply_async(page)
 
 
     @task(ignore_result=True)
     @task(ignore_result=True)
     def parse_page(url, page, callback=None):
     def parse_page(url, page, callback=None):
         info = myparser.parse_document(page)
         info = myparser.parse_document(page)
         if callback:
         if callback:
-            callback(url, info)
+            subtask(callback).apply_async(url, info)
 
 
     @task(ignore_result=True)
     @task(ignore_result=True)
     def store_page_info(url, info):
     def store_page_info(url, info):
         PageInfo.objects.create(url, info)
         PageInfo.objects.create(url, info)
 
 
 
 
+We use :class:`~celery.task.sets.subtask` here to safely pass
+around the callback task. :class:`~celery.task.sets.subtask` is a 
+subclass of dict used to wrap the arguments and execution options
+for a single task invocation.
+
 
 
 Performance and Strategies
 Performance and Strategies
 ==========================
 ==========================