Bläddra i källkod

Add optional taskset_id parameter to TaskSet.apply and TaskSet.apply_async.

Branko Čibej 14 år sedan
förälder
incheckning
f9e7c42237
2 ändrade filer med 15 tillägg och 9 borttagningar
  1. 10 9
      celery/task/sets.py
  2. 5 0
      celery/tests/test_task/test_task.py

+ 10 - 9
celery/task/sets.py

@@ -148,32 +148,33 @@ class TaskSet(UserList):
         self.Publisher = Publisher or self.app.amqp.TaskPublisher
 
     def apply_async(self, connection=None, connect_timeout=None,
-            publisher=None):
+            publisher=None, taskset_id=None):
         """Apply taskset."""
         return self.app.with_default_connection(self._apply_async)(
                     connection=connection,
                     connect_timeout=connect_timeout,
-                    publisher=publisher)
+                    publisher=publisher,
+                    taskset_id=taskset_id)
 
     def _apply_async(self, connection=None, connect_timeout=None,
-            publisher=None):
+            publisher=None, taskset_id=None):
         if self.app.conf.CELERY_ALWAYS_EAGER:
-            return self.apply()
+            return self.apply(taskset_id=taskset_id)
 
-        taskset_id = gen_unique_id()
+        setid = taskset_id or gen_unique_id()
         pub = publisher or self.Publisher(connection=connection)
         try:
-            results = [task.apply_async(taskset_id=taskset_id, publisher=pub)
+            results = [task.apply_async(taskset_id=setid, publisher=pub)
                             for task in self.tasks]
         finally:
             if not publisher:  # created by us.
                 pub.close()
 
-        return self.app.TaskSetResult(taskset_id, results)
+        return self.app.TaskSetResult(setid, results)
 
-    def apply(self):
+    def apply(self, taskset_id=None):
         """Applies the taskset locally by blocking until all tasks return."""
-        setid = gen_unique_id()
+        setid = taskset_id or gen_unique_id()
         return self.app.TaskSetResult(setid, [task.apply(taskset_id=setid)
                                                 for task in self.tasks])
 

+ 5 - 0
celery/tests/test_task/test_task.py

@@ -435,6 +435,11 @@ class TestTaskSet(unittest.TestCase):
                     increment_by=m.get("kwargs", {}).get("increment_by"))
         self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
 
+    def test_named_taskset(self):
+        prefix = "test_named_taskset-"
+        ts = task.TaskSet([return_True_task.subtask([1])])
+        res = ts.apply(taskset_id=prefix+gen_unique_id())
+        self.assertTrue(res.taskset_id.startswith(prefix))
 
 class TestTaskApply(unittest.TestCase):