Browse Source

Adds chord callback support for the amqp backend

Ask Solem 13 years ago
parent
commit
9bac5498eb

+ 2 - 0
celery/backends/amqp.py

@@ -37,6 +37,8 @@ class AMQPBackend(BaseDictBackend):
 
     BacklogLimitExceeded = BacklogLimitExceeded
 
+    supports_native_join = True
+
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
             **kwargs):

+ 5 - 1
celery/backends/base.py

@@ -40,6 +40,9 @@ class BaseBackend(object):
     #: argument which is for each pass.
     subpolling_interval = None
 
+    #: If true the backend must implement :meth:`get_many`.
+    supports_native_join = False
+
     def __init__(self, *args, **kwargs):
         from ..app import app_or_default
         self.app = app_or_default(kwargs.get("app"))
@@ -203,8 +206,9 @@ class BaseBackend(object):
     def on_chord_part_return(self, task):
         pass
 
-    def on_chord_apply(self, setid, body, *args, **kwargs):
+    def on_chord_apply(self, setid, body, result=None, **kwargs):
         from ..registry import tasks
+        kwargs["result"] = [r.task_id for r in result]
         tasks["celery.chord_unlock"].apply_async((setid, body, ), kwargs,
                                                  countdown=1)
 

+ 1 - 0
celery/backends/cache.py

@@ -64,6 +64,7 @@ backends = {"memcache": lambda: get_best_memcache,
 
 class CacheBackend(KeyValueStoreBackend):
     servers = None
+    supports_native_join = True
 
     def __init__(self, expires=None, backend=None, options={}, **kwargs):
         super(CacheBackend, self).__init__(self, **kwargs)

+ 4 - 1
celery/backends/redis.py

@@ -32,6 +32,8 @@ class RedisBackend(KeyValueStoreBackend):
     #: default Redis password (:const:`None`)
     password = None
 
+    supports_native_join = True
+
     def __init__(self, host=None, port=None, db=None, password=None,
             expires=None, **kwargs):
         super(RedisBackend, self).__init__(**kwargs)
@@ -70,7 +72,8 @@ class RedisBackend(KeyValueStoreBackend):
     def delete(self, key):
         self.client.delete(key)
 
-    def on_chord_apply(self, setid, *args, **kwargs):
+    def on_chord_apply(self, setid, body, result=None, **kwargs):
+        self.app.TaskSetResult(setid, r).save()
         pass
 
     def on_chord_part_return(self, task, propagate=False,

+ 4 - 0
celery/result.py

@@ -451,6 +451,10 @@ class ResultSet(object):
         """Deprecated alias to :attr:`results`."""
         return self.results
 
+    @property
+    def supports_native_join(self):
+        return self.results[0].backend.supports_native_join
+
 
 class TaskSetResult(ResultSet):
     """An instance of this class is returned by

+ 9 - 8
celery/task/chords.py

@@ -12,7 +12,7 @@
 from __future__ import absolute_import
 
 from .. import current_app
-from ..result import TaskSetResult
+from ..result import AsyncResult, TaskSetResult
 from ..utils import uuid
 
 from .sets import TaskSet, subtask
@@ -20,11 +20,11 @@ from .sets import TaskSet, subtask
 
 @current_app.task(name="celery.chord_unlock", max_retries=None)
 def _unlock_chord(setid, callback, interval=1, propagate=False,
-        max_retries=None):
-    result = TaskSetResult.restore(setid)
+        max_retries=None, result=None):
+    result = TaskSetResult(setid, map(AsyncResult, result))
     if result.ready():
-        subtask(callback).delay(result.join(propagate=propagate))
-        result.delete()
+        j = result.join_native if result.supports_native_join else result.join
+        subtask(callback).delay(j(propagate=propagate))
     else:
         _unlock_chord.retry(countdown=interval, max_retries=max_retries)
 
@@ -43,10 +43,11 @@ class Chord(current_app.Task):
             tid = uuid()
             task.options.update(task_id=tid, chord=body)
             r.append(current_app.AsyncResult(tid))
-        current_app.TaskSetResult(setid, r).save()
-        self.backend.on_chord_apply(setid, body, interval,
+        self.backend.on_chord_apply(setid, body,
+                                    interval=interval,
                                     max_retries=max_retries,
-                                    propagate=propagate)
+                                    propagate=propagate,
+                                    result=r)
         return set.apply_async(taskset_id=setid)