base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # -*- coding: utf-8 -*-
  2. """celery.backends.base"""
  3. from __future__ import absolute_import
  4. import time
  5. import sys
  6. from datetime import timedelta
  7. from kombu import serialization
  8. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  9. from celery import states
  10. from celery.app import current_task
  11. from celery.datastructures import LRUCache
  12. from celery.exceptions import TimeoutError, TaskRevokedError
  13. from celery.result import from_serializable
  14. from celery.utils import timeutils
  15. from celery.utils.serialization import (
  16. get_pickled_exception,
  17. get_pickleable_exception,
  18. create_exception_cls,
  19. )
  20. EXCEPTION_ABLE_CODECS = frozenset(["pickle", "yaml"])
  21. is_py3k = sys.version_info >= (3, 0)
  22. def unpickle_backend(cls, args, kwargs):
  23. """Returns an unpickled backend."""
  24. return cls(*args, **kwargs)
  25. class BaseBackend(object):
  26. """Base backend class."""
  27. READY_STATES = states.READY_STATES
  28. UNREADY_STATES = states.UNREADY_STATES
  29. EXCEPTION_STATES = states.EXCEPTION_STATES
  30. TimeoutError = TimeoutError
  31. #: Time to sleep between polling each individual item
  32. #: in `ResultSet.iterate`. as opposed to the `interval`
  33. #: argument which is for each pass.
  34. subpolling_interval = None
  35. #: If true the backend must implement :meth:`get_many`.
  36. supports_native_join = False
  37. def __init__(self, *args, **kwargs):
  38. from celery.app import app_or_default
  39. self.app = app_or_default(kwargs.get("app"))
  40. self.serializer = kwargs.get("serializer",
  41. self.app.conf.CELERY_RESULT_SERIALIZER)
  42. (self.content_type,
  43. self.content_encoding,
  44. self.encoder) = serialization.registry._encoders[self.serializer]
  45. def encode(self, data):
  46. _, _, payload = serialization.encode(data, serializer=self.serializer)
  47. return payload
  48. def decode(self, payload):
  49. payload = is_py3k and payload or str(payload)
  50. return serialization.decode(payload,
  51. content_type=self.content_type,
  52. content_encoding=self.content_encoding)
  53. def prepare_expires(self, value, type=None):
  54. if value is None:
  55. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  56. if isinstance(value, timedelta):
  57. value = timeutils.timedelta_seconds(value)
  58. if value is not None and type:
  59. return type(value)
  60. return value
  61. def encode_result(self, result, status):
  62. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  63. return self.prepare_exception(result)
  64. else:
  65. return self.prepare_value(result)
  66. def store_result(self, task_id, result, status, traceback=None):
  67. """Store the result and status of a task."""
  68. raise NotImplementedError(
  69. "store_result is not supported by this backend.")
  70. def mark_as_started(self, task_id, **meta):
  71. """Mark a task as started"""
  72. return self.store_result(task_id, meta, status=states.STARTED)
  73. def mark_as_done(self, task_id, result):
  74. """Mark task as successfully executed."""
  75. return self.store_result(task_id, result, status=states.SUCCESS)
  76. def mark_as_failure(self, task_id, exc, traceback=None):
  77. """Mark task as executed with failure. Stores the execption."""
  78. return self.store_result(task_id, exc, status=states.FAILURE,
  79. traceback=traceback)
  80. def mark_as_retry(self, task_id, exc, traceback=None):
  81. """Mark task as being retries. Stores the current
  82. exception (if any)."""
  83. return self.store_result(task_id, exc, status=states.RETRY,
  84. traceback=traceback)
  85. def mark_as_revoked(self, task_id):
  86. return self.store_result(task_id, TaskRevokedError(),
  87. status=states.REVOKED, traceback=None)
  88. def prepare_exception(self, exc):
  89. """Prepare exception for serialization."""
  90. if self.serializer in EXCEPTION_ABLE_CODECS:
  91. return get_pickleable_exception(exc)
  92. return {"exc_type": type(exc).__name__, "exc_message": str(exc)}
  93. def exception_to_python(self, exc):
  94. """Convert serialized exception to Python exception."""
  95. if self.serializer in EXCEPTION_ABLE_CODECS:
  96. return get_pickled_exception(exc)
  97. return create_exception_cls(from_utf8(exc["exc_type"]),
  98. sys.modules[__name__])(exc["exc_message"])
  99. def prepare_value(self, result):
  100. """Prepare value for storage."""
  101. return result
  102. def forget(self, task_id):
  103. raise NotImplementedError("%s does not implement forget." % (
  104. self.__class__))
  105. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  106. """Wait for task and return its result.
  107. If the task raises an exception, this exception
  108. will be re-raised by :func:`wait_for`.
  109. If `timeout` is not :const:`None`, this raises the
  110. :class:`celery.exceptions.TimeoutError` exception if the operation
  111. takes longer than `timeout` seconds.
  112. """
  113. time_elapsed = 0.0
  114. while 1:
  115. status = self.get_status(task_id)
  116. if status == states.SUCCESS:
  117. return self.get_result(task_id)
  118. elif status in states.PROPAGATE_STATES:
  119. result = self.get_result(task_id)
  120. if propagate:
  121. raise result
  122. return result
  123. # avoid hammering the CPU checking status.
  124. time.sleep(interval)
  125. time_elapsed += interval
  126. if timeout and time_elapsed >= timeout:
  127. raise TimeoutError("The operation timed out.")
  128. def cleanup(self):
  129. """Backend cleanup. Is run by
  130. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  131. pass
  132. def process_cleanup(self):
  133. """Cleanup actions to do at the end of a task worker process."""
  134. pass
  135. def get_status(self, task_id):
  136. """Get the status of a task."""
  137. raise NotImplementedError(
  138. "get_status is not supported by this backend.")
  139. def get_result(self, task_id):
  140. """Get the result of a task."""
  141. raise NotImplementedError(
  142. "get_result is not supported by this backend.")
  143. def get_children(self, task_id):
  144. raise NotImplementedError(
  145. "get_children is not supported by this backend.")
  146. def get_traceback(self, task_id):
  147. """Get the traceback for a failed task."""
  148. raise NotImplementedError(
  149. "get_traceback is not supported by this backend.")
  150. def save_taskset(self, taskset_id, result):
  151. """Store the result and status of a task."""
  152. raise NotImplementedError(
  153. "save_taskset is not supported by this backend.")
  154. def restore_taskset(self, taskset_id, cache=True):
  155. """Get the result of a taskset."""
  156. raise NotImplementedError(
  157. "restore_taskset is not supported by this backend.")
  158. def delete_taskset(self, taskset_id):
  159. raise NotImplementedError(
  160. "delete_taskset is not supported by this backend.")
  161. def reload_task_result(self, task_id):
  162. """Reload task result, even if it has been previously fetched."""
  163. raise NotImplementedError(
  164. "reload_task_result is not supported by this backend.")
  165. def reload_taskset_result(self, task_id):
  166. """Reload taskset result, even if it has been previously fetched."""
  167. raise NotImplementedError(
  168. "reload_taskset_result is not supported by this backend.")
  169. def on_chord_part_return(self, task, propagate=False):
  170. pass
  171. def fallback_chord_unlock(self, setid, body, result=None, **kwargs):
  172. kwargs["result"] = [r.id for r in result]
  173. self.app.tasks["celery.chord_unlock"].apply_async((setid, body, ),
  174. kwargs, countdown=1)
  175. on_chord_apply = fallback_chord_unlock
  176. def current_task_children(self):
  177. current = current_task()
  178. if current:
  179. return [r.serializable() for r in current.request.children]
  180. def __reduce__(self, args=(), kwargs={}):
  181. return (unpickle_backend, (self.__class__, args, kwargs))
  182. class BaseDictBackend(BaseBackend):
  183. def __init__(self, *args, **kwargs):
  184. super(BaseDictBackend, self).__init__(*args, **kwargs)
  185. self._cache = LRUCache(limit=kwargs.get("max_cached_results") or
  186. self.app.conf.CELERY_MAX_CACHED_RESULTS)
  187. def store_result(self, task_id, result, status, traceback=None, **kwargs):
  188. """Store task result and status."""
  189. result = self.encode_result(result, status)
  190. return self._store_result(task_id, result, status, traceback, **kwargs)
  191. def forget(self, task_id):
  192. self._cache.pop(task_id, None)
  193. self._forget(task_id)
  194. def _forget(self, task_id):
  195. raise NotImplementedError("%s does not implement forget." % (
  196. self.__class__))
  197. def get_status(self, task_id):
  198. """Get the status of a task."""
  199. return self.get_task_meta(task_id)["status"]
  200. def get_traceback(self, task_id):
  201. """Get the traceback for a failed task."""
  202. return self.get_task_meta(task_id).get("traceback")
  203. def get_result(self, task_id):
  204. """Get the result of a task."""
  205. meta = self.get_task_meta(task_id)
  206. if meta["status"] in self.EXCEPTION_STATES:
  207. return self.exception_to_python(meta["result"])
  208. else:
  209. return meta["result"]
  210. def get_children(self, task_id):
  211. """Get the list of subtasks sent by a task."""
  212. try:
  213. return self.get_task_meta(task_id)["children"]
  214. except KeyError:
  215. pass
  216. def get_task_meta(self, task_id, cache=True):
  217. if cache:
  218. try:
  219. return self._cache[task_id]
  220. except KeyError:
  221. pass
  222. meta = self._get_task_meta_for(task_id)
  223. if cache and meta.get("status") == states.SUCCESS:
  224. self._cache[task_id] = meta
  225. return meta
  226. def reload_task_result(self, task_id):
  227. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  228. def reload_taskset_result(self, taskset_id):
  229. self._cache[taskset_id] = self.get_taskset_meta(taskset_id,
  230. cache=False)
  231. def get_taskset_meta(self, taskset_id, cache=True):
  232. if cache:
  233. try:
  234. return self._cache[taskset_id]
  235. except KeyError:
  236. pass
  237. meta = self._restore_taskset(taskset_id)
  238. if cache and meta is not None:
  239. self._cache[taskset_id] = meta
  240. return meta
  241. def restore_taskset(self, taskset_id, cache=True):
  242. """Get the result for a taskset."""
  243. meta = self.get_taskset_meta(taskset_id, cache=cache)
  244. if meta:
  245. return meta["result"]
  246. def save_taskset(self, taskset_id, result):
  247. """Store the result of an executed taskset."""
  248. return self._save_taskset(taskset_id, result)
  249. def delete_taskset(self, taskset_id):
  250. self._cache.pop(taskset_id, None)
  251. return self._delete_taskset(taskset_id)
  252. class KeyValueStoreBackend(BaseDictBackend):
  253. task_keyprefix = ensure_bytes("celery-task-meta-")
  254. taskset_keyprefix = ensure_bytes("celery-taskset-meta-")
  255. chord_keyprefix = ensure_bytes("chord-unlock-")
  256. implements_incr = False
  257. def get(self, key):
  258. raise NotImplementedError("Must implement the get method.")
  259. def mget(self, keys):
  260. raise NotImplementedError("Does not support get_many")
  261. def set(self, key, value):
  262. raise NotImplementedError("Must implement the set method.")
  263. def delete(self, key):
  264. raise NotImplementedError("Must implement the delete method")
  265. def incr(self, key):
  266. raise NotImplementedError("Does not implement incr")
  267. def expire(self, key, value):
  268. pass
  269. def get_key_for_task(self, task_id):
  270. """Get the cache key for a task by id."""
  271. return self.task_keyprefix + ensure_bytes(task_id)
  272. def get_key_for_taskset(self, taskset_id):
  273. """Get the cache key for a taskset by id."""
  274. return self.taskset_keyprefix + ensure_bytes(taskset_id)
  275. def get_key_for_chord(self, taskset_id):
  276. """Get the cache key for the chord waiting on taskset with given id."""
  277. return self.chord_keyprefix + ensure_bytes(taskset_id)
  278. def _strip_prefix(self, key):
  279. """Takes bytes, emits string."""
  280. key = ensure_bytes(key)
  281. for prefix in self.task_keyprefix, self.taskset_keyprefix:
  282. if key.startswith(prefix):
  283. return bytes_to_str(key[len(prefix):])
  284. return bytes_to_str(key)
  285. def _mget_to_results(self, values, keys):
  286. if hasattr(values, "items"):
  287. # client returns dict so mapping preserved.
  288. return dict((self._strip_prefix(k), self.decode(v))
  289. for k, v in values.iteritems()
  290. if v is not None)
  291. else:
  292. # client returns list so need to recreate mapping.
  293. return dict((bytes_to_str(keys[i]), self.decode(value))
  294. for i, value in enumerate(values)
  295. if value is not None)
  296. def get_many(self, task_ids, timeout=None, interval=0.5):
  297. ids = set(task_ids)
  298. cached_ids = set()
  299. for task_id in ids:
  300. try:
  301. cached = self._cache[task_id]
  302. except KeyError:
  303. pass
  304. else:
  305. if cached["status"] in states.READY_STATES:
  306. yield bytes_to_str(task_id), cached
  307. cached_ids.add(task_id)
  308. ids ^= cached_ids
  309. iterations = 0
  310. while ids:
  311. keys = list(ids)
  312. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  313. for k in keys]), keys)
  314. self._cache.update(r)
  315. ids ^= set(map(bytes_to_str, r))
  316. for key, value in r.iteritems():
  317. yield bytes_to_str(key), value
  318. if timeout and iterations * interval >= timeout:
  319. raise TimeoutError("Operation timed out (%s)" % (timeout, ))
  320. time.sleep(interval) # don't busy loop.
  321. iterations += 0
  322. def _forget(self, task_id):
  323. self.delete(self.get_key_for_task(task_id))
  324. def _store_result(self, task_id, result, status, traceback=None):
  325. meta = {"status": status, "result": result, "traceback": traceback,
  326. "children": self.current_task_children()}
  327. self.set(self.get_key_for_task(task_id), self.encode(meta))
  328. return result
  329. def _save_taskset(self, taskset_id, result):
  330. self.set(self.get_key_for_taskset(taskset_id),
  331. self.encode({"result": result.serializable()}))
  332. return result
  333. def _delete_taskset(self, taskset_id):
  334. self.delete(self.get_key_for_taskset(taskset_id))
  335. def _get_task_meta_for(self, task_id):
  336. """Get task metadata for a task by id."""
  337. meta = self.get(self.get_key_for_task(task_id))
  338. if not meta:
  339. return {"status": states.PENDING, "result": None}
  340. return self.decode(meta)
  341. def _restore_taskset(self, taskset_id):
  342. """Get task metadata for a task by id."""
  343. meta = self.get(self.get_key_for_taskset(taskset_id))
  344. # previously this was always pickled, but later this
  345. # was extended to support other serializers, so the
  346. # structure is kind of weird.
  347. if meta:
  348. meta = self.decode(meta)
  349. result = meta["result"]
  350. if isinstance(result, (list, tuple)):
  351. return {"result": from_serializable(result)}
  352. return meta
  353. def on_chord_apply(self, setid, body, result=None, **kwargs):
  354. if self.implements_incr:
  355. self.app.TaskSetResult(setid, result).save()
  356. else:
  357. self.fallback_chord_unlock(setid, body, result, **kwargs)
  358. def on_chord_part_return(self, task, propagate=False):
  359. if not self.implements_incr:
  360. return
  361. from celery import subtask
  362. from celery.result import TaskSetResult
  363. setid = task.request.taskset
  364. if not setid:
  365. return
  366. key = self.get_key_for_chord(setid)
  367. deps = TaskSetResult.restore(setid, backend=task.backend)
  368. val = self.incr(key)
  369. if val >= deps.total:
  370. subtask(task.request.chord).delay(deps.join(propagate=propagate))
  371. deps.delete()
  372. self.client.delete(key)
  373. else:
  374. self.expire(key, 86400)
  375. class DisabledBackend(BaseBackend):
  376. _cache = {} # need this attribute to reset cache in tests.
  377. def store_result(self, *args, **kwargs):
  378. pass
  379. def _is_disabled(self, *args, **kwargs):
  380. raise NotImplementedError("No result backend configured. "
  381. "Please see the documentation for more information.")
  382. wait_for = get_status = get_result = get_traceback = _is_disabled