Bläddra i källkod

Implements join_native() + iter_native() for Redis and Cache backends

Ask Solem 14 år sedan
förälder
incheckning
a40ab46dc8
5 ändrade filer med 86 tillägg och 20 borttagningar
  1. 2 2
      celery/backends/amqp.py
  2. 52 3
      celery/backends/base.py
  3. 7 0
      celery/backends/cache.py
  4. 3 0
      celery/backends/redis.py
  5. 22 15
      celery/result.py

+ 2 - 2
celery/backends/amqp.py

@@ -190,7 +190,7 @@ class AMQPBackend(BaseDictBackend):
             with self._create_consumer(binding, channel) as consumer:
             with self._create_consumer(binding, channel) as consumer:
                 return self.drain_events(conn, consumer, timeout).values()[0]
                 return self.drain_events(conn, consumer, timeout).values()[0]
 
 
-    def get_many(self, task_ids, timeout=None):
+    def get_many(self, task_ids, timeout=None, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             ids = set(task_ids)
             cached_ids = set()
             cached_ids = set()
@@ -210,7 +210,7 @@ class AMQPBackend(BaseDictBackend):
                 while ids:
                 while ids:
                     r = self.drain_events(conn, consumer, timeout)
                     r = self.drain_events(conn, consumer, timeout)
                     ids ^= set(r.keys())
                     ids ^= set(r.keys())
-                    for ready_id, ready_meta in r.items():
+                    for ready_id, ready_meta in r.iteritems():
                         yield ready_id, ready_meta
                         yield ready_id, ready_meta
 
 
     def reload_task_result(self, task_id):
     def reload_task_result(self, task_id):

+ 52 - 3
celery/backends/base.py

@@ -19,6 +19,8 @@ class BaseBackend(object):
 
 
     TimeoutError = TimeoutError
     TimeoutError = TimeoutError
 
 
+    can_get_many = False
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         from celery.app import app_or_default
         from celery.app import app_or_default
         self.app = app_or_default(kwargs.get("app"))
         self.app = app_or_default(kwargs.get("app"))
@@ -248,10 +250,15 @@ class BaseDictBackend(BaseBackend):
 
 
 
 
 class KeyValueStoreBackend(BaseDictBackend):
 class KeyValueStoreBackend(BaseDictBackend):
+    task_keyprefix = "celery-task-meta-"
+    taskset_keyprefix = "celery-taskset-meta-"
 
 
     def get(self, key):
     def get(self, key):
         raise NotImplementedError("Must implement the get method.")
         raise NotImplementedError("Must implement the get method.")
 
 
+    def mget(self, keys):
+        raise NotImplementedError("Does not support get_many")
+
     def set(self, key, value):
     def set(self, key, value):
         raise NotImplementedError("Must implement the set method.")
         raise NotImplementedError("Must implement the set method.")
 
 
@@ -260,11 +267,53 @@ class KeyValueStoreBackend(BaseDictBackend):
 
 
     def get_key_for_task(self, task_id):
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
         """Get the cache key for a task by id."""
-        return "celery-task-meta-%s" % task_id
+        return self.task_keyprefix + task_id
 
 
-    def get_key_for_taskset(self, task_id):
+    def get_key_for_taskset(self, taskset_id):
         """Get the cache key for a task by id."""
         """Get the cache key for a task by id."""
-        return "celery-taskset-meta-%s" % task_id
+        return self.taskset_keyprefix + taskset_id
+
+    def _strip_prefix(self, key):
+        for prefix in self.task_keyprefix, self.taskset_keyprefix:
+            if key.startswith(prefix):
+                return key[len(prefix):]
+        return key
+
+    def _mget_to_results(self, values, keys):
+        if hasattr(values, "items"):
+            # client returns dict so mapping preserved.
+            return dict((self._strip_prefix(k), pickle.loads(str(v)))
+                            for k, v in values.iteritems()
+                                if v is not None)
+        else:
+            # client returns list so need to recreate mapping.
+            return dict((keys[i], pickle.loads(str(value)))
+                            for i, value in enumerate(values)
+                                if value is not None)
+
+    def get_many(self, task_ids, timeout=None, interval=0.5):
+        ids = set(task_ids)
+        cached_ids = set()
+        for task_id in ids:
+            try:
+                cached = self._cache[task_id]
+            except KeyError:
+                pass
+            else:
+                if cached["status"] in states.READY_STATES:
+                    yield task_id, cached
+                    cached_ids.add(taskid)
+
+        ids ^= cached_ids
+        while ids:
+            keys = list(ids)
+            r = self._mget_to_results(self.mget([self.get_key_for_task(k)
+                                                    for k in keys]), keys)
+            self._cache.update(r)
+            ids ^= set(r.keys())
+            for key, value in r.iteritems():
+                yield key, value
+            time.sleep(interval)  # don't busy loop.
 
 
     def _forget(self, task_id):
     def _forget(self, task_id):
         self.delete(self.get_key_for_task(task_id))
         self.delete(self.get_key_for_task(task_id))

+ 7 - 0
celery/backends/cache.py

@@ -41,6 +41,10 @@ class DummyClient(object):
     def get(self, key, *args, **kwargs):
     def get(self, key, *args, **kwargs):
         return self.cache.get(key)
         return self.cache.get(key)
 
 
+    def get_multi(self, keys):
+        cache = self.cache
+        return dict((k, cache[k]) for k in keys if k in cache)
+
     def set(self, key, value, *args, **kwargs):
     def set(self, key, value, *args, **kwargs):
         self.cache[key] = value
         self.cache[key] = value
 
 
@@ -77,6 +81,9 @@ class CacheBackend(KeyValueStoreBackend):
     def get(self, key):
     def get(self, key):
         return self.client.get(key)
         return self.client.get(key)
 
 
+    def mget(self, keys):
+        return self.client.get_multi(keys)
+
     def set(self, key, value):
     def set(self, key, value):
         return self.client.set(key, value, self.expires)
         return self.client.set(key, value, self.expires)
 
 

+ 3 - 0
celery/backends/redis.py

@@ -57,6 +57,9 @@ class RedisBackend(KeyValueStoreBackend):
     def get(self, key):
     def get(self, key):
         return self.client.get(key)
         return self.client.get(key)
 
 
+    def mget(self, keys):
+        return self.client.mget(keys)
+
     def set(self, key, value):
     def set(self, key, value):
         client = self.client
         client = self.client
         client.set(key, value)
         client.set(key, value)

+ 22 - 15
celery/result.py

@@ -390,33 +390,40 @@ class ResultSet(object):
                                        interval=interval))
                                        interval=interval))
         return results
         return results
 
 
-    def iter_native(self, timeout=None):
-        backend = self.results[0].backend
-        ids = [result.task_id for result in self.results]
-        return backend.get_many(ids, timeout=timeout)
-
-    def join_native(self, timeout=None, propagate=True):
-        """Backend optimized version of :meth:`join`.
+    def iter_native(self, timeout=None, interval=None):
+        """Backend optimized version of :meth:`iterate`.
 
 
         .. versionadded:: 2.2
         .. versionadded:: 2.2
 
 
         Note that this does not support collecting the results
         Note that this does not support collecting the results
         for different task types using different backends.
         for different task types using different backends.
 
 
-        This is currently only supported by the AMQP result backend.
+        This is currently only supported by the AMQP, Redis and cache
+        result backends.
 
 
         """
         """
         backend = self.results[0].backend
         backend = self.results[0].backend
-        results = [None for _ in xrange(len(self.results))]
-
         ids = [result.task_id for result in self.results]
         ids = [result.task_id for result in self.results]
-        states = dict(backend.get_many(ids, timeout=timeout))
+        return backend.get_many(ids, timeout=timeout, interval=interval)
 
 
-        for task_id, meta in states.items():
-            index = self.results.index(task_id)
-            results[index] = meta["result"]
+    def join_native(self, timeout=None, propagate=True, interval=0.5):
+        """Backend optimized version of :meth:`join`.
 
 
-        return list(results)
+        .. versionadded:: 2.2
+
+        Note that this does not support collecting the results
+        for different task types using different backends.
+
+        This is currently only supported by the AMQP, Redis and cache
+        result backends.
+
+        """
+        results = self.results
+        acc = [None for _ in xrange(self.total)]
+        for task_id, meta in self.iter_native(timeout=timeout,
+                                              interval=interval):
+            acc[results.index(task_id)] = meta["result"]
+        return acc
 
 
     @property
     @property
     def total(self):
     def total(self):