base.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from billiard.einfo import ExceptionInfo
  15. from kombu.serialization import (
  16. dumps, loads, prepare_accept_content,
  17. registry as serializer_registry,
  18. )
  19. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  20. from celery import states
  21. from celery import current_app, maybe_signature
  22. from celery.app import current_task
  23. from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
  24. from celery.five import items
  25. from celery.result import (
  26. GroupResult, ResultBase, allow_join_result, result_from_tuple,
  27. )
  28. from celery.utils.functional import LRUCache
  29. from celery.utils.log import get_logger
  30. from celery.utils.serialization import (
  31. get_pickled_exception,
  32. get_pickleable_exception,
  33. create_exception_cls,
  34. )
  35. __all__ = ['BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend']
  36. EXCEPTION_ABLE_CODECS = frozenset({'pickle'})
  37. PY3 = sys.version_info >= (3, 0)
  38. logger = get_logger(__name__)
  39. def unpickle_backend(cls, args, kwargs):
  40. """Return an unpickled backend."""
  41. return cls(*args, app=current_app._get_current_object(), **kwargs)
  42. class _nulldict(dict):
  43. def ignore(self, *a, **kw):
  44. pass
  45. __setitem__ = update = setdefault = ignore
  46. class BaseBackend(object):
  47. READY_STATES = states.READY_STATES
  48. UNREADY_STATES = states.UNREADY_STATES
  49. EXCEPTION_STATES = states.EXCEPTION_STATES
  50. TimeoutError = TimeoutError
  51. #: Time to sleep between polling each individual item
  52. #: in `ResultSet.iterate`. as opposed to the `interval`
  53. #: argument which is for each pass.
  54. subpolling_interval = None
  55. #: If true the backend must implement :meth:`get_many`.
  56. supports_native_join = False
  57. #: If true the backend must automatically expire results.
  58. #: The daily backend_cleanup periodic task will not be triggered
  59. #: in this case.
  60. supports_autoexpire = False
  61. #: Set to true if the backend is peristent by default.
  62. persistent = True
  63. retry_policy = {
  64. 'max_retries': 20,
  65. 'interval_start': 0,
  66. 'interval_step': 1,
  67. 'interval_max': 1,
  68. }
  69. def __init__(self, app,
  70. serializer=None, max_cached_results=None, accept=None,
  71. expires=None, expires_type=None, **kwargs):
  72. self.app = app
  73. conf = self.app.conf
  74. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  75. (self.content_type,
  76. self.content_encoding,
  77. self.encoder) = serializer_registry._encoders[self.serializer]
  78. cmax = max_cached_results or conf.CELERY_MAX_CACHED_RESULTS
  79. self._cache = _nulldict() if cmax == -1 else LRUCache(limit=cmax)
  80. self.expires = self.prepare_expires(expires, expires_type)
  81. self.accept = prepare_accept_content(
  82. conf.CELERY_ACCEPT_CONTENT if accept is None else accept,
  83. )
  84. def mark_as_started(self, task_id, **meta):
  85. """Mark a task as started"""
  86. return self.store_result(task_id, meta, status=states.STARTED)
  87. def mark_as_done(self, task_id, result, request=None):
  88. """Mark task as successfully executed."""
  89. return self.store_result(task_id, result,
  90. status=states.SUCCESS, request=request)
  91. def mark_as_failure(self, task_id, exc, traceback=None, request=None):
  92. """Mark task as executed with failure. Stores the exception."""
  93. return self.store_result(task_id, exc, status=states.FAILURE,
  94. traceback=traceback, request=request)
  95. def chord_error_from_stack(self, callback, exc=None):
  96. from celery import group
  97. app = self.app
  98. backend = app._tasks[callback.task].backend
  99. try:
  100. group(
  101. [app.signature(errback)
  102. for errback in callback.options.get('link_error') or []],
  103. app=app,
  104. ).apply_async((callback.id,))
  105. except Exception as eb_exc:
  106. return backend.fail_from_current_stack(callback.id, exc=eb_exc)
  107. else:
  108. return backend.fail_from_current_stack(callback.id, exc=exc)
  109. def fail_from_current_stack(self, task_id, exc=None):
  110. type_, real_exc, tb = sys.exc_info()
  111. try:
  112. exc = real_exc if exc is None else exc
  113. ei = ExceptionInfo((type_, exc, tb))
  114. self.mark_as_failure(task_id, exc, ei.traceback)
  115. return ei
  116. finally:
  117. del(tb)
  118. def mark_as_retry(self, task_id, exc, traceback=None, request=None):
  119. """Mark task as being retries. Stores the current
  120. exception (if any)."""
  121. return self.store_result(task_id, exc, status=states.RETRY,
  122. traceback=traceback, request=request)
  123. def mark_as_revoked(self, task_id, reason='', request=None):
  124. return self.store_result(task_id, TaskRevokedError(reason),
  125. status=states.REVOKED, traceback=None,
  126. request=request)
  127. def prepare_exception(self, exc, serializer=None):
  128. """Prepare exception for serialization."""
  129. serializer = self.serializer if serializer is None else serializer
  130. if serializer in EXCEPTION_ABLE_CODECS:
  131. return get_pickleable_exception(exc)
  132. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  133. def exception_to_python(self, exc):
  134. """Convert serialized exception to Python exception."""
  135. if not isinstance(exc, BaseException):
  136. exc = create_exception_cls(
  137. from_utf8(exc['exc_type']), __name__)(exc['exc_message'])
  138. if self.serializer in EXCEPTION_ABLE_CODECS:
  139. exc = get_pickled_exception(exc)
  140. return exc
  141. def prepare_value(self, result):
  142. """Prepare value for storage."""
  143. if self.serializer != 'pickle' and isinstance(result, ResultBase):
  144. return result.as_tuple()
  145. return result
  146. def encode(self, data):
  147. _, _, payload = dumps(data, serializer=self.serializer)
  148. return payload
  149. def meta_from_decoded(self, meta):
  150. if meta['status'] in self.EXCEPTION_STATES:
  151. meta['result'] = self.exception_to_python(meta['result'])
  152. return meta
  153. def decode_result(self, payload):
  154. return self.meta_from_decoded(self.decode(payload))
  155. def decode(self, payload):
  156. payload = PY3 and payload or str(payload)
  157. return loads(payload,
  158. content_type=self.content_type,
  159. content_encoding=self.content_encoding,
  160. accept=self.accept)
  161. def wait_for(self, task_id,
  162. timeout=None, interval=0.5, no_ack=True, on_interval=None):
  163. """Wait for task and return its result.
  164. If the task raises an exception, this exception
  165. will be re-raised by :func:`wait_for`.
  166. If `timeout` is not :const:`None`, this raises the
  167. :class:`celery.exceptions.TimeoutError` exception if the operation
  168. takes longer than `timeout` seconds.
  169. """
  170. time_elapsed = 0.0
  171. while 1:
  172. meta = self.get_task_meta(task_id)
  173. if meta['status'] in states.READY_STATES:
  174. return meta
  175. if on_interval:
  176. on_interval()
  177. # avoid hammering the CPU checking status.
  178. time.sleep(interval)
  179. time_elapsed += interval
  180. if timeout and time_elapsed >= timeout:
  181. raise TimeoutError('The operation timed out.')
  182. def prepare_expires(self, value, type=None):
  183. if value is None:
  184. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  185. if isinstance(value, timedelta):
  186. value = value.total_seconds()
  187. if value is not None and type:
  188. return type(value)
  189. return value
  190. def prepare_persistent(self, enabled=None):
  191. if enabled is not None:
  192. return enabled
  193. p = self.app.conf.CELERY_RESULT_PERSISTENT
  194. return self.persistent if p is None else p
  195. def encode_result(self, result, status):
  196. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  197. return self.prepare_exception(result)
  198. else:
  199. return self.prepare_value(result)
  200. def is_cached(self, task_id):
  201. return task_id in self._cache
  202. def store_result(self, task_id, result, status,
  203. traceback=None, request=None, **kwargs):
  204. """Update task state and result."""
  205. result = self.encode_result(result, status)
  206. self._store_result(task_id, result, status, traceback,
  207. request=request, **kwargs)
  208. return result
  209. def forget(self, task_id):
  210. self._cache.pop(task_id, None)
  211. self._forget(task_id)
  212. def _forget(self, task_id):
  213. raise NotImplementedError('backend does not implement forget.')
  214. def get_status(self, task_id):
  215. """Get the status of a task."""
  216. return self.get_task_meta(task_id)['status']
  217. def get_traceback(self, task_id):
  218. """Get the traceback for a failed task."""
  219. return self.get_task_meta(task_id).get('traceback')
  220. def get_result(self, task_id):
  221. """Get the result of a task."""
  222. return self.get_task_meta(task_id).get('result')
  223. def get_children(self, task_id):
  224. """Get the list of subtasks sent by a task."""
  225. try:
  226. return self.get_task_meta(task_id)['children']
  227. except KeyError:
  228. pass
  229. def get_task_meta(self, task_id, cache=True):
  230. if cache:
  231. try:
  232. return self._cache[task_id]
  233. except KeyError:
  234. pass
  235. meta = self._get_task_meta_for(task_id)
  236. if cache and meta.get('status') == states.SUCCESS:
  237. self._cache[task_id] = meta
  238. return meta
  239. def reload_task_result(self, task_id):
  240. """Reload task result, even if it has been previously fetched."""
  241. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  242. def reload_group_result(self, group_id):
  243. """Reload group result, even if it has been previously fetched."""
  244. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  245. def get_group_meta(self, group_id, cache=True):
  246. if cache:
  247. try:
  248. return self._cache[group_id]
  249. except KeyError:
  250. pass
  251. meta = self._restore_group(group_id)
  252. if cache and meta is not None:
  253. self._cache[group_id] = meta
  254. return meta
  255. def restore_group(self, group_id, cache=True):
  256. """Get the result for a group."""
  257. meta = self.get_group_meta(group_id, cache=cache)
  258. if meta:
  259. return meta['result']
  260. def save_group(self, group_id, result):
  261. """Store the result of an executed group."""
  262. return self._save_group(group_id, result)
  263. def delete_group(self, group_id):
  264. self._cache.pop(group_id, None)
  265. return self._delete_group(group_id)
  266. def cleanup(self):
  267. """Backend cleanup. Is run by
  268. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  269. pass
  270. def process_cleanup(self):
  271. """Cleanup actions to do at the end of a task worker process."""
  272. pass
  273. def on_task_call(self, producer, task_id):
  274. return {}
  275. def add_to_chord(self, chord_id, result):
  276. raise NotImplementedError('Backend does not support add_to_chord')
  277. def on_chord_part_return(self, task, state, result, propagate=False):
  278. pass
  279. def fallback_chord_unlock(self, group_id, body, result=None,
  280. countdown=1, **kwargs):
  281. kwargs['result'] = [r.as_tuple() for r in result]
  282. self.app.tasks['celery.chord_unlock'].apply_async(
  283. (group_id, body,), kwargs, countdown=countdown,
  284. )
  285. def apply_chord(self, header, partial_args, group_id, body,
  286. options={}, **kwargs):
  287. fixed_options = {k: v for k,v in options.items() if k!='task_id'}
  288. result = header(*partial_args, task_id=group_id, **fixed_options or {})
  289. self.fallback_chord_unlock(group_id, body, **kwargs)
  290. return result
  291. def current_task_children(self, request=None):
  292. request = request or getattr(current_task(), 'request', None)
  293. if request:
  294. return [r.as_tuple() for r in getattr(request, 'children', [])]
  295. def __reduce__(self, args=(), kwargs={}):
  296. return (unpickle_backend, (self.__class__, args, kwargs))
  297. BaseDictBackend = BaseBackend # XXX compat
  298. class KeyValueStoreBackend(BaseBackend):
  299. key_t = ensure_bytes
  300. task_keyprefix = 'celery-task-meta-'
  301. group_keyprefix = 'celery-taskset-meta-'
  302. chord_keyprefix = 'chord-unlock-'
  303. implements_incr = False
  304. def __init__(self, *args, **kwargs):
  305. if hasattr(self.key_t, '__func__'):
  306. self.key_t = self.key_t.__func__ # remove binding
  307. self._encode_prefixes()
  308. super(KeyValueStoreBackend, self).__init__(*args, **kwargs)
  309. if self.implements_incr:
  310. self.apply_chord = self._apply_chord_incr
  311. def _encode_prefixes(self):
  312. self.task_keyprefix = self.key_t(self.task_keyprefix)
  313. self.group_keyprefix = self.key_t(self.group_keyprefix)
  314. self.chord_keyprefix = self.key_t(self.chord_keyprefix)
  315. def get(self, key):
  316. raise NotImplementedError('Must implement the get method.')
  317. def mget(self, keys):
  318. raise NotImplementedError('Does not support get_many')
  319. def set(self, key, value):
  320. raise NotImplementedError('Must implement the set method.')
  321. def delete(self, key):
  322. raise NotImplementedError('Must implement the delete method')
  323. def incr(self, key):
  324. raise NotImplementedError('Does not implement incr')
  325. def expire(self, key, value):
  326. pass
  327. def get_key_for_task(self, task_id, key=''):
  328. """Get the cache key for a task by id."""
  329. key_t = self.key_t
  330. return key_t('').join([
  331. self.task_keyprefix, key_t(task_id), key_t(key),
  332. ])
  333. def get_key_for_group(self, group_id, key=''):
  334. """Get the cache key for a group by id."""
  335. key_t = self.key_t
  336. return key_t('').join([
  337. self.group_keyprefix, key_t(group_id), key_t(key),
  338. ])
  339. def get_key_for_chord(self, group_id, key=''):
  340. """Get the cache key for the chord waiting on group with given id."""
  341. key_t = self.key_t
  342. return key_t('').join([
  343. self.chord_keyprefix, key_t(group_id), key_t(key),
  344. ])
  345. def _strip_prefix(self, key):
  346. """Takes bytes, emits string."""
  347. key = self.key_t(key)
  348. for prefix in self.task_keyprefix, self.group_keyprefix:
  349. if key.startswith(prefix):
  350. return bytes_to_str(key[len(prefix):])
  351. return bytes_to_str(key)
  352. def _filter_ready(self, values, READY_STATES=states.READY_STATES):
  353. for k, v in values:
  354. if v is not None:
  355. v = self.decode_result(v)
  356. if v['status'] in READY_STATES:
  357. yield k, v
  358. def _mget_to_results(self, values, keys):
  359. if hasattr(values, 'items'):
  360. # client returns dict so mapping preserved.
  361. return {
  362. self._strip_prefix(k): v
  363. for k, v in self._filter_ready(items(values))
  364. }
  365. else:
  366. # client returns list so need to recreate mapping.
  367. return {
  368. bytes_to_str(keys[i]): v
  369. for i, v in self._filter_ready(enumerate(values))
  370. }
  371. def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
  372. on_message=None,
  373. READY_STATES=states.READY_STATES):
  374. interval = 0.5 if interval is None else interval
  375. ids = task_ids if isinstance(task_ids, set) else set(task_ids)
  376. cached_ids = set()
  377. cache = self._cache
  378. for task_id in ids:
  379. try:
  380. cached = cache[task_id]
  381. except KeyError:
  382. pass
  383. else:
  384. if cached['status'] in READY_STATES:
  385. yield bytes_to_str(task_id), cached
  386. cached_ids.add(task_id)
  387. ids.difference_update(cached_ids)
  388. iterations = 0
  389. while ids:
  390. keys = list(ids)
  391. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  392. for k in keys]), keys)
  393. cache.update(r)
  394. ids.difference_update({bytes_to_str(v) for v in r})
  395. for key, value in items(r):
  396. if on_message is not None:
  397. on_message(value)
  398. yield bytes_to_str(key), value
  399. if timeout and iterations * interval >= timeout:
  400. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  401. time.sleep(interval) # don't busy loop.
  402. iterations += 1
  403. def _forget(self, task_id):
  404. self.delete(self.get_key_for_task(task_id))
  405. def _store_result(self, task_id, result, status,
  406. traceback=None, request=None, **kwargs):
  407. meta = {'status': status, 'result': result, 'traceback': traceback,
  408. 'children': self.current_task_children(request)}
  409. self.set(self.get_key_for_task(task_id), self.encode(meta))
  410. return result
  411. def _save_group(self, group_id, result):
  412. self.set(self.get_key_for_group(group_id),
  413. self.encode({'result': result.as_tuple()}))
  414. return result
  415. def _delete_group(self, group_id):
  416. self.delete(self.get_key_for_group(group_id))
  417. def _get_task_meta_for(self, task_id):
  418. """Get task metadata for a task by id."""
  419. meta = self.get(self.get_key_for_task(task_id))
  420. if not meta:
  421. return {'status': states.PENDING, 'result': None}
  422. return self.decode_result(meta)
  423. def _restore_group(self, group_id):
  424. """Get task metadata for a task by id."""
  425. meta = self.get(self.get_key_for_group(group_id))
  426. # previously this was always pickled, but later this
  427. # was extended to support other serializers, so the
  428. # structure is kind of weird.
  429. if meta:
  430. meta = self.decode(meta)
  431. result = meta['result']
  432. meta['result'] = result_from_tuple(result, self.app)
  433. return meta
  434. def _apply_chord_incr(self, header, partial_args, group_id, body,
  435. result=None, options={}, **kwargs):
  436. self.save_group(group_id, self.app.GroupResult(group_id, result))
  437. fixed_options = {k: v for k,v in options.items() if k != 'task_id'}
  438. return header(*partial_args, task_id=group_id, **fixed_options or {})
  439. def on_chord_part_return(self, task, state, result, propagate=None):
  440. if not self.implements_incr:
  441. return
  442. app = self.app
  443. if propagate is None:
  444. propagate = app.conf.CELERY_CHORD_PROPAGATES
  445. gid = task.request.group
  446. if not gid:
  447. return
  448. key = self.get_key_for_chord(gid)
  449. try:
  450. deps = GroupResult.restore(gid, backend=task.backend)
  451. except Exception as exc:
  452. callback = maybe_signature(task.request.chord, app=app)
  453. logger.error('Chord %r raised: %r', gid, exc, exc_info=1)
  454. return self.chord_error_from_stack(
  455. callback,
  456. ChordError('Cannot restore group: {0!r}'.format(exc)),
  457. )
  458. if deps is None:
  459. try:
  460. raise ValueError(gid)
  461. except ValueError as exc:
  462. callback = maybe_signature(task.request.chord, app=app)
  463. logger.error('Chord callback %r raised: %r', gid, exc,
  464. exc_info=1)
  465. return self.chord_error_from_stack(
  466. callback,
  467. ChordError('GroupResult {0} no longer exists'.format(gid)),
  468. )
  469. val = self.incr(key)
  470. size = len(deps)
  471. if val > size:
  472. logger.warning('Chord counter incremented too many times for %r',
  473. gid)
  474. elif val == size:
  475. callback = maybe_signature(task.request.chord, app=app)
  476. j = deps.join_native if deps.supports_native_join else deps.join
  477. try:
  478. with allow_join_result():
  479. ret = j(timeout=3.0, propagate=propagate)
  480. except Exception as exc:
  481. try:
  482. culprit = next(deps._failed_join_report())
  483. reason = 'Dependency {0.id} raised {1!r}'.format(
  484. culprit, exc,
  485. )
  486. except StopIteration:
  487. reason = repr(exc)
  488. logger.error('Chord %r raised: %r', gid, reason, exc_info=1)
  489. self.chord_error_from_stack(callback, ChordError(reason))
  490. else:
  491. try:
  492. callback.delay(ret)
  493. except Exception as exc:
  494. logger.error('Chord %r raised: %r', gid, exc, exc_info=1)
  495. self.chord_error_from_stack(
  496. callback,
  497. ChordError('Callback error: {0!r}'.format(exc)),
  498. )
  499. finally:
  500. deps.delete()
  501. self.client.delete(key)
  502. else:
  503. self.expire(key, 86400)
  504. class DisabledBackend(BaseBackend):
  505. _cache = {} # need this attribute to reset cache in tests.
  506. def store_result(self, *args, **kwargs):
  507. pass
  508. def _is_disabled(self, *args, **kwargs):
  509. raise NotImplementedError(
  510. 'No result backend configured. '
  511. 'Please see the documentation for more information.')
  512. wait_for = get_status = get_result = get_traceback = _is_disabled