base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from billiard.einfo import ExceptionInfo
  15. from kombu.serialization import (
  16. encode, decode, prepare_accept_encoding,
  17. registry as serializer_registry,
  18. )
  19. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  20. from celery import states
  21. from celery.app import current_task
  22. from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
  23. from celery.five import items
  24. from celery.result import from_serializable, GroupResult
  25. from celery.utils import timeutils
  26. from celery.utils.functional import LRUCache
  27. from celery.utils.serialization import (
  28. get_pickled_exception,
  29. get_pickleable_exception,
  30. create_exception_cls,
  31. )
  32. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  33. PY3 = sys.version_info >= (3, 0)
  34. def unpickle_backend(cls, args, kwargs):
  35. """Returns an unpickled backend."""
  36. from celery import current_app
  37. return cls(*args, app=current_app._get_current_object(), **kwargs)
  38. class BaseBackend(object):
  39. READY_STATES = states.READY_STATES
  40. UNREADY_STATES = states.UNREADY_STATES
  41. EXCEPTION_STATES = states.EXCEPTION_STATES
  42. TimeoutError = TimeoutError
  43. #: Time to sleep between polling each individual item
  44. #: in `ResultSet.iterate`. as opposed to the `interval`
  45. #: argument which is for each pass.
  46. subpolling_interval = None
  47. #: If true the backend must implement :meth:`get_many`.
  48. supports_native_join = False
  49. #: If true the backend must automatically expire results.
  50. #: The daily backend_cleanup periodic task will not be triggered
  51. #: in this case.
  52. supports_autoexpire = False
  53. def __init__(self, app, serializer=None,
  54. max_cached_results=None, accept=None, **kwargs):
  55. self.app = app
  56. conf = self.app.conf
  57. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  58. (self.content_type,
  59. self.content_encoding,
  60. self.encoder) = serializer_registry._encoders[self.serializer]
  61. self._cache = LRUCache(
  62. limit=max_cached_results or conf.CELERY_MAX_CACHED_RESULTS,
  63. )
  64. self.accept = prepare_accept_encoding(
  65. conf.CELERY_ACCEPT_CONTENT if accept is None else accept,
  66. )
  67. def mark_as_started(self, task_id, **meta):
  68. """Mark a task as started"""
  69. return self.store_result(task_id, meta, status=states.STARTED)
  70. def mark_as_done(self, task_id, result):
  71. """Mark task as successfully executed."""
  72. return self.store_result(task_id, result, status=states.SUCCESS)
  73. def mark_as_failure(self, task_id, exc, traceback=None):
  74. """Mark task as executed with failure. Stores the execption."""
  75. return self.store_result(task_id, exc, status=states.FAILURE,
  76. traceback=traceback)
  77. def fail_from_current_stack(self, task_id, exc=None):
  78. type_, real_exc, tb = sys.exc_info()
  79. try:
  80. exc = real_exc if exc is None else exc
  81. ei = ExceptionInfo((type_, exc, tb))
  82. self.mark_as_failure(task_id, exc, ei.traceback)
  83. return ei
  84. finally:
  85. del(tb)
  86. def mark_as_retry(self, task_id, exc, traceback=None):
  87. """Mark task as being retries. Stores the current
  88. exception (if any)."""
  89. return self.store_result(task_id, exc, status=states.RETRY,
  90. traceback=traceback)
  91. def mark_as_revoked(self, task_id, reason=''):
  92. return self.store_result(task_id, TaskRevokedError(reason),
  93. status=states.REVOKED, traceback=None)
  94. def prepare_exception(self, exc):
  95. """Prepare exception for serialization."""
  96. if self.serializer in EXCEPTION_ABLE_CODECS:
  97. return get_pickleable_exception(exc)
  98. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  99. def exception_to_python(self, exc):
  100. """Convert serialized exception to Python exception."""
  101. if self.serializer in EXCEPTION_ABLE_CODECS:
  102. return get_pickled_exception(exc)
  103. return create_exception_cls(from_utf8(exc['exc_type']),
  104. sys.modules[__name__])(exc['exc_message'])
  105. def prepare_value(self, result):
  106. """Prepare value for storage."""
  107. if isinstance(result, GroupResult):
  108. return result.serializable()
  109. return result
  110. def encode(self, data):
  111. _, _, payload = encode(data, serializer=self.serializer)
  112. return payload
  113. def decode(self, payload):
  114. payload = PY3 and payload or str(payload)
  115. return decode(payload,
  116. content_type=self.content_type,
  117. content_encoding=self.content_encoding,
  118. accept=self.accept)
  119. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  120. """Wait for task and return its result.
  121. If the task raises an exception, this exception
  122. will be re-raised by :func:`wait_for`.
  123. If `timeout` is not :const:`None`, this raises the
  124. :class:`celery.exceptions.TimeoutError` exception if the operation
  125. takes longer than `timeout` seconds.
  126. """
  127. time_elapsed = 0.0
  128. while 1:
  129. status = self.get_status(task_id)
  130. if status == states.SUCCESS:
  131. return self.get_result(task_id)
  132. elif status in states.PROPAGATE_STATES:
  133. result = self.get_result(task_id)
  134. if propagate:
  135. raise result
  136. return result
  137. # avoid hammering the CPU checking status.
  138. time.sleep(interval)
  139. time_elapsed += interval
  140. if timeout and time_elapsed >= timeout:
  141. raise TimeoutError('The operation timed out.')
  142. def prepare_expires(self, value, type=None):
  143. if value is None:
  144. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  145. if isinstance(value, timedelta):
  146. value = timeutils.timedelta_seconds(value)
  147. if value is not None and type:
  148. return type(value)
  149. return value
  150. def encode_result(self, result, status):
  151. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  152. return self.prepare_exception(result)
  153. else:
  154. return self.prepare_value(result)
  155. def is_cached(self, task_id):
  156. return task_id in self._cache
  157. def store_result(self, task_id, result, status, traceback=None, **kwargs):
  158. """Update task state and result."""
  159. result = self.encode_result(result, status)
  160. self._store_result(task_id, result, status, traceback, **kwargs)
  161. return result
  162. def forget(self, task_id):
  163. self._cache.pop(task_id, None)
  164. self._forget(task_id)
  165. def _forget(self, task_id):
  166. raise NotImplementedError('backend does not implement forget.')
  167. def get_status(self, task_id):
  168. """Get the status of a task."""
  169. return self.get_task_meta(task_id)['status']
  170. def get_traceback(self, task_id):
  171. """Get the traceback for a failed task."""
  172. return self.get_task_meta(task_id).get('traceback')
  173. def get_result(self, task_id):
  174. """Get the result of a task."""
  175. meta = self.get_task_meta(task_id)
  176. if meta['status'] in self.EXCEPTION_STATES:
  177. return self.exception_to_python(meta['result'])
  178. else:
  179. return meta['result']
  180. def get_children(self, task_id):
  181. """Get the list of subtasks sent by a task."""
  182. try:
  183. return self.get_task_meta(task_id)['children']
  184. except KeyError:
  185. pass
  186. def get_task_meta(self, task_id, cache=True):
  187. if cache:
  188. try:
  189. return self._cache[task_id]
  190. except KeyError:
  191. pass
  192. meta = self._get_task_meta_for(task_id)
  193. if cache and meta.get('status') == states.SUCCESS:
  194. self._cache[task_id] = meta
  195. return meta
  196. def reload_task_result(self, task_id):
  197. """Reload task result, even if it has been previously fetched."""
  198. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  199. def reload_group_result(self, group_id):
  200. """Reload group result, even if it has been previously fetched."""
  201. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  202. def get_group_meta(self, group_id, cache=True):
  203. if cache:
  204. try:
  205. return self._cache[group_id]
  206. except KeyError:
  207. pass
  208. meta = self._restore_group(group_id)
  209. if cache and meta is not None:
  210. self._cache[group_id] = meta
  211. return meta
  212. def restore_group(self, group_id, cache=True):
  213. """Get the result for a group."""
  214. meta = self.get_group_meta(group_id, cache=cache)
  215. if meta:
  216. return meta['result']
  217. def save_group(self, group_id, result):
  218. """Store the result of an executed group."""
  219. return self._save_group(group_id, result)
  220. def delete_group(self, group_id):
  221. self._cache.pop(group_id, None)
  222. return self._delete_group(group_id)
  223. def cleanup(self):
  224. """Backend cleanup. Is run by
  225. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  226. pass
  227. def process_cleanup(self):
  228. """Cleanup actions to do at the end of a task worker process."""
  229. pass
  230. def on_task_call(self, producer, task_id):
  231. return {}
  232. def on_chord_part_return(self, task, propagate=False):
  233. pass
  234. def fallback_chord_unlock(self, group_id, body, result=None,
  235. countdown=1, **kwargs):
  236. kwargs['result'] = [r.serializable() for r in result]
  237. self.app.tasks['celery.chord_unlock'].apply_async(
  238. (group_id, body, ), kwargs, countdown=countdown,
  239. )
  240. on_chord_apply = fallback_chord_unlock
  241. def current_task_children(self):
  242. current = current_task()
  243. if current:
  244. return [r.serializable() for r in current.request.children]
  245. def __reduce__(self, args=(), kwargs={}):
  246. return (unpickle_backend, (self.__class__, args, kwargs))
  247. BaseDictBackend = BaseBackend # XXX compat
  248. class KeyValueStoreBackend(BaseBackend):
  249. task_keyprefix = ensure_bytes('celery-task-meta-')
  250. group_keyprefix = ensure_bytes('celery-taskset-meta-')
  251. chord_keyprefix = ensure_bytes('chord-unlock-')
  252. implements_incr = False
  253. def get(self, key):
  254. raise NotImplementedError('Must implement the get method.')
  255. def mget(self, keys):
  256. raise NotImplementedError('Does not support get_many')
  257. def set(self, key, value):
  258. raise NotImplementedError('Must implement the set method.')
  259. def delete(self, key):
  260. raise NotImplementedError('Must implement the delete method')
  261. def incr(self, key):
  262. raise NotImplementedError('Does not implement incr')
  263. def expire(self, key, value):
  264. pass
  265. def get_key_for_task(self, task_id):
  266. """Get the cache key for a task by id."""
  267. return self.task_keyprefix + ensure_bytes(task_id)
  268. def get_key_for_group(self, group_id):
  269. """Get the cache key for a group by id."""
  270. return self.group_keyprefix + ensure_bytes(group_id)
  271. def get_key_for_chord(self, group_id):
  272. """Get the cache key for the chord waiting on group with given id."""
  273. return self.chord_keyprefix + ensure_bytes(group_id)
  274. def _strip_prefix(self, key):
  275. """Takes bytes, emits string."""
  276. key = ensure_bytes(key)
  277. for prefix in self.task_keyprefix, self.group_keyprefix:
  278. if key.startswith(prefix):
  279. return bytes_to_str(key[len(prefix):])
  280. return bytes_to_str(key)
  281. def _mget_to_results(self, values, keys):
  282. if hasattr(values, 'items'):
  283. # client returns dict so mapping preserved.
  284. return dict((self._strip_prefix(k), self.decode(v))
  285. for k, v in items(values)
  286. if v is not None)
  287. else:
  288. # client returns list so need to recreate mapping.
  289. return dict((bytes_to_str(keys[i]), self.decode(value))
  290. for i, value in enumerate(values)
  291. if value is not None)
  292. def get_many(self, task_ids, timeout=None, interval=0.5):
  293. ids = set(task_ids)
  294. cached_ids = set()
  295. for task_id in ids:
  296. try:
  297. cached = self._cache[task_id]
  298. except KeyError:
  299. pass
  300. else:
  301. if cached['status'] in states.READY_STATES:
  302. yield bytes_to_str(task_id), cached
  303. cached_ids.add(task_id)
  304. ids.difference_update(cached_ids)
  305. iterations = 0
  306. while ids:
  307. keys = list(ids)
  308. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  309. for k in keys]), keys)
  310. self._cache.update(r)
  311. ids.difference_update(set(bytes_to_str(v) for v in r))
  312. for key, value in items(r):
  313. yield bytes_to_str(key), value
  314. if timeout and iterations * interval >= timeout:
  315. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  316. time.sleep(interval) # don't busy loop.
  317. iterations += 1
  318. def _forget(self, task_id):
  319. self.delete(self.get_key_for_task(task_id))
  320. def _store_result(self, task_id, result, status, traceback=None):
  321. meta = {'status': status, 'result': result, 'traceback': traceback,
  322. 'children': self.current_task_children()}
  323. self.set(self.get_key_for_task(task_id), self.encode(meta))
  324. return result
  325. def _save_group(self, group_id, result):
  326. self.set(self.get_key_for_group(group_id),
  327. self.encode({'result': result.serializable()}))
  328. return result
  329. def _delete_group(self, group_id):
  330. self.delete(self.get_key_for_group(group_id))
  331. def _get_task_meta_for(self, task_id):
  332. """Get task metadata for a task by id."""
  333. meta = self.get(self.get_key_for_task(task_id))
  334. if not meta:
  335. return {'status': states.PENDING, 'result': None}
  336. return self.decode(meta)
  337. def _restore_group(self, group_id):
  338. """Get task metadata for a task by id."""
  339. meta = self.get(self.get_key_for_group(group_id))
  340. # previously this was always pickled, but later this
  341. # was extended to support other serializers, so the
  342. # structure is kind of weird.
  343. if meta:
  344. meta = self.decode(meta)
  345. result = meta['result']
  346. meta['result'] = from_serializable(result, self.app)
  347. return meta
  348. def on_chord_apply(self, group_id, body, result=None, **kwargs):
  349. if self.implements_incr:
  350. self.save_group(group_id, self.app.GroupResult(group_id, result))
  351. else:
  352. self.fallback_chord_unlock(group_id, body, result, **kwargs)
  353. def on_chord_part_return(self, task, propagate=None):
  354. if not self.implements_incr:
  355. return
  356. from celery import subtask
  357. from celery.result import GroupResult
  358. app = self.app
  359. if propagate is None:
  360. propagate = self.app.conf.CELERY_CHORD_PROPAGATES
  361. gid = task.request.group
  362. if not gid:
  363. return
  364. key = self.get_key_for_chord(gid)
  365. deps = GroupResult.restore(gid, backend=task.backend)
  366. val = self.incr(key)
  367. if val >= len(deps):
  368. j = deps.join_native if deps.supports_native_join else deps.join
  369. callback = subtask(task.request.chord)
  370. try:
  371. ret = j(propagate=propagate)
  372. except Exception as exc:
  373. try:
  374. culprit = next(deps._failed_join_report())
  375. reason = 'Dependency {0.id} raised {1!r}'.format(
  376. culprit, exc,
  377. )
  378. except StopIteration:
  379. reason = repr(exc)
  380. app._tasks[callback.task].backend.fail_from_current_stack(
  381. callback.id, exc=ChordError(reason),
  382. )
  383. else:
  384. try:
  385. callback.delay(ret)
  386. except Exception as exc:
  387. app._tasks[callback.task].backend.fail_from_current_stack(
  388. callback.id,
  389. exc=ChordError('Callback error: {0!r}'.format(exc)),
  390. )
  391. finally:
  392. deps.delete()
  393. self.client.delete(key)
  394. else:
  395. self.expire(key, 86400)
  396. class DisabledBackend(BaseBackend):
  397. _cache = {} # need this attribute to reset cache in tests.
  398. def store_result(self, *args, **kwargs):
  399. pass
  400. def _is_disabled(self, *args, **kwargs):
  401. raise NotImplementedError(
  402. 'No result backend configured. '
  403. 'Please see the documentation for more information.')
  404. wait_for = get_status = get_result = get_traceback = _is_disabled