|
@@ -28,44 +28,15 @@ else: # pragma: no cover
|
|
|
from kombu.syn import detect_environment
|
|
|
from kombu.utils import cached_property
|
|
|
from kombu.exceptions import EncodeError
|
|
|
-from kombu.serialization import register, disable_insecure_serializers
|
|
|
from celery import states
|
|
|
from celery.exceptions import ImproperlyConfigured
|
|
|
from celery.five import string_t
|
|
|
from celery.utils.timeutils import maybe_timedelta
|
|
|
-from celery.result import AsyncResult
|
|
|
|
|
|
from .base import BaseBackend
|
|
|
|
|
|
__all__ = ['MongoBackend']
|
|
|
|
|
|
-BINARY_CODECS = frozenset(['pickle', 'msgpack'])
|
|
|
-
|
|
|
-# register a fake bson serializer which will return the document as it is
|
|
|
-
|
|
|
-
|
|
|
-class bson_serializer():
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def loads(obj, *args, **kwargs):
|
|
|
- if isinstance(obj, string_t):
|
|
|
- try:
|
|
|
- from anyjson import loads
|
|
|
- return loads(obj)
|
|
|
- except:
|
|
|
- pass
|
|
|
- return obj
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def dumps(obj, *args, **kwargs):
|
|
|
- return obj
|
|
|
-
|
|
|
-register('bson', bson_serializer.loads, bson_serializer.dumps,
|
|
|
- content_type='application/data',
|
|
|
- content_encoding='utf-8')
|
|
|
-
|
|
|
-disable_insecure_serializers(['json', 'bson'])
|
|
|
-
|
|
|
|
|
|
class Bunch(object):
|
|
|
|
|
@@ -102,11 +73,6 @@ class MongoBackend(BaseBackend):
|
|
|
self.expires = kwargs.get('expires') or maybe_timedelta(
|
|
|
self.app.conf.CELERY_TASK_RESULT_EXPIRES)
|
|
|
|
|
|
- # little hack to get over standard kombu loads because
|
|
|
- # mongo return strings which don't get decoded!
|
|
|
- if self.serializer == 'bson':
|
|
|
- self.decode = self.decode_bson
|
|
|
-
|
|
|
if not pymongo:
|
|
|
raise ImproperlyConfigured(
|
|
|
'You need to install the pymongo library to use the '
|
|
@@ -175,21 +141,15 @@ class MongoBackend(BaseBackend):
|
|
|
self._connection = None
|
|
|
|
|
|
def encode(self, data):
|
|
|
- payload = super(MongoBackend, self).encode(data)
|
|
|
- # serializer which are in a unsupported format (pickle/binary)
|
|
|
- if self.serializer in BINARY_CODECS:
|
|
|
- payload = Binary(payload)
|
|
|
-
|
|
|
- return payload
|
|
|
-
|
|
|
- def decode_bson(self, data):
|
|
|
- return bson_serializer.loads(data)
|
|
|
+ if self.serializer == 'bson':
|
|
|
+ # mongodb handles serialization
|
|
|
+ return data
|
|
|
+ return super(MongoBackend, self).encode(data)
|
|
|
|
|
|
- def encode_result(self, result, status):
|
|
|
- if status in self.EXCEPTION_STATES and isinstance(result, Exception):
|
|
|
- return self.prepare_exception(result)
|
|
|
- else:
|
|
|
- return self.prepare_value(result)
|
|
|
+ def decode(self, data):
|
|
|
+ if self.serializer == 'bson':
|
|
|
+ return data
|
|
|
+ return super(MongoBackend, self).decode(data)
|
|
|
|
|
|
def _store_result(self, task_id, result, status,
|
|
|
traceback=None, request=None, **kwargs):
|
|
@@ -213,24 +173,20 @@ class MongoBackend(BaseBackend):
|
|
|
|
|
|
def _get_task_meta_for(self, task_id):
|
|
|
"""Get task metadata for a task by id."""
|
|
|
-
|
|
|
# if collection don't contain it try searching in the
|
|
|
# group_collection it could be a groupresult instead
|
|
|
obj = self.collection.find_one({'_id': task_id}) or \
|
|
|
self.group_collection.find_one({'_id': task_id})
|
|
|
- if not obj:
|
|
|
- return {'status': states.PENDING, 'result': None}
|
|
|
-
|
|
|
- meta = {
|
|
|
- 'task_id': obj['_id'],
|
|
|
- 'status': obj['status'],
|
|
|
- 'result': self.decode(obj['result']),
|
|
|
- 'date_done': obj['date_done'],
|
|
|
- 'traceback': self.decode(obj['traceback']),
|
|
|
- 'children': self.decode(obj['children']),
|
|
|
- }
|
|
|
-
|
|
|
- return meta
|
|
|
+ if obj:
|
|
|
+ return {
|
|
|
+ 'task_id': obj['_id'],
|
|
|
+ 'status': obj['status'],
|
|
|
+ 'result': self.decode(obj['result']),
|
|
|
+ 'date_done': obj['date_done'],
|
|
|
+ 'traceback': self.decode(obj['traceback']),
|
|
|
+ 'children': self.decode(obj['children']),
|
|
|
+ }
|
|
|
+ return {'status': states.PENDING, 'result': None}
|
|
|
|
|
|
def _save_group(self, group_id, result):
|
|
|
"""Save the group result."""
|
|
@@ -247,20 +203,15 @@ class MongoBackend(BaseBackend):
|
|
|
def _restore_group(self, group_id):
|
|
|
"""Get the result for a group by id."""
|
|
|
obj = self.group_collection.find_one({'_id': group_id})
|
|
|
- if not obj:
|
|
|
- return
|
|
|
-
|
|
|
- tasks = self.decode(obj['result'])
|
|
|
-
|
|
|
- tasks = [AsyncResult(task) for task in tasks]
|
|
|
-
|
|
|
- meta = {
|
|
|
- 'task_id': obj['_id'],
|
|
|
- 'result': tasks,
|
|
|
- 'date_done': obj['date_done'],
|
|
|
- }
|
|
|
-
|
|
|
- return meta
|
|
|
+ if obj:
|
|
|
+ tasks = [self.app.AsyncResult(task)
|
|
|
+ for task in self.decode(obj['result'])]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'task_id': obj['_id'],
|
|
|
+ 'result': tasks,
|
|
|
+ 'date_done': obj['date_done'],
|
|
|
+ }
|
|
|
|
|
|
def _delete_group(self, group_id):
|
|
|
"""Delete a group by id."""
|