|  | @@ -20,21 +20,48 @@ if pymongo:
 | 
	
		
			
				|  |  |          from bson.binary import Binary
 | 
	
		
			
				|  |  |      except ImportError:                     # pragma: no cover
 | 
	
		
			
				|  |  |          from pymongo.binary import Binary   # noqa
 | 
	
		
			
				|  |  | +    from pymongo.errors import InvalidDocument # noqa
 | 
	
		
			
				|  |  |  else:                                       # pragma: no cover
 | 
	
		
			
				|  |  |      Binary = None                           # noqa
 | 
	
		
			
				|  |  | +    InvalidDocument = None                  # noqa
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from kombu.syn import detect_environment
 | 
	
		
			
				|  |  |  from kombu.utils import cached_property
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +from kombu.exceptions import EncodeError
 | 
	
		
			
				|  |  | +from kombu.serialization import register, disable_insecure_serializers
 | 
	
		
			
				|  |  |  from celery import states
 | 
	
		
			
				|  |  |  from celery.exceptions import ImproperlyConfigured
 | 
	
		
			
				|  |  |  from celery.five import string_t
 | 
	
		
			
				|  |  |  from celery.utils.timeutils import maybe_timedelta
 | 
	
		
			
				|  |  | +from celery.result import AsyncResult
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from .base import BaseBackend
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  __all__ = ['MongoBackend']
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +BINARY_CODECS = frozenset(['pickle','msgpack'])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#register a fake bson serializer which will return the document as it is
 | 
	
		
			
				|  |  | +class bson_serializer():
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def loads(obj, *args, **kwargs):
 | 
	
		
			
				|  |  | +        if isinstance(obj,string_t):
 | 
	
		
			
				|  |  | +            try:
 | 
	
		
			
				|  |  | +                from anyjson import loads
 | 
	
		
			
				|  |  | +                return loads(obj)
 | 
	
		
			
				|  |  | +            except:
 | 
	
		
			
				|  |  | +                pass
 | 
	
		
			
				|  |  | +        return obj
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    def dumps(obj, *args, **kwargs):
 | 
	
		
			
				|  |  | +        return obj
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +register('bson', bson_serializer.loads, bson_serializer.dumps,
 | 
	
		
			
				|  |  | +    content_type='application/data',
 | 
	
		
			
				|  |  | +    content_encoding='utf-8')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +disable_insecure_serializers(['json','bson'])
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class Bunch(object):
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -43,6 +70,7 @@ class Bunch(object):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      host = 'localhost'
 | 
	
		
			
				|  |  |      port = 27017
 | 
	
		
			
				|  |  |      user = None
 | 
	
	
		
			
				|  | @@ -64,10 +92,16 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          self.options = {}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          super(MongoBackend, self).__init__(*args, **kwargs)
 | 
	
		
			
				|  |  |          self.expires = kwargs.get('expires') or maybe_timedelta(
 | 
	
		
			
				|  |  |              self.app.conf.CELERY_TASK_RESULT_EXPIRES)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        # little hack to get over standard kombu loads because
 | 
	
		
			
				|  |  | +        # mongo return strings which don't get decoded!
 | 
	
		
			
				|  |  | +        if self.serializer == 'bson':
 | 
	
		
			
				|  |  | +            self.decode = self.decode_bson
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          if not pymongo:
 | 
	
		
			
				|  |  |              raise ImproperlyConfigured(
 | 
	
		
			
				|  |  |                  'You need to install the pymongo library to use the '
 | 
	
	
		
			
				|  | @@ -88,6 +122,9 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |              self.taskmeta_collection = config.pop(
 | 
	
		
			
				|  |  |                  'taskmeta_collection', self.taskmeta_collection,
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  | +            self.groupmeta_collection = config.pop(
 | 
	
		
			
				|  |  | +                'groupmeta_collection', self.taskmeta_collection,
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              self.options = dict(config, **config.pop('options', None) or {})
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -101,6 +138,7 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |              # Specifying backend as an URL
 | 
	
		
			
				|  |  |              self.host = url
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _get_connection(self):
 | 
	
		
			
				|  |  |          """Connect to the MongoDB server."""
 | 
	
		
			
				|  |  |          if self._connection is None:
 | 
	
	
		
			
				|  | @@ -132,25 +170,50 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |              del(self.database)
 | 
	
		
			
				|  |  |              self._connection = None
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def encode(self, data):
 | 
	
		
			
				|  |  | +        payload = super(MongoBackend, self).encode(data)
 | 
	
		
			
				|  |  | +        #serializer which are in a unsupported format (pickle/binary)
 | 
	
		
			
				|  |  | +        if self.serializer in BINARY_CODECS:
 | 
	
		
			
				|  |  | +            payload = Binary(payload)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return payload
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def decode_bson(self, data):
 | 
	
		
			
				|  |  | +        return bson_serializer.loads(data)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def encode_result(self, result, status):
 | 
	
		
			
				|  |  | +        if status in self.EXCEPTION_STATES and isinstance(result, Exception):
 | 
	
		
			
				|  |  | +            return self.prepare_exception(result)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return self.prepare_value(result)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _store_result(self, task_id, result, status,
 | 
	
		
			
				|  |  |                        traceback=None, request=None, **kwargs):
 | 
	
		
			
				|  |  |          """Store return value and status of an executed task."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          meta = {'_id': task_id,
 | 
	
		
			
				|  |  |                  'status': status,
 | 
	
		
			
				|  |  | -                'result': Binary(self.encode(result)),
 | 
	
		
			
				|  |  | +                'result': self.encode(result),
 | 
	
		
			
				|  |  |                  'date_done': datetime.utcnow(),
 | 
	
		
			
				|  |  | -                'traceback': Binary(self.encode(traceback)),
 | 
	
		
			
				|  |  | -                'children': Binary(self.encode(
 | 
	
		
			
				|  |  | +                'traceback': self.encode(traceback),
 | 
	
		
			
				|  |  | +                'children': self.encode(
 | 
	
		
			
				|  |  |                      self.current_task_children(request),
 | 
	
		
			
				|  |  | -                ))}
 | 
	
		
			
				|  |  | -        self.collection.save(meta)
 | 
	
		
			
				|  |  | +                )}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            self.collection.save(meta)
 | 
	
		
			
				|  |  | +        except InvalidDocument as exc:
 | 
	
		
			
				|  |  | +            raise EncodeError(exc)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return result
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _get_task_meta_for(self, task_id):
 | 
	
		
			
				|  |  |          """Get task metadata for a task by id."""
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        obj = self.collection.find_one({'_id': task_id})
 | 
	
		
			
				|  |  | +        # if collection don't contain it try searching in the
 | 
	
		
			
				|  |  | +        # group_collection it could be a groupresult instead
 | 
	
		
			
				|  |  | +        obj = self.collection.find_one({'_id': task_id}) or \
 | 
	
		
			
				|  |  | +              self.group_collection.find_one({'_id': task_id})
 | 
	
		
			
				|  |  |          if not obj:
 | 
	
		
			
				|  |  |              return {'status': states.PENDING, 'result': None}
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -167,22 +230,29 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _save_group(self, group_id, result):
 | 
	
		
			
				|  |  |          """Save the group result."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        task_ids = [ i.id for i in result ]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          meta = {'_id': group_id,
 | 
	
		
			
				|  |  | -                'result': Binary(self.encode(result)),
 | 
	
		
			
				|  |  | +                'result': self.encode(task_ids),
 | 
	
		
			
				|  |  |                  'date_done': datetime.utcnow()}
 | 
	
		
			
				|  |  | -        self.collection.save(meta)
 | 
	
		
			
				|  |  | +        self.group_collection.save(meta)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          return result
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _restore_group(self, group_id):
 | 
	
		
			
				|  |  |          """Get the result for a group by id."""
 | 
	
		
			
				|  |  | -        obj = self.collection.find_one({'_id': group_id})
 | 
	
		
			
				|  |  | +        obj = self.group_collection.find_one({'_id': group_id})
 | 
	
		
			
				|  |  |          if not obj:
 | 
	
		
			
				|  |  |              return
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        tasks = self.decode(obj['result'])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        tasks = [ AsyncResult(task) for task in tasks ]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          meta = {
 | 
	
		
			
				|  |  |              'task_id': obj['_id'],
 | 
	
		
			
				|  |  | -            'result': self.decode(obj['result']),
 | 
	
		
			
				|  |  | +            'result': tasks,
 | 
	
		
			
				|  |  |              'date_done': obj['date_done'],
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -190,7 +260,7 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _delete_group(self, group_id):
 | 
	
		
			
				|  |  |          """Delete a group by id."""
 | 
	
		
			
				|  |  | -        self.collection.remove({'_id': group_id})
 | 
	
		
			
				|  |  | +        self.group_collection.remove({'_id': group_id})
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _forget(self, task_id):
 | 
	
		
			
				|  |  |          """
 | 
	
	
		
			
				|  | @@ -209,6 +279,9 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |          self.collection.remove(
 | 
	
		
			
				|  |  |              {'date_done': {'$lt': self.app.now() - self.expires}},
 | 
	
		
			
				|  |  |          )
 | 
	
		
			
				|  |  | +        self.group_collection.remove(
 | 
	
		
			
				|  |  | +            {'date_done': {'$lt': self.app.now() - self.expires}},
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __reduce__(self, args=(), kwargs={}):
 | 
	
		
			
				|  |  |          kwargs.update(
 | 
	
	
		
			
				|  | @@ -240,3 +313,13 @@ class MongoBackend(BaseBackend):
 | 
	
		
			
				|  |  |          # in the background. Once completed cleanup will be much faster
 | 
	
		
			
				|  |  |          collection.ensure_index('date_done', background='true')
 | 
	
		
			
				|  |  |          return collection
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @cached_property
 | 
	
		
			
				|  |  | +    def group_collection(self):
 | 
	
		
			
				|  |  | +        """Get the metadata task collection."""
 | 
	
		
			
				|  |  | +        collection = self.database[self.groupmeta_collection]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Ensure an index on date_done is there, if not process the index
 | 
	
		
			
				|  |  | +        # in the background. Once completed cleanup will be much faster
 | 
	
		
			
				|  |  | +        collection.ensure_index('date_done', background='true')
 | 
	
		
			
				|  |  | +        return collection
 |