Browse Source

Adds Chord support for the cache backend.

The implementation is efficient and should make chords for the cache
backend just as good, or better as when using the Redis backend.

(dummy cache support has also been added, but being in-memory
this is rather limited and to be used for unittests only)

Also adds unittests and fixes the memcache backend's test reset code:

If one changes the common cache backend now shared in 'class
test_CacheBackend' to memcached instead of the dummy backend (for
example, to actually test this backend), then this code currently fails
as the _imp variable is not reset in this one case. Fix this small
omission and make some of the other code clearer (explicit tuple) and
more consistent (order of the mock context decorators).

Closes #533

Signed-off-by: Dan McGee <dan@archlinux.org>
Dan McGee 13 years ago
parent
commit
93258b8bb4

+ 10 - 5
celery/backends/base.py

@@ -13,7 +13,7 @@ from .. import states
 from ..datastructures import LRUCache
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..utils import timeutils
-from ..utils.encoding import from_utf8
+from ..utils.encoding import ensure_bytes, from_utf8
 from ..utils.serialization import (get_pickled_exception,
                                    get_pickleable_exception,
                                    create_exception_cls)
@@ -203,7 +203,7 @@ class BaseBackend(object):
         raise NotImplementedError(
                 "reload_taskset_result is not supported by this backend.")
 
-    def on_chord_part_return(self, task):
+    def on_chord_part_return(self, task, propagate=False):
         pass
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
@@ -301,6 +301,7 @@ class BaseDictBackend(BaseBackend):
 class KeyValueStoreBackend(BaseDictBackend):
     task_keyprefix = "celery-task-meta-"
     taskset_keyprefix = "celery-taskset-meta-"
+    chord_keyprefix = "chord-unlock-"
 
     def get(self, key):
         raise NotImplementedError("Must implement the get method.")
@@ -316,11 +317,15 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
-        return self.task_keyprefix + task_id
+        return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
 
     def get_key_for_taskset(self, taskset_id):
-        """Get the cache key for a task by id."""
-        return self.taskset_keyprefix + taskset_id
+        """Get the cache key for a taskset by id."""
+        return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
+
+    def get_key_for_chord(self, taskset_id):
+        """Get the cache key for the chord waiting on taskset with given id."""
+        return ensure_bytes(self.chord_keyprefix) + ensure_bytes(taskset_id)
 
     def _strip_prefix(self, key):
         for prefix in self.task_keyprefix, self.taskset_keyprefix:

+ 21 - 8
celery/backends/cache.py

@@ -4,7 +4,6 @@ from __future__ import absolute_import
 from ..datastructures import LRUCache
 from ..exceptions import ImproperlyConfigured
 from ..utils import cached_property
-from ..utils.encoding import ensure_bytes
 
 from .base import KeyValueStoreBackend
 
@@ -24,7 +23,7 @@ def import_best_memcache():
                 raise ImproperlyConfigured(
                         "Memcached backend requires either the 'pylibmc' "
                         "or 'memcache' library")
-        _imp[0] = is_pylibmc, memcache
+        _imp[0] = (is_pylibmc, memcache)
     return _imp[0]
 
 
@@ -55,6 +54,9 @@ class DummyClient(object):
     def delete(self, key, *args, **kwargs):
         self.cache.pop(key, None)
 
+    def incr(self, key, delta=1):
+        return self.cache.incr(key, delta)
+
 
 backends = {"memcache": lambda: get_best_memcache,
             "memcached": lambda: get_best_memcache,
@@ -85,12 +87,6 @@ class CacheBackend(KeyValueStoreBackend):
                     "following backends: %s" % (self.backend,
                                                 ", ".join(backends.keys())))
 
-    def get_key_for_task(self, task_id):
-        return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
-
-    def get_key_for_taskset(self, taskset_id):
-        return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
-
     def get(self, key):
         return self.client.get(key)
 
@@ -103,6 +99,23 @@ class CacheBackend(KeyValueStoreBackend):
     def delete(self, key):
         return self.client.delete(key)
 
+    def on_chord_apply(self, setid, body, result=None, **kwargs):
+        key = self.get_key_for_chord(setid)
+        self.client.set(key, '0', time=86400)
+
+    def on_chord_part_return(self, task, propagate=False):
+        from ..task.sets import subtask
+        from ..result import TaskSetResult
+        setid = task.request.taskset
+        if not setid:
+            return
+        key = self.get_key_for_chord(setid)
+        deps = TaskSetResult.restore(setid, backend=task.backend)
+        if self.client.incr(key) >= deps.total:
+            subtask(task.request.chord).delay(deps.join(propagate=propagate))
+            deps.delete()
+            self.client.delete(key)
+
     @cached_property
     def client(self):
         return self.Client(self.servers, **self.options)

+ 7 - 5
celery/backends/redis.py

@@ -74,19 +74,21 @@ class RedisBackend(KeyValueStoreBackend):
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
         self.app.TaskSetResult(setid, result).save()
-        pass
 
-    def on_chord_part_return(self, task, propagate=False,
-            keyprefix="chord-unlock-%s"):
+    def on_chord_part_return(self, task, propagate=False):
         from ..task.sets import subtask
         from ..result import TaskSetResult
         setid = task.request.taskset
-        key = keyprefix % setid
+        if not setid:
+            return
+        key = self.get_key_for_chord(setid)
         deps = TaskSetResult.restore(setid, backend=task.backend)
         if self.client.incr(key) >= deps.total:
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             deps.delete()
-        self.client.expire(key, 86400)
+            self.client.delete(key)
+        else:
+            self.client.expire(key, 86400)
 
     @cached_property
     def client(self):

+ 8 - 0
celery/datastructures.py

@@ -374,6 +374,14 @@ class LRUCache(UserDict):
                 pass
     itervalues = _iterate_values
 
+    def incr(self, key, delta=1):
+        with self.mutex:
+            # this acts as memcached does- store as a string, but return a
+            # integer as long as it exists and we can cast it
+            newval = int(self.data.pop(key)) + delta
+            self[key] = str(newval)
+            return newval
+
 
 class TokenBucket(object):
     """Token Bucket Algorithm.

+ 58 - 33
celery/tests/test_backends/test_cache.py

@@ -6,10 +6,14 @@ import types
 
 from contextlib import contextmanager
 
+from mock import Mock, patch
+
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
+from celery.registry import tasks
 from celery.result import AsyncResult
+from celery.task import subtask
 from celery.utils import uuid
 from celery.utils.encoding import str_to_bytes
 
@@ -24,60 +28,80 @@ class SomeClass(object):
 
 class test_CacheBackend(unittest.TestCase):
 
-    def test_mark_as_done(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid = uuid()
+    def setUp(self):
+        self.tb = CacheBackend(backend="memory://")
+        self.tid = uuid()
 
-        self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertIsNone(tb.get_result(tid))
+    def test_mark_as_done(self):
+        self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
+        self.assertIsNone(self.tb.get_result(self.tid))
 
-        tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_status(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
+        self.tb.mark_as_done(self.tid, 42)
+        self.assertEqual(self.tb.get_status(self.tid), states.SUCCESS)
+        self.assertEqual(self.tb.get_result(self.tid), 42)
 
     def test_is_pickled(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid2 = uuid()
         result = {"foo": "baz", "bar": SomeClass(12345)}
-        tb.mark_as_done(tid2, result)
+        self.tb.mark_as_done(self.tid, result)
         # is serialized properly.
-        rindb = tb.get_result(tid2)
+        rindb = self.tb.get_result(self.tid)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)
 
     def test_mark_as_failure(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid3 = uuid()
         try:
             raise KeyError("foo")
         except KeyError, exception:
-            pass
-            tb.mark_as_failure(tid3, exception)
-            self.assertEqual(tb.get_status(tid3), states.FAILURE)
-            self.assertIsInstance(tb.get_result(tid3), KeyError)
+            self.tb.mark_as_failure(self.tid, exception)
+            self.assertEqual(self.tb.get_status(self.tid), states.FAILURE)
+            self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
 
-    def test_mget(self):
+    def test_on_chord_apply(self):
         tb = CacheBackend(backend="memory://")
-        tb.set("foo", 1)
-        tb.set("bar", 2)
+        tb.on_chord_apply("setid", [])
+
+    @patch("celery.result.TaskSetResult")
+    def test_on_chord_part_return(self, setresult):
+        tb = CacheBackend(backend="memory://")
+
+        deps = Mock()
+        deps.total = 2
+        setresult.restore.return_value = deps
+        task = Mock()
+        task.name = "foobarbaz"
+        try:
+            tasks["foobarbaz"] = task
+            task.request.chord = subtask(task)
+            task.request.taskset = "setid"
 
-        self.assertDictEqual(tb.mget(["foo", "bar"]),
+            tb.on_chord_apply(task.request.taskset, [])
+
+            self.assertFalse(deps.join.called)
+            tb.on_chord_part_return(task)
+            self.assertFalse(deps.join.called)
+
+            tb.on_chord_part_return(task)
+            deps.join.assert_called_with(propagate=False)
+            deps.delete.assert_called_with()
+
+        finally:
+            tasks.pop("foobarbaz")
+
+    def test_mget(self):
+        self.tb.set("foo", 1)
+        self.tb.set("bar", 2)
+
+        self.assertDictEqual(self.tb.mget(["foo", "bar"]),
                              {"foo": 1, "bar": 2})
 
     def test_forget(self):
-        tb = CacheBackend(backend="memory://")
-        tid = uuid()
-        tb.mark_as_done(tid, {"foo": "bar"})
-        x = AsyncResult(tid, backend=tb)
+        self.tb.mark_as_done(self.tid, {"foo": "bar"})
+        x = AsyncResult(self.tid, backend=self.tb)
         x.forget()
         self.assertIsNone(x.result)
 
     def test_process_cleanup(self):
-        tb = CacheBackend(backend="memory://")
-        tb.process_cleanup()
+        self.tb.process_cleanup()
 
     def test_expires_as_int(self):
         tb = CacheBackend(backend="memory://", expires=10)
@@ -129,8 +153,8 @@ class MockCacheMixin(object):
 class test_get_best_memcache(unittest.TestCase, MockCacheMixin):
 
     def test_pylibmc(self):
-        with reset_modules("celery.backends.cache"):
-            with self.mock_pylibmc():
+        with self.mock_pylibmc():
+            with reset_modules("celery.backends.cache"):
                 from celery.backends import cache
                 cache._imp = [None]
                 self.assertEqual(cache.get_best_memcache().__module__,
@@ -157,6 +181,7 @@ class test_get_best_memcache(unittest.TestCase, MockCacheMixin):
         with self.mock_pylibmc():
             with reset_modules("celery.backends.cache"):
                 from celery.backends import cache
+                cache._imp = [None]
                 cache.get_best_memcache(behaviors={"foo": "bar"})
                 self.assertTrue(cache._imp[0])
                 cache.get_best_memcache()

+ 3 - 3
docs/userguide/tasksets.rst

@@ -244,14 +244,14 @@ Example implementation:
         unlock_chord.retry(countdown=interval, max_retries=max_retries)
 
 
-This is used by all result backends except Redis, which increments a
+This is used by all result backends except Redis and Memcached, which increment a
 counter after each task in the header, then applying the callback when the
 counter exceeds the number of tasks in the set. *Note:* chords do not properly
 work with Redis before version 2.2; you will need to upgrade to at least 2.2 to
 use them.
 
-The Redis approach is a much better solution, but not easily implemented
-in other backends (suggestions welcome!).
+The Redis and Memcached approach is a much better solution, but not easily
+implemented in other backends (suggestions welcome!).
 
 
 .. note::