Quellcode durchsuchen

add a groupmeta_collection option to save group results in a different collection than results and add a 'bson' fake kombu encoder to allow pymongo to serialize natively data in mongodb

ocean1 vor 11 Jahren
Ursprung
Commit
4f5e8b80b7
1 geänderte Dateien mit 95 neuen und 12 gelöschten Zeilen
  1. 95 12
      celery/backends/mongodb.py

+ 95 - 12
celery/backends/mongodb.py

@@ -20,21 +20,48 @@ if pymongo:
         from bson.binary import Binary
     except ImportError:                     # pragma: no cover
         from pymongo.binary import Binary   # noqa
+    from pymongo.errors import InvalidDocument # noqa
 else:                                       # pragma: no cover
     Binary = None                           # noqa
+    InvalidDocument = None                  # noqa
 
 from kombu.syn import detect_environment
 from kombu.utils import cached_property
-
+from kombu.exceptions import EncodeError
+from kombu.serialization import register, disable_insecure_serializers
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.five import string_t
 from celery.utils.timeutils import maybe_timedelta
+from celery.result import AsyncResult
 
 from .base import BaseBackend
 
 __all__ = ['MongoBackend']
 
+BINARY_CODECS = frozenset(['pickle','msgpack'])
+
+#register a fake bson serializer which will return the document as it is
+class bson_serializer():
+    @staticmethod
+    def loads(obj, *args, **kwargs):
+        if isinstance(obj,string_t):
+            try:
+                from anyjson import loads
+                return loads(obj)
+            except:
+                pass
+        return obj
+
+    @staticmethod
+    def dumps(obj, *args, **kwargs):
+        return obj
+
+register('bson', bson_serializer.loads, bson_serializer.dumps,
+    content_type='application/data',
+    content_encoding='utf-8')
+
+disable_insecure_serializers(['json','bson'])
 
 class Bunch(object):
 
@@ -43,6 +70,7 @@ class Bunch(object):
 
 
 class MongoBackend(BaseBackend):
+
     host = 'localhost'
     port = 27017
     user = None
@@ -64,10 +92,16 @@ class MongoBackend(BaseBackend):
 
         """
         self.options = {}
+
         super(MongoBackend, self).__init__(*args, **kwargs)
         self.expires = kwargs.get('expires') or maybe_timedelta(
             self.app.conf.CELERY_TASK_RESULT_EXPIRES)
 
+        # little hack to get over standard kombu loads because
+        # mongo return strings which don't get decoded!
+        if self.serializer == 'bson':
+            self.decode = self.decode_bson
+
         if not pymongo:
             raise ImproperlyConfigured(
                 'You need to install the pymongo library to use the '
@@ -88,6 +122,9 @@ class MongoBackend(BaseBackend):
             self.taskmeta_collection = config.pop(
                 'taskmeta_collection', self.taskmeta_collection,
             )
+            self.groupmeta_collection = config.pop(
+                'groupmeta_collection', self.taskmeta_collection,
+            )
 
             self.options = dict(config, **config.pop('options', None) or {})
 
@@ -101,6 +138,7 @@ class MongoBackend(BaseBackend):
             # Specifying backend as an URL
             self.host = url
 
+
     def _get_connection(self):
         """Connect to the MongoDB server."""
         if self._connection is None:
@@ -132,25 +170,50 @@ class MongoBackend(BaseBackend):
             del(self.database)
             self._connection = None
 
+    def encode(self, data):
+        payload = super(MongoBackend, self).encode(data)
+        #serializer which are in a unsupported format (pickle/binary)
+        if self.serializer in BINARY_CODECS:
+            payload = Binary(payload)
+
+        return payload
+
+    def decode_bson(self, data):
+        return bson_serializer.loads(data)
+
+    def encode_result(self, result, status):
+        if status in self.EXCEPTION_STATES and isinstance(result, Exception):
+            return self.prepare_exception(result)
+        else:
+            return self.prepare_value(result)
+
     def _store_result(self, task_id, result, status,
                       traceback=None, request=None, **kwargs):
         """Store return value and status of an executed task."""
+
         meta = {'_id': task_id,
                 'status': status,
-                'result': Binary(self.encode(result)),
+                'result': self.encode(result),
                 'date_done': datetime.utcnow(),
-                'traceback': Binary(self.encode(traceback)),
-                'children': Binary(self.encode(
+                'traceback': self.encode(traceback),
+                'children': self.encode(
                     self.current_task_children(request),
-                ))}
-        self.collection.save(meta)
+                )}
+
+        try:
+            self.collection.save(meta)
+        except InvalidDocument as exc:
+            raise EncodeError(exc)
 
         return result
 
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
 
-        obj = self.collection.find_one({'_id': task_id})
+        # if collection don't contain it try searching in the
+        # group_collection it could be a groupresult instead
+        obj = self.collection.find_one({'_id': task_id}) or \
+              self.group_collection.find_one({'_id': task_id})
         if not obj:
             return {'status': states.PENDING, 'result': None}
 
@@ -167,22 +230,29 @@ class MongoBackend(BaseBackend):
 
     def _save_group(self, group_id, result):
         """Save the group result."""
+
+        task_ids = [ i.id for i in result ]
+
         meta = {'_id': group_id,
-                'result': Binary(self.encode(result)),
+                'result': self.encode(task_ids),
                 'date_done': datetime.utcnow()}
-        self.collection.save(meta)
+        self.group_collection.save(meta)
 
         return result
 
     def _restore_group(self, group_id):
         """Get the result for a group by id."""
-        obj = self.collection.find_one({'_id': group_id})
+        obj = self.group_collection.find_one({'_id': group_id})
         if not obj:
             return
 
+        tasks = self.decode(obj['result'])
+
+        tasks = [ AsyncResult(task) for task in tasks ]
+
         meta = {
             'task_id': obj['_id'],
-            'result': self.decode(obj['result']),
+            'result': tasks,
             'date_done': obj['date_done'],
         }
 
@@ -190,7 +260,7 @@ class MongoBackend(BaseBackend):
 
     def _delete_group(self, group_id):
         """Delete a group by id."""
-        self.collection.remove({'_id': group_id})
+        self.group_collection.remove({'_id': group_id})
 
     def _forget(self, task_id):
         """
@@ -209,6 +279,9 @@ class MongoBackend(BaseBackend):
         self.collection.remove(
             {'date_done': {'$lt': self.app.now() - self.expires}},
         )
+        self.group_collection.remove(
+            {'date_done': {'$lt': self.app.now() - self.expires}},
+        )
 
     def __reduce__(self, args=(), kwargs={}):
         kwargs.update(
@@ -240,3 +313,13 @@ class MongoBackend(BaseBackend):
         # in the background. Once completed cleanup will be much faster
         collection.ensure_index('date_done', background='true')
         return collection
+
+    @cached_property
+    def group_collection(self):
+        """Get the metadata task collection."""
+        collection = self.database[self.groupmeta_collection]
+
+        # Ensure an index on date_done is there, if not process the index
+        # in the background. Once completed cleanup will be much faster
+        collection.ensure_index('date_done', background='true')
+        return collection