|
@@ -20,21 +20,48 @@ if pymongo:
|
|
|
from bson.binary import Binary
|
|
|
except ImportError: # pragma: no cover
|
|
|
from pymongo.binary import Binary # noqa
|
|
|
+ from pymongo.errors import InvalidDocument # noqa
|
|
|
else: # pragma: no cover
|
|
|
Binary = None # noqa
|
|
|
+ InvalidDocument = None # noqa
|
|
|
|
|
|
from kombu.syn import detect_environment
|
|
|
from kombu.utils import cached_property
|
|
|
-
|
|
|
+from kombu.exceptions import EncodeError
|
|
|
+from kombu.serialization import register, disable_insecure_serializers
|
|
|
from celery import states
|
|
|
from celery.exceptions import ImproperlyConfigured
|
|
|
from celery.five import string_t
|
|
|
from celery.utils.timeutils import maybe_timedelta
|
|
|
+from celery.result import AsyncResult
|
|
|
|
|
|
from .base import BaseBackend
|
|
|
|
|
|
__all__ = ['MongoBackend']
|
|
|
|
|
|
+BINARY_CODECS = frozenset(['pickle','msgpack'])
|
|
|
+
|
|
|
+#register a fake bson serializer which will return the document as it is
|
|
|
+class bson_serializer():
|
|
|
+ @staticmethod
|
|
|
+ def loads(obj, *args, **kwargs):
|
|
|
+ if isinstance(obj,string_t):
|
|
|
+ try:
|
|
|
+ from anyjson import loads
|
|
|
+ return loads(obj)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+ return obj
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def dumps(obj, *args, **kwargs):
|
|
|
+ return obj
|
|
|
+
|
|
|
+register('bson', bson_serializer.loads, bson_serializer.dumps,
|
|
|
+ content_type='application/data',
|
|
|
+ content_encoding='utf-8')
|
|
|
+
|
|
|
+disable_insecure_serializers(['json','bson'])
|
|
|
|
|
|
class Bunch(object):
|
|
|
|
|
@@ -43,6 +70,7 @@ class Bunch(object):
|
|
|
|
|
|
|
|
|
class MongoBackend(BaseBackend):
|
|
|
+
|
|
|
host = 'localhost'
|
|
|
port = 27017
|
|
|
user = None
|
|
@@ -64,10 +92,16 @@ class MongoBackend(BaseBackend):
|
|
|
|
|
|
"""
|
|
|
self.options = {}
|
|
|
+
|
|
|
super(MongoBackend, self).__init__(*args, **kwargs)
|
|
|
self.expires = kwargs.get('expires') or maybe_timedelta(
|
|
|
self.app.conf.CELERY_TASK_RESULT_EXPIRES)
|
|
|
|
|
|
+ # little hack to get over standard kombu loads because
|
|
|
+ # mongo return strings which don't get decoded!
|
|
|
+ if self.serializer == 'bson':
|
|
|
+ self.decode = self.decode_bson
|
|
|
+
|
|
|
if not pymongo:
|
|
|
raise ImproperlyConfigured(
|
|
|
'You need to install the pymongo library to use the '
|
|
@@ -88,6 +122,9 @@ class MongoBackend(BaseBackend):
|
|
|
self.taskmeta_collection = config.pop(
|
|
|
'taskmeta_collection', self.taskmeta_collection,
|
|
|
)
|
|
|
+ self.groupmeta_collection = config.pop(
|
|
|
+ 'groupmeta_collection', self.taskmeta_collection,
|
|
|
+ )
|
|
|
|
|
|
self.options = dict(config, **config.pop('options', None) or {})
|
|
|
|
|
@@ -101,6 +138,7 @@ class MongoBackend(BaseBackend):
|
|
|
# Specifying backend as an URL
|
|
|
self.host = url
|
|
|
|
|
|
+
|
|
|
def _get_connection(self):
|
|
|
"""Connect to the MongoDB server."""
|
|
|
if self._connection is None:
|
|
@@ -132,25 +170,50 @@ class MongoBackend(BaseBackend):
|
|
|
del(self.database)
|
|
|
self._connection = None
|
|
|
|
|
|
+ def encode(self, data):
|
|
|
+ payload = super(MongoBackend, self).encode(data)
|
|
|
+ #serializer which are in a unsupported format (pickle/binary)
|
|
|
+ if self.serializer in BINARY_CODECS:
|
|
|
+ payload = Binary(payload)
|
|
|
+
|
|
|
+ return payload
|
|
|
+
|
|
|
+ def decode_bson(self, data):
|
|
|
+ return bson_serializer.loads(data)
|
|
|
+
|
|
|
+ def encode_result(self, result, status):
|
|
|
+ if status in self.EXCEPTION_STATES and isinstance(result, Exception):
|
|
|
+ return self.prepare_exception(result)
|
|
|
+ else:
|
|
|
+ return self.prepare_value(result)
|
|
|
+
|
|
|
def _store_result(self, task_id, result, status,
|
|
|
traceback=None, request=None, **kwargs):
|
|
|
"""Store return value and status of an executed task."""
|
|
|
+
|
|
|
meta = {'_id': task_id,
|
|
|
'status': status,
|
|
|
- 'result': Binary(self.encode(result)),
|
|
|
+ 'result': self.encode(result),
|
|
|
'date_done': datetime.utcnow(),
|
|
|
- 'traceback': Binary(self.encode(traceback)),
|
|
|
- 'children': Binary(self.encode(
|
|
|
+ 'traceback': self.encode(traceback),
|
|
|
+ 'children': self.encode(
|
|
|
self.current_task_children(request),
|
|
|
- ))}
|
|
|
- self.collection.save(meta)
|
|
|
+ )}
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.collection.save(meta)
|
|
|
+ except InvalidDocument as exc:
|
|
|
+ raise EncodeError(exc)
|
|
|
|
|
|
return result
|
|
|
|
|
|
def _get_task_meta_for(self, task_id):
|
|
|
"""Get task metadata for a task by id."""
|
|
|
|
|
|
- obj = self.collection.find_one({'_id': task_id})
|
|
|
+ # if collection don't contain it try searching in the
|
|
|
+ # group_collection it could be a groupresult instead
|
|
|
+ obj = self.collection.find_one({'_id': task_id}) or \
|
|
|
+ self.group_collection.find_one({'_id': task_id})
|
|
|
if not obj:
|
|
|
return {'status': states.PENDING, 'result': None}
|
|
|
|
|
@@ -167,22 +230,29 @@ class MongoBackend(BaseBackend):
|
|
|
|
|
|
def _save_group(self, group_id, result):
|
|
|
"""Save the group result."""
|
|
|
+
|
|
|
+ task_ids = [ i.id for i in result ]
|
|
|
+
|
|
|
meta = {'_id': group_id,
|
|
|
- 'result': Binary(self.encode(result)),
|
|
|
+ 'result': self.encode(task_ids),
|
|
|
'date_done': datetime.utcnow()}
|
|
|
- self.collection.save(meta)
|
|
|
+ self.group_collection.save(meta)
|
|
|
|
|
|
return result
|
|
|
|
|
|
def _restore_group(self, group_id):
|
|
|
"""Get the result for a group by id."""
|
|
|
- obj = self.collection.find_one({'_id': group_id})
|
|
|
+ obj = self.group_collection.find_one({'_id': group_id})
|
|
|
if not obj:
|
|
|
return
|
|
|
|
|
|
+ tasks = self.decode(obj['result'])
|
|
|
+
|
|
|
+ tasks = [ AsyncResult(task) for task in tasks ]
|
|
|
+
|
|
|
meta = {
|
|
|
'task_id': obj['_id'],
|
|
|
- 'result': self.decode(obj['result']),
|
|
|
+ 'result': tasks,
|
|
|
'date_done': obj['date_done'],
|
|
|
}
|
|
|
|
|
@@ -190,7 +260,7 @@ class MongoBackend(BaseBackend):
|
|
|
|
|
|
def _delete_group(self, group_id):
|
|
|
"""Delete a group by id."""
|
|
|
- self.collection.remove({'_id': group_id})
|
|
|
+ self.group_collection.remove({'_id': group_id})
|
|
|
|
|
|
def _forget(self, task_id):
|
|
|
"""
|
|
@@ -209,6 +279,9 @@ class MongoBackend(BaseBackend):
|
|
|
self.collection.remove(
|
|
|
{'date_done': {'$lt': self.app.now() - self.expires}},
|
|
|
)
|
|
|
+ self.group_collection.remove(
|
|
|
+ {'date_done': {'$lt': self.app.now() - self.expires}},
|
|
|
+ )
|
|
|
|
|
|
def __reduce__(self, args=(), kwargs={}):
|
|
|
kwargs.update(
|
|
@@ -240,3 +313,13 @@ class MongoBackend(BaseBackend):
|
|
|
# in the background. Once completed cleanup will be much faster
|
|
|
collection.ensure_index('date_done', background='true')
|
|
|
return collection
|
|
|
+
|
|
|
+ @cached_property
|
|
|
+ def group_collection(self):
|
|
|
+ """Get the metadata task collection."""
|
|
|
+ collection = self.database[self.groupmeta_collection]
|
|
|
+
|
|
|
+ # Ensure an index on date_done is there, if not process the index
|
|
|
+ # in the background. Once completed cleanup will be much faster
|
|
|
+ collection.ensure_index('date_done', background='true')
|
|
|
+ return collection
|