Ask Solem hace 14 años
padre
commit
e5cf80c137
Se han modificado 3 ficheros con 76 adiciones y 115 borrados
  1. 41 50
      celery/task/base.py
  2. 2 1
      celery/task/http.py
  3. 33 64
      celery/task/sets.py

+ 41 - 50
celery/task/base.py

@@ -13,7 +13,7 @@ from celery.schedules import maybe_schedule
 from celery.utils import deprecated, mattrgetter, gen_unique_id, \
                          fun_takes_kwargs
 from celery.utils.functional import wraps
-from celery.utils.timeutils import timedelta_seconds
+from celery.utils import timeutils
 
 from celery.task import sets
 
@@ -64,30 +64,26 @@ class TaskType(type):
     """
 
     def __new__(cls, name, bases, attrs):
-        super_new = super(TaskType, cls).__new__
+        new = super(TaskType, cls).__new__
         task_module = attrs["__module__"]
 
-        # Abstract class, remove the abstract attribute so
-        # any class inheriting from this won't be abstract by default.
+        # Abstract class: abstract attribute should not be inherited.
         if attrs.pop("abstract", None) or not attrs.get("autoregister", True):
-            return super_new(cls, name, bases, attrs)
+            return new(cls, name, bases, attrs)
 
-        # Automatically generate missing name.
+        # Automatically generate missing/empty name.
         if not attrs.get("name"):
-            task_name = ".".join([sys.modules[task_module].__name__, name])
-            attrs["name"] = task_name
+            attrs["name"] = '.'.join([sys.modules[task_module].__name__, name])
 
         # Because of the way import happens (recursively)
         # we may or may not be the first time the task tries to register
-        # with the framework. There should only be one class for each task
+        # with the framework.  There should only be one class for each task
         # name, so we always return the registered version.
-
         task_name = attrs["name"]
         if task_name not in tasks:
-            task_cls = super_new(cls, name, bases, attrs)
+            task_cls = new(cls, name, bases, attrs)
             if task_module == "__main__" and task_cls.app.main:
-                task_name = task_cls.name = ".".join([task_cls.app.main,
-                                                      name])
+                task_name = task_cls.name = '.'.join([task_cls.app.main, name])
             tasks.register(task_cls)
         task = tasks[task_name].__class__
         return task
@@ -97,16 +93,11 @@ class TaskType(type):
 
 
 class BaseTask(object):
-    """A Celery task.
-
-    All subclasses of :class:`Task` must define the :meth:`run` method,
-    which is the actual method the `celery` daemon executes.
+    """Task base class.
 
-    The :meth:`run` method can take use of the default keyword arguments,
-    as listed in the :meth:`run` documentation.
-
-    The resulting class is callable, which if called will apply the
-    :meth:`run` method.
+    When called tasks apply the :meth:`run` method.  This method must
+    be defined by all tasks (that is unless the :meth:`__call__` method
+    is overridden).
 
     """
     __metaclass__ = TaskType
@@ -123,10 +114,10 @@ class BaseTask(object):
     abstract = True
 
     #: If disabled the worker will not forward magic keyword arguments.
-    #: Depracted and scheduled for removal in v3.0.
+    #: Deprecated and scheduled for removal in v3.0.
     accept_magic_kwargs = False
 
-    #: Current request context (when task is executed).
+    #: Request context (set when task is applied).
     request = Context()
 
     #: Destination queue.  The queue needs to exist
@@ -184,21 +175,19 @@ class BaseTask(object):
     #: If enabled an e-mail will be sent to :setting:`ADMINS` whenever a task
     #: of this type fails.
     send_error_emails = False
-
     disable_error_emails = False                            # FIXME
 
     #: List of exception types to send error e-mails for.
     error_whitelist = ()
 
-    #: The name of a serializer that has been registered with
+    #: The name of a serializer that are registered with
     #: :mod:`kombu.serialization.registry`.  Default is `"pickle"`.
     serializer = "pickle"
 
     #: The result store backend used for this task.
     backend = None
 
-    #: If disabled the task will not be automatically registered
-    #: in the task registry.
+    #: If disabled this task won't be registered automatically.
     autoregister = True
 
     #: If enabled the task will report its status as "started" when the task
@@ -214,7 +203,7 @@ class BaseTask(object):
     #: :setting:`CELERY_TRACK_STARTED` setting.
     track_started = False
 
-    #: When enabled  messages for this task will be acknowledged **after**
+    #: When enabled messages for this task will be acknowledged **after**
     #: the task has been executed, and not *just before* which is the
     #: default behavior.
     #:
@@ -267,16 +256,22 @@ class BaseTask(object):
 
         :rtype :class:`~celery.app.amqp.TaskPublisher`:
 
-        Please be sure to close the connection after use::
+        .. warning::
 
-            >>> publisher = self.get_publisher()
-            >>> # ... do something with publisher
-            >>> publisher.connection.close()
+            If you don't specify a connection, one will automatically
+            be established for you, in that case you need to close this
+            connection after use::
 
-        The connection can also be used as a context::
+            Please be sure to close the connection after use::
 
-            >>> with self.get_publisher() as publisher:
-            ...     # ... do something with publisher
+                >>> publisher = self.get_publisher()
+                >>> # ... do something with publisher
+                >>> publisher.connection.close()
+
+            or used as a context::
+
+                >>> with self.get_publisher() as publisher:
+                ...     # ... do something with publisher
 
         """
         if exchange is None:
@@ -500,27 +495,25 @@ class BaseTask(object):
         to convey that the rest of the block will not be executed.
 
         """
+        max_retries = self.max_retries
         request = self.request
         if args is None:
             args = request.args
         if kwargs is None:
             kwargs = request.kwargs
-
         delivery_info = request.delivery_info
+
         if delivery_info:
             options.setdefault("exchange", delivery_info.get("exchange"))
             options.setdefault("routing_key", delivery_info.get("routing_key"))
+        countdown = options.setdefault("countdown", self.default_retry_delay)
+        options.update({"retries": request.retries + 1,
+                        "task_id": request.id})
 
-        options["retries"] = request.retries + 1
-        options["task_id"] = request.id
-        options["countdown"] = options.get("countdown",
-                                           self.default_retry_delay)
-        max_exc = exc or self.MaxRetriesExceededError(
-                "Can't retry %s[%s] args:%s kwargs:%s" % (
-                    self.name, options["task_id"], args, kwargs))
-        max_retries = self.max_retries
         if max_retries is not None and options["retries"] > max_retries:
-            raise max_exc
+            raise exc or self.MaxRetriesExceededError(
+                            "Can't retry %s[%s] args:%s kwargs:%s" % (
+                                self.name, options["task_id"], args, kwargs))
 
         # If task was executed eagerly using apply(),
         # then the retry must also be executed eagerly.
@@ -528,10 +521,8 @@ class BaseTask(object):
             return self.apply(args=args, kwargs=kwargs, **options).get()
 
         self.apply_async(args=args, kwargs=kwargs, **options)
-
         if throw:
-            message = "Retry in %d seconds." % options["countdown"]
-            raise RetryTaskError(message, exc)
+            raise RetryTaskError("Retry in %d seconds" % (countdown, ), exc)
 
     @classmethod
     def apply(self, args=None, kwargs=None, **options):
@@ -818,7 +809,7 @@ class PeriodicTask(Task):
         Doesn't account for negative timedeltas.
 
         """
-        return timedelta_seconds(delta)
+        return timeutils.timedelta_seconds(delta)
 
     def is_due(self, last_run_at):
         """Returns tuple of two items `(is_due, next_time_to_run)`,

+ 2 - 1
celery/task/http.py

@@ -3,6 +3,7 @@ from urllib import urlencode
 from urlparse import urlparse
 
 from anyjson import deserialize
+from kombu.utils import kwdict
 
 from celery import __version__ as celery_version
 from celery.task.base import Task as BaseTask
@@ -24,7 +25,7 @@ class UnknownStatusError(InvalidResponseError):
 
 
 def maybe_utf8(value):
-    """Encode utf-8 value, only if the value is actually utf-8."""
+    """Encode to utf-8, only if the value is Unicode."""
     if isinstance(value, unicode):
         return value.encode("utf-8")
     return value

+ 33 - 64
celery/task/sets.py

@@ -3,7 +3,7 @@ import warnings
 from celery import registry
 from celery.app import app_or_default
 from celery.datastructures import AttributeDict
-from celery.utils import gen_unique_id
+from celery.utils import cached_property, gen_unique_id
 from celery.utils.compat import UserList
 
 TASKSET_DEPRECATION_TEXT = """\
@@ -18,7 +18,6 @@ this so the syntax has been changed to:
     ts = TaskSet(tasks=[
             %(cls)s.subtask(args1, kwargs1, options1),
             %(cls)s.subtask(args2, kwargs2, options2),
-            %(cls)s.subtask(args3, kwargs3, options3),
             ...
             %(cls)s.subtask(argsN, kwargsN, optionsN),
     ])
@@ -53,13 +52,11 @@ class subtask(AttributeDict):
 
     """
 
-    def __init__(self, task=None, args=None, kwargs=None, options=None,
-            **extra):
+    def __init__(self, task=None, args=None, kwargs=None, options=None, **ex):
         init = super(subtask, self).__init__
 
         if isinstance(task, dict):
-            # Use the values from a dict.
-            return init(task)
+            return init(task)  # works like dict(d)
 
         # Also supports using task class/instance instead of string name.
         try:
@@ -68,7 +65,7 @@ class subtask(AttributeDict):
             task_name = task
 
         init(task=task_name, args=tuple(args or ()),
-                             kwargs=dict(kwargs or {}, **extra),
+                             kwargs=dict(kwargs or {}, **ex),
                              options=options or {})
 
     def delay(self, *argmerge, **kwmerge):
@@ -81,7 +78,7 @@ class subtask(AttributeDict):
         args = tuple(args) + tuple(self.args)
         kwargs = dict(self.kwargs, **kwargs)
         options = dict(self.options, **options)
-        return self.get_type().apply(args, kwargs, **options)
+        return self.type.apply(args, kwargs, **options)
 
     def apply_async(self, args=(), kwargs={}, **options):
         """Apply this task asynchronously."""
@@ -89,42 +86,48 @@ class subtask(AttributeDict):
         args = tuple(args) + tuple(self.args)
         kwargs = dict(self.kwargs, **kwargs)
         options = dict(self.options, **options)
-        return self.get_type().apply_async(args, kwargs, **options)
+        return self.type.apply_async(args, kwargs, **options)
 
     def get_type(self):
-        # For JSON serialization, the task class is lazily loaded,
+        return self.type
+
+    def __reduce__(self):
+        # for serialization, the task type is lazily loaded,
         # and not stored in the dict itself.
+        return (self.__class__, (dict(self), ), None)
+
+    def __repr__(self, kwformat=lambda i: "%s=%r" % i, sep=', '):
+        kw = self["kwargs"]
+        return "%s(%s%s%s)" % (self["task"], sep.join(map(repr, self["args"])),
+                kw and sep or "", sep.join(map(kwformat, kw.iteritems())))
+
+    @cached_property
+    def type(self):
         return registry.tasks[self.task]
 
 
 class TaskSet(UserList):
     """A task containing several subtasks, making it possible
-    to track how many, or when all of the tasks has been completed.
+    to track how many, or when all of the tasks have been completed.
 
     :param tasks: A list of :class:`subtask` instances.
 
-    .. attribute:: total
-
-        Total number of subtasks in this task set.
-
     Example::
 
-        >>> from djangofeeds.tasks import RefreshFeedTask
-        >>> from celery.task.sets import TaskSet, subtask
-        >>> urls = ("http://cnn.com/rss",
-        ...         "http://bbc.co.uk/rss",
-        ...         "http://xkcd.com/rss")
-        >>> subtasks = [RefreshFeedTask.subtask(kwargs={"feed_url": url})
-        ...                 for url in urls]
-        >>> taskset = TaskSet(tasks=subtasks)
+        >>> urls = ("http://cnn.com/rss", "http://bbc.co.uk/rss")
+        >>> taskset = TaskSet(refresh_feed.subtask((url, )) for url in urls)
         >>> taskset_result = taskset.apply_async()
-        >>> list_of_return_values = taskset_result.join()
+        >>> list_of_return_values = taskset_result.join()  # *expensive*
 
     """
     _task = None                # compat
     _task_name = None           # compat
 
+    #: Total number of subtasks in this set.
+    total = None
+
     def __init__(self, task=None, tasks=None, app=None, Publisher=None):
+        self.app = app_or_default(app)
         if task is not None:
             if hasattr(task, "__iter__"):
                 tasks = task
@@ -138,14 +141,13 @@ class TaskSet(UserList):
                 warnings.warn(TASKSET_DEPRECATION_TEXT % {
                                 "cls": task.__class__.__name__},
                               DeprecationWarning)
-
-        self.app = app_or_default(app)
         self.data = list(tasks or [])
         self.total = len(self.tasks)
         self.Publisher = Publisher or self.app.amqp.TaskPublisher
 
     def apply_async(self, connection=None, connect_timeout=None,
             publisher=None):
+        """Apply taskset."""
         return self.app.with_default_connection(self._apply_async)(
                     connection=connection,
                     connect_timeout=connect_timeout,
@@ -153,43 +155,13 @@ class TaskSet(UserList):
 
     def _apply_async(self, connection=None, connect_timeout=None,
             publisher=None):
-        """Run all tasks in the taskset.
-
-        Returns a :class:`celery.result.TaskSetResult` instance.
-
-        Example
-
-            >>> ts = TaskSet(tasks=(
-            ...         RefreshFeedTask.subtask(["http://foo.com/rss"]),
-            ...         RefreshFeedTask.subtask(["http://bar.com/rss"]),
-            ... ))
-            >>> result = ts.apply_async()
-            >>> result.taskset_id
-            "d2c9b261-8eff-4bfb-8459-1e1b72063514"
-            >>> result.subtask_ids
-            ["b4996460-d959-49c8-aeb9-39c530dcde25",
-            "598d2d18-ab86-45ca-8b4f-0779f5d6a3cb"]
-            >>> result.waiting()
-            True
-            >>> time.sleep(10)
-            >>> result.ready()
-            True
-            >>> result.successful()
-            True
-            >>> result.failed()
-            False
-            >>> result.join()
-            [True, True]
-
-        """
         if self.app.conf.CELERY_ALWAYS_EAGER:
             return self.apply()
 
         taskset_id = gen_unique_id()
         pub = publisher or self.Publisher(connection=connection)
         try:
-            results = [task.apply_async(taskset_id=taskset_id,
-                                        publisher=pub)
+            results = [task.apply_async(taskset_id=taskset_id, publisher=pub)
                             for task in self.tasks]
         finally:
             if not publisher:  # created by us.
@@ -198,13 +170,10 @@ class TaskSet(UserList):
         return self.app.TaskSetResult(taskset_id, results)
 
     def apply(self):
-        """Applies the taskset locally."""
-        taskset_id = gen_unique_id()
-
-        # this will be filled with EagerResults.
-        results = [task.apply(taskset_id=taskset_id)
-                        for task in self.tasks]
-        return self.app.TaskSetResult(taskset_id, results)
+        """Applies the taskset locally by blocking until all tasks return."""
+        setid = gen_unique_id()
+        return self.app.TaskSetResult(setid, [task.apply(taskset_id=setid)
+                                                for task in self.tasks])
 
     @property
     def tasks(self):