Browse Source

Care about chain's link_error tasks (#4240)

* Assign link_error sugnatures to each task

issue #4232
Chain could have link_error signatures for error processing.
Joining chains copies tasks from other chain to original one.
Thus copying loses other chain's link_error signatures.
Assigning chain's link_error signatures to each task could have
the same effect.
Each task from other chain are cloned to leave original ones as is.

* Clone full chain's state

In appending task to chain clone full chain's state and
append task to chain's tasks.

* Fix chaining to chains

* Add test for keeping link_error on chaining

* Fix chaining group to chain

* Fix indentation according to PEP8

* Fix blank lines

* Avoid reduce function usage

* Move common code to separate method
Anton 7 years ago
parent
commit
9a1064e71b
2 changed files with 63 additions and 12 deletions
  1. 17 12
      celery/canvas.py
  2. 46 0
      t/unit/tasks/test_canvas.py

+ 17 - 12
celery/canvas.py

@@ -395,23 +395,19 @@ class Signature(dict):
             other = maybe_unroll_group(other)
             if isinstance(self, _chain):
                 # chain | group() -> chain
-                sig = self.clone()
-                sig.tasks.append(other)
-                return sig
+                return _chain(seq_concat_item(
+                    self.unchain_tasks(), other), app=self._app)
             # task | group() -> chain
             return _chain(self, other, app=self.app)
 
         if not isinstance(self, _chain) and isinstance(other, _chain):
             # task | chain -> chain
-            return _chain(
-                seq_concat_seq((self,), other.tasks), app=self._app)
+            return _chain(seq_concat_seq(
+                (self,), other.unchain_tasks()), app=self._app)
         elif isinstance(other, _chain):
             # chain | chain -> chain
-            sig = self.clone()
-            if isinstance(sig.tasks, tuple):
-                sig.tasks = list(sig.tasks)
-            sig.tasks.extend(other.tasks)
-            return sig
+            return _chain(seq_concat_seq(
+                self.unchain_tasks(), other.unchain_tasks()), app=self._app)
         elif isinstance(self, chord):
             # chord(ONE, body) | other -> ONE | body | other
             # chord with one header task is unecessary.
@@ -436,8 +432,8 @@ class Signature(dict):
                     return sig
                 else:
                     # chain | task -> chain
-                    return _chain(
-                        seq_concat_item(self.tasks, other), app=self._app)
+                    return _chain(seq_concat_item(
+                        self.unchain_tasks(), other), app=self._app)
             # task | task -> chain
             return _chain(self, other, app=self._app)
         return NotImplemented
@@ -557,6 +553,15 @@ class _chain(Signature):
         ]
         return s
 
+    def unchain_tasks(self):
+        # Clone chain's tasks assigning sugnatures from link_error
+        # to each task
+        tasks = [t.clone() for t in self.tasks]
+        for sig in self.options.get('link_error', []):
+            for task in tasks:
+                task.link_error(sig)
+        return tasks
+
     def apply_async(self, args=(), kwargs={}, **options):
         # python is best at unpacking kwargs, so .run is here to do that.
         app = self.app

+ 46 - 0
t/unit/tasks/test_canvas.py

@@ -188,6 +188,52 @@ class test_Signature(CanvasCase):
         s = signature('xxx.not.registered', app=self.app)
         assert s._apply_async
 
+    def test_keeping_link_error_on_chaining(self):
+        x = self.add.s(2, 2) | self.mul.s(4)
+        assert isinstance(x, _chain)
+        x.link_error(SIG)
+        assert SIG in x.options['link_error']
+
+        t = signature(SIG)
+        z = x | t
+        assert isinstance(z, _chain)
+        assert t in z.tasks
+        assert not z.options.get('link_error')
+        assert SIG in z.tasks[0].options['link_error']
+        assert not z.tasks[2].options.get('link_error')
+        assert SIG in x.options['link_error']
+        assert t not in x.tasks
+        assert not x.tasks[0].options.get('link_error')
+
+        z = t | x
+        assert isinstance(z, _chain)
+        assert t in z.tasks
+        assert not z.options.get('link_error')
+        assert SIG in z.tasks[1].options['link_error']
+        assert not z.tasks[0].options.get('link_error')
+        assert SIG in x.options['link_error']
+        assert t not in x.tasks
+        assert not x.tasks[0].options.get('link_error')
+
+        y = self.add.s(4, 4) | self.div.s(2)
+        assert isinstance(y, _chain)
+
+        z = x | y
+        assert isinstance(z, _chain)
+        assert not z.options.get('link_error')
+        assert SIG in z.tasks[0].options['link_error']
+        assert not z.tasks[2].options.get('link_error')
+        assert SIG in x.options['link_error']
+        assert not x.tasks[0].options.get('link_error')
+
+        z = y | x
+        assert isinstance(z, _chain)
+        assert not z.options.get('link_error')
+        assert SIG in z.tasks[3].options['link_error']
+        assert not z.tasks[1].options.get('link_error')
+        assert SIG in x.options['link_error']
+        assert not x.tasks[0].options.get('link_error')
+
 
 class test_xmap_xstarmap(CanvasCase):