Browse Source

Make TaskSet.run() respect message options from the Task class. Closes #16.

Ask Solem 16 years ago
parent
commit
03d30a32de
5 changed files with 42 additions and 38 deletions
  1. 1 1
      README
  2. 1 1
      celery/__init__.py
  3. 3 3
      celery/messaging.py
  4. 5 10
      celery/result.py
  5. 32 23
      celery/task.py

+ 1 - 1
README

@@ -2,7 +2,7 @@
 celery - Distributed Task Queue for Django.
 ============================================
 
-:Version: 0.3.8
+:Version: 0.3.11
 
 Introduction
 ============

+ 1 - 1
celery/__init__.py

@@ -1,5 +1,5 @@
 """Distributed Task Queue for Django"""
-VERSION = (0, 3, 10)
+VERSION = (0, 3, 11)
 __version__ = ".".join(map(str, VERSION))
 __author__ = "Ask Solem"
 __contact__ = "askh@opera.com"

+ 3 - 3
celery/messaging.py

@@ -25,8 +25,8 @@ class TaskPublisher(Publisher):
         return self._delay_task(task_name=task_name, task_args=task_args,
                                 task_kwargs=task_kwargs, **kwargs)
 
-    def delay_task_in_set(self, task_name, taskset_id, task_args,
-            task_kwargs, **kwargs):
+    def delay_task_in_set(self, taskset_id, task_name, task_args, task_kwargs,
+            **kwargs):
         """Delay a task which part of a task set."""
         return self._delay_task(task_name=task_name, part_of_set=taskset_id,
                                 task_args=task_args, task_kwargs=task_kwargs,
@@ -46,7 +46,7 @@ class TaskPublisher(Publisher):
         immediate = kwargs.get("immediate")
         mandatory = kwargs.get("mandatory")
         routing_key = kwargs.get("routing_key")
-
+    
         task_args = task_args or []
         task_kwargs = task_kwargs or {}
         task_id = task_id or str(uuid.uuid4())

+ 5 - 10
celery/result.py

@@ -149,26 +149,21 @@ class TaskSetResult(object):
     single entity.
 
     :option taskset_id: see :attr:`taskset_id`.
-    :option subtask_ids: see :attr:`subtask_ids`.
+    :option subtasks see :attr:`subtasks`.
 
     .. attribute:: taskset_id
 
         The UUID of the taskset itself.
 
-    .. attribute:: subtask_ids
-
-        The list of task UUID's for all of the subtasks.
-
     .. attribute:: subtasks
 
         A list of :class:`AsyncResult`` instances for all of the subtasks.
 
     """
 
-    def __init__(self, taskset_id, subtask_ids):
+    def __init__(self, taskset_id, subtasks):
         self.taskset_id = taskset_id
-        self.subtask_ids = subtask_ids
-        self.subtasks = map(AsyncResult, self.subtask_ids)
+        self.subtasks = subtasks
 
     def itersubtasks(self):
         """Taskset subtask iterator.
@@ -239,8 +234,8 @@ class TaskSetResult(object):
         :raises: The exception if any of the tasks raised an exception.
 
         """
-        results = dict([(task_id, AsyncResult(task_id))
-                            for task_id in self.subtask_ids])
+        results = dict([(subtask.task_id, AsyncResult(subtask.task_id))
+                            for subtask in self.subtasks])
         while results:
             for task_id, pending_result in results.items():
                 if pending_result.status == "DONE":

+ 32 - 23
celery/task.py

@@ -12,13 +12,14 @@ from celery.registry import tasks
 from datetime import timedelta
 from celery.backends import default_backend
 from celery.result import AsyncResult, TaskSetResult
+from django.utils.functional import curry
 import uuid
 import pickle
 
 
 def apply_async(task, args=None, kwargs=None, routing_key=None,
         immediate=None, mandatory=None, connection=None,
-        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None):
+        connect_timeout=AMQP_CONNECTION_TIMEOUT, priority=None, **opts):
     """Run a task asynchronously by the celery daemon(s).
 
     :param task: The task to run (a callable object, or a :class:`Task`
@@ -47,27 +48,34 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     :keyword priority: The task priority, a number between ``0`` and ``9``.
 
     """
-    if not args:
-        args = []
-    if not kwargs:
-        kwargs = []
-    message_opts = {"routing_key": routing_key,
-                    "immediate": immediate,
-                    "mandatory": mandatory,
-                    "priority": priority}
-    for option_name, option_value in message_opts.items():
-        message_opts[option_name] = getattr(task, option_name, option_value)
+    args = args or []
+    kwargs = kwargs or {}
+    routing_key = routing_key or getattr(task, "routing_key", None)
+    immediate = immediate or getattr(task, "immediate", None)
+    mandatory = mandatory or getattr(task, "mandatory", None)
+    priority = priority or getattr(task, "priority", None)
+    taskset_id = opts.get("taskset_id")
+    publisher = opts.get("publisher")
 
     need_to_close_connection = False
-    if not connection:
-        connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
-        need_to_close_connection = True
+    if not publisher:
+        if not connection:
+            connection = DjangoAMQPConnection(connect_timeout=connect_timeout)
+            need_to_close_connection = True
+        publisher = TaskPublisher(connection=connection)
+
+    delay_task = publisher.delay_task
+    if taskset_id:
+        delay_task = curry(publisher.delay_task_in_set, taskset_id)
+        
+    task_id = delay_task(task.name, args, kwargs,
+                         routing_key=routing_key, mandatory=mandatory,
+                         immediate=immediate, priority=priority)
 
-    publisher = TaskPublisher(connection=connection)
-    task_id = publisher.delay_task(task.name, args, kwargs, **message_opts)
-    publisher.close()
     if need_to_close_connection:
+        publisher.close()
         connection.close()
+
     return AsyncResult(task_id)
 
 
@@ -323,9 +331,12 @@ class TaskSet(object):
     def __init__(self, task, args):
         try:
             task_name = task.name
+            task_obj = task
         except AttributeError:
             task_name = task
+            task_obj = tasks[task_name]
 
+        self.task = task_obj
         self.task_name = task_name
         self.arguments = args
         self.total = len(args)
@@ -363,14 +374,12 @@ class TaskSet(object):
         taskset_id = str(uuid.uuid4())
         conn = DjangoAMQPConnection(connect_timeout=connect_timeout)
         publisher = TaskPublisher(connection=conn)
-        subtask_ids = [publisher.delay_task_in_set(task_name=self.task_name,
-                                                   taskset_id=taskset_id,
-                                                   task_args=arg,
-                                                   task_kwargs=kwarg)
-                        for arg, kwarg in self.arguments]
+        subtasks = [apply_async(self.task, args, kwargs,
+                                taskset_id=taskset_id, publisher=publisher)
+                        for args, kwargs in self.arguments]
         publisher.close()
         conn.close()
-        return TaskSetResult(taskset_id, subtask_ids)
+        return TaskSetResult(taskset_id, subtasks)
 
     def iterate(self):
         """Iterate over the results returned after calling :meth:`run`.