浏览代码

Implements join_native() + iter_native() for Redis and Cache backends

Ask Solem 14 年之前
父节点
当前提交
a40ab46dc8
共有 5 个文件被更改,包括 86 次插入20 次删除
  1. 2 2
      celery/backends/amqp.py
  2. 52 3
      celery/backends/base.py
  3. 7 0
      celery/backends/cache.py
  4. 3 0
      celery/backends/redis.py
  5. 22 15
      celery/result.py

+ 2 - 2
celery/backends/amqp.py

@@ -190,7 +190,7 @@ class AMQPBackend(BaseDictBackend):
             with self._create_consumer(binding, channel) as consumer:
                 return self.drain_events(conn, consumer, timeout).values()[0]
 
-    def get_many(self, task_ids, timeout=None):
+    def get_many(self, task_ids, timeout=None, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             cached_ids = set()
@@ -210,7 +210,7 @@ class AMQPBackend(BaseDictBackend):
                 while ids:
                     r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r.keys())
-                    for ready_id, ready_meta in r.items():
+                    for ready_id, ready_meta in r.iteritems():
                         yield ready_id, ready_meta
 
     def reload_task_result(self, task_id):

+ 52 - 3
celery/backends/base.py

@@ -19,6 +19,8 @@ class BaseBackend(object):
 
     TimeoutError = TimeoutError
 
+    can_get_many = False
+
     def __init__(self, *args, **kwargs):
         from celery.app import app_or_default
         self.app = app_or_default(kwargs.get("app"))
@@ -248,10 +250,15 @@ class BaseDictBackend(BaseBackend):
 
 
 class KeyValueStoreBackend(BaseDictBackend):
+    task_keyprefix = "celery-task-meta-"
+    taskset_keyprefix = "celery-taskset-meta-"
 
     def get(self, key):
         raise NotImplementedError("Must implement the get method.")
 
+    def mget(self, keys):
+        raise NotImplementedError("Does not support get_many")
+
     def set(self, key, value):
         raise NotImplementedError("Must implement the set method.")
 
@@ -260,11 +267,53 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
-        return "celery-task-meta-%s" % task_id
+        return self.task_keyprefix + task_id
 
-    def get_key_for_taskset(self, task_id):
+    def get_key_for_taskset(self, taskset_id):
         """Get the cache key for a task by id."""
-        return "celery-taskset-meta-%s" % task_id
+        return self.taskset_keyprefix + taskset_id
+
+    def _strip_prefix(self, key):
+        for prefix in self.task_keyprefix, self.taskset_keyprefix:
+            if key.startswith(prefix):
+                return key[len(prefix):]
+        return key
+
+    def _mget_to_results(self, values, keys):
+        if hasattr(values, "items"):
+            # client returns dict so mapping preserved.
+            return dict((self._strip_prefix(k), pickle.loads(str(v)))
+                            for k, v in values.iteritems()
+                                if v is not None)
+        else:
+            # client returns list so need to recreate mapping.
+            return dict((keys[i], pickle.loads(str(value)))
+                            for i, value in enumerate(values)
+                                if value is not None)
+
+    def get_many(self, task_ids, timeout=None, interval=0.5):
+        ids = set(task_ids)
+        cached_ids = set()
+        for task_id in ids:
+            try:
+                cached = self._cache[task_id]
+            except KeyError:
+                pass
+            else:
+                if cached["status"] in states.READY_STATES:
+                    yield task_id, cached
+                    cached_ids.add(taskid)
+
+        ids ^= cached_ids
+        while ids:
+            keys = list(ids)
+            r = self._mget_to_results(self.mget([self.get_key_for_task(k)
+                                                    for k in keys]), keys)
+            self._cache.update(r)
+            ids ^= set(r.keys())
+            for key, value in r.iteritems():
+                yield key, value
+            time.sleep(interval)  # don't busy loop.
 
     def _forget(self, task_id):
         self.delete(self.get_key_for_task(task_id))

+ 7 - 0
celery/backends/cache.py

@@ -41,6 +41,10 @@ class DummyClient(object):
     def get(self, key, *args, **kwargs):
         return self.cache.get(key)
 
+    def get_multi(self, keys):
+        cache = self.cache
+        return dict((k, cache[k]) for k in keys if k in cache)
+
     def set(self, key, value, *args, **kwargs):
         self.cache[key] = value
 
@@ -77,6 +81,9 @@ class CacheBackend(KeyValueStoreBackend):
     def get(self, key):
         return self.client.get(key)
 
+    def mget(self, keys):
+        return self.client.get_multi(keys)
+
     def set(self, key, value):
         return self.client.set(key, value, self.expires)
 

+ 3 - 0
celery/backends/redis.py

@@ -57,6 +57,9 @@ class RedisBackend(KeyValueStoreBackend):
     def get(self, key):
         return self.client.get(key)
 
+    def mget(self, keys):
+        return self.client.mget(keys)
+
     def set(self, key, value):
         client = self.client
         client.set(key, value)

+ 22 - 15
celery/result.py

@@ -390,33 +390,40 @@ class ResultSet(object):
                                        interval=interval))
         return results
 
-    def iter_native(self, timeout=None):
-        backend = self.results[0].backend
-        ids = [result.task_id for result in self.results]
-        return backend.get_many(ids, timeout=timeout)
-
-    def join_native(self, timeout=None, propagate=True):
-        """Backend optimized version of :meth:`join`.
+    def iter_native(self, timeout=None, interval=None):
+        """Backend optimized version of :meth:`iterate`.
 
         .. versionadded:: 2.2
 
         Note that this does not support collecting the results
         for different task types using different backends.
 
-        This is currently only supported by the AMQP result backend.
+        This is currently only supported by the AMQP, Redis and cache
+        result backends.
 
         """
         backend = self.results[0].backend
-        results = [None for _ in xrange(len(self.results))]
-
         ids = [result.task_id for result in self.results]
-        states = dict(backend.get_many(ids, timeout=timeout))
+        return backend.get_many(ids, timeout=timeout, interval=interval)
 
-        for task_id, meta in states.items():
-            index = self.results.index(task_id)
-            results[index] = meta["result"]
+    def join_native(self, timeout=None, propagate=True, interval=0.5):
+        """Backend optimized version of :meth:`join`.
 
-        return list(results)
+        .. versionadded:: 2.2
+
+        Note that this does not support collecting the results
+        for different task types using different backends.
+
+        This is currently only supported by the AMQP, Redis and cache
+        result backends.
+
+        """
+        results = self.results
+        acc = [None for _ in xrange(self.total)]
+        for task_id, meta in self.iter_native(timeout=timeout,
+                                              interval=interval):
+            acc[results.index(task_id)] = meta["result"]
+        return acc
 
     @property
     def total(self):