Browse Source

Cosmetics for #1990

Ask Solem 10 years ago
parent
commit
639b40f630
1 changed files with 27 additions and 76 deletions
  1. 27 76
      celery/backends/mongodb.py

+ 27 - 76
celery/backends/mongodb.py

@@ -28,44 +28,15 @@ else:                                       # pragma: no cover
 from kombu.syn import detect_environment
 from kombu.utils import cached_property
 from kombu.exceptions import EncodeError
-from kombu.serialization import register, disable_insecure_serializers
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.five import string_t
 from celery.utils.timeutils import maybe_timedelta
-from celery.result import AsyncResult
 
 from .base import BaseBackend
 
 __all__ = ['MongoBackend']
 
-BINARY_CODECS = frozenset(['pickle', 'msgpack'])
-
-# register a fake bson serializer which will return the document as it is
-
-
-class bson_serializer():
-
-    @staticmethod
-    def loads(obj, *args, **kwargs):
-        if isinstance(obj, string_t):
-            try:
-                from anyjson import loads
-                return loads(obj)
-            except:
-                pass
-        return obj
-
-    @staticmethod
-    def dumps(obj, *args, **kwargs):
-        return obj
-
-register('bson', bson_serializer.loads, bson_serializer.dumps,
-         content_type='application/data',
-         content_encoding='utf-8')
-
-disable_insecure_serializers(['json', 'bson'])
-
 
 class Bunch(object):
 
@@ -102,11 +73,6 @@ class MongoBackend(BaseBackend):
         self.expires = kwargs.get('expires') or maybe_timedelta(
             self.app.conf.CELERY_TASK_RESULT_EXPIRES)
 
-        # little hack to get over standard kombu loads because
-        # mongo return strings which don't get decoded!
-        if self.serializer == 'bson':
-            self.decode = self.decode_bson
-
         if not pymongo:
             raise ImproperlyConfigured(
                 'You need to install the pymongo library to use the '
@@ -175,21 +141,15 @@ class MongoBackend(BaseBackend):
             self._connection = None
 
     def encode(self, data):
-        payload = super(MongoBackend, self).encode(data)
-        # serializer which are in a unsupported format (pickle/binary)
-        if self.serializer in BINARY_CODECS:
-            payload = Binary(payload)
-
-        return payload
-
-    def decode_bson(self, data):
-        return bson_serializer.loads(data)
+        if self.serializer == 'bson':
+            # mongodb handles serialization
+            return data
+        return super(MongoBackend, self).encode(data)
 
-    def encode_result(self, result, status):
-        if status in self.EXCEPTION_STATES and isinstance(result, Exception):
-            return self.prepare_exception(result)
-        else:
-            return self.prepare_value(result)
+    def decode(self, data):
+        if self.serializer == 'bson':
+            return data
+        return super(MongoBackend, self).decode(data)
 
     def _store_result(self, task_id, result, status,
                       traceback=None, request=None, **kwargs):
@@ -213,24 +173,20 @@ class MongoBackend(BaseBackend):
 
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
-
         # if collection don't contain it try searching in the
         # group_collection it could be a groupresult instead
         obj = self.collection.find_one({'_id': task_id}) or \
             self.group_collection.find_one({'_id': task_id})
-        if not obj:
-            return {'status': states.PENDING, 'result': None}
-
-        meta = {
-            'task_id': obj['_id'],
-            'status': obj['status'],
-            'result': self.decode(obj['result']),
-            'date_done': obj['date_done'],
-            'traceback': self.decode(obj['traceback']),
-            'children': self.decode(obj['children']),
-        }
-
-        return meta
+        if obj:
+            return {
+                'task_id': obj['_id'],
+                'status': obj['status'],
+                'result': self.decode(obj['result']),
+                'date_done': obj['date_done'],
+                'traceback': self.decode(obj['traceback']),
+                'children': self.decode(obj['children']),
+            }
+        return {'status': states.PENDING, 'result': None}
 
     def _save_group(self, group_id, result):
         """Save the group result."""
@@ -247,20 +203,15 @@ class MongoBackend(BaseBackend):
     def _restore_group(self, group_id):
         """Get the result for a group by id."""
         obj = self.group_collection.find_one({'_id': group_id})
-        if not obj:
-            return
-
-        tasks = self.decode(obj['result'])
-
-        tasks = [AsyncResult(task) for task in tasks]
-
-        meta = {
-            'task_id': obj['_id'],
-            'result': tasks,
-            'date_done': obj['date_done'],
-        }
-
-        return meta
+        if obj:
+            tasks = [self.app.AsyncResult(task)
+                     for task in self.decode(obj['result'])]
+
+            return {
+                'task_id': obj['_id'],
+                'result': tasks,
+                'date_done': obj['date_done'],
+            }
 
     def _delete_group(self, group_id):
         """Delete a group by id."""