Browse Source

Canvas: Fixes error "dict object has no attr clone" (Closes #3381)

Ask Solem 8 years ago
parent
commit
b56f63a6af
2 changed files with 84 additions and 11 deletions
  1. 61 10
      celery/canvas.py
  2. 23 1
      t/unit/tasks/test_canvas.py

+ 61 - 10
celery/canvas.py

@@ -118,6 +118,34 @@ def _upgrade(fields, sig):
     return sig
 
 
+def _seq_concat_item(seq, item):
+    """Return copy of sequence seq with item added.
+
+    Returns:
+        Sequence: if seq is a tuple, the result will be a tuple,
+           otherwise it depends on the implementation of ``__add__``.
+    """
+    return seq + (item,) if isinstance(seq, tuple) else seq + [item]
+
+
+def _seq_concat_seq(a, b):
+    """Concatenate two sequences: ``a + b``.
+
+    Returns:
+        Sequence: The return value will depend on the largest sequence
+            - if b is larger and is a tuple, the return value will be a tuple.
+            - if a is larger and is a list, the return value will be a list,
+    """
+    # find the type of the largest sequence
+    prefer = type(max([a, b], key=len))
+    # convert the smallest list to the type of the largest sequence.
+    if not isinstance(a, prefer):
+        a = prefer(a)
+    if not isinstance(b, prefer):
+        b = prefer(b)
+    return a + b
+
+
 @abstract.CallableSignature.register
 @python_2_unicode_compatible
 class Signature(dict):
@@ -434,10 +462,12 @@ class Signature(dict):
             return chain(self, other, app=self._app)
         if not isinstance(self, chain) and isinstance(other, chain):
             # task | chain -> chain
-            return chain((self,) + other.tasks, app=self._app)
+            return chain(
+                _seq_concat_seq((self,), other.tasks), app=self._app)
         elif isinstance(other, chain):
             # chain | chain -> chain
-            return chain(*self.tasks + other.tasks, app=self._app)
+            return chain(
+                _seq_concat_seq(self.tasks, other.tasks), app=self._app)
         elif isinstance(self, chord):
             sig = self.clone()
             sig.body = sig.body | other
@@ -445,7 +475,8 @@ class Signature(dict):
         elif isinstance(other, Signature):
             if isinstance(self, chain):
                 # chain | task -> chain
-                return chain(*self.tasks + (other,), app=self._app)
+                return chain(
+                    _seq_concat_item(self.tasks, other), app=self._app)
             # task | task -> chain
             return chain(self, other, app=self._app)
         return NotImplemented
@@ -595,8 +626,12 @@ class chain(Signature):
             return self.apply_async(args, kwargs)
 
     def clone(self, *args, **kwargs):
+        to_signature = maybe_signature
         s = Signature.clone(self, *args, **kwargs)
-        s.kwargs['tasks'] = [sig.clone() for sig in s.kwargs['tasks']]
+        s.kwargs['tasks'] = [
+            to_signature(sig, app=self._app, clone=True)
+            for sig in s.kwargs['tasks']
+        ]
         return s
 
     def apply_async(self, args=(), kwargs={}, **options):
@@ -1281,7 +1316,7 @@ class chord(Signature):
         s = Signature.clone(self, *args, **kwargs)
         # need to make copy of body
         try:
-            s.kwargs['body'] = s.kwargs['body'].clone()
+            s.kwargs['body'] = maybe_signature(s.kwargs['body'], clone=True)
         except (AttributeError, KeyError):
             pass
         return s
@@ -1326,14 +1361,30 @@ def signature(varies, *args, **kwargs):
 subtask = signature   # XXX compat
 
 
-def maybe_signature(d, app=None):
-    """Ensure obj is a signature, or None."""
+def maybe_signature(d, app=None, clone=False):
+    """Ensure obj is a signature, or None.
+
+    Arguments:
+        d (Optional[Union[abstract.CallableSignature, Mapping]]):
+            Signature or dict-serialized signature.
+        app (celery.Celery):
+            App to bind signature to.
+        clone (bool):
+            If d' is already a signature, the signature
+           will be cloned when this flag is enabled.
+
+    Returns:
+        Optional[abstract.CallableSignature]
+    """
     if d is not None:
-        if (isinstance(d, dict) and
-                not isinstance(d, abstract.CallableSignature)):
+        if isinstance(d, abstract.CallableSignature):
+            if clone:
+                d = d.clone()
+        elif isinstance(d, dict):
             d = signature(d)
+
         if app is not None:
             d._app = app
-        return d
+    return d
 
 maybe_subtask = maybe_signature  # XXX compat

+ 23 - 1
t/unit/tasks/test_canvas.py

@@ -17,6 +17,7 @@ from celery.canvas import (
     _maybe_group,
     maybe_signature,
     maybe_unroll_group,
+    _seq_concat_seq,
 )
 from celery.result import EagerResult
 
@@ -41,6 +42,18 @@ class test_maybe_unroll_group:
         assert maybe_unroll_group(g) is g
 
 
+@pytest.mark.parametrize('a,b,expected', [
+    ((1, 2, 3), [4, 5], (1, 2, 3, 4, 5)),
+    ((1, 2), [3, 4, 5], [1, 2, 3, 4, 5]),
+    ([1, 2, 3], (4, 5), [1, 2, 3, 4, 5]),
+    ([1, 2], (3, 4, 5), (1, 2, 3, 4, 5)),
+])
+def test_seq_concat_seq(a, b, expected):
+    res = _seq_concat_seq(a, b)
+    assert type(res) is type(expected)  # noqa
+    assert res == expected
+
+
 class CanvasCase:
 
     def setup(self):
@@ -349,13 +362,22 @@ class test_chain(CanvasCase):
         tasks2, _ = c2.prepare_steps((), c2.tasks)
         assert isinstance(tasks2[0], group)
 
-    def test_group_to_chord__protocol_2(self):
+    def test_group_to_chord__protocol_2__or(self):
         c = (
             group([self.add.s(i, i) for i in range(5)], app=self.app) |
             self.add.s(10) |
             self.add.s(20) |
             self.add.s(30)
         )
+        assert isinstance(c, chord)
+
+    def test_group_to_chord__protocol_2(self):
+        c = chain(
+            group([self.add.s(i, i) for i in range(5)], app=self.app),
+            self.add.s(10),
+            self.add.s(20),
+            self.add.s(30)
+        )
         c._use_link = False
         tasks, _ = c.prepare_steps((), c.tasks)
         assert isinstance(tasks[-1], chord)