Browse Source

Adds Chord support for the cache backend.

The implementation is efficient and should make chords for the cache
backend just as good, or better as when using the Redis backend.

(dummy cache support has also been added, but being in-memory
this is rather limited and to be used for unittests only)

Also adds unittests and fixes the memcache backend's test reset code:

If one changes the common cache backend now shared in 'class
test_CacheBackend' to memcached instead of the dummy backend (for
example, to actually test this backend), then this code currently fails
as the _imp variable is not reset in this one case. Fix this small
omission and make some of the other code clearer (explicit tuple) and
more consistent (order of the mock context decorators).

Closes #533

Signed-off-by: Dan McGee <dan@archlinux.org>
Dan McGee 13 years ago
parent
commit
93258b8bb4

+ 10 - 5
celery/backends/base.py

@@ -13,7 +13,7 @@ from .. import states
 from ..datastructures import LRUCache
 from ..datastructures import LRUCache
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..utils import timeutils
 from ..utils import timeutils
-from ..utils.encoding import from_utf8
+from ..utils.encoding import ensure_bytes, from_utf8
 from ..utils.serialization import (get_pickled_exception,
 from ..utils.serialization import (get_pickled_exception,
                                    get_pickleable_exception,
                                    get_pickleable_exception,
                                    create_exception_cls)
                                    create_exception_cls)
@@ -203,7 +203,7 @@ class BaseBackend(object):
         raise NotImplementedError(
         raise NotImplementedError(
                 "reload_taskset_result is not supported by this backend.")
                 "reload_taskset_result is not supported by this backend.")
 
 
-    def on_chord_part_return(self, task):
+    def on_chord_part_return(self, task, propagate=False):
         pass
         pass
 
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
     def on_chord_apply(self, setid, body, result=None, **kwargs):
@@ -301,6 +301,7 @@ class BaseDictBackend(BaseBackend):
 class KeyValueStoreBackend(BaseDictBackend):
 class KeyValueStoreBackend(BaseDictBackend):
     task_keyprefix = "celery-task-meta-"
     task_keyprefix = "celery-task-meta-"
     taskset_keyprefix = "celery-taskset-meta-"
     taskset_keyprefix = "celery-taskset-meta-"
+    chord_keyprefix = "chord-unlock-"
 
 
     def get(self, key):
     def get(self, key):
         raise NotImplementedError("Must implement the get method.")
         raise NotImplementedError("Must implement the get method.")
@@ -316,11 +317,15 @@ class KeyValueStoreBackend(BaseDictBackend):
 
 
     def get_key_for_task(self, task_id):
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
         """Get the cache key for a task by id."""
-        return self.task_keyprefix + task_id
+        return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
 
 
     def get_key_for_taskset(self, taskset_id):
     def get_key_for_taskset(self, taskset_id):
-        """Get the cache key for a task by id."""
-        return self.taskset_keyprefix + taskset_id
+        """Get the cache key for a taskset by id."""
+        return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
+
+    def get_key_for_chord(self, taskset_id):
+        """Get the cache key for the chord waiting on taskset with given id."""
+        return ensure_bytes(self.chord_keyprefix) + ensure_bytes(taskset_id)
 
 
     def _strip_prefix(self, key):
     def _strip_prefix(self, key):
         for prefix in self.task_keyprefix, self.taskset_keyprefix:
         for prefix in self.task_keyprefix, self.taskset_keyprefix:

+ 21 - 8
celery/backends/cache.py

@@ -4,7 +4,6 @@ from __future__ import absolute_import
 from ..datastructures import LRUCache
 from ..datastructures import LRUCache
 from ..exceptions import ImproperlyConfigured
 from ..exceptions import ImproperlyConfigured
 from ..utils import cached_property
 from ..utils import cached_property
-from ..utils.encoding import ensure_bytes
 
 
 from .base import KeyValueStoreBackend
 from .base import KeyValueStoreBackend
 
 
@@ -24,7 +23,7 @@ def import_best_memcache():
                 raise ImproperlyConfigured(
                 raise ImproperlyConfigured(
                         "Memcached backend requires either the 'pylibmc' "
                         "Memcached backend requires either the 'pylibmc' "
                         "or 'memcache' library")
                         "or 'memcache' library")
-        _imp[0] = is_pylibmc, memcache
+        _imp[0] = (is_pylibmc, memcache)
     return _imp[0]
     return _imp[0]
 
 
 
 
@@ -55,6 +54,9 @@ class DummyClient(object):
     def delete(self, key, *args, **kwargs):
     def delete(self, key, *args, **kwargs):
         self.cache.pop(key, None)
         self.cache.pop(key, None)
 
 
+    def incr(self, key, delta=1):
+        return self.cache.incr(key, delta)
+
 
 
 backends = {"memcache": lambda: get_best_memcache,
 backends = {"memcache": lambda: get_best_memcache,
             "memcached": lambda: get_best_memcache,
             "memcached": lambda: get_best_memcache,
@@ -85,12 +87,6 @@ class CacheBackend(KeyValueStoreBackend):
                     "following backends: %s" % (self.backend,
                     "following backends: %s" % (self.backend,
                                                 ", ".join(backends.keys())))
                                                 ", ".join(backends.keys())))
 
 
-    def get_key_for_task(self, task_id):
-        return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
-
-    def get_key_for_taskset(self, taskset_id):
-        return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
-
     def get(self, key):
     def get(self, key):
         return self.client.get(key)
         return self.client.get(key)
 
 
@@ -103,6 +99,23 @@ class CacheBackend(KeyValueStoreBackend):
     def delete(self, key):
     def delete(self, key):
         return self.client.delete(key)
         return self.client.delete(key)
 
 
+    def on_chord_apply(self, setid, body, result=None, **kwargs):
+        key = self.get_key_for_chord(setid)
+        self.client.set(key, '0', time=86400)
+
+    def on_chord_part_return(self, task, propagate=False):
+        from ..task.sets import subtask
+        from ..result import TaskSetResult
+        setid = task.request.taskset
+        if not setid:
+            return
+        key = self.get_key_for_chord(setid)
+        deps = TaskSetResult.restore(setid, backend=task.backend)
+        if self.client.incr(key) >= deps.total:
+            subtask(task.request.chord).delay(deps.join(propagate=propagate))
+            deps.delete()
+            self.client.delete(key)
+
     @cached_property
     @cached_property
     def client(self):
     def client(self):
         return self.Client(self.servers, **self.options)
         return self.Client(self.servers, **self.options)

+ 7 - 5
celery/backends/redis.py

@@ -74,19 +74,21 @@ class RedisBackend(KeyValueStoreBackend):
 
 
     def on_chord_apply(self, setid, body, result=None, **kwargs):
     def on_chord_apply(self, setid, body, result=None, **kwargs):
         self.app.TaskSetResult(setid, result).save()
         self.app.TaskSetResult(setid, result).save()
-        pass
 
 
-    def on_chord_part_return(self, task, propagate=False,
-            keyprefix="chord-unlock-%s"):
+    def on_chord_part_return(self, task, propagate=False):
         from ..task.sets import subtask
         from ..task.sets import subtask
         from ..result import TaskSetResult
         from ..result import TaskSetResult
         setid = task.request.taskset
         setid = task.request.taskset
-        key = keyprefix % setid
+        if not setid:
+            return
+        key = self.get_key_for_chord(setid)
         deps = TaskSetResult.restore(setid, backend=task.backend)
         deps = TaskSetResult.restore(setid, backend=task.backend)
         if self.client.incr(key) >= deps.total:
         if self.client.incr(key) >= deps.total:
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             subtask(task.request.chord).delay(deps.join(propagate=propagate))
             deps.delete()
             deps.delete()
-        self.client.expire(key, 86400)
+            self.client.delete(key)
+        else:
+            self.client.expire(key, 86400)
 
 
     @cached_property
     @cached_property
     def client(self):
     def client(self):

+ 8 - 0
celery/datastructures.py

@@ -374,6 +374,14 @@ class LRUCache(UserDict):
                 pass
                 pass
     itervalues = _iterate_values
     itervalues = _iterate_values
 
 
+    def incr(self, key, delta=1):
+        with self.mutex:
+            # this acts as memcached does- store as a string, but return a
+            # integer as long as it exists and we can cast it
+            newval = int(self.data.pop(key)) + delta
+            self[key] = str(newval)
+            return newval
+
 
 
 class TokenBucket(object):
 class TokenBucket(object):
     """Token Bucket Algorithm.
     """Token Bucket Algorithm.

+ 58 - 33
celery/tests/test_backends/test_cache.py

@@ -6,10 +6,14 @@ import types
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
+from mock import Mock, patch
+
 from celery import states
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.backends.cache import CacheBackend, DummyClient
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
+from celery.registry import tasks
 from celery.result import AsyncResult
 from celery.result import AsyncResult
+from celery.task import subtask
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.encoding import str_to_bytes
 from celery.utils.encoding import str_to_bytes
 
 
@@ -24,60 +28,80 @@ class SomeClass(object):
 
 
 class test_CacheBackend(unittest.TestCase):
 class test_CacheBackend(unittest.TestCase):
 
 
-    def test_mark_as_done(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid = uuid()
+    def setUp(self):
+        self.tb = CacheBackend(backend="memory://")
+        self.tid = uuid()
 
 
-        self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertIsNone(tb.get_result(tid))
+    def test_mark_as_done(self):
+        self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
+        self.assertIsNone(self.tb.get_result(self.tid))
 
 
-        tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_status(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
+        self.tb.mark_as_done(self.tid, 42)
+        self.assertEqual(self.tb.get_status(self.tid), states.SUCCESS)
+        self.assertEqual(self.tb.get_result(self.tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid2 = uuid()
         result = {"foo": "baz", "bar": SomeClass(12345)}
         result = {"foo": "baz", "bar": SomeClass(12345)}
-        tb.mark_as_done(tid2, result)
+        self.tb.mark_as_done(self.tid, result)
         # is serialized properly.
         # is serialized properly.
-        rindb = tb.get_result(tid2)
+        rindb = self.tb.get_result(self.tid)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)
         self.assertEqual(rindb.get("bar").data, 12345)
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
-        tb = CacheBackend(backend="memory://")
-
-        tid3 = uuid()
         try:
         try:
             raise KeyError("foo")
             raise KeyError("foo")
         except KeyError, exception:
         except KeyError, exception:
-            pass
-            tb.mark_as_failure(tid3, exception)
-            self.assertEqual(tb.get_status(tid3), states.FAILURE)
-            self.assertIsInstance(tb.get_result(tid3), KeyError)
+            self.tb.mark_as_failure(self.tid, exception)
+            self.assertEqual(self.tb.get_status(self.tid), states.FAILURE)
+            self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
 
 
-    def test_mget(self):
+    def test_on_chord_apply(self):
         tb = CacheBackend(backend="memory://")
         tb = CacheBackend(backend="memory://")
-        tb.set("foo", 1)
-        tb.set("bar", 2)
+        tb.on_chord_apply("setid", [])
+
+    @patch("celery.result.TaskSetResult")
+    def test_on_chord_part_return(self, setresult):
+        tb = CacheBackend(backend="memory://")
+
+        deps = Mock()
+        deps.total = 2
+        setresult.restore.return_value = deps
+        task = Mock()
+        task.name = "foobarbaz"
+        try:
+            tasks["foobarbaz"] = task
+            task.request.chord = subtask(task)
+            task.request.taskset = "setid"
 
 
-        self.assertDictEqual(tb.mget(["foo", "bar"]),
+            tb.on_chord_apply(task.request.taskset, [])
+
+            self.assertFalse(deps.join.called)
+            tb.on_chord_part_return(task)
+            self.assertFalse(deps.join.called)
+
+            tb.on_chord_part_return(task)
+            deps.join.assert_called_with(propagate=False)
+            deps.delete.assert_called_with()
+
+        finally:
+            tasks.pop("foobarbaz")
+
+    def test_mget(self):
+        self.tb.set("foo", 1)
+        self.tb.set("bar", 2)
+
+        self.assertDictEqual(self.tb.mget(["foo", "bar"]),
                              {"foo": 1, "bar": 2})
                              {"foo": 1, "bar": 2})
 
 
     def test_forget(self):
     def test_forget(self):
-        tb = CacheBackend(backend="memory://")
-        tid = uuid()
-        tb.mark_as_done(tid, {"foo": "bar"})
-        x = AsyncResult(tid, backend=tb)
+        self.tb.mark_as_done(self.tid, {"foo": "bar"})
+        x = AsyncResult(self.tid, backend=self.tb)
         x.forget()
         x.forget()
         self.assertIsNone(x.result)
         self.assertIsNone(x.result)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
-        tb = CacheBackend(backend="memory://")
-        tb.process_cleanup()
+        self.tb.process_cleanup()
 
 
     def test_expires_as_int(self):
     def test_expires_as_int(self):
         tb = CacheBackend(backend="memory://", expires=10)
         tb = CacheBackend(backend="memory://", expires=10)
@@ -129,8 +153,8 @@ class MockCacheMixin(object):
 class test_get_best_memcache(unittest.TestCase, MockCacheMixin):
 class test_get_best_memcache(unittest.TestCase, MockCacheMixin):
 
 
     def test_pylibmc(self):
     def test_pylibmc(self):
-        with reset_modules("celery.backends.cache"):
-            with self.mock_pylibmc():
+        with self.mock_pylibmc():
+            with reset_modules("celery.backends.cache"):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 self.assertEqual(cache.get_best_memcache().__module__,
                 self.assertEqual(cache.get_best_memcache().__module__,
@@ -157,6 +181,7 @@ class test_get_best_memcache(unittest.TestCase, MockCacheMixin):
         with self.mock_pylibmc():
         with self.mock_pylibmc():
             with reset_modules("celery.backends.cache"):
             with reset_modules("celery.backends.cache"):
                 from celery.backends import cache
                 from celery.backends import cache
+                cache._imp = [None]
                 cache.get_best_memcache(behaviors={"foo": "bar"})
                 cache.get_best_memcache(behaviors={"foo": "bar"})
                 self.assertTrue(cache._imp[0])
                 self.assertTrue(cache._imp[0])
                 cache.get_best_memcache()
                 cache.get_best_memcache()

+ 3 - 3
docs/userguide/tasksets.rst

@@ -244,14 +244,14 @@ Example implementation:
         unlock_chord.retry(countdown=interval, max_retries=max_retries)
         unlock_chord.retry(countdown=interval, max_retries=max_retries)
 
 
 
 
-This is used by all result backends except Redis, which increments a
+This is used by all result backends except Redis and Memcached, which increment a
 counter after each task in the header, then applying the callback when the
 counter after each task in the header, then applying the callback when the
 counter exceeds the number of tasks in the set. *Note:* chords do not properly
 counter exceeds the number of tasks in the set. *Note:* chords do not properly
 work with Redis before version 2.2; you will need to upgrade to at least 2.2 to
 work with Redis before version 2.2; you will need to upgrade to at least 2.2 to
 use them.
 use them.
 
 
-The Redis approach is a much better solution, but not easily implemented
-in other backends (suggestions welcome!).
+The Redis and Memcached approach is a much better solution, but not easily
+implemented in other backends (suggestions welcome!).
 
 
 
 
 .. note::
 .. note::