Przeglądaj źródła

Adds Task.add_to_chord and Task.replace_in_chord

Ask Solem 10 lat temu
rodzic
commit
133f2a1aec

+ 28 - 1
celery/app/task.py

@@ -16,7 +16,7 @@ from celery import current_app
 from celery import states
 from celery._state import _task_stack
 from celery.canvas import signature
-from celery.exceptions import MaxRetriesExceededError, Reject, Retry
+from celery.exceptions import Ignore, MaxRetriesExceededError, Reject, Retry
 from celery.five import class_property, items
 from celery.result import EagerResult
 from celery.utils import uuid, maybe_reraise
@@ -686,6 +686,33 @@ class Task(object):
         with self.app.events.default_dispatcher(hostname=req.hostname) as d:
             return d.send(type_, uuid=req.id, **fields)
 
+    def replace_in_chord(self, sig):
+        sig.freeze(self.request.id,
+                   group_id=self.request.group,
+                   chord=self.request.chord,
+                   root_id=self.request.root_id)
+        sig.delay()
+        raise Ignore('Chord member replaced by new task')
+
+    def add_to_chord(self, sig, lazy=False):
+        """Add signature to the chord the current task is a member of.
+
+        :param sig: Signature to extend chord with.
+        :param lazy: If enabled the new task will not actually be called,
+                      and ``sig.delay()`` must be called manually.
+
+        Currently only supported by the Redis result backend when
+        ``?new_join=1`` is enabled.
+
+        """
+        if not self.request.chord:
+            raise ValueError('Current task is not member of any chord')
+        result = sig.freeze(group_id=self.request.group,
+                            chord=self.request.chord,
+                            root_id=self.request.root_id)
+        self.backend.add_to_chord(self.request.group, result)
+        return sig.delay() if not lazy else sig
+
     def update_state(self, task_id=None, state=None, meta=None):
         """Update task state.
 

+ 3 - 0
celery/backends/base.py

@@ -335,6 +335,9 @@ class BaseBackend(object):
     def on_task_call(self, producer, task_id):
         return {}
 
+    def add_to_chord(self, chord_id, result):
+        raise NotImplementedError('Backend does not support add_to_chord')
+
     def on_chord_part_return(self, task, state, result, propagate=False):
         pass
 

+ 12 - 3
celery/backends/redis.py

@@ -177,6 +177,9 @@ class RedisBackend(KeyValueStoreBackend):
     def expire(self, key, value):
         return self.client.expire(key, value)
 
+    def add_to_chord(self, group_id, result):
+        self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
+
     def _unpack_chord_result(self, tup, decode,
                              PROPAGATE_STATES=states.PROPAGATE_STATES):
         _, tid, state, retval = decode(tup)
@@ -201,21 +204,27 @@ class RedisBackend(KeyValueStoreBackend):
 
         client = self.client
         jkey = self.get_key_for_group(gid, '.j')
+        tkey = self.get_key_for_group(gid, '.t')
         result = self.encode_result(result, state)
-        _, readycount, _ = client.pipeline()                            \
+        _, readycount, totaldiff, _, _ = client.pipeline()              \
             .rpush(jkey, self.encode([1, tid, state, result]))          \
             .llen(jkey)                                                 \
+            .get(tkey)                                                  \
             .expire(jkey, 86400)                                        \
+            .expire(tkey, 86400)                                        \
             .execute()
 
+        totaldiff = int(totaldiff or 0)
+
         try:
             callback = maybe_signature(request.chord, app=app)
-            total = callback['chord_size']
+            total = callback['chord_size'] + totaldiff
             if readycount >= total:
                 decode, unpack = self.decode, self._unpack_chord_result
-                resl, _ = client.pipeline()     \
+                resl, _, _ = client.pipeline()  \
                     .lrange(jkey, 0, total)     \
                     .delete(jkey)               \
+                    .delete(tkey)               \
                     .execute()
                 try:
                     callback.delay([unpack(tup, decode) for tup in resl])

+ 5 - 0
celery/tests/backends/test_base.py

@@ -188,6 +188,11 @@ class test_BaseBackend_dict(AppCase):
         b.save_group('foofoo', 'xxx')
         b._save_group.assert_called_with('foofoo', 'xxx')
 
+    def test_add_to_chord_interface(self):
+        b = BaseBackend(self.app)
+        with self.assertRaises(NotImplementedError):
+            b.add_to_chord('group_id', 'sig')
+
     def test_forget_interface(self):
         b = BaseBackend(self.app)
         with self.assertRaises(NotImplementedError):

+ 14 - 4
celery/tests/backends/test_redis.py

@@ -12,7 +12,8 @@ from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 
 from celery.tests.case import (
-    AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
+    AppCase, Mock, MockCallbacks, SkipTest,
+    call, depends_on_current_app, patch,
 )
 
 
@@ -194,6 +195,12 @@ class test_RedisBackend(AppCase):
         self.assertEqual(b.on_chord_part_return, b._new_chord_return)
         self.assertEqual(b.apply_chord, b._new_chord_apply)
 
+    def test_add_to_chord(self):
+        b = self.Backend('redis://?new_join=True', app=self.app)
+        gid = uuid()
+        b.add_to_chord(gid, 'sig')
+        b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
+
     def test_default_is_old_join(self):
         b = self.Backend(app=self.app)
         self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
@@ -250,9 +257,12 @@ class test_RedisBackend(AppCase):
             self.assertTrue(b.client.rpush.call_count)
             b.client.rpush.reset_mock()
         self.assertTrue(b.client.lrange.call_count)
-        gkey = b.get_key_for_group('group_id', '.j')
-        b.client.delete.assert_called_with(gkey)
-        b.client.expire.assert_called_witeh(gkey, 86400)
+        jkey = b.get_key_for_group('group_id', '.j')
+        tkey = b.get_key_for_group('group_id', '.t')
+        b.client.delete.assert_has_calls([call(jkey), call(tkey)])
+        b.client.expire.assert_has_calls([
+            call(jkey, 86400), call(tkey, 86400),
+        ])
 
     def test_process_cleanup(self):
         self.Backend(app=self.app, new_join=True).process_cleanup()

+ 52 - 1
celery/tests/tasks/test_chord.py

@@ -2,7 +2,7 @@ from __future__ import absolute_import
 
 from contextlib import contextmanager
 
-from celery import group
+from celery import group, uuid
 from celery import canvas
 from celery import result
 from celery.exceptions import ChordError, Retry
@@ -219,6 +219,57 @@ class test_chord(ChordCase):
             chord.run = prev
 
 
+class test_add_to_chord(AppCase):
+
+    def setup(self):
+
+        @self.app.task(shared=False)
+        def add(x, y):
+            return x + y
+        self.add = add
+
+        @self.app.task(shared=False, bind=True)
+        def adds(self, sig, lazy=False):
+            return self.add_to_chord(sig, lazy)
+        self.adds = adds
+
+    def test_add_to_chord(self):
+        self.app.backend = Mock(name='backend')
+
+        sig = self.add.s(2, 2)
+        sig.delay = Mock(name='sig.delay')
+        self.adds.request.group = uuid()
+        self.adds.request.id = uuid()
+
+        with self.assertRaises(ValueError):
+            # task not part of chord
+            self.adds.run(sig)
+        self.adds.request.chord = self.add.s()
+
+        res1 = self.adds.run(sig, True)
+        self.assertEqual(res1, sig)
+        self.assertTrue(sig.options['task_id'])
+        self.assertEqual(sig.options['group_id'], self.adds.request.group)
+        self.assertEqual(sig.options['chord'], self.adds.request.chord)
+        self.assertFalse(sig.delay.called)
+        self.app.backend.add_to_chord.assert_called_with(
+            self.adds.request.group, sig.freeze(),
+        )
+
+        self.app.backend.reset_mock()
+        sig2 = self.add.s(4, 4)
+        sig2.delay = Mock(name='sig2.delay')
+        res2 = self.adds.run(sig2)
+        self.assertEqual(res2, sig2.delay.return_value)
+        self.assertTrue(sig2.options['task_id'])
+        self.assertEqual(sig2.options['group_id'], self.adds.request.group)
+        self.assertEqual(sig2.options['chord'], self.adds.request.chord)
+        sig2.delay.assert_called_with()
+        self.app.backend.add_to_chord.assert_called_with(
+            self.adds.request.group, sig2.freeze(),
+        )
+
+
 class test_Chord_task(ChordCase):
 
     def test_run(self):

+ 11 - 0
funtests/stress/stress/app.py

@@ -121,6 +121,17 @@ def segfault():
     assert False, 'should not get here'
 
 
+@app.task(bind=True)
+def chord_adds(self, x):
+    self.add_to_chord(add.s(x, x))
+    return 42
+
+
+@app.task(bind=True)
+def chord_replace(self, x):
+    return self.replace_in_chord(add.s(x, x))
+
+
 @app.task
 def raising(exc=KeyError()):
     raise exc

+ 1 - 1
funtests/stress/stress/templates.py

@@ -91,7 +91,7 @@ class redis(default):
 
 @template()
 class redistore(default):
-    CELERY_RESULT_BACKEND = 'redis://'
+    CELERY_RESULT_BACKEND = 'redis://?new_join=1'
 
 
 @template()