Explorar o código

Cleanup chain.prepare

Ask Solem %!s(int64=9) %!d(string=hai) anos
pai
achega
9a03964bf8
Modificáronse 2 ficheiros con 22 adicións e 32 borrados
  1. 18 28
      celery/canvas.py
  2. 4 4
      celery/tests/app/test_builtins.py

+ 18 - 28
celery/canvas.py

@@ -415,20 +415,12 @@ class chain(Signature):
             )
 
         if results:
-            # make sure we can do a link() and link_error() on a chain object.
-            if self._use_link:
-                # old task protocol used link for chains, last is last.
-                if link:
-                    tasks[-1].set(link=link)
-                tasks[0].apply_async(**options)
-                return results[-1]
-            else:
-                # -- using chain message field means last task is first.
-                if link:
-                    tasks[0].set(link=link)
-                first_task = tasks.pop()
-                first_task.apply_async(chain=tasks, **options)
-                return results[0]
+            if link:
+                tasks[0].set(link=link)
+            first_task = tasks.pop()
+            first_task.apply_async(
+                chain=tasks if not use_link else None, **options)
+            return results[0]
 
     def freeze(self, _id=None, group_id=None, chord=None, root_id=None):
         _, results = self._frozen = self.prepare_steps(
@@ -452,17 +444,15 @@ class chain(Signature):
             use_link = False
         steps = deque(tasks)
 
-        steps_pop = steps.popleft if use_link else steps.pop
-        steps_extend = steps.extendleft if use_link else steps.extend
-        extend_order = reversed if use_link else noop
+        steps_pop = steps.pop
+        steps_extend = steps.extend
 
         next_step = prev_task = prev_res = None
         tasks, results = [], []
         i = 0
         while steps:
             task = steps_pop()
-            last_task = not steps if use_link else not i
-            first_task = not i if use_link else not steps
+            is_first_task, is_last_task = not steps, not i
 
             if not isinstance(task, abstract.CallableSignature):
                 task = from_dict(task, app=app)
@@ -471,19 +461,19 @@ class chain(Signature):
 
             # first task gets partial args from chain
             if clone:
-                task = task.clone(args) if not i else task.clone()
-            elif first_task:
+                task = task.clone(args) if is_first_task else task.clone()
+            elif is_first_task:
                 task.args = tuple(args) + tuple(task.args)
 
             if isinstance(task, chain):
                 # splice the chain
-                steps_extend(extend_order(task.tasks))
+                steps_extend(task.tasks)
                 continue
             elif isinstance(task, group):
-                if (steps if use_link else prev_task):
+                if prev_task:
                     # automatically upgrade group(...) | s to chord(group, s)
                     try:
-                        next_step = steps_pop() if use_link else prev_task
+                        next_step = prev_task
                         # for chords we freeze by pretending it's a normal
                         # signature instead of a group.
                         res = Signature.freeze(next_step, root_id=root_id)
@@ -494,7 +484,7 @@ class chain(Signature):
                     except IndexError:
                         pass  # no callback, so keep as group.
 
-            if last_task:
+            if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
                 # in the chain.  If the chord is part of a chord/group
                 # then that chord/group must synchronize based on the
@@ -512,9 +502,9 @@ class chain(Signature):
             if prev_task:
                 if use_link:
                     # link previous task to this task.
-                    prev_task.link(task)
+                    task.link(prev_task)
                     if not res.parent:
-                        res.parent = prev_res
+                        prev_res.parent = res.parent
                 else:
                     prev_res.parent = res
 
@@ -686,7 +676,7 @@ class group(Signature):
                     task = from_dict(task, app=app)
                 if isinstance(task, group):
                     # needs yield_from :(
-                    unroll = task._prepared(
+                    unroll = task_prepared(
                         task.tasks, partial_args, group_id, root_id, app,
                     )
                     for taskN, resN in unroll:

+ 4 - 4
celery/tests/app/test_builtins.py

@@ -142,14 +142,14 @@ class test_chain(BuiltinsCase):
         )
         c._use_link = True
         tasks, _ = c.prepare_steps((), c.tasks)
-        self.assertIsInstance(tasks[0], chord)
-        self.assertTrue(tasks[0].body.options['link'])
-        self.assertTrue(tasks[0].body.options['link'][0].options['link'])
+        self.assertIsInstance(tasks[-1], chord)
+        self.assertTrue(tasks[-1].body.options['link'])
+        self.assertTrue(tasks[-1].body.options['link'][0].options['link'])
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True
         tasks2, _ = c2.prepare_steps((), c2.tasks)
-        self.assertIsInstance(tasks2[1], group)
+        self.assertIsInstance(tasks2[0], group)
 
     def test_group_to_chord__protocol_2(self):
         c = (