Browse Source

Canvas: chain(X, Y, Z) is now the same as X | Y | Z (related to discussion in #3323)

Ask Solem 8 years ago
parent
commit
c718a6b50f
2 changed files with 64 additions and 34 deletions
  1. 52 23
      celery/canvas.py
  2. 12 11
      t/unit/tasks/test_canvas.py

+ 52 - 23
celery/canvas.py

@@ -7,13 +7,14 @@
 """
 from __future__ import absolute_import, unicode_literals
 
+import itertools
+import operator
 import sys
 
 from collections import MutableSequence, deque
 from copy import deepcopy
 from functools import partial as _partial, reduce
 from operator import itemgetter
-from itertools import chain as _chain
 
 from kombu.utils.functional import fxrange, reprcall
 from kombu.utils.objects import cached_property
@@ -241,9 +242,11 @@ class Signature(dict):
     _app = _type = None
 
     @classmethod
-    def register_type(cls, subclass, name=None):
-        cls.TYPES[name or subclass.__name__] = subclass
-        return subclass
+    def register_type(cls, name=None):
+        def _inner(subclass):
+            cls.TYPES[name or subclass.__name__] = subclass
+            return subclass
+        return _inner
 
     @classmethod
     def from_dict(cls, d, app=None):
@@ -473,34 +476,43 @@ class Signature(dict):
 
         "unchain" if you will, but with links intact.
         """
-        return list(_chain.from_iterable(_chain(
+        return list(itertools.chain.from_iterable(itertools.chain(
             [[self]],
             (link.flatten_links()
                 for link in maybe_list(self.options.get('link')) or [])
         )))
 
     def __or__(self, other):
+        # These could be implemented in each individual class,
+        # I'm sure, but for now we have this.
+        if isinstance(other, chord) and len(other.tasks) == 1:
+            # chord with one header -> header[0] | body
+            other = other.tasks[0] | other.body
         if isinstance(self, group):
             if isinstance(other, group):
                 # group() | group() -> single group
-                return group(_chain(self.tasks, other.tasks), app=self.app)
+                return group(itertools.chain(self.tasks, other.tasks), app=self.app)
             # group() | task -> chord
+            if len(self.tasks) == 1:
+                # group(ONE.s()) | other -> ONE.s() | other
+                # Issue #3323
+                return self.tasks[0] | other
             return chord(self, body=other, app=self._app)
         elif isinstance(other, group):
             # unroll group with one member
             other = maybe_unroll_group(other)
-            if isinstance(self, chain):
+            if isinstance(self, _chain):
                 # chain | group() -> chain
                 sig = self.clone()
                 sig.tasks.append(other)
                 return sig
             # task | group() -> chain
-            return chain(self, other, app=self.app)
-        if not isinstance(self, chain) and isinstance(other, chain):
+            return _chain(self, other, app=self.app)
+        if not isinstance(self, _chain) and isinstance(other, _chain):
             # task | chain -> chain
-            return chain(
+            return _chain(
                 _seq_concat_seq((self,), other.tasks), app=self._app)
-        elif isinstance(other, chain):
+        elif isinstance(other, _chain):
             # chain | chain -> chain
             sig = self.clone()
             if isinstance(sig.tasks, tuple):
@@ -508,12 +520,16 @@ class Signature(dict):
             sig.tasks.extend(other.tasks)
             return sig
         elif isinstance(self, chord):
+            # chord(ONE, body) | other -> ONE | body | other
+            # chord with one header task is unecessary.
+            if len(self.tasks) == 1:
+                return self.tasks[0] | self.body | other
             # chord | task ->  attach to body
             sig = self.clone()
             sig.body = sig.body | other
             return sig
         elif isinstance(other, Signature):
-            if isinstance(self, chain):
+            if isinstance(self, _chain):
                 if isinstance(self.tasks[-1], group):
                     # CHAIN [last item is group] | TASK -> chord
                     sig = self.clone()
@@ -527,10 +543,10 @@ class Signature(dict):
                     return sig
                 else:
                     # chain | task -> chain
-                    return chain(
+                    return _chain(
                         _seq_concat_item(self.tasks, other), app=self._app)
             # task | task -> chain
-            return chain(self, other, app=self._app)
+            return _chain(self, other, app=self._app)
         return NotImplemented
 
     def election(self):
@@ -611,9 +627,9 @@ class Signature(dict):
         'immutable', 'Flag set if no longer accepts new arguments')
 
 
-@Signature.register_type
+@Signature.register_type(name='chain')
 @python_2_unicode_compatible
-class chain(Signature):
+class _chain(Signature):
     """Chain tasks together.
 
     Each tasks follows one another,
@@ -869,6 +885,19 @@ class chain(Signature):
             ' | '.join(repr(t) for t in self.tasks))
 
 
+class chain(_chain):
+    # could be function, but must be able to reference as :class:`chain`.
+
+    def __new__(self, *tasks, **kwargs):
+        # This forces `chain(X, Y, Z)` to work the same way as `X | Y | Z`
+        if not kwargs and tasks:
+            if len(tasks) == 1 and is_list(tasks[0]):
+                # ensure chain(generator_expression) works.
+                tasks = tasks[0]
+            return reduce(operator.or_, tasks)
+        return super(chain, self).__new__(self, *tasks, **kwargs)
+
+
 class _basemap(Signature):
     _task_name = None
     _unpack_args = itemgetter('task', 'it')
@@ -894,7 +923,7 @@ class _basemap(Signature):
         )
 
 
-@Signature.register_type
+@Signature.register_type()
 @python_2_unicode_compatible
 class xmap(_basemap):
     """Map operation for tasks.
@@ -912,7 +941,7 @@ class xmap(_basemap):
             task.task, truncate(repr(it), 100))
 
 
-@Signature.register_type
+@Signature.register_type()
 @python_2_unicode_compatible
 class xstarmap(_basemap):
     """Map operation for tasks, using star arguments."""
@@ -925,7 +954,7 @@ class xstarmap(_basemap):
             task.task, truncate(repr(it), 100))
 
 
-@Signature.register_type
+@Signature.register_type()
 class chunks(Signature):
     """Partition of tasks in n chunks."""
 
@@ -979,7 +1008,7 @@ def _maybe_group(tasks, app):
     return tasks
 
 
-@Signature.register_type
+@Signature.register_type()
 @python_2_unicode_compatible
 class group(Signature):
     """Creates a group of tasks to be executed in parallel.
@@ -1220,7 +1249,7 @@ class group(Signature):
         return app if app is not None else current_app
 
 
-@Signature.register_type
+@Signature.register_type()
 @python_2_unicode_compatible
 class chord(Signature):
     r"""Barrier synchronization primitive.
@@ -1385,12 +1414,12 @@ class chord(Signature):
             if isinstance(self.body, chain):
                 return _shorten_names(
                     self.body.tasks[0]['task'],
-                    '({0} | {1!r})'.format(
+                    '%({0} | {1!r})'.format(
                         self.body.tasks[0].reprcall(self.tasks),
                         chain(self.body.tasks[1:], app=self._app),
                     ),
                 )
-            return _shorten_names(
+            return '%' + _shorten_names(
                 self.body['task'], self.body.reprcall(self.tasks))
         return '<chord without body: {0.tasks!r}>'.format(self)
 

+ 12 - 11
t/unit/tasks/test_canvas.py

@@ -5,6 +5,7 @@ from celery._state import _task_stack
 from celery.canvas import (
     Signature,
     chain,
+    _chain,
     group,
     chord,
     signature,
@@ -147,16 +148,16 @@ class test_Signature(CanvasCase):
 
     def test_OR(self):
         x = self.add.s(2, 2) | self.mul.s(4)
-        assert isinstance(x, chain)
+        assert isinstance(x, _chain)
         y = self.add.s(4, 4) | self.div.s(2)
         z = x | y
-        assert isinstance(y, chain)
-        assert isinstance(z, chain)
+        assert isinstance(y, _chain)
+        assert isinstance(z, _chain)
         assert len(z.tasks) == 4
         with pytest.raises(TypeError):
             x | 10
         ax = self.add.s(2, 2) | (self.add.s(4) | self.add.s(8))
-        assert isinstance(ax, chain)
+        assert isinstance(ax, _chain)
         assert len(ax.tasks), 3 == 'consolidates chain to chain'
 
     def test_INVERT(self):
@@ -329,9 +330,9 @@ class test_chain(CanvasCase):
             self.add.s(20),
             self.add.s(30)
         )
-        c._use_link = False
-        tasks, _ = c.prepare_steps((), c.tasks)
-        assert isinstance(tasks[-1], chord)
+        assert isinstance(c, chord)
+        assert isinstance(c.body, _chain)
+        assert len(c.body.tasks) == 3
 
         c2 = self.add.s(2, 2) | group(self.add.s(i, i) for i in range(10))
         c2._use_link = False
@@ -367,8 +368,8 @@ class test_chain(CanvasCase):
 
     def test_reverse(self):
         x = self.add.s(2, 2) | self.add.s(2)
-        assert isinstance(signature(x), chain)
-        assert isinstance(signature(dict(x)), chain)
+        assert isinstance(signature(x), _chain)
+        assert isinstance(signature(dict(x)), _chain)
 
     def test_always_eager(self):
         self.app.conf.task_always_eager = True
@@ -401,9 +402,9 @@ class test_chain(CanvasCase):
     def test_from_dict_no_args__with_args(self):
         x = dict(self.add.s(2, 2) | self.add.s(4))
         x['args'] = None
-        assert isinstance(chain.from_dict(x), chain)
+        assert isinstance(chain.from_dict(x), _chain)
         x['args'] = (2,)
-        assert isinstance(chain.from_dict(x), chain)
+        assert isinstance(chain.from_dict(x), _chain)
 
     def test_accepts_generator_argument(self):
         x = chain(self.add.s(i) for i in range(10))