瀏覽代碼

Merge branch 'ocean1/mongodb_fix'

Ask Solem 10 年之前
父節點
當前提交
a9dce32690
共有 1 個文件被更改,包括 72 次插入34 次删除
  1. 72 34
      celery/backends/mongodb.py

+ 72 - 34
celery/backends/mongodb.py

@@ -20,12 +20,14 @@ if pymongo:
         from bson.binary import Binary
     except ImportError:                     # pragma: no cover
         from pymongo.binary import Binary   # noqa
+    from pymongo.errors import InvalidDocument  # noqa
 else:                                       # pragma: no cover
     Binary = None                           # noqa
+    InvalidDocument = None                  # noqa
 
 from kombu.syn import detect_environment
 from kombu.utils import cached_property
-
+from kombu.exceptions import EncodeError
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.five import string_t
@@ -43,12 +45,14 @@ class Bunch(object):
 
 
 class MongoBackend(BaseBackend):
+
     host = 'localhost'
     port = 27017
     user = None
     password = None
     database_name = 'celery'
     taskmeta_collection = 'celery_taskmeta'
+    groupmeta_collection = 'celery_groupmeta'
     max_pool_size = 10
     options = None
 
@@ -64,6 +68,7 @@ class MongoBackend(BaseBackend):
 
         """
         self.options = {}
+
         super(MongoBackend, self).__init__(*args, **kwargs)
         self.expires = kwargs.get('expires') or maybe_timedelta(
             self.app.conf.CELERY_TASK_RESULT_EXPIRES)
@@ -88,6 +93,9 @@ class MongoBackend(BaseBackend):
             self.taskmeta_collection = config.pop(
                 'taskmeta_collection', self.taskmeta_collection,
             )
+            self.groupmeta_collection = config.pop(
+                'groupmeta_collection', self.groupmeta_collection,
+            )
 
             self.options = dict(config, **config.pop('options', None) or {})
 
@@ -131,65 +139,82 @@ class MongoBackend(BaseBackend):
             del(self.database)
             self._connection = None
 
+    def encode(self, data):
+        if self.serializer == 'bson':
+            # mongodb handles serialization
+            return data
+        return super(MongoBackend, self).encode(data)
+
+    def decode(self, data):
+        if self.serializer == 'bson':
+            return data
+        return super(MongoBackend, self).decode(data)
+
     def _store_result(self, task_id, result, status,
                       traceback=None, request=None, **kwargs):
         """Store return value and status of an executed task."""
+
         meta = {'_id': task_id,
                 'status': status,
-                'result': Binary(self.encode(result)),
+                'result': self.encode(result),
                 'date_done': datetime.utcnow(),
-                'traceback': Binary(self.encode(traceback)),
-                'children': Binary(self.encode(
+                'traceback': self.encode(traceback),
+                'children': self.encode(
                     self.current_task_children(request),
-                ))}
-        self.collection.save(meta)
+                )}
+
+        try:
+            self.collection.save(meta)
+        except InvalidDocument as exc:
+            raise EncodeError(exc)
 
         return result
 
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
-
-        obj = self.collection.find_one({'_id': task_id})
-        if not obj:
-            return {'status': states.PENDING, 'result': None}
-
-        meta = {
-            'task_id': obj['_id'],
-            'status': obj['status'],
-            'result': self.decode(obj['result']),
-            'date_done': obj['date_done'],
-            'traceback': self.decode(obj['traceback']),
-            'children': self.decode(obj['children']),
-        }
-
-        return meta
+        # if collection don't contain it try searching in the
+        # group_collection it could be a groupresult instead
+        obj = self.collection.find_one({'_id': task_id}) or \
+            self.group_collection.find_one({'_id': task_id})
+        if obj:
+            return {
+                'task_id': obj['_id'],
+                'status': obj['status'],
+                'result': self.decode(obj['result']),
+                'date_done': obj['date_done'],
+                'traceback': self.decode(obj['traceback']),
+                'children': self.decode(obj['children']),
+            }
+        return {'status': states.PENDING, 'result': None}
 
     def _save_group(self, group_id, result):
         """Save the group result."""
+
+        task_ids = [i.id for i in result]
+
         meta = {'_id': group_id,
-                'result': Binary(self.encode(result)),
+                'result': self.encode(task_ids),
                 'date_done': datetime.utcnow()}
-        self.collection.save(meta)
+        self.group_collection.save(meta)
 
         return result
 
     def _restore_group(self, group_id):
         """Get the result for a group by id."""
-        obj = self.collection.find_one({'_id': group_id})
-        if not obj:
-            return
-
-        meta = {
-            'task_id': obj['_id'],
-            'result': self.decode(obj['result']),
-            'date_done': obj['date_done'],
-        }
+        obj = self.group_collection.find_one({'_id': group_id})
+        if obj:
+            tasks = [self.app.AsyncResult(task)
+                     for task in self.decode(obj['result'])]
 
-        return meta
+            return {
+                'task_id': obj['_id'],
+                'result': tasks,
+                'date_done': obj['date_done'],
+            }
 
     def _delete_group(self, group_id):
         """Delete a group by id."""
-        self.collection.remove({'_id': group_id})
+        self.group_collection.remove({'_id': group_id})
 
     def _forget(self, task_id):
         """
@@ -208,6 +233,9 @@ class MongoBackend(BaseBackend):
         self.collection.remove(
             {'date_done': {'$lt': self.app.now() - self.expires}},
         )
+        self.group_collection.remove(
+            {'date_done': {'$lt': self.app.now() - self.expires}},
+        )
 
     def __reduce__(self, args=(), kwargs={}):
         kwargs.update(
@@ -239,3 +267,13 @@ class MongoBackend(BaseBackend):
         # in the background. Once completed cleanup will be much faster
         collection.ensure_index('date_done', background='true')
         return collection
+
+    @cached_property
+    def group_collection(self):
+        """Get the metadata task collection."""
+        collection = self.database[self.groupmeta_collection]
+
+        # Ensure an index on date_done is there, if not process the index
+        # in the background. Once completed cleanup will be much faster
+        collection.ensure_index('date_done', background='true')
+        return collection