mongodb.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # -*- coding: utf-8 -*-
  2. """MongoDB backend for celery."""
  3. from __future__ import absolute_import
  4. from datetime import datetime
  5. try:
  6. import pymongo
  7. except ImportError:
  8. pymongo = None # noqa
  9. from .. import states
  10. from ..exceptions import ImproperlyConfigured
  11. from ..utils.timeutils import maybe_timedelta
  12. from .base import BaseDictBackend
  13. class Bunch:
  14. def __init__(self, **kw):
  15. self.__dict__.update(kw)
  16. class MongoBackend(BaseDictBackend):
  17. mongodb_host = "localhost"
  18. mongodb_port = 27017
  19. mongodb_user = None
  20. mongodb_password = None
  21. mongodb_database = "celery"
  22. mongodb_taskmeta_collection = "celery_taskmeta"
  23. def __init__(self, *args, **kwargs):
  24. """Initialize MongoDB backend instance.
  25. :raises celery.exceptions.ImproperlyConfigured: if
  26. module :mod:`pymongo` is not available.
  27. """
  28. super(MongoBackend, self).__init__(*args, **kwargs)
  29. self.expires = kwargs.get("expires") or maybe_timedelta(
  30. self.app.conf.CELERY_TASK_RESULT_EXPIRES)
  31. if not pymongo:
  32. raise ImproperlyConfigured(
  33. "You need to install the pymongo library to use the "
  34. "MongoDB backend.")
  35. config = self.app.conf.get("CELERY_MONGODB_BACKEND_SETTINGS", None)
  36. if config is not None:
  37. if not isinstance(config, dict):
  38. raise ImproperlyConfigured(
  39. "MongoDB backend settings should be grouped in a dict")
  40. self.mongodb_host = config.get("host", self.mongodb_host)
  41. self.mongodb_port = int(config.get("port", self.mongodb_port))
  42. self.mongodb_user = config.get("user", self.mongodb_user)
  43. self.mongodb_password = config.get(
  44. "password", self.mongodb_password)
  45. self.mongodb_database = config.get(
  46. "database", self.mongodb_database)
  47. self.mongodb_taskmeta_collection = config.get(
  48. "taskmeta_collection", self.mongodb_taskmeta_collection)
  49. self._connection = None
  50. self._database = None
  51. def _get_connection(self):
  52. """Connect to the MongoDB server."""
  53. if self._connection is None:
  54. from pymongo.connection import Connection
  55. # The first pymongo.Connection() argument (host) can be
  56. # a list of ['host:port'] elements or a mongodb connection
  57. # URI. If this is the case, don't use self.mongodb_port
  58. # but let pymongo get the port(s) from the URI instead.
  59. # This enables the use of replica sets and sharding.
  60. # See pymongo.Connection() for more info.
  61. args = [self.mongodb_host]
  62. if isinstance(self.mongodb_host, basestring) \
  63. and not self.mongodb_host.startswith("mongodb://"):
  64. args.append(self.mongodb_port)
  65. self._connection = Connection(*args)
  66. return self._connection
  67. def _get_database(self):
  68. """"Get database from MongoDB connection and perform authentication
  69. if necessary."""
  70. if self._database is None:
  71. conn = self._get_connection()
  72. db = conn[self.mongodb_database]
  73. if self.mongodb_user and self.mongodb_password:
  74. auth = db.authenticate(self.mongodb_user,
  75. self.mongodb_password)
  76. if not auth:
  77. raise ImproperlyConfigured(
  78. "Invalid MongoDB username or password.")
  79. self._database = db
  80. return self._database
  81. def process_cleanup(self):
  82. if self._connection is not None:
  83. # MongoDB connection will be closed automatically when object
  84. # goes out of scope
  85. self._connection = None
  86. def _store_result(self, task_id, result, status, traceback=None):
  87. """Store return value and status of an executed task."""
  88. from pymongo.binary import Binary
  89. meta = {"_id": task_id,
  90. "status": status,
  91. "result": Binary(self.encode(result)),
  92. "date_done": datetime.utcnow(),
  93. "traceback": Binary(self.encode(traceback))}
  94. db = self._get_database()
  95. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  96. taskmeta_collection.save(meta, safe=True)
  97. return result
  98. def _get_task_meta_for(self, task_id):
  99. """Get task metadata for a task by id."""
  100. db = self._get_database()
  101. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  102. obj = taskmeta_collection.find_one({"_id": task_id})
  103. if not obj:
  104. return {"status": states.PENDING, "result": None}
  105. meta = {
  106. "task_id": obj["_id"],
  107. "status": obj["status"],
  108. "result": self.decode(obj["result"]),
  109. "date_done": obj["date_done"],
  110. "traceback": self.decode(obj["traceback"]),
  111. }
  112. return meta
  113. def _save_taskset(self, taskset_id, result):
  114. """Save the taskset result."""
  115. from pymongo.binary import Binary
  116. meta = {"_id": taskset_id,
  117. "result": Binary(self.encode(result)),
  118. "date_done": datetime.utcnow()}
  119. db = self._get_database()
  120. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  121. taskmeta_collection.save(meta, safe=True)
  122. return result
  123. def _restore_taskset(self, taskset_id):
  124. """Get the result for a taskset by id."""
  125. db = self._get_database()
  126. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  127. obj = taskmeta_collection.find_one({"_id": taskset_id})
  128. if not obj:
  129. return None
  130. meta = {
  131. "task_id": obj["_id"],
  132. "result": self.decode(obj["result"]),
  133. "date_done": obj["date_done"],
  134. }
  135. return meta
  136. def _delete_taskset(self, taskset_id):
  137. """Delete a taskset by id."""
  138. db = self._get_database()
  139. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  140. taskmeta_collection.remove({"_id": taskset_id})
  141. def cleanup(self):
  142. """Delete expired metadata."""
  143. db = self._get_database()
  144. taskmeta_collection = db[self.mongodb_taskmeta_collection]
  145. taskmeta_collection.remove({
  146. "date_done": {
  147. "$lt": datetime.utcnow() - self.expires,
  148. }
  149. })
  150. def __reduce__(self, args=(), kwargs={}):
  151. kwargs.update(
  152. dict(expires=self.expires))
  153. return super(MongoBackend, self).__reduce__(args, kwargs)