ソースを参照

Redis backend chord optimization: Avoid save group at apply and pipeline + O(1) join.

This change is backward incompatible and so is not enabled by default.

To enable this optimization you have to set the `new_join` option
and it must be enabled by all clients and workers part of the chord::

    redis://?new_join=1
Ask Solem 11 年 前
コミット
f09b0413aa

+ 1 - 0
celery/app/builtins.py

@@ -352,6 +352,7 @@ def add_chord_task(app):
             if eager:
                 return header.apply(args=partial_args, task_id=group_id)
 
+            body.setdefault('chord_size', len(header.tasks))
             results = [AsyncResult(prepare_member(task, body, group_id))
                        for task in header.tasks]
 

+ 1 - 1
celery/app/trace.py

@@ -272,7 +272,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                 # -* POST *-
                 if state not in IGNORE_STATES:
                     if task_request.chord:
-                        on_chord_part_return(task)
+                        on_chord_part_return(task, state, R)
                     if task_after_return:
                         task_after_return(
                             state, retval, uuid, args, kwargs, None,

+ 17 - 8
celery/backends/base.py

@@ -311,7 +311,7 @@ class BaseBackend(object):
     def on_task_call(self, producer, task_id):
         return {}
 
-    def on_chord_part_return(self, task, propagate=False):
+    def on_chord_part_return(self, task, state, result, propagate=False):
         pass
 
     def fallback_chord_unlock(self, group_id, body, result=None,
@@ -374,17 +374,26 @@ class KeyValueStoreBackend(BaseBackend):
     def expire(self, key, value):
         pass
 
-    def get_key_for_task(self, task_id):
+    def get_key_for_task(self, task_id, key=''):
         """Get the cache key for a task by id."""
-        return self.task_keyprefix + self.key_t(task_id)
+        key_t = self.key_t
+        return ''.join([
+            self.task_keyprefix, key_t(task_id), key_t(key),
+        ])
 
-    def get_key_for_group(self, group_id):
+    def get_key_for_group(self, group_id, key=''):
         """Get the cache key for a group by id."""
-        return self.group_keyprefix + self.key_t(group_id)
+        key_t = self.key_t
+        return ''.join([
+            self.group_keyprefix, key_t(group_id), key_t(key),
+        ])
 
-    def get_key_for_chord(self, group_id):
+    def get_key_for_chord(self, group_id, key=''):
         """Get the cache key for the chord waiting on group with given id."""
-        return self.chord_keyprefix + self.key_t(group_id)
+        key_t = self.key_t
+        return ''.join([
+            self.chord_keyprefix, key_t(group_id), key_t(key),
+        ])
 
     def _strip_prefix(self, key):
         """Takes bytes, emits string."""
@@ -479,7 +488,7 @@ class KeyValueStoreBackend(BaseBackend):
         self.save_group(group_id, self.app.GroupResult(group_id, result))
         return header(*partial_args, task_id=group_id)
 
-    def on_chord_part_return(self, task, propagate=None):
+    def on_chord_part_return(self, task, state, result, propagate=None):
         if not self.implements_incr:
             return
         app = self.app

+ 69 - 3
celery/backends/redis.py

@@ -13,9 +13,11 @@ from functools import partial
 from kombu.utils import cached_property, retry_over_time
 from kombu.utils.url import _parse_url
 
-from celery.exceptions import ImproperlyConfigured
+from celery import states
+from celery.canvas import maybe_signature
+from celery.exceptions import ChordError, ImproperlyConfigured
 from celery.five import string_t
-from celery.utils import deprecated_property
+from celery.utils import deprecated_property, strtobool
 from celery.utils.functional import dictfilter
 from celery.utils.log import get_logger
 from celery.utils.timeutils import humanize_seconds
@@ -56,7 +58,7 @@ class RedisBackend(KeyValueStoreBackend):
 
     def __init__(self, host=None, port=None, db=None, password=None,
                  expires=None, max_connections=None, url=None,
-                 connection_pool=None, **kwargs):
+                 connection_pool=None, new_join=False, **kwargs):
         super(RedisBackend, self).__init__(**kwargs)
         conf = self.app.conf
         if self.redis is None:
@@ -90,6 +92,14 @@ class RedisBackend(KeyValueStoreBackend):
         self.url = url
         self.expires = self.prepare_expires(expires, type=int)
 
+        try:
+            new_join = strtobool(self.connparams.pop('new_join'))
+        except KeyError:
+            pass
+        if new_join:
+            self.apply_chord = self._new_chord_apply
+            self.on_chord_part_return = self._new_chord_return
+
         self.connection_errors, self.channel_errors = get_redis_error_classes()
 
     def _params_from_url(self, url, defaults):
@@ -165,6 +175,62 @@ class RedisBackend(KeyValueStoreBackend):
     def expire(self, key, value):
         return self.client.expire(key, value)
 
+    def _unpack_chord_result(self, tup, decode,
+                             PROPAGATE_STATES=states.PROPAGATE_STATES):
+        _, tid, state, retval = decode(tup)
+        if state in PROPAGATE_STATES:
+            raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
+        return retval
+
+    def _new_chord_apply(self, header, partial_args, group_id, body,
+                         result=None, **options):
+        # avoids saving the group in the redis db.
+        return header(*partial_args, task_id=group_id)
+
+    def _new_chord_return(self, task, state, result, propagate=None,
+                          PROPAGATE_STATES=states.PROPAGATE_STATES):
+        app = self.app
+        if propagate is None:
+            propagate = self.app.conf.CELERY_CHORD_PROPAGATES
+        request = task.request
+        tid, gid = request.id, request.group
+        if not gid or not tid:
+            return
+
+        client = self.client
+        jkey = self.get_key_for_group(gid, '.j')
+        result = self.encode_result(result, state)
+        _, readycount, _ = client.pipeline()                            \
+            .rpush(jkey, self.encode([1, tid, state, result]))          \
+            .llen(jkey)                                                 \
+            .expire(jkey, 86400)                                        \
+            .execute()
+
+        try:
+            callback = maybe_signature(request.chord, app=app)
+            total = callback['chord_size']
+            if readycount >= total:
+                decode, unpack = self.decode, self._unpack_chord_result
+                resl, _ = client.pipeline()     \
+                    .lrange(jkey, 0, total)     \
+                    .delete(jkey)               \
+                    .execute()
+                try:
+                    callback.delay([unpack(tup, decode) for tup in resl])
+                except Exception as exc:
+                    app._tasks[callback.task].backend.fail_from_current_stack(
+                        callback.id,
+                        exc=ChordError('Callback error: {0!r}'.format(exc)),
+                    )
+        except ChordError as exc:
+            app._tasks[callback.task].backend.fail_from_current_stack(
+                callback.id, exc=exc,
+            )
+        except Exception as exc:
+            app._tasks[callback.task].backend.fail_from_current_stack(
+                callback.id, exc=ChordError('Join error: {0!r}').format(exc),
+            )
+
     @property
     def ConnectionPool(self):
         if self._ConnectionPool is None:

+ 1 - 1
celery/tests/app/test_app.py

@@ -644,7 +644,7 @@ class test_App(AppCase):
 
 class test_defaults(AppCase):
 
-    def test_str_to_bool(self):
+    def test_strtobool(self):
         for s in ('false', 'no', '0'):
             self.assertFalse(defaults.strtobool(s))
         for s in ('true', 'yes', '1'):

+ 9 - 7
celery/tests/backends/test_base.py

@@ -62,7 +62,7 @@ class test_BaseBackend_interface(AppCase):
             self.b.forget('SOMExx-N0nex1stant-IDxx-')
 
     def test_on_chord_part_return(self):
-        self.b.on_chord_part_return(None)
+        self.b.on_chord_part_return(None, None, None)
 
     def test_apply_chord(self, unlock='celery.chord_unlock'):
         self.app.tasks[unlock] = Mock()
@@ -246,7 +246,7 @@ class test_KeyValueStoreBackend(AppCase):
 
     def test_on_chord_part_return(self):
         assert not self.b.implements_incr
-        self.b.on_chord_part_return(None)
+        self.b.on_chord_part_return(None, None, None)
 
     def test_get_store_delete_result(self):
         tid = uuid()
@@ -282,12 +282,14 @@ class test_KeyValueStoreBackend(AppCase):
     def test_chord_part_return_no_gid(self):
         self.b.implements_incr = True
         task = Mock()
+        state = 'SUCCESS'
+        result = 10
         task.request.group = None
         self.b.get_key_for_chord = Mock()
         self.b.get_key_for_chord.side_effect = AssertionError(
             'should not get here',
         )
-        self.assertIsNone(self.b.on_chord_part_return(task))
+        self.assertIsNone(self.b.on_chord_part_return(task, state, result))
 
     @contextmanager
     def _chord_part_context(self, b):
@@ -315,14 +317,14 @@ class test_KeyValueStoreBackend(AppCase):
 
     def test_chord_part_return_propagate_set(self):
         with self._chord_part_context(self.b) as (task, deps, _):
-            self.b.on_chord_part_return(task, propagate=True)
+            self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=True)
             self.assertFalse(self.b.expire.called)
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(propagate=True, timeout=3.0)
 
     def test_chord_part_return_propagate_default(self):
         with self._chord_part_context(self.b) as (task, deps, _):
-            self.b.on_chord_part_return(task, propagate=None)
+            self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=None)
             self.assertFalse(self.b.expire.called)
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(
@@ -334,7 +336,7 @@ class test_KeyValueStoreBackend(AppCase):
         with self._chord_part_context(self.b) as (task, deps, callback):
             deps._failed_join_report = lambda: iter([])
             deps.join_native.side_effect = KeyError('foo')
-            self.b.on_chord_part_return(task)
+            self.b.on_chord_part_return(task, 'SUCCESS', 10)
             self.assertTrue(self.b.fail_from_current_stack.called)
             args = self.b.fail_from_current_stack.call_args
             exc = args[1]['exc']
@@ -348,7 +350,7 @@ class test_KeyValueStoreBackend(AppCase):
                 self.app.AsyncResult('culprit'),
             ])
             deps.join_native.side_effect = KeyError('foo')
-            b.on_chord_part_return(task)
+            b.on_chord_part_return(task, 'SUCCESS', 10)
             self.assertTrue(b.fail_from_current_stack.called)
             args = b.fail_from_current_stack.call_args
             exc = args[1]['exc']

+ 2 - 2
celery/tests/backends/test_cache.py

@@ -86,10 +86,10 @@ class test_CacheBackend(AppCase):
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
         self.assertFalse(deps.join_native.called)
-        tb.on_chord_part_return(task)
+        tb.on_chord_part_return(task, 'SUCCESS', 10)
         self.assertFalse(deps.join_native.called)
 
-        tb.on_chord_part_return(task)
+        tb.on_chord_part_return(task, 'SUCCESS', 10)
         deps.join_native.assert_called_with(propagate=True, timeout=3.0)
         deps.delete.assert_called_with()
 

+ 86 - 74
celery/tests/backends/test_redis.py

@@ -4,52 +4,55 @@ from datetime import timedelta
 
 from pickle import loads, dumps
 
-from kombu.utils import cached_property, uuid
-
 from celery import signature
 from celery import states
 from celery import group
+from celery import uuid
 from celery.datastructures import AttributeDict
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.timeutils import timedelta_seconds
 
 from celery.tests.case import (
-    AppCase, Mock, SkipTest, depends_on_current_app, patch,
+    AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
 )
 
 
-class Redis(object):
+class Connection(object):
+    connected = True
+
+    def disconnect(self):
+        self.connected = False
 
-    class Connection(object):
-        connected = True
 
-        def disconnect(self):
-            self.connected = False
+class Pipeline(object):
 
-    class Pipeline(object):
+    def __init__(self, client):
+        self.client = client
+        self.steps = []
 
-        def __init__(self, client):
-            self.client = client
-            self.steps = []
+    def __getattr__(self, attr):
 
-        def __getattr__(self, attr):
+        def add_step(*args, **kwargs):
+            self.steps.append((getattr(self.client, attr), args, kwargs))
+            return self
+        return add_step
 
-            def add_step(*args, **kwargs):
-                self.steps.append((getattr(self.client, attr), args, kwargs))
-                return self
-            return add_step
+    def execute(self):
+        return [step(*a, **kw) for step, a, kw in self.steps]
 
-        def execute(self):
-            return [step(*a, **kw) for step, a, kw in self.steps]
+
+class Redis(MockCallbacks):
+    Connection = Connection
+    Pipeline = Pipeline
 
     def __init__(self, host=None, port=None, db=None, password=None, **kw):
         self.host = host
         self.port = port
         self.db = db
         self.password = password
-        self.connection = self.Connection()
         self.keyspace = {}
         self.expiry = {}
+        self.connection = self.Connection()
 
     def get(self, key):
         return self.keyspace.get(key)
@@ -63,12 +66,10 @@ class Redis(object):
 
     def expire(self, key, expires):
         self.expiry[key] = expires
+        return expires
 
     def delete(self, key):
-        self.keyspace.pop(key)
-
-    def publish(self, key, value):
-        pass
+        return bool(self.keyspace.pop(key, None))
 
     def pipeline(self):
         return self.Pipeline(self)
@@ -91,41 +92,34 @@ class redis(object):
 class test_RedisBackend(AppCase):
 
     def get_backend(self):
-        from celery.backends import redis
+        from celery.backends.redis import RedisBackend
 
-        class RedisBackend(redis.RedisBackend):
+        class _RedisBackend(RedisBackend):
             redis = redis
 
-        return RedisBackend
+        return _RedisBackend
 
     def setup(self):
         self.Backend = self.get_backend()
 
-        class MockBackend(self.Backend):
-
-            @cached_property
-            def client(self):
-                return Mock()
-
-        self.MockBackend = MockBackend
-
     @depends_on_current_app
     def test_reduce(self):
         try:
             from celery.backends.redis import RedisBackend
-            x = RedisBackend(app=self.app)
+            x = RedisBackend(app=self.app, new_join=True)
             self.assertTrue(loads(dumps(x)))
         except ImportError:
             raise SkipTest('redis not installed')
 
     def test_no_redis(self):
-        self.MockBackend.redis = None
+        self.Backend.redis = None
         with self.assertRaises(ImproperlyConfigured):
-            self.MockBackend(app=self.app)
+            self.Backend(app=self.app, new_join=True)
 
     def test_url(self):
-        x = self.MockBackend(
+        x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
+            new_join=True,
         )
         self.assertTrue(x.connparams)
         self.assertEqual(x.connparams['host'], 'vandelay.com')
@@ -134,8 +128,9 @@ class test_RedisBackend(AppCase):
         self.assertEqual(x.connparams['password'], 'bosco')
 
     def test_socket_url(self):
-        x = self.MockBackend(
+        x = self.Backend(
             'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
+            new_join=True,
         )
         self.assertTrue(x.connparams)
         self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
@@ -148,8 +143,9 @@ class test_RedisBackend(AppCase):
         self.assertEqual(x.connparams['db'], 3)
 
     def test_compat_propertie(self):
-        x = self.MockBackend(
+        x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
+            new_join=True,
         )
         with self.assertPendingDeprecation():
             self.assertEqual(x.host, 'vandelay.com')
@@ -167,71 +163,85 @@ class test_RedisBackend(AppCase):
             'CELERY_ACCEPT_CONTENT': ['json'],
             'CELERY_TASK_RESULT_EXPIRES': None,
         })
-        self.MockBackend(app=self.app)
+        self.Backend(app=self.app, new_join=True)
 
     def test_expires_defaults_to_config(self):
         self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
-        b = self.Backend(expires=None, app=self.app)
+        b = self.Backend(expires=None, app=self.app, new_join=True)
         self.assertEqual(b.expires, 10)
 
     def test_expires_is_int(self):
-        b = self.Backend(expires=48, app=self.app)
+        b = self.Backend(expires=48, app=self.app, new_join=True)
         self.assertEqual(b.expires, 48)
 
+    def test_set_new_join_from_url_query(self):
+        b = self.Backend('redis://?new_join=True;foobar=1', app=self.app)
+        self.assertEqual(b.on_chord_part_return, b._new_chord_return)
+        self.assertEqual(b.apply_chord, b._new_chord_apply)
+
+    def test_default_is_old_join(self):
+        b = self.Backend(app=self.app)
+        self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
+        self.assertNotEqual(b.apply_chord, b._new_chord_apply)
+
     def test_expires_is_None(self):
-        b = self.Backend(expires=None, app=self.app)
+        b = self.Backend(expires=None, app=self.app, new_join=True)
         self.assertEqual(b.expires, timedelta_seconds(
             self.app.conf.CELERY_TASK_RESULT_EXPIRES))
 
     def test_expires_is_timedelta(self):
-        b = self.Backend(expires=timedelta(minutes=1), app=self.app)
+        b = self.Backend(
+            expires=timedelta(minutes=1), app=self.app, new_join=1,
+        )
         self.assertEqual(b.expires, 60)
 
     def test_apply_chord(self):
-        self.Backend(app=self.app).apply_chord(
+        self.Backend(app=self.app, new_join=True).apply_chord(
             group(app=self.app), (), 'group_id', {},
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
 
     def test_mget(self):
-        b = self.MockBackend(app=self.app)
+        b = self.Backend(app=self.app, new_join=True)
         self.assertTrue(b.mget(['a', 'b', 'c']))
         b.client.mget.assert_called_with(['a', 'b', 'c'])
 
     def test_set_no_expire(self):
-        b = self.MockBackend(app=self.app)
+        b = self.Backend(app=self.app, new_join=True)
         b.expires = None
         b.set('foo', 'bar')
 
     @patch('celery.result.GroupResult.restore')
     def test_on_chord_part_return(self, restore):
-        b = self.MockBackend(app=self.app)
-        deps = Mock()
-        deps.__len__ = Mock()
-        deps.__len__.return_value = 10
-        restore.return_value = deps
-        b.client.incr.return_value = 1
-        task = Mock()
-        task.name = 'foobarbaz'
-        self.app.tasks['foobarbaz'] = task
-        task.request.chord = signature(task)
-        task.request.group = 'group_id'
-
-        b.on_chord_part_return(task)
-        self.assertTrue(b.client.incr.call_count)
-
-        b.client.incr.return_value = len(deps)
-        b.on_chord_part_return(task)
-        deps.join_native.assert_called_with(propagate=True, timeout=3.0)
-        deps.delete.assert_called_with()
-
-        self.assertTrue(b.client.expire.call_count)
+        b = self.Backend(app=self.app, new_join=True)
+
+        def create_task():
+            tid = uuid()
+            task = Mock(name='task-{0}'.format(tid))
+            task.name = 'foobarbaz'
+            self.app.tasks['foobarbaz'] = task
+            task.request.chord = signature(task)
+            task.request.id = tid
+            task.request.chord['chord_size'] = 10
+            task.request.group = 'group_id'
+            return task
+
+        tasks = [create_task() for i in range(10)]
+
+        for i in range(10):
+            b.on_chord_part_return(tasks[i], states.SUCCESS, i)
+            self.assertTrue(b.client.rpush.call_count)
+            b.client.rpush.reset_mock()
+        self.assertTrue(b.client.lrange.call_count)
+        gkey = b.get_key_for_group('group_id', '.j')
+        b.client.delete.assert_called_with(gkey)
+        b.client.expire.assert_called_witeh(gkey, 86400)
 
     def test_process_cleanup(self):
-        self.Backend(app=self.app).process_cleanup()
+        self.Backend(app=self.app, new_join=True).process_cleanup()
 
     def test_get_set_forget(self):
-        b = self.Backend(app=self.app)
+        b = self.Backend(app=self.app, new_join=True)
         tid = uuid()
         b.store_result(tid, 42, states.SUCCESS)
         self.assertEqual(b.get_status(tid), states.SUCCESS)
@@ -240,8 +250,10 @@ class test_RedisBackend(AppCase):
         self.assertEqual(b.get_status(tid), states.PENDING)
 
     def test_set_expires(self):
-        b = self.Backend(expires=512, app=self.app)
+        b = self.Backend(expires=512, app=self.app, new_join=True)
         tid = uuid()
         key = b.get_key_for_task(tid)
         b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(b.client.expiry[key], 512)
+        b.client.expire.assert_called_with(
+            key, 512,
+        )

+ 21 - 0
celery/tests/case.py

@@ -171,6 +171,27 @@ def ContextMock(*args, **kwargs):
     return obj
 
 
+def _bind(f, o):
+    @wraps(f)
+    def bound_meth(*fargs, **fkwargs):
+        return f(o, *fargs, **fkwargs)
+    return bound_meth
+
+
+class MockCallbacks(object):
+
+    def __new__(cls, *args, **kwargs):
+        r = Mock(name=cls.__name__)
+        cls.__init__.__func__(r, *args, **kwargs)
+        for key, value in items(vars(cls)):
+            if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
+                if inspect.ismethod(value) or inspect.isfunction(value):
+                    r.__getattr__(key).side_effect = _bind(value, r)
+                else:
+                    r.__setattr__(key, value)
+        return r
+
+
 def skip_unless_module(module):
 
     def _inner(fun):

+ 1 - 1
celery/tests/tasks/test_trace.py

@@ -101,7 +101,7 @@ class test_trace(TraceCase):
         add.backend = Mock()
 
         self.trace(add, (2, 2), {}, request={'chord': uuid()})
-        add.backend.on_chord_part_return.assert_called_with(add)
+        add.backend.on_chord_part_return.assert_called_with(add, 'SUCCESS', 4)
 
     def test_when_backend_cleanup_raises(self):
 

+ 3 - 1
funtests/stress/stress/templates.py

@@ -69,7 +69,9 @@ class default(object):
 @template()
 class redis(default):
     BROKER_URL = os.environ.get('CSTRESS_BROKER', 'redis://')
-    CELERY_RESULT_BACKEND = os.environ.get('CSTRESS_bACKEND', 'redis://')
+    CELERY_RESULT_BACKEND = os.environ.get(
+        'CSTRESS_bACKEND', 'redis://?new_join=1',
+    )
     BROKER_TRANSPORT_OPTIONS = {'fanout_prefix': True}