Browse Source

passing callbacks to accumulate when replacing tasks #2722

nicolasunravel 8 years ago
parent
commit
e89f4ca864
2 changed files with 45 additions and 5 deletions
  1. 2 0
      celery/app/task.py
  2. 43 5
      celery/tests/tasks/test_tasks.py

+ 2 - 0
celery/app/task.py

@@ -777,6 +777,8 @@ class Task(object):
         if isinstance(sig, group):
             sig |= self.app.tasks['celery.accumulate'].s(index=0).set(
                 chord=chord,
+                link=self.request.callbacks,
+                link_error=self.request.errbacks,
             )
             chord = None
         sig.freeze(self.request.id,

+ 43 - 5
celery/tests/tasks/test_tasks.py

@@ -48,11 +48,13 @@ class TasksCase(AppCase):
         def increment_counter(self, increment_by=1):
             self.count += increment_by or 1
             return self.count
+
         self.increment_counter = increment_counter
 
         @self.app.task(shared=False)
         def raising():
             raise KeyError('foo')
+
         self.raising = raising
 
         @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
@@ -66,6 +68,7 @@ class TasksCase(AppCase):
                 return arg1
             else:
                 raise self.retry(countdown=0, max_retries=rmax)
+
         self.retry_task = retry_task
 
         @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
@@ -76,6 +79,7 @@ class TasksCase(AppCase):
                 return 42
             else:
                 raise self.retry(countdown=0)
+
         self.retry_task_noargs = retry_task_noargs
 
         @self.app.task(bind=True, max_retries=3, iterations=0,
@@ -87,6 +91,7 @@ class TasksCase(AppCase):
             if retries >= 3:
                 return arg1
             raise self.retry(countdown=0)
+
         self.retry_task_mockapply = retry_task_mockapply
 
         @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
@@ -102,20 +107,23 @@ class TasksCase(AppCase):
                 except MyCustomException as exc:
                     kwargs.update(kwarg=kwarg)
                     raise self.retry(countdown=0, exc=exc)
+
         self.retry_task_customexc = retry_task_customexc
 
         @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
                        shared=False)
         def autoretry_task_no_kwargs(self, a, b):
             self.iterations += 1
-            return a/b
+            return a / b
+
         self.autoretry_task_no_kwargs = autoretry_task_no_kwargs
 
         @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
                        retry_kwargs={'max_retries': 5}, shared=False)
         def autoretry_task(self, a, b):
             self.iterations += 1
-            return a/b
+            return a / b
+
         self.autoretry_task = autoretry_task
 
 
@@ -271,13 +279,16 @@ class test_tasks(TasksCase):
         @self.app.task(shared=True)
         def xxx():
             pass
+
         self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
 
     @patch('celery.app.task.current_app')
     @depends_on_current_app
     def test_bind__no_app(self, current_app):
+
         class XTask(Task):
             _app = None
+
         XTask._app = None
         XTask.__bound__ = False
         XTask.bind = Mock(name='bind')
@@ -400,15 +411,41 @@ class test_tasks(TasksCase):
         mytask.request.id = 'fb'
         mytask.send_event('task-foo', id=3122)
         mytask.app.events.default_dispatcher().send.assert_called_with(
-            'task-foo', uuid='fb', id=3122,
-        )
+            'task-foo', uuid='fb', id=3122)
 
     def test_replace(self):
         sig1 = Mock(name='sig1')
         with self.assertRaises(Ignore):
             self.mytask.replace(sig1)
 
-    def test_replace__group(self):
+    def test_replace_callback(self):
+        c = group([self.mytask.s()], app=self.app)
+        c.freeze = Mock(name='freeze')
+        c.delay = Mock(name='delay')
+        self.mytask.request.id = 'id'
+        self.mytask.request.group = 'group'
+        self.mytask.request.root_id = 'root_id'
+        self.mytask.request.callbacks = 'callbacks'
+
+        class TaskMock(Mock):
+            def __json__(self):
+                return "whatever"
+
+        mocked_signature = TaskMock(name='s',
+                                    # side_effect=mocked_s
+                                    )
+        accumulate_mock = TaskMock(name='accumulate',
+                                   s=mocked_signature)
+        self.mytask.app.tasks['celery.accumulate'] = accumulate_mock
+
+        try:
+            self.mytask.replace(c)
+        except Ignore:
+            mocked_signature.return_value.set. \
+                assert_called_with(chord=None,
+                                   link="callbacks")
+
+    def test_replace_group(self):
         c = group([self.mytask.s()], app=self.app)
         c.freeze = Mock(name='freeze')
         c.delay = Mock(name='delay')
@@ -466,6 +503,7 @@ class test_tasks(TasksCase):
             @self.app.task(shared=False)
             def task():
                 pass
+
             task.annotate()
             self.assertEqual(task.FOO, 'BAR')