瀏覽代碼

KeyValueStoreBackend and BaseDictBackend now respects CELERY_RESULT_SERIALIZER.

Closes #435.
Steeve Morin 13 年之前
父節點
當前提交
1d20a4babb
共有 1 個文件被更改,包括 29 次插入19 次删除
  1. 29 19
      celery/backends/base.py

+ 29 - 19
celery/backends/base.py

@@ -4,14 +4,18 @@ import sys
 
 from datetime import timedelta
 
+from kombu import serialization
+
 from celery import states
 from celery.exceptions import TimeoutError, TaskRevokedError
 from celery.utils import timeutils
-from celery.utils.serialization import pickle, get_pickled_exception
+from celery.utils.serialization import get_pickled_exception
 from celery.utils.serialization import get_pickleable_exception
 from celery.utils.serialization import create_exception_cls
 from celery.datastructures import LocalCache
 
+EXCEPTION_ABLE_CODECS = frozenset(["pickle", "yaml"])
+
 
 def unpickle_backend(cls, args, kwargs):
     """Returns an unpickled backend."""
@@ -34,6 +38,11 @@ class BaseBackend(object):
     def __init__(self, *args, **kwargs):
         from celery.app import app_or_default
         self.app = app_or_default(kwargs.get("app"))
+        self.serializer = kwargs.get("serializer",
+                                     self.app.conf.CELERY_RESULT_SERIALIZER)
+        (self.content_type,
+         self.content_encoding,
+         self.encoder) = serialization.registry._encoders[self.serializer]
 
     def prepare_expires(self, value, type=None):
         if value is None:
@@ -80,16 +89,13 @@ class BaseBackend(object):
 
     def prepare_exception(self, exc):
         """Prepare exception for serialization."""
-        if (self.app.conf["CELERY_RESULT_SERIALIZER"] in ("pickle", "yaml")):
+        if self.serializer in EXCEPTION_ABLE_CODECS:
             return get_pickleable_exception(exc)
-        return {
-            "exc_type": type(exc).__name__,
-            "exc_message": str(exc),
-        }
+        return {"exc_type": type(exc).__name__, "exc_message": str(exc)}
 
     def exception_to_python(self, exc):
         """Convert serialized exception to Python exception."""
-        if (self.app.conf["CELERY_RESULT_SERIALIZER"] in ("pickle", "yaml")):
+        if self.serializer in EXCEPTION_ABLE_CODECS:
             return get_pickled_exception(exc)
         return create_exception_cls(exc["exc_type"].encode("utf-8"),
                                     sys.modules[__name__])
@@ -306,12 +312,12 @@ class KeyValueStoreBackend(BaseDictBackend):
     def _mget_to_results(self, values, keys):
         if hasattr(values, "items"):
             # client returns dict so mapping preserved.
-            return dict((self._strip_prefix(k), pickle.loads(str(v)))
+            return dict((self._strip_prefix(k), self.decode(v))
                             for k, v in values.iteritems()
                                 if v is not None)
         else:
             # client returns list so need to recreate mapping.
-            return dict((keys[i], pickle.loads(str(value)))
+            return dict((keys[i], self.decode(value))
                             for i, value in enumerate(values)
                                 if value is not None)
 
@@ -342,14 +348,23 @@ class KeyValueStoreBackend(BaseDictBackend):
     def _forget(self, task_id):
         self.delete(self.get_key_for_task(task_id))
 
+    def encode(self, data):
+        _, _, payload = serialization.encode(data, serializer=self.serializer)
+        return payload
+
+    def decode(self, payload):
+        return serialization.decode(str(payload),
+                                    content_type=self.content_type,
+                                    content_encoding=self.content_encoding)
+
     def _store_result(self, task_id, result, status, traceback=None):
         meta = {"status": status, "result": result, "traceback": traceback}
-        self.set(self.get_key_for_task(task_id), pickle.dumps(meta))
+        self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 
     def _save_taskset(self, taskset_id, result):
         self.set(self.get_key_for_taskset(taskset_id),
-                 pickle.dumps({"result": result}))
+                 self.encode({"result": result}))
         return result
 
     def _delete_taskset(self, taskset_id):
@@ -360,14 +375,13 @@ class KeyValueStoreBackend(BaseDictBackend):
         meta = self.get(self.get_key_for_task(task_id))
         if not meta:
             return {"status": states.PENDING, "result": None}
-        return pickle.loads(str(meta))
+        return self.decode(meta)
 
     def _restore_taskset(self, taskset_id):
         """Get task metadata for a task by id."""
         meta = self.get(self.get_key_for_taskset(taskset_id))
         if meta:
-            meta = pickle.loads(str(meta))
-            return meta
+            return self.decode(meta)
 
 
 class DisabledBackend(BaseBackend):
@@ -378,8 +392,4 @@ class DisabledBackend(BaseBackend):
     def _is_disabled(self, *args, **kwargs):
         raise NotImplementedError("No result backend configured.  "
                 "Please see the documentation for more information.")
-
-    wait_for = _is_disabled
-    get_status = _is_disabled
-    get_result = _is_disabled
-    get_traceback = _is_disabled
+    wait_for = get_status = get_result = get_traceback = _is_disabled