Browse Source

Adds Task.add_to_chord and Task.replace_in_chord

Ask Solem 10 years ago
parent
commit
133f2a1aec

+ 28 - 1
celery/app/task.py

@@ -16,7 +16,7 @@ from celery import current_app
 from celery import states
 from celery import states
 from celery._state import _task_stack
 from celery._state import _task_stack
 from celery.canvas import signature
 from celery.canvas import signature
-from celery.exceptions import MaxRetriesExceededError, Reject, Retry
+from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.five import class_property, items
 from celery.five import class_property, items
 from celery.result import EagerResult
 from celery.result import EagerResult
 from celery.utils import uuid, maybe_reraise
 from celery.utils import uuid, maybe_reraise
@@ -686,6 +686,33 @@ class Task(object):
         with self.app.events.default_dispatcher(hostname=req.hostname) as d:
         with self.app.events.default_dispatcher(hostname=req.hostname) as d:
             return d.send(type_, uuid=req.id, **fields)
             return d.send(type_, uuid=req.id, **fields)
 
 
+    def replace_in_chord(self, sig):
+        sig.freeze(self.request.id,
+                   group_id=self.request.group,
+                   chord=self.request.chord,
+                   root_id=self.request.root_id)
+        sig.delay()
+        raise Ignore('Chord member replaced by new task')
+
+    def add_to_chord(self, sig, lazy=False):
+        """Add signature to the chord the current task is a member of.
+
+        :param sig: Signature to extend chord with.
+        :param lazy: If enabled the new task will not actually be called,
+                      and ``sig.delay()`` must be called manually.
+
+        Currently only supported by the Redis result backend when
+        ``?new_join=1`` is enabled.
+
+        """
+        if not self.request.chord:
+            raise ValueError('Current task is not member of any chord')
+        result = sig.freeze(group_id=self.request.group,
+                            chord=self.request.chord,
+                            root_id=self.request.root_id)
+        self.backend.add_to_chord(self.request.group, result)
+        return sig.delay() if not lazy else sig
+
     def update_state(self, task_id=None, state=None, meta=None):
     def update_state(self, task_id=None, state=None, meta=None):
         """Update task state.
         """Update task state.
 
 

+ 3 - 0
celery/backends/base.py

@@ -335,6 +335,9 @@ class BaseBackend(object):
     def on_task_call(self, producer, task_id):
     def on_task_call(self, producer, task_id):
         return {}
         return {}
 
 
+    def add_to_chord(self, chord_id, result):
+        raise NotImplementedError('Backend does not support add_to_chord')
+
     def on_chord_part_return(self, task, state, result, propagate=False):
     def on_chord_part_return(self, task, state, result, propagate=False):
         pass
         pass
 
 

+ 12 - 3
celery/backends/redis.py

@@ -177,6 +177,9 @@ class RedisBackend(KeyValueStoreBackend):
     def expire(self, key, value):
     def expire(self, key, value):
         return self.client.expire(key, value)
         return self.client.expire(key, value)
 
 
+    def add_to_chord(self, group_id, result):
+        self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
+
     def _unpack_chord_result(self, tup, decode,
     def _unpack_chord_result(self, tup, decode,
                              PROPAGATE_STATES=states.PROPAGATE_STATES):
                              PROPAGATE_STATES=states.PROPAGATE_STATES):
         _, tid, state, retval = decode(tup)
         _, tid, state, retval = decode(tup)
@@ -201,21 +204,27 @@ class RedisBackend(KeyValueStoreBackend):
 
 
         client = self.client
         client = self.client
         jkey = self.get_key_for_group(gid, '.j')
         jkey = self.get_key_for_group(gid, '.j')
+        tkey = self.get_key_for_group(gid, '.t')
         result = self.encode_result(result, state)
         result = self.encode_result(result, state)
-        _, readycount, _ = client.pipeline()                            \
+        _, readycount, totaldiff, _, _ = client.pipeline()              \
             .rpush(jkey, self.encode([1, tid, state, result]))          \
             .rpush(jkey, self.encode([1, tid, state, result]))          \
             .llen(jkey)                                                 \
             .llen(jkey)                                                 \
+            .get(tkey)                                                  \
             .expire(jkey, 86400)                                        \
             .expire(jkey, 86400)                                        \
+            .expire(tkey, 86400)                                        \
             .execute()
             .execute()
 
 
+        totaldiff = int(totaldiff or 0)
+
         try:
         try:
             callback = maybe_signature(request.chord, app=app)
             callback = maybe_signature(request.chord, app=app)
-            total = callback['chord_size']
+            total = callback['chord_size'] + totaldiff
             if readycount >= total:
             if readycount >= total:
                 decode, unpack = self.decode, self._unpack_chord_result
                 decode, unpack = self.decode, self._unpack_chord_result
-                resl, _ = client.pipeline()     \
+                resl, _, _ = client.pipeline()  \
                     .lrange(jkey, 0, total)     \
                     .lrange(jkey, 0, total)     \
                     .delete(jkey)               \
                     .delete(jkey)               \
+                    .delete(tkey)               \
                     .execute()
                     .execute()
                 try:
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])
                     callback.delay([unpack(tup, decode) for tup in resl])

+ 5 - 0
celery/tests/backends/test_base.py

@@ -188,6 +188,11 @@ class test_BaseBackend_dict(AppCase):
         b.save_group('foofoo', 'xxx')
         b.save_group('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
 
 
+    def test_add_to_chord_interface(self):
+        b = BaseBackend(self.app)
+        with self.assertRaises(NotImplementedError):
+            b.add_to_chord('group_id', 'sig')
+
     def test_forget_interface(self):
     def test_forget_interface(self):
         b = BaseBackend(self.app)
         b = BaseBackend(self.app)
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):

+ 14 - 4
celery/tests/backends/test_redis.py

@@ -12,7 +12,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
 from celery.tests.case import (
 from celery.tests.case import (
-    AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
+    AppCase, Mock, MockCallbacks, SkipTest,
+    call, depends_on_current_app, patch,
 )
 )
 
 
 
 
@@ -194,6 +195,12 @@ class test_RedisBackend(AppCase):
         self.assertEqual(b.on_chord_part_return, b._new_chord_return)
         self.assertEqual(b.on_chord_part_return, b._new_chord_return)
         self.assertEqual(b.apply_chord, b._new_chord_apply)
         self.assertEqual(b.apply_chord, b._new_chord_apply)
 
 
+    def test_add_to_chord(self):
+        b = self.Backend('redis://?new_join=True', app=self.app)
+        gid = uuid()
+        b.add_to_chord(gid, 'sig')
+        b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
+
     def test_default_is_old_join(self):
     def test_default_is_old_join(self):
         b = self.Backend(app=self.app)
         b = self.Backend(app=self.app)
         self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
         self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
@@ -250,9 +257,12 @@ class test_RedisBackend(AppCase):
             self.assertTrue(b.client.rpush.call_count)
             self.assertTrue(b.client.rpush.call_count)
             b.client.rpush.reset_mock()
             b.client.rpush.reset_mock()
         self.assertTrue(b.client.lrange.call_count)
         self.assertTrue(b.client.lrange.call_count)
-        gkey = b.get_key_for_group('group_id', '.j')
-        b.client.delete.assert_called_with(gkey)
-        b.client.expire.assert_called_witeh(gkey, 86400)
+        jkey = b.get_key_for_group('group_id', '.j')
+        tkey = b.get_key_for_group('group_id', '.t')
+        b.client.delete.assert_has_calls([call(jkey), call(tkey)])
+        b.client.expire.assert_has_calls([
+            call(jkey, 86400), call(tkey, 86400),
+        ])
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         self.Backend(app=self.app, new_join=True).process_cleanup()
         self.Backend(app=self.app, new_join=True).process_cleanup()

+ 52 - 1
celery/tests/tasks/test_chord.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
-from celery import group
+from celery import group, uuid
 from celery import canvas
 from celery import canvas
 from celery import result
 from celery import result
 from celery.exceptions import ChordError, Retry
 from celery.exceptions import ChordError, Retry
@@ -219,6 +219,57 @@ class test_chord(ChordCase):
             chord.run = prev
             chord.run = prev
 
 
 
 
+class test_add_to_chord(AppCase):
+
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
+        @self.app.task(shared=False, bind=True)
+        def adds(self, sig, lazy=False):
+            return self.add_to_chord(sig, lazy)
+        self.adds = adds
+
+    def test_add_to_chord(self):
+        self.app.backend = Mock(name='backend')
+
+        sig = self.add.s(2, 2)
+        sig.delay = Mock(name='sig.delay')
+        self.adds.request.group = uuid()
+        self.adds.request.id = uuid()
+
+        with self.assertRaises(ValueError):
+            # task not part of chord
+            self.adds.run(sig)
+        self.adds.request.chord = self.add.s()
+
+        res1 = self.adds.run(sig, True)
+        self.assertEqual(res1, sig)
+        self.assertTrue(sig.options['task_id'])
+        self.assertEqual(sig.options['group_id'], self.adds.request.group)
+        self.assertEqual(sig.options['chord'], self.adds.request.chord)
+        self.assertFalse(sig.delay.called)
+        self.app.backend.add_to_chord.assert_called_with(
+            self.adds.request.group, sig.freeze(),
+        )
+
+        self.app.backend.reset_mock()
+        sig2 = self.add.s(4, 4)
+        sig2.delay = Mock(name='sig2.delay')
+        res2 = self.adds.run(sig2)
+        self.assertEqual(res2, sig2.delay.return_value)
+        self.assertTrue(sig2.options['task_id'])
+        self.assertEqual(sig2.options['group_id'], self.adds.request.group)
+        self.assertEqual(sig2.options['chord'], self.adds.request.chord)
+        sig2.delay.assert_called_with()
+        self.app.backend.add_to_chord.assert_called_with(
+            self.adds.request.group, sig2.freeze(),
+        )
+
+
 class test_Chord_task(ChordCase):
 class test_Chord_task(ChordCase):
 
 
     def test_run(self):
     def test_run(self):

+ 11 - 0
funtests/stress/stress/app.py

@@ -121,6 +121,17 @@ def segfault():
     assert False, 'should not get here'
     assert False, 'should not get here'
 
 
 
 
+@app.task(bind=True)
+def chord_adds(self, x):
+    self.add_to_chord(add.s(x, x))
+    return 42
+
+
+@app.task(bind=True)
+def chord_replace(self, x):
+    return self.replace_in_chord(add.s(x, x))
+
+
 @app.task
 @app.task
 def raising(exc=KeyError()):
 def raising(exc=KeyError()):
     raise exc
     raise exc

+ 1 - 1
funtests/stress/stress/templates.py

@@ -91,7 +91,7 @@ class redis(default):
 
 
 @template()
 @template()
 class redistore(default):
 class redistore(default):
-    CELERY_RESULT_BACKEND = 'redis://'
+    CELERY_RESULT_BACKEND = 'redis://?new_join=1'
 
 
 
 
 @template()
 @template()