Procházet zdrojové kódy

[Canvas] Support special case of group(A.s() | group(B.s() | C.S()))

Ask Solem před 9 roky
rodič
revize
1e3fcaa969
1 změnil soubory, kde provedl 26 přidání a 4 odebrání
  1. 26 4
      celery/canvas.py

+ 26 - 4
celery/canvas.py

@@ -21,6 +21,7 @@ from itertools import chain as _chain
 from kombu.utils import cached_property, fxrange, reprcall, uuid
 
 from celery._state import current_app, get_current_worker_task
+from celery.result import GroupResult
 from celery.utils.functional import (
     maybe_list, is_list, regen,
     chunks as _chunks,
@@ -368,6 +369,7 @@ class chain(Signature):
             self, 'celery.chain', (), {'tasks': tasks}, **options
         )
         self.subtask_type = 'chain'
+        self._frozen = None
 
     def __call__(self, *args, **kwargs):
         if self.tasks:
@@ -387,10 +389,14 @@ class chain(Signature):
         app = app or self.app
         args = (tuple(args) + tuple(self.args)
                 if args and not self.immutable else self.args)
-        tasks, results = self.prepare_steps(
-            args, self.tasks, root_id, link_error, app,
-            task_id, group_id, chord,
-        )
+
+        try:
+            tasks, results = self._frozen
+        except (AttributeError, ValueError):
+            tasks, results = self.prepare_steps(
+                args, self.tasks, root_id, link_error, app,
+                task_id, group_id, chord,
+            )
         if results:
             # make sure we can do a link() and link_error() on a chain object.
             if link:
@@ -398,6 +404,12 @@ class chain(Signature):
             tasks[0].apply_async(**options)
             return results[-1]
 
+    def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
+        _, results = self._frozen = self.prepare_steps(
+            (), self.tasks, root_id, None, self.app, _id, group_id, chord,
+        )
+        return results[-1]
+
     def prepare_steps(self, args, tasks,
                       root_id=None, link_error=None, app=None,
                       last_task_id=None, group_id=None, chord_body=None,
@@ -665,6 +677,16 @@ class group(Signature):
         result = self.app.GroupResult(
             group_id, list(self._apply_tasks(tasks, producer, app, **options)),
         )
+
+        # - Special case of group(A.s() | group(B.s(), C.s()))
+        # That is, group with single item that is a chain but the
+        # last task in that chain is a group.
+        #
+        # We cannot actually support arbitrary GroupResults in chains,
+        # but this special case we can.
+        if len(result) == 1 and isinstance(result[0], GroupResult):
+            result = result[0]
+
         parent_task = get_current_worker_task()
         if add_to_parent and parent_task:
             parent_task.add_trail(result)