mongodb.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. """MongoDB result store backend."""
  3. from datetime import datetime, timedelta
  4. from kombu.utils.objects import cached_property
  5. from kombu.utils.url import maybe_sanitize_url
  6. from kombu.exceptions import EncodeError
  7. from celery import states
  8. from celery.exceptions import ImproperlyConfigured
  9. from .base import BaseBackend
  10. try:
  11. import pymongo
  12. except ImportError: # pragma: no cover
  13. pymongo = None # noqa
  14. if pymongo:
  15. try:
  16. from bson.binary import Binary
  17. except ImportError: # pragma: no cover
  18. from pymongo.binary import Binary # noqa
  19. from pymongo.errors import InvalidDocument # noqa
  20. else: # pragma: no cover
  21. Binary = None # noqa
  22. class InvalidDocument(Exception): # noqa
  23. pass
  24. __all__ = ['MongoBackend']
  25. class MongoBackend(BaseBackend):
  26. """MongoDB result backend.
  27. Raises:
  28. celery.exceptions.ImproperlyConfigured:
  29. if module :pypi:`pymongo` is not available.
  30. """
  31. mongo_host = None
  32. host = 'localhost'
  33. port = 27017
  34. user = None
  35. password = None
  36. database_name = 'celery'
  37. taskmeta_collection = 'celery_taskmeta'
  38. groupmeta_collection = 'celery_groupmeta'
  39. max_pool_size = 10
  40. options = None
  41. supports_autoexpire = False
  42. _connection = None
  43. def __init__(self, app=None, **kwargs):
  44. self.options = {}
  45. super().__init__(app, **kwargs)
  46. if not pymongo:
  47. raise ImproperlyConfigured(
  48. 'You need to install the pymongo library to use the '
  49. 'MongoDB backend.')
  50. # Set option defaults
  51. for key, value in self._prepare_client_options().items():
  52. self.options.setdefault(key, value)
  53. # update conf with mongo uri data, only if uri was given
  54. if self.url:
  55. if self.url == 'mongodb://':
  56. self.url += 'localhost'
  57. uri_data = pymongo.uri_parser.parse_uri(self.url)
  58. # build the hosts list to create a mongo connection
  59. hostslist = [
  60. '{0}:{1}'.format(x[0], x[1]) for x in uri_data['nodelist']
  61. ]
  62. self.user = uri_data['username']
  63. self.password = uri_data['password']
  64. self.mongo_host = hostslist
  65. if uri_data['database']:
  66. # if no database is provided in the uri, use default
  67. self.database_name = uri_data['database']
  68. self.options.update(uri_data['options'])
  69. # update conf with specific settings
  70. config = self.app.conf.get('mongodb_backend_settings')
  71. if config is not None:
  72. if not isinstance(config, dict):
  73. raise ImproperlyConfigured(
  74. 'MongoDB backend settings should be grouped in a dict')
  75. config = dict(config) # don't modify original
  76. if 'host' in config or 'port' in config:
  77. # these should take over uri conf
  78. self.mongo_host = None
  79. self.host = config.pop('host', self.host)
  80. self.port = config.pop('port', self.port)
  81. self.mongo_host = config.pop('mongo_host', self.mongo_host)
  82. self.user = config.pop('user', self.user)
  83. self.password = config.pop('password', self.password)
  84. self.database_name = config.pop('database', self.database_name)
  85. self.taskmeta_collection = config.pop(
  86. 'taskmeta_collection', self.taskmeta_collection,
  87. )
  88. self.groupmeta_collection = config.pop(
  89. 'groupmeta_collection', self.groupmeta_collection,
  90. )
  91. self.options.update(config.pop('options', {}))
  92. self.options.update(config)
  93. def _prepare_client_options(self):
  94. if pymongo.version_tuple >= (3,):
  95. return {'maxPoolSize': self.max_pool_size}
  96. else: # pragma: no cover
  97. return {'max_pool_size': self.max_pool_size,
  98. 'auto_start_request': False}
  99. def _get_connection(self):
  100. """Connect to the MongoDB server."""
  101. if self._connection is None:
  102. from pymongo import MongoClient
  103. host = self.mongo_host
  104. if not host:
  105. # The first pymongo.Connection() argument (host) can be
  106. # a list of ['host:port'] elements or a mongodb connection
  107. # URI. If this is the case, don't use self.port
  108. # but let pymongo get the port(s) from the URI instead.
  109. # This enables the use of replica sets and sharding.
  110. # See pymongo.Connection() for more info.
  111. host = self.host
  112. if (isinstance(host, str) and
  113. not host.startswith('mongodb://')):
  114. host = 'mongodb://{0}:{1}'.format(host, self.port)
  115. # don't change self.options
  116. conf = dict(self.options)
  117. conf['host'] = host
  118. self._connection = MongoClient(**conf)
  119. return self._connection
  120. def encode(self, data):
  121. if self.serializer == 'bson':
  122. # mongodb handles serialization
  123. return data
  124. return super().encode(data)
  125. def decode(self, data):
  126. if self.serializer == 'bson':
  127. return data
  128. return super().decode(data)
  129. def _store_result(self, task_id, result, state,
  130. traceback=None, request=None, **kwargs):
  131. """Store return value and state of an executed task."""
  132. meta = {
  133. '_id': task_id,
  134. 'status': state,
  135. 'result': self.encode(result),
  136. 'date_done': datetime.utcnow(),
  137. 'traceback': self.encode(traceback),
  138. 'children': self.encode(
  139. self.current_task_children(request),
  140. ),
  141. }
  142. try:
  143. self.collection.save(meta)
  144. except InvalidDocument as exc:
  145. raise EncodeError(exc)
  146. return result
  147. def _get_task_meta_for(self, task_id):
  148. """Get task meta-data for a task by id."""
  149. obj = self.collection.find_one({'_id': task_id})
  150. if obj:
  151. return self.meta_from_decoded({
  152. 'task_id': obj['_id'],
  153. 'status': obj['status'],
  154. 'result': self.decode(obj['result']),
  155. 'date_done': obj['date_done'],
  156. 'traceback': self.decode(obj['traceback']),
  157. 'children': self.decode(obj['children']),
  158. })
  159. return {'status': states.PENDING, 'result': None}
  160. def _save_group(self, group_id, result):
  161. """Save the group result."""
  162. self.group_collection.save({
  163. '_id': group_id,
  164. 'result': self.encode([i.id for i in result]),
  165. 'date_done': datetime.utcnow(),
  166. })
  167. return result
  168. def _restore_group(self, group_id):
  169. """Get the result for a group by id."""
  170. obj = self.group_collection.find_one({'_id': group_id})
  171. if obj:
  172. return {
  173. 'task_id': obj['_id'],
  174. 'date_done': obj['date_done'],
  175. 'result': [
  176. self.app.AsyncResult(task)
  177. for task in self.decode(obj['result'])
  178. ],
  179. }
  180. def _delete_group(self, group_id):
  181. """Delete a group by id."""
  182. self.group_collection.remove({'_id': group_id})
  183. def _forget(self, task_id):
  184. """Remove result from MongoDB.
  185. Raises:
  186. pymongo.exceptions.OperationsError:
  187. if the task_id could not be removed.
  188. """
  189. # By using safe=True, this will wait until it receives a response from
  190. # the server. Likewise, it will raise an OperationsError if the
  191. # response was unable to be completed.
  192. self.collection.remove({'_id': task_id})
  193. def cleanup(self):
  194. """Delete expired meta-data."""
  195. self.collection.remove(
  196. {'date_done': {'$lt': self.app.now() - self.expires_delta}},
  197. )
  198. self.group_collection.remove(
  199. {'date_done': {'$lt': self.app.now() - self.expires_delta}},
  200. )
  201. def __reduce__(self, args=(), kwargs={}):
  202. return super().__reduce__(
  203. args, dict(kwargs, expires=self.expires, url=self.url))
  204. def _get_database(self):
  205. conn = self._get_connection()
  206. db = conn[self.database_name]
  207. if self.user and self.password:
  208. if not db.authenticate(self.user, self.password):
  209. raise ImproperlyConfigured(
  210. 'Invalid MongoDB username or password.')
  211. return db
  212. @cached_property
  213. def database(self):
  214. """Get database from MongoDB connection.
  215. performs authentication if necessary.
  216. """
  217. return self._get_database()
  218. @cached_property
  219. def collection(self):
  220. """Get the meta-data task collection."""
  221. collection = self.database[self.taskmeta_collection]
  222. # Ensure an index on date_done is there, if not process the index
  223. # in the background. Once completed cleanup will be much faster
  224. collection.ensure_index('date_done', background='true')
  225. return collection
  226. @cached_property
  227. def group_collection(self):
  228. """Get the meta-data task collection."""
  229. collection = self.database[self.groupmeta_collection]
  230. # Ensure an index on date_done is there, if not process the index
  231. # in the background. Once completed cleanup will be much faster
  232. collection.ensure_index('date_done', background='true')
  233. return collection
  234. @cached_property
  235. def expires_delta(self):
  236. return timedelta(seconds=self.expires)
  237. def as_uri(self, include_password=False):
  238. """Return the backend as an URI.
  239. Arguments:
  240. include_password (bool): Password censored if disabled.
  241. """
  242. if not self.url:
  243. return 'mongodb://'
  244. if include_password:
  245. return self.url
  246. if ',' not in self.url:
  247. return maybe_sanitize_url(self.url)
  248. uri1, remainder = self.url.split(',', 1)
  249. return ','.join([maybe_sanitize_url(uri1), remainder])