|
@@ -13,7 +13,7 @@ from .. import states
|
|
|
from ..datastructures import LRUCache
|
|
|
from ..exceptions import TimeoutError, TaskRevokedError
|
|
|
from ..utils import timeutils
|
|
|
-from ..utils.encoding import ensure_bytes, from_utf8
|
|
|
+from ..utils.encoding import bytes_to_str, ensure_bytes, from_utf8
|
|
|
from ..utils.serialization import (get_pickled_exception,
|
|
|
get_pickleable_exception,
|
|
|
create_exception_cls)
|
|
@@ -299,9 +299,9 @@ class BaseDictBackend(BaseBackend):
|
|
|
|
|
|
|
|
|
class KeyValueStoreBackend(BaseDictBackend):
|
|
|
- task_keyprefix = "celery-task-meta-"
|
|
|
- taskset_keyprefix = "celery-taskset-meta-"
|
|
|
- chord_keyprefix = "chord-unlock-"
|
|
|
+ task_keyprefix = ensure_bytes("celery-task-meta-")
|
|
|
+ taskset_keyprefix = ensure_bytes("celery-taskset-meta-")
|
|
|
+ chord_keyprefix = ensure_bytes("chord-unlock-")
|
|
|
|
|
|
def get(self, key):
|
|
|
raise NotImplementedError("Must implement the get method.")
|
|
@@ -317,21 +317,22 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
|
|
|
def get_key_for_task(self, task_id):
|
|
|
"""Get the cache key for a task by id."""
|
|
|
- return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
|
|
|
+ return self.task_keyprefix + ensure_bytes(task_id)
|
|
|
|
|
|
def get_key_for_taskset(self, taskset_id):
|
|
|
"""Get the cache key for a taskset by id."""
|
|
|
- return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
|
|
|
+ return self.taskset_keyprefix + ensure_bytes(taskset_id)
|
|
|
|
|
|
def get_key_for_chord(self, taskset_id):
|
|
|
"""Get the cache key for the chord waiting on taskset with given id."""
|
|
|
- return ensure_bytes(self.chord_keyprefix) + ensure_bytes(taskset_id)
|
|
|
+ return self.chord_keyprefix + ensure_bytes(taskset_id)
|
|
|
|
|
|
def _strip_prefix(self, key):
|
|
|
+ """Takes bytes, emits string."""
|
|
|
for prefix in self.task_keyprefix, self.taskset_keyprefix:
|
|
|
if key.startswith(prefix):
|
|
|
- return key[len(prefix):]
|
|
|
- return key
|
|
|
+ return bytes_to_str(key[len(prefix):])
|
|
|
+ return bytes_to_str(key)
|
|
|
|
|
|
def _mget_to_results(self, values, keys):
|
|
|
if hasattr(values, "items"):
|
|
@@ -341,7 +342,7 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
if v is not None)
|
|
|
else:
|
|
|
# client returns list so need to recreate mapping.
|
|
|
- return dict((keys[i], self.decode(value))
|
|
|
+ return dict((bytes_to_str(keys[i]), self.decode(value))
|
|
|
for i, value in enumerate(values)
|
|
|
if value is not None)
|
|
|
|
|
@@ -355,7 +356,7 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
pass
|
|
|
else:
|
|
|
if cached["status"] in states.READY_STATES:
|
|
|
- yield task_id, cached
|
|
|
+ yield bytes_to_str(task_id), cached
|
|
|
cached_ids.add(task_id)
|
|
|
|
|
|
ids ^= cached_ids
|
|
@@ -365,9 +366,9 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
r = self._mget_to_results(self.mget([self.get_key_for_task(k)
|
|
|
for k in keys]), keys)
|
|
|
self._cache.update(r)
|
|
|
- ids ^= set(r)
|
|
|
+ ids ^= set(map(bytes_to_str, r))
|
|
|
for key, value in r.iteritems():
|
|
|
- yield key, value
|
|
|
+ yield bytes_to_str(key), value
|
|
|
if timeout and iterations * interval >= timeout:
|
|
|
raise TimeoutError("Operation timed out (%s)" % (timeout, ))
|
|
|
time.sleep(interval) # don't busy loop.
|