Browse Source

Fixes canvas issues: Do not precalculate parent_id/root_id, and always keep result.parent references

Ask Solem 8 years ago
parent
commit
6ed45eccb9
4 changed files with 69 additions and 152 deletions
  1. 3 0
      celery/app/amqp.py
  2. 6 8
      celery/app/base.py
  3. 54 54
      celery/canvas.py
  4. 6 90
      t/unit/tasks/test_canvas.py

+ 3 - 0
celery/app/amqp.py

@@ -342,6 +342,9 @@ class AMQP(object):
             if chord:
                 chord = utf8dict(chord)
 
+        if not root_id:  # empty root_id defaults to task_id
+            root_id = task_id
+
         return task_message(
             headers={
                 'lang': 'py',

+ 6 - 8
celery/app/base.py

@@ -690,15 +690,13 @@ class Celery(object):
         options = router.route(
             options, route_name or name, args, kwargs, task_type)
 
-        if root_id is None:
-            parent, have_parent = self.current_worker_task, True
+        if not root_id or not parent_id:
+            parent = self.current_worker_task
             if parent:
-                root_id = parent.request.root_id or parent.request.id
-        if parent_id is None:
-            if not have_parent:
-                parent, have_parent = self.current_worker_task, True
-            if parent:
-                parent_id = parent.request.id
+                if not root_id:
+                    root_id = parent.request.root_id or parent.request.id
+                if not parent_id:
+                    parent_id = parent.request.id
 
         message = amqp.create_task_message(
             task_id, name, args, kwargs, countdown, eta, group_id,

+ 54 - 54
celery/canvas.py

@@ -388,9 +388,6 @@ class Signature(dict):
     def set_immutable(self, immutable):
         self.immutable = immutable
 
-    def set_parent_id(self, parent_id):
-        self.parent_id = parent_id
-
     def _with_list_option(self, key):
         items = self.options.setdefault(key, [])
         if not isinstance(items, MutableSequence):
@@ -457,26 +454,48 @@ class Signature(dict):
             # group() | task -> chord
             return chord(self, body=other, app=self._app)
         elif isinstance(other, group):
-            # task | group() -> unroll group with one member
+            # unroll group with one member
             other = maybe_unroll_group(other)
-            return chain(self, other, app=self._app)
+            if isinstance(self, chain):
+                # chain | group() -> chain
+                sig = self.clone()
+                sig.tasks.append(other)
+                return sig
+            # task | group() -> chain
+            return chain(self, other, app=self.app)
         if not isinstance(self, chain) and isinstance(other, chain):
             # task | chain -> chain
             return chain(
                 _seq_concat_seq((self,), other.tasks), app=self._app)
         elif isinstance(other, chain):
             # chain | chain -> chain
-            return chain(
-                _seq_concat_seq(self.tasks, other.tasks), app=self._app)
+            sig = self.clone()
+            if isinstance(sig.tasks, tuple):
+                sig.tasks = list(sig.tasks)
+            sig.tasks.extend(other.tasks)
+            return sig
         elif isinstance(self, chord):
+            # chord | task ->  attach to body
             sig = self.clone()
             sig.body = sig.body | other
             return sig
         elif isinstance(other, Signature):
             if isinstance(self, chain):
-                # chain | task -> chain
-                return chain(
-                    _seq_concat_item(self.tasks, other), app=self._app)
+                if isinstance(self.tasks[-1], group):
+                    # CHAIN [last item is group] | TASK -> chord
+                    sig = self.clone()
+                    sig.tasks[-1] = chord(
+                        sig.tasks[-1], other, app=self._app)
+                    return sig
+                elif isinstance(self.tasks[-1], chord):
+                    # CHAIN [last item is chord] -> chain with chord body.
+                    sig = self.clone()
+                    sig.tasks[-1].body = sig.tasks[-1].body | other
+                    return sig
+                else:
+                    # chain | task -> chain
+                    return chain(
+                        _seq_concat_item(self.tasks, other), app=self._app)
             # task | task -> chain
             return chain(self, other, app=self._app)
         return NotImplemented
@@ -712,7 +731,7 @@ class chain(Signature):
         steps_extend = steps.extend
 
         prev_task = None
-        prev_res = prev_prev_res = None
+        prev_res = None
         tasks, results = [], []
         i = 0
         while steps:
@@ -745,7 +764,6 @@ class chain(Signature):
                     task, body=prev_task,
                     task_id=prev_res.task_id, root_id=root_id, app=app,
                 )
-                prev_res = prev_prev_res
 
             if is_last_task:
                 # chain(task_id=id) means task id is set for the last task
@@ -763,27 +781,12 @@ class chain(Signature):
             i += 1
 
             if prev_task:
-                prev_task.set_parent_id(task.id)
-
                 if use_link:
                     # link previous task to this task.
                     task.link(prev_task)
 
-                if prev_res:
-                    if isinstance(prev_task, chord):
-                        # If previous task was a chord,
-                        # the freeze above would have set a parent for
-                        # us, but we'd be overwriting it here.
-
-                        # so fix this relationship so it's:
-                        #     chord body -> group -> THIS RES
-                        assert isinstance(prev_res.parent, GroupResult)
-                        prev_res.parent.parent = res
-                    else:
-                        prev_res.parent = res
-
-            if is_first_task and parent_id is not None:
-                task.set_parent_id(parent_id)
+                if prev_res and not prev_res.parent:
+                    prev_res.parent = res
 
             if link_error:
                 for errback in maybe_list(link_error):
@@ -792,14 +795,18 @@ class chain(Signature):
             tasks.append(task)
             results.append(res)
 
-            prev_task, prev_prev_res, prev_res = (
-                task, prev_res, res,
-            )
-
-        if root_id is None and tasks:
-            root_id = tasks[-1].id
-            for task in reversed(tasks):
-                task.options['root_id'] = root_id
+            prev_task, prev_res = task, res
+            if isinstance(task, chord):
+                # If the task is a chord, and the body is a chain
+                # the chain has already been prepared, and res is
+                # set to the last task in the callback chain.
+
+                # We need to change that so that it points to the
+                # group result object.
+                node = res
+                while node.parent:
+                    node = node.parent
+                prev_res = node
         return tasks, results
 
     def apply(self, args=(), kwargs={}, **options):
@@ -1112,10 +1119,6 @@ class group(Signature):
             options.pop('task_id', uuid()))
         return options, group_id, options.get('root_id')
 
-    def set_parent_id(self, parent_id):
-        for task in self.tasks:
-            task.set_parent_id(parent_id)
-
     def freeze(self, _id=None, group_id=None, chord=None,
                root_id=None, parent_id=None):
         # pylint: disable=redefined-outer-name
@@ -1238,21 +1241,19 @@ class chord(Signature):
             self.tasks = group(self.tasks, app=self.app)
         header_result = self.tasks.freeze(
             parent_id=parent_id, root_id=root_id, chord=self.body)
-        bodyres = self.body.freeze(
-            _id, parent_id=header_result.id, root_id=root_id)
-        bodyres.parent = header_result
+        bodyres = self.body.freeze(_id, root_id=root_id)
+        # we need to link the body result back to the group result,
+        # but the body may actually be a chain,
+        # so find the first result without a parent
+        node = bodyres
+        while node:
+            if not node.parent:
+                node.parent = header_result
+                break
+            node = node.parent
         self.id = self.tasks.id
-        self.body.set_parent_id(self.id)
         return bodyres
 
-    def set_parent_id(self, parent_id):
-        tasks = self.tasks
-        if isinstance(tasks, group):
-            tasks = tasks.tasks
-        for task in tasks:
-            task.set_parent_id(parent_id)
-        self.parent_id = parent_id
-
     def apply_async(self, args=(), kwargs={}, task_id=None,
                     producer=None, publisher=None, connection=None,
                     router=None, result_cls=None, **options):
@@ -1304,7 +1305,6 @@ class chord(Signature):
 
         results = header.freeze(
             group_id=group_id, chord=body, root_id=root_id).results
-        body.set_parent_id(group_id)
         bodyres = body.freeze(task_id, root_id=root_id)
 
         parent = app.backend.apply_chord(

+ 6 - 90
t/unit/tasks/test_canvas.py

@@ -253,7 +253,7 @@ class test_chain(CanvasCase):
 
     def test_repr(self):
         x = self.add.s(2, 2) | self.add.s(2)
-        assert repr(x) == '%s(2, 2) | %s(2)' % (self.add.name, self.add.name)
+        assert repr(x) == '%s(2, 2) | add(2)' % (self.add.name,)
 
     def test_apply_async(self):
         c = self.add.s(2, 2) | self.add.s(4) | self.add.s(8)
@@ -262,43 +262,6 @@ class test_chain(CanvasCase):
         assert result.parent.parent
         assert result.parent.parent.parent is None
 
-    def test_group_to_chord__freeze_parent_id(self):
-        def using_freeze(c):
-            c.freeze(parent_id='foo', root_id='root')
-            return c._frozen[0]
-        self.assert_group_to_chord_parent_ids(using_freeze)
-
-    def assert_group_to_chord_parent_ids(self, freezefun):
-        c = (
-            self.add.s(5, 5) |
-            group([self.add.s(i, i) for i in range(5)], app=self.app) |
-            self.add.si(10, 10) |
-            self.add.si(20, 20) |
-            self.add.si(30, 30)
-        )
-        tasks = freezefun(c)
-        assert tasks[-1].parent_id == 'foo'
-        assert tasks[-1].root_id == 'root'
-        assert tasks[-2].parent_id == tasks[-1].id
-        assert tasks[-2].root_id == 'root'
-        assert tasks[-2].body.parent_id == tasks[-2].tasks.id
-        assert tasks[-2].body.parent_id == tasks[-2].id
-        assert tasks[-2].body.root_id == 'root'
-        assert tasks[-2].tasks.tasks[0].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[0].root_id == 'root'
-        assert tasks[-2].tasks.tasks[1].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[1].root_id == 'root'
-        assert tasks[-2].tasks.tasks[2].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[2].root_id == 'root'
-        assert tasks[-2].tasks.tasks[3].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[3].root_id == 'root'
-        assert tasks[-2].tasks.tasks[4].parent_id == tasks[-1].id
-        assert tasks[-2].tasks.tasks[4].root_id == 'root'
-        assert tasks[-3].parent_id == tasks[-2].body.id
-        assert tasks[-3].root_id == 'root'
-        assert tasks[-4].parent_id == tasks[-3].id
-        assert tasks[-4].root_id == 'root'
-
     def test_splices_chains(self):
         c = chain(
             self.add.s(5, 5),
@@ -341,21 +304,12 @@ class test_chain(CanvasCase):
         assert tasks[-1].args[0] == 5
         assert isinstance(tasks[-2], chord)
         assert len(tasks[-2].tasks) == 5
-        assert tasks[-2].parent_id == tasks[-1].id
-        assert tasks[-2].root_id == tasks[-1].id
-        assert tasks[-2].body.args[0] == 10
-        assert tasks[-2].body.parent_id == tasks[-2].id
-
-        assert tasks[-3].args[0] == 20
-        assert tasks[-3].root_id == tasks[-1].id
-        assert tasks[-3].parent_id == tasks[-2].body.id
-
-        assert tasks[-4].args[0] == 30
-        assert tasks[-4].parent_id == tasks[-3].id
-        assert tasks[-4].root_id == tasks[-1].id
 
-        assert tasks[-2].body.options['link']
-        assert tasks[-2].body.options['link'][0].options['link']
+        body = tasks[-2].body
+        assert len(body.tasks) == 3
+        assert body.tasks[0].args[0] == 10
+        assert body.tasks[1].args[0] == 20
+        assert body.tasks[2].args[0] == 30
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = True
@@ -437,39 +391,6 @@ class test_chain(CanvasCase):
         assert chain(app=self.app)() is None
         assert chain(app=self.app).apply_async() is None
 
-    def test_root_id_parent_id(self):
-        self.app.conf.task_protocol = 2
-        c = chain(self.add.si(i, i) for i in range(4))
-        c.freeze()
-        tasks, _ = c._frozen
-        for i, task in enumerate(tasks):
-            assert task.root_id == tasks[-1].id
-            try:
-                assert task.parent_id == tasks[i + 1].id
-            except IndexError:
-                assert i == len(tasks) - 1
-            else:
-                valid_parents = i
-        assert valid_parents == len(tasks) - 2
-
-        self.assert_sent_with_ids(tasks[-1], tasks[-1].id, 'foo',
-                                  parent_id='foo')
-        assert tasks[-2].options['parent_id']
-        self.assert_sent_with_ids(tasks[-2], tasks[-1].id, tasks[-1].id)
-        self.assert_sent_with_ids(tasks[-3], tasks[-1].id, tasks[-2].id)
-        self.assert_sent_with_ids(tasks[-4], tasks[-1].id, tasks[-3].id)
-
-    def assert_sent_with_ids(self, task, rid, pid, **options):
-        self.app.amqp.send_task_message = Mock(name='send_task_message')
-        self.app.backend = Mock()
-        self.app.producer_or_acquire = ContextMock()
-
-        task.apply_async(**options)
-        self.app.amqp.send_task_message.assert_called()
-        message = self.app.amqp.send_task_message.call_args[0][2]
-        assert message.headers['parent_id'] == pid
-        assert message.headers['root_id'] == rid
-
     def test_call_no_tasks(self):
         x = chain()
         assert not x()
@@ -687,11 +608,6 @@ class test_chord(CanvasCase):
         x = chord(group(self.add.s(2, 2), self.add.s(4, 4), app=self.app))
         assert x.tasks
 
-    def test_set_parent_id(self):
-        x = chord(group(self.add.s(2, 2)))
-        x.tasks = [self.add.s(2, 2)]
-        x.set_parent_id('pid')
-
     def test_app_when_app(self):
         app = Mock(name='app')
         x = chord([self.add.s(4, 4)], app=app)