base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from billiard.einfo import ExceptionInfo
  15. from kombu import serialization
  16. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  17. from celery import states
  18. from celery.app import current_task
  19. from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
  20. from celery.five import items
  21. from celery.result import from_serializable, GroupResult
  22. from celery.utils import timeutils
  23. from celery.utils.functional import LRUCache
  24. from celery.utils.serialization import (
  25. get_pickled_exception,
  26. get_pickleable_exception,
  27. create_exception_cls,
  28. )
  29. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  30. PY3 = sys.version_info >= (3, 0)
  31. def unpickle_backend(cls, args, kwargs):
  32. """Returns an unpickled backend."""
  33. from celery import current_app
  34. return cls(*args, app=current_app._get_current_object(), **kwargs)
  35. class BaseBackend(object):
  36. READY_STATES = states.READY_STATES
  37. UNREADY_STATES = states.UNREADY_STATES
  38. EXCEPTION_STATES = states.EXCEPTION_STATES
  39. TimeoutError = TimeoutError
  40. #: Time to sleep between polling each individual item
  41. #: in `ResultSet.iterate`. as opposed to the `interval`
  42. #: argument which is for each pass.
  43. subpolling_interval = None
  44. #: If true the backend must implement :meth:`get_many`.
  45. supports_native_join = False
  46. #: If true the backend must automatically expire results.
  47. #: The daily backend_cleanup periodic task will not be triggered
  48. #: in this case.
  49. supports_autoexpire = False
  50. def __init__(self, app, serializer=None,
  51. max_cached_results=None, **kwargs):
  52. self.app = app
  53. conf = self.app.conf
  54. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  55. (self.content_type,
  56. self.content_encoding,
  57. self.encoder) = serialization.registry._encoders[self.serializer]
  58. self._cache = LRUCache(
  59. limit=max_cached_results or conf.CELERY_MAX_CACHED_RESULTS,
  60. )
  61. def mark_as_started(self, task_id, **meta):
  62. """Mark a task as started"""
  63. return self.store_result(task_id, meta, status=states.STARTED)
  64. def mark_as_done(self, task_id, result):
  65. """Mark task as successfully executed."""
  66. return self.store_result(task_id, result, status=states.SUCCESS)
  67. def mark_as_failure(self, task_id, exc, traceback=None):
  68. """Mark task as executed with failure. Stores the execption."""
  69. return self.store_result(task_id, exc, status=states.FAILURE,
  70. traceback=traceback)
  71. def fail_from_current_stack(self, task_id, exc=None):
  72. type_, real_exc, tb = sys.exc_info()
  73. try:
  74. exc = real_exc if exc is None else exc
  75. ei = ExceptionInfo((type_, exc, tb))
  76. self.mark_as_failure(task_id, exc, ei.traceback)
  77. return ei
  78. finally:
  79. del(tb)
  80. def mark_as_retry(self, task_id, exc, traceback=None):
  81. """Mark task as being retries. Stores the current
  82. exception (if any)."""
  83. return self.store_result(task_id, exc, status=states.RETRY,
  84. traceback=traceback)
  85. def mark_as_revoked(self, task_id, reason=''):
  86. return self.store_result(task_id, TaskRevokedError(reason),
  87. status=states.REVOKED, traceback=None)
  88. def prepare_exception(self, exc):
  89. """Prepare exception for serialization."""
  90. if self.serializer in EXCEPTION_ABLE_CODECS:
  91. return get_pickleable_exception(exc)
  92. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  93. def exception_to_python(self, exc):
  94. """Convert serialized exception to Python exception."""
  95. if self.serializer in EXCEPTION_ABLE_CODECS:
  96. return get_pickled_exception(exc)
  97. return create_exception_cls(from_utf8(exc['exc_type']),
  98. sys.modules[__name__])(exc['exc_message'])
  99. def prepare_value(self, result):
  100. """Prepare value for storage."""
  101. if isinstance(result, GroupResult):
  102. return result.serializable()
  103. return result
  104. def encode(self, data):
  105. _, _, payload = serialization.encode(data, serializer=self.serializer)
  106. return payload
  107. def decode(self, payload):
  108. payload = PY3 and payload or str(payload)
  109. return serialization.decode(payload,
  110. content_type=self.content_type,
  111. content_encoding=self.content_encoding)
  112. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  113. """Wait for task and return its result.
  114. If the task raises an exception, this exception
  115. will be re-raised by :func:`wait_for`.
  116. If `timeout` is not :const:`None`, this raises the
  117. :class:`celery.exceptions.TimeoutError` exception if the operation
  118. takes longer than `timeout` seconds.
  119. """
  120. time_elapsed = 0.0
  121. while 1:
  122. status = self.get_status(task_id)
  123. if status == states.SUCCESS:
  124. return self.get_result(task_id)
  125. elif status in states.PROPAGATE_STATES:
  126. result = self.get_result(task_id)
  127. if propagate:
  128. raise result
  129. return result
  130. # avoid hammering the CPU checking status.
  131. time.sleep(interval)
  132. time_elapsed += interval
  133. if timeout and time_elapsed >= timeout:
  134. raise TimeoutError('The operation timed out.')
  135. def prepare_expires(self, value, type=None):
  136. if value is None:
  137. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  138. if isinstance(value, timedelta):
  139. value = timeutils.timedelta_seconds(value)
  140. if value is not None and type:
  141. return type(value)
  142. return value
  143. def encode_result(self, result, status):
  144. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  145. return self.prepare_exception(result)
  146. else:
  147. return self.prepare_value(result)
  148. def is_cached(self, task_id):
  149. return task_id in self._cache
  150. def store_result(self, task_id, result, status, traceback=None, **kwargs):
  151. """Update task state and result."""
  152. result = self.encode_result(result, status)
  153. self._store_result(task_id, result, status, traceback, **kwargs)
  154. return result
  155. def forget(self, task_id):
  156. self._cache.pop(task_id, None)
  157. self._forget(task_id)
  158. def _forget(self, task_id):
  159. raise NotImplementedError('backend does not implement forget.')
  160. def get_status(self, task_id):
  161. """Get the status of a task."""
  162. return self.get_task_meta(task_id)['status']
  163. def get_traceback(self, task_id):
  164. """Get the traceback for a failed task."""
  165. return self.get_task_meta(task_id).get('traceback')
  166. def get_result(self, task_id):
  167. """Get the result of a task."""
  168. meta = self.get_task_meta(task_id)
  169. if meta['status'] in self.EXCEPTION_STATES:
  170. return self.exception_to_python(meta['result'])
  171. else:
  172. return meta['result']
  173. def get_children(self, task_id):
  174. """Get the list of subtasks sent by a task."""
  175. try:
  176. return self.get_task_meta(task_id)['children']
  177. except KeyError:
  178. pass
  179. def get_task_meta(self, task_id, cache=True):
  180. if cache:
  181. try:
  182. return self._cache[task_id]
  183. except KeyError:
  184. pass
  185. meta = self._get_task_meta_for(task_id)
  186. if cache and meta.get('status') == states.SUCCESS:
  187. self._cache[task_id] = meta
  188. return meta
  189. def reload_task_result(self, task_id):
  190. """Reload task result, even if it has been previously fetched."""
  191. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  192. def reload_group_result(self, group_id):
  193. """Reload group result, even if it has been previously fetched."""
  194. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  195. def get_group_meta(self, group_id, cache=True):
  196. if cache:
  197. try:
  198. return self._cache[group_id]
  199. except KeyError:
  200. pass
  201. meta = self._restore_group(group_id)
  202. if cache and meta is not None:
  203. self._cache[group_id] = meta
  204. return meta
  205. def restore_group(self, group_id, cache=True):
  206. """Get the result for a group."""
  207. meta = self.get_group_meta(group_id, cache=cache)
  208. if meta:
  209. return meta['result']
  210. def save_group(self, group_id, result):
  211. """Store the result of an executed group."""
  212. return self._save_group(group_id, result)
  213. def delete_group(self, group_id):
  214. self._cache.pop(group_id, None)
  215. return self._delete_group(group_id)
  216. def cleanup(self):
  217. """Backend cleanup. Is run by
  218. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  219. pass
  220. def process_cleanup(self):
  221. """Cleanup actions to do at the end of a task worker process."""
  222. pass
  223. def on_task_call(self, producer, task_id):
  224. return {}
  225. def on_chord_part_return(self, task, propagate=False):
  226. pass
  227. def fallback_chord_unlock(self, group_id, body, result=None,
  228. countdown=1, **kwargs):
  229. kwargs['result'] = [r.serializable() for r in result]
  230. self.app.tasks['celery.chord_unlock'].apply_async(
  231. (group_id, body, ), kwargs, countdown=countdown,
  232. )
  233. on_chord_apply = fallback_chord_unlock
  234. def current_task_children(self):
  235. current = current_task()
  236. if current:
  237. return [r.serializable() for r in current.request.children]
  238. def __reduce__(self, args=(), kwargs={}):
  239. return (unpickle_backend, (self.__class__, args, kwargs))
  240. BaseDictBackend = BaseBackend # XXX compat
  241. class KeyValueStoreBackend(BaseBackend):
  242. task_keyprefix = ensure_bytes('celery-task-meta-')
  243. group_keyprefix = ensure_bytes('celery-taskset-meta-')
  244. chord_keyprefix = ensure_bytes('chord-unlock-')
  245. implements_incr = False
  246. def get(self, key):
  247. raise NotImplementedError('Must implement the get method.')
  248. def mget(self, keys):
  249. raise NotImplementedError('Does not support get_many')
  250. def set(self, key, value):
  251. raise NotImplementedError('Must implement the set method.')
  252. def delete(self, key):
  253. raise NotImplementedError('Must implement the delete method')
  254. def incr(self, key):
  255. raise NotImplementedError('Does not implement incr')
  256. def expire(self, key, value):
  257. pass
  258. def get_key_for_task(self, task_id):
  259. """Get the cache key for a task by id."""
  260. return self.task_keyprefix + ensure_bytes(task_id)
  261. def get_key_for_group(self, group_id):
  262. """Get the cache key for a group by id."""
  263. return self.group_keyprefix + ensure_bytes(group_id)
  264. def get_key_for_chord(self, group_id):
  265. """Get the cache key for the chord waiting on group with given id."""
  266. return self.chord_keyprefix + ensure_bytes(group_id)
  267. def _strip_prefix(self, key):
  268. """Takes bytes, emits string."""
  269. key = ensure_bytes(key)
  270. for prefix in self.task_keyprefix, self.group_keyprefix:
  271. if key.startswith(prefix):
  272. return bytes_to_str(key[len(prefix):])
  273. return bytes_to_str(key)
  274. def _mget_to_results(self, values, keys):
  275. if hasattr(values, 'items'):
  276. # client returns dict so mapping preserved.
  277. return dict((self._strip_prefix(k), self.decode(v))
  278. for k, v in items(values)
  279. if v is not None)
  280. else:
  281. # client returns list so need to recreate mapping.
  282. return dict((bytes_to_str(keys[i]), self.decode(value))
  283. for i, value in enumerate(values)
  284. if value is not None)
  285. def get_many(self, task_ids, timeout=None, interval=0.5):
  286. ids = set(task_ids)
  287. cached_ids = set()
  288. for task_id in ids:
  289. try:
  290. cached = self._cache[task_id]
  291. except KeyError:
  292. pass
  293. else:
  294. if cached['status'] in states.READY_STATES:
  295. yield bytes_to_str(task_id), cached
  296. cached_ids.add(task_id)
  297. ids.difference_update(cached_ids)
  298. iterations = 0
  299. while ids:
  300. keys = list(ids)
  301. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  302. for k in keys]), keys)
  303. self._cache.update(r)
  304. ids.difference_update(set(bytes_to_str(v) for v in r))
  305. for key, value in items(r):
  306. yield bytes_to_str(key), value
  307. if timeout and iterations * interval >= timeout:
  308. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  309. time.sleep(interval) # don't busy loop.
  310. iterations += 1
  311. def _forget(self, task_id):
  312. self.delete(self.get_key_for_task(task_id))
  313. def _store_result(self, task_id, result, status, traceback=None):
  314. meta = {'status': status, 'result': result, 'traceback': traceback,
  315. 'children': self.current_task_children()}
  316. self.set(self.get_key_for_task(task_id), self.encode(meta))
  317. return result
  318. def _save_group(self, group_id, result):
  319. self.set(self.get_key_for_group(group_id),
  320. self.encode({'result': result.serializable()}))
  321. return result
  322. def _delete_group(self, group_id):
  323. self.delete(self.get_key_for_group(group_id))
  324. def _get_task_meta_for(self, task_id):
  325. """Get task metadata for a task by id."""
  326. meta = self.get(self.get_key_for_task(task_id))
  327. if not meta:
  328. return {'status': states.PENDING, 'result': None}
  329. return self.decode(meta)
  330. def _restore_group(self, group_id):
  331. """Get task metadata for a task by id."""
  332. meta = self.get(self.get_key_for_group(group_id))
  333. # previously this was always pickled, but later this
  334. # was extended to support other serializers, so the
  335. # structure is kind of weird.
  336. if meta:
  337. meta = self.decode(meta)
  338. result = meta['result']
  339. meta['result'] = from_serializable(result, self.app)
  340. return meta
  341. def on_chord_apply(self, group_id, body, result=None, **kwargs):
  342. if self.implements_incr:
  343. self.save_group(group_id, self.app.GroupResult(group_id, result))
  344. else:
  345. self.fallback_chord_unlock(group_id, body, result, **kwargs)
  346. def on_chord_part_return(self, task, propagate=None):
  347. if not self.implements_incr:
  348. return
  349. from celery import subtask
  350. from celery.result import GroupResult
  351. app = self.app
  352. if propagate is None:
  353. propagate = self.app.conf.CELERY_CHORD_PROPAGATES
  354. gid = task.request.group
  355. if not gid:
  356. return
  357. key = self.get_key_for_chord(gid)
  358. deps = GroupResult.restore(gid, backend=task.backend)
  359. val = self.incr(key)
  360. if val >= len(deps):
  361. j = deps.join_native if deps.supports_native_join else deps.join
  362. callback = subtask(task.request.chord)
  363. try:
  364. ret = j(propagate=propagate)
  365. except Exception as exc:
  366. try:
  367. culprit = next(deps._failed_join_report())
  368. reason = 'Dependency {0.id} raised {1!r}'.format(
  369. culprit, exc,
  370. )
  371. except StopIteration:
  372. reason = repr(exc)
  373. app._tasks[callback.task].backend.fail_from_current_stack(
  374. callback.id, exc=ChordError(reason),
  375. )
  376. else:
  377. try:
  378. callback.delay(ret)
  379. except Exception as exc:
  380. app._tasks[callback.task].backend.fail_from_current_stack(
  381. callback.id,
  382. exc=ChordError('Callback error: {0!r}'.format(exc)),
  383. )
  384. finally:
  385. deps.delete()
  386. self.client.delete(key)
  387. else:
  388. self.expire(key, 86400)
  389. class DisabledBackend(BaseBackend):
  390. _cache = {} # need this attribute to reset cache in tests.
  391. def store_result(self, *args, **kwargs):
  392. pass
  393. def _is_disabled(self, *args, **kwargs):
  394. raise NotImplementedError(
  395. 'No result backend configured. '
  396. 'Please see the documentation for more information.')
  397. wait_for = get_status = get_result = get_traceback = _is_disabled