|
@@ -19,6 +19,8 @@ class BaseBackend(object):
|
|
|
|
|
|
TimeoutError = TimeoutError
|
|
|
|
|
|
+ can_get_many = False
|
|
|
+
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
from celery.app import app_or_default
|
|
|
self.app = app_or_default(kwargs.get("app"))
|
|
@@ -248,10 +250,15 @@ class BaseDictBackend(BaseBackend):
|
|
|
|
|
|
|
|
|
class KeyValueStoreBackend(BaseDictBackend):
|
|
|
+ task_keyprefix = "celery-task-meta-"
|
|
|
+ taskset_keyprefix = "celery-taskset-meta-"
|
|
|
|
|
|
def get(self, key):
|
|
|
raise NotImplementedError("Must implement the get method.")
|
|
|
|
|
|
+ def mget(self, keys):
|
|
|
+ raise NotImplementedError("Does not support get_many")
|
|
|
+
|
|
|
def set(self, key, value):
|
|
|
raise NotImplementedError("Must implement the set method.")
|
|
|
|
|
@@ -260,11 +267,53 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
|
|
|
def get_key_for_task(self, task_id):
|
|
|
"""Get the cache key for a task by id."""
|
|
|
- return "celery-task-meta-%s" % task_id
|
|
|
+ return self.task_keyprefix + task_id
|
|
|
|
|
|
- def get_key_for_taskset(self, task_id):
|
|
|
+ def get_key_for_taskset(self, taskset_id):
|
|
|
"""Get the cache key for a task by id."""
|
|
|
- return "celery-taskset-meta-%s" % task_id
|
|
|
+ return self.taskset_keyprefix + taskset_id
|
|
|
+
|
|
|
+ def _strip_prefix(self, key):
|
|
|
+ for prefix in self.task_keyprefix, self.taskset_keyprefix:
|
|
|
+ if key.startswith(prefix):
|
|
|
+ return key[len(prefix):]
|
|
|
+ return key
|
|
|
+
|
|
|
+ def _mget_to_results(self, values, keys):
|
|
|
+ if hasattr(values, "items"):
|
|
|
+ # client returns dict so mapping preserved.
|
|
|
+ return dict((self._strip_prefix(k), pickle.loads(str(v)))
|
|
|
+ for k, v in values.iteritems()
|
|
|
+ if v is not None)
|
|
|
+ else:
|
|
|
+ # client returns list so need to recreate mapping.
|
|
|
+ return dict((keys[i], pickle.loads(str(value)))
|
|
|
+ for i, value in enumerate(values)
|
|
|
+ if value is not None)
|
|
|
+
|
|
|
+ def get_many(self, task_ids, timeout=None, interval=0.5):
|
|
|
+ ids = set(task_ids)
|
|
|
+ cached_ids = set()
|
|
|
+ for task_id in ids:
|
|
|
+ try:
|
|
|
+ cached = self._cache[task_id]
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ if cached["status"] in states.READY_STATES:
|
|
|
+ yield task_id, cached
|
|
|
+ cached_ids.add(taskid)
|
|
|
+
|
|
|
+ ids ^= cached_ids
|
|
|
+ while ids:
|
|
|
+ keys = list(ids)
|
|
|
+ r = self._mget_to_results(self.mget([self.get_key_for_task(k)
|
|
|
+ for k in keys]), keys)
|
|
|
+ self._cache.update(r)
|
|
|
+ ids ^= set(r.keys())
|
|
|
+ for key, value in r.iteritems():
|
|
|
+ yield key, value
|
|
|
+ time.sleep(interval) # don't busy loop.
|
|
|
|
|
|
def _forget(self, task_id):
|
|
|
self.delete(self.get_key_for_task(task_id))
|