|
@@ -4,14 +4,18 @@ import sys
|
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
+from kombu import serialization
|
|
|
+
|
|
|
from celery import states
|
|
|
from celery.exceptions import TimeoutError, TaskRevokedError
|
|
|
from celery.utils import timeutils
|
|
|
-from celery.utils.serialization import pickle, get_pickled_exception
|
|
|
+from celery.utils.serialization import get_pickled_exception
|
|
|
from celery.utils.serialization import get_pickleable_exception
|
|
|
from celery.utils.serialization import create_exception_cls
|
|
|
from celery.datastructures import LocalCache
|
|
|
|
|
|
+EXCEPTION_ABLE_CODECS = frozenset(["pickle", "yaml"])
|
|
|
+
|
|
|
|
|
|
def unpickle_backend(cls, args, kwargs):
|
|
|
"""Returns an unpickled backend."""
|
|
@@ -34,6 +38,11 @@ class BaseBackend(object):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
from celery.app import app_or_default
|
|
|
self.app = app_or_default(kwargs.get("app"))
|
|
|
+ self.serializer = kwargs.get("serializer",
|
|
|
+ self.app.conf.CELERY_RESULT_SERIALIZER)
|
|
|
+ (self.content_type,
|
|
|
+ self.content_encoding,
|
|
|
+ self.encoder) = serialization.registry._encoders[self.serializer]
|
|
|
|
|
|
def prepare_expires(self, value, type=None):
|
|
|
if value is None:
|
|
@@ -80,16 +89,13 @@ class BaseBackend(object):
|
|
|
|
|
|
def prepare_exception(self, exc):
|
|
|
"""Prepare exception for serialization."""
|
|
|
- if (self.app.conf["CELERY_RESULT_SERIALIZER"] in ("pickle", "yaml")):
|
|
|
+ if self.serializer in EXCEPTION_ABLE_CODECS:
|
|
|
return get_pickleable_exception(exc)
|
|
|
- return {
|
|
|
- "exc_type": type(exc).__name__,
|
|
|
- "exc_message": str(exc),
|
|
|
- }
|
|
|
+ return {"exc_type": type(exc).__name__, "exc_message": str(exc)}
|
|
|
|
|
|
def exception_to_python(self, exc):
|
|
|
"""Convert serialized exception to Python exception."""
|
|
|
- if (self.app.conf["CELERY_RESULT_SERIALIZER"] in ("pickle", "yaml")):
|
|
|
+ if self.serializer in EXCEPTION_ABLE_CODECS:
|
|
|
return get_pickled_exception(exc)
|
|
|
return create_exception_cls(exc["exc_type"].encode("utf-8"),
|
|
|
sys.modules[__name__])
|
|
@@ -306,12 +312,12 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
def _mget_to_results(self, values, keys):
|
|
|
if hasattr(values, "items"):
|
|
|
# client returns dict so mapping preserved.
|
|
|
- return dict((self._strip_prefix(k), pickle.loads(str(v)))
|
|
|
+ return dict((self._strip_prefix(k), self.decode(v))
|
|
|
for k, v in values.iteritems()
|
|
|
if v is not None)
|
|
|
else:
|
|
|
# client returns list so need to recreate mapping.
|
|
|
- return dict((keys[i], pickle.loads(str(value)))
|
|
|
+ return dict((keys[i], self.decode(value))
|
|
|
for i, value in enumerate(values)
|
|
|
if value is not None)
|
|
|
|
|
@@ -342,14 +348,23 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
def _forget(self, task_id):
|
|
|
self.delete(self.get_key_for_task(task_id))
|
|
|
|
|
|
+ def encode(self, data):
|
|
|
+ _, _, payload = serialization.encode(data, serializer=self.serializer)
|
|
|
+ return payload
|
|
|
+
|
|
|
+ def decode(self, payload):
|
|
|
+ return serialization.decode(str(payload),
|
|
|
+ content_type=self.content_type,
|
|
|
+ content_encoding=self.content_encoding)
|
|
|
+
|
|
|
def _store_result(self, task_id, result, status, traceback=None):
|
|
|
meta = {"status": status, "result": result, "traceback": traceback}
|
|
|
- self.set(self.get_key_for_task(task_id), pickle.dumps(meta))
|
|
|
+ self.set(self.get_key_for_task(task_id), self.encode(meta))
|
|
|
return result
|
|
|
|
|
|
def _save_taskset(self, taskset_id, result):
|
|
|
self.set(self.get_key_for_taskset(taskset_id),
|
|
|
- pickle.dumps({"result": result}))
|
|
|
+ self.encode({"result": result}))
|
|
|
return result
|
|
|
|
|
|
def _delete_taskset(self, taskset_id):
|
|
@@ -360,14 +375,13 @@ class KeyValueStoreBackend(BaseDictBackend):
|
|
|
meta = self.get(self.get_key_for_task(task_id))
|
|
|
if not meta:
|
|
|
return {"status": states.PENDING, "result": None}
|
|
|
- return pickle.loads(str(meta))
|
|
|
+ return self.decode(meta)
|
|
|
|
|
|
def _restore_taskset(self, taskset_id):
|
|
|
"""Get task metadata for a task by id."""
|
|
|
meta = self.get(self.get_key_for_taskset(taskset_id))
|
|
|
if meta:
|
|
|
- meta = pickle.loads(str(meta))
|
|
|
- return meta
|
|
|
+ return self.decode(meta)
|
|
|
|
|
|
|
|
|
class DisabledBackend(BaseBackend):
|
|
@@ -378,8 +392,4 @@ class DisabledBackend(BaseBackend):
|
|
|
def _is_disabled(self, *args, **kwargs):
|
|
|
raise NotImplementedError("No result backend configured. "
|
|
|
"Please see the documentation for more information.")
|
|
|
-
|
|
|
- wait_for = _is_disabled
|
|
|
- get_status = _is_disabled
|
|
|
- get_result = _is_disabled
|
|
|
- get_traceback = _is_disabled
|
|
|
+ wait_for = get_status = get_result = get_traceback = _is_disabled
|