Преглед изворни кода

chain.clone lost tasks when argument to chain is generator

Ask Solem пре 8 година
родитељ
комит
d7dad441ca
2 измењених фајлова са 12 додато и 2 уклоњено
  1. 6 2
      celery/canvas.py
  2. 6 0
      t/unit/tasks/test_canvas.py

+ 6 - 2
celery/canvas.py

@@ -748,7 +748,7 @@ class chain(Signature):
                 tasks = d['kwargs']['tasks'] = list(tasks)
             # First task must be signature object to get app
             tasks[0] = maybe_signature(tasks[0], app=app)
-        return _upgrade(d, chain(*tasks, app=app, **d['options']))
+        return _upgrade(d, chain(tasks, app=app, **d['options']))
 
     @property
     def app(self):
@@ -761,6 +761,9 @@ class chain(Signature):
         return app or current_app
 
     def __repr__(self):
+        if not self.tasks:
+            return '<{0}@{1:#x}: empty>'.format(
+                type(self).__name__, id(self))
         return ' | '.join(repr(t) for t in self.tasks)
 
 
@@ -865,7 +868,7 @@ def _maybe_group(tasks, app):
     if isinstance(tasks, dict):
         tasks = signature(tasks, app=app)
 
-    if isinstance(tasks, group):
+    if isinstance(tasks, (group, chain)):
         tasks = tasks.tasks
     elif isinstance(tasks, abstract.CallableSignature):
         tasks = [tasks]
@@ -1199,6 +1202,7 @@ class chord(Signature):
     def apply_async(self, args=(), kwargs={}, task_id=None,
                     producer=None, publisher=None, connection=None,
                     router=None, result_cls=None, **options):
+        kwargs = kwargs or {}
         args = (tuple(args) + tuple(self.args)
                 if args and not self.immutable else self.args)
         body = kwargs.pop('body', None) or self.kwargs['body']

+ 6 - 0
t/unit/tasks/test_canvas.py

@@ -232,6 +232,12 @@ class test_chunks(CanvasCase):
 
 class test_chain(CanvasCase):
 
+    def test_clone_preserves_state(self):
+        x = chain(self.add.s(i, i) for i in range(10))
+        assert x.clone().tasks == x.tasks
+        assert x.clone().kwargs == x.kwargs
+        assert x.clone().args == x.args
+
     def test_repr(self):
         x = self.add.s(2, 2) | self.add.s(2)
         assert repr(x) == '%s(2, 2) | %s(2)' % (self.add.name, self.add.name)