base.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from billiard.einfo import ExceptionInfo
  15. from kombu.serialization import (
  16. dumps, loads, prepare_accept_content,
  17. registry as serializer_registry,
  18. )
  19. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  20. from celery import states
  21. from celery import current_app, maybe_signature
  22. from celery.app import current_task
  23. from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
  24. from celery.five import items
  25. from celery.result import (
  26. GroupResult, ResultBase, allow_join_result, result_from_tuple,
  27. )
  28. from celery.utils import timeutils
  29. from celery.utils.functional import LRUCache
  30. from celery.utils.serialization import (
  31. get_pickled_exception,
  32. get_pickleable_exception,
  33. create_exception_cls,
  34. )
  35. __all__ = ['BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend']
  36. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  37. PY3 = sys.version_info >= (3, 0)
  38. def unpickle_backend(cls, args, kwargs):
  39. """Return an unpickled backend."""
  40. return cls(*args, app=current_app._get_current_object(), **kwargs)
  41. class BaseBackend(object):
  42. READY_STATES = states.READY_STATES
  43. UNREADY_STATES = states.UNREADY_STATES
  44. EXCEPTION_STATES = states.EXCEPTION_STATES
  45. TimeoutError = TimeoutError
  46. #: Time to sleep between polling each individual item
  47. #: in `ResultSet.iterate`. as opposed to the `interval`
  48. #: argument which is for each pass.
  49. subpolling_interval = None
  50. #: If true the backend must implement :meth:`get_many`.
  51. supports_native_join = False
  52. #: If true the backend must automatically expire results.
  53. #: The daily backend_cleanup periodic task will not be triggered
  54. #: in this case.
  55. supports_autoexpire = False
  56. #: Set to true if the backend is peristent by default.
  57. persistent = True
  58. retry_policy = {
  59. 'max_retries': 20,
  60. 'interval_start': 0,
  61. 'interval_step': 1,
  62. 'interval_max': 1,
  63. }
  64. def __init__(self, app, serializer=None,
  65. max_cached_results=None, accept=None, **kwargs):
  66. self.app = app
  67. conf = self.app.conf
  68. self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
  69. (self.content_type,
  70. self.content_encoding,
  71. self.encoder) = serializer_registry._encoders[self.serializer]
  72. self._cache = LRUCache(
  73. limit=max_cached_results or conf.CELERY_MAX_CACHED_RESULTS,
  74. )
  75. self.accept = prepare_accept_content(
  76. conf.CELERY_ACCEPT_CONTENT if accept is None else accept,
  77. )
  78. def mark_as_started(self, task_id, **meta):
  79. """Mark a task as started"""
  80. return self.store_result(task_id, meta, status=states.STARTED)
  81. def mark_as_done(self, task_id, result, request=None):
  82. """Mark task as successfully executed."""
  83. return self.store_result(task_id, result,
  84. status=states.SUCCESS, request=request)
  85. def mark_as_failure(self, task_id, exc, traceback=None, request=None):
  86. """Mark task as executed with failure. Stores the execption."""
  87. return self.store_result(task_id, exc, status=states.FAILURE,
  88. traceback=traceback, request=request)
  89. def fail_from_current_stack(self, task_id, exc=None):
  90. type_, real_exc, tb = sys.exc_info()
  91. try:
  92. exc = real_exc if exc is None else exc
  93. ei = ExceptionInfo((type_, exc, tb))
  94. self.mark_as_failure(task_id, exc, ei.traceback)
  95. return ei
  96. finally:
  97. del(tb)
  98. def mark_as_retry(self, task_id, exc, traceback=None, request=None):
  99. """Mark task as being retries. Stores the current
  100. exception (if any)."""
  101. return self.store_result(task_id, exc, status=states.RETRY,
  102. traceback=traceback, request=request)
  103. def mark_as_revoked(self, task_id, reason='', request=None):
  104. return self.store_result(task_id, TaskRevokedError(reason),
  105. status=states.REVOKED, traceback=None,
  106. request=request)
  107. def prepare_exception(self, exc):
  108. """Prepare exception for serialization."""
  109. if self.serializer in EXCEPTION_ABLE_CODECS:
  110. return get_pickleable_exception(exc)
  111. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  112. def exception_to_python(self, exc):
  113. """Convert serialized exception to Python exception."""
  114. if self.serializer in EXCEPTION_ABLE_CODECS:
  115. return get_pickled_exception(exc)
  116. return create_exception_cls(
  117. from_utf8(exc['exc_type']), __name__)(exc['exc_message'])
  118. def prepare_value(self, result):
  119. """Prepare value for storage."""
  120. if self.serializer != 'pickle' and isinstance(result, ResultBase):
  121. return result.as_tuple()
  122. return result
  123. def encode(self, data):
  124. _, _, payload = dumps(data, serializer=self.serializer)
  125. return payload
  126. def decode(self, payload):
  127. payload = PY3 and payload or str(payload)
  128. return loads(payload,
  129. content_type=self.content_type,
  130. content_encoding=self.content_encoding,
  131. accept=self.accept)
  132. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  133. """Wait for task and return its result.
  134. If the task raises an exception, this exception
  135. will be re-raised by :func:`wait_for`.
  136. If `timeout` is not :const:`None`, this raises the
  137. :class:`celery.exceptions.TimeoutError` exception if the operation
  138. takes longer than `timeout` seconds.
  139. """
  140. time_elapsed = 0.0
  141. while 1:
  142. status = self.get_status(task_id)
  143. if status == states.SUCCESS:
  144. return self.get_result(task_id)
  145. elif status in states.PROPAGATE_STATES:
  146. result = self.get_result(task_id)
  147. if propagate:
  148. raise result
  149. return result
  150. # avoid hammering the CPU checking status.
  151. time.sleep(interval)
  152. time_elapsed += interval
  153. if timeout and time_elapsed >= timeout:
  154. raise TimeoutError('The operation timed out.')
  155. def prepare_expires(self, value, type=None):
  156. if value is None:
  157. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  158. if isinstance(value, timedelta):
  159. value = timeutils.timedelta_seconds(value)
  160. if value is not None and type:
  161. return type(value)
  162. return value
  163. def prepare_persistent(self, enabled=None):
  164. if enabled is not None:
  165. return enabled
  166. p = self.app.conf.CELERY_RESULT_PERSISTENT
  167. return self.persistent if p is None else p
  168. def encode_result(self, result, status):
  169. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  170. return self.prepare_exception(result)
  171. else:
  172. return self.prepare_value(result)
  173. def is_cached(self, task_id):
  174. return task_id in self._cache
  175. def store_result(self, task_id, result, status,
  176. traceback=None, request=None, **kwargs):
  177. """Update task state and result."""
  178. result = self.encode_result(result, status)
  179. self._store_result(task_id, result, status, traceback,
  180. request=request, **kwargs)
  181. return result
  182. def forget(self, task_id):
  183. self._cache.pop(task_id, None)
  184. self._forget(task_id)
  185. def _forget(self, task_id):
  186. raise NotImplementedError('backend does not implement forget.')
  187. def get_status(self, task_id):
  188. """Get the status of a task."""
  189. return self.get_task_meta(task_id)['status']
  190. def get_traceback(self, task_id):
  191. """Get the traceback for a failed task."""
  192. return self.get_task_meta(task_id).get('traceback')
  193. def get_result(self, task_id):
  194. """Get the result of a task."""
  195. meta = self.get_task_meta(task_id)
  196. if meta['status'] in self.EXCEPTION_STATES:
  197. return self.exception_to_python(meta['result'])
  198. else:
  199. return meta['result']
  200. def get_children(self, task_id):
  201. """Get the list of subtasks sent by a task."""
  202. try:
  203. return self.get_task_meta(task_id)['children']
  204. except KeyError:
  205. pass
  206. def get_task_meta(self, task_id, cache=True):
  207. if cache:
  208. try:
  209. return self._cache[task_id]
  210. except KeyError:
  211. pass
  212. meta = self._get_task_meta_for(task_id)
  213. if cache and meta.get('status') == states.SUCCESS:
  214. self._cache[task_id] = meta
  215. return meta
  216. def reload_task_result(self, task_id):
  217. """Reload task result, even if it has been previously fetched."""
  218. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  219. def reload_group_result(self, group_id):
  220. """Reload group result, even if it has been previously fetched."""
  221. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  222. def get_group_meta(self, group_id, cache=True):
  223. if cache:
  224. try:
  225. return self._cache[group_id]
  226. except KeyError:
  227. pass
  228. meta = self._restore_group(group_id)
  229. if cache and meta is not None:
  230. self._cache[group_id] = meta
  231. return meta
  232. def restore_group(self, group_id, cache=True):
  233. """Get the result for a group."""
  234. meta = self.get_group_meta(group_id, cache=cache)
  235. if meta:
  236. return meta['result']
  237. def save_group(self, group_id, result):
  238. """Store the result of an executed group."""
  239. return self._save_group(group_id, result)
  240. def delete_group(self, group_id):
  241. self._cache.pop(group_id, None)
  242. return self._delete_group(group_id)
  243. def cleanup(self):
  244. """Backend cleanup. Is run by
  245. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  246. pass
  247. def process_cleanup(self):
  248. """Cleanup actions to do at the end of a task worker process."""
  249. pass
  250. def on_task_call(self, producer, task_id):
  251. return {}
  252. def on_chord_part_return(self, task, propagate=False):
  253. pass
  254. def fallback_chord_unlock(self, group_id, body, result=None,
  255. countdown=1, **kwargs):
  256. kwargs['result'] = [r.as_tuple() for r in result]
  257. self.app.tasks['celery.chord_unlock'].apply_async(
  258. (group_id, body, ), kwargs, countdown=countdown,
  259. )
  260. def apply_chord(self, header, partial_args, group_id, body, **options):
  261. result = header(*partial_args, task_id=group_id)
  262. self.fallback_chord_unlock(group_id, body, **options)
  263. return result
  264. def current_task_children(self, request=None):
  265. request = request or getattr(current_task(), 'request', None)
  266. if request:
  267. return [r.as_tuple() for r in getattr(request, 'children', [])]
  268. def __reduce__(self, args=(), kwargs={}):
  269. return (unpickle_backend, (self.__class__, args, kwargs))
  270. BaseDictBackend = BaseBackend # XXX compat
  271. class KeyValueStoreBackend(BaseBackend):
  272. key_t = ensure_bytes
  273. task_keyprefix = 'celery-task-meta-'
  274. group_keyprefix = 'celery-taskset-meta-'
  275. chord_keyprefix = 'chord-unlock-'
  276. implements_incr = False
  277. def __init__(self, *args, **kwargs):
  278. if hasattr(self.key_t, '__func__'):
  279. self.key_t = self.key_t.__func__ # remove binding
  280. self._encode_prefixes()
  281. super(KeyValueStoreBackend, self).__init__(*args, **kwargs)
  282. if self.implements_incr:
  283. self.apply_chord = self._apply_chord_incr
  284. def _encode_prefixes(self):
  285. self.task_keyprefix = self.key_t(self.task_keyprefix)
  286. self.group_keyprefix = self.key_t(self.group_keyprefix)
  287. self.chord_keyprefix = self.key_t(self.chord_keyprefix)
  288. def get(self, key):
  289. raise NotImplementedError('Must implement the get method.')
  290. def mget(self, keys):
  291. raise NotImplementedError('Does not support get_many')
  292. def set(self, key, value):
  293. raise NotImplementedError('Must implement the set method.')
  294. def delete(self, key):
  295. raise NotImplementedError('Must implement the delete method')
  296. def incr(self, key):
  297. raise NotImplementedError('Does not implement incr')
  298. def expire(self, key, value):
  299. pass
  300. def get_key_for_task(self, task_id):
  301. """Get the cache key for a task by id."""
  302. return self.task_keyprefix + self.key_t(task_id)
  303. def get_key_for_group(self, group_id):
  304. """Get the cache key for a group by id."""
  305. return self.group_keyprefix + self.key_t(group_id)
  306. def get_key_for_chord(self, group_id):
  307. """Get the cache key for the chord waiting on group with given id."""
  308. return self.chord_keyprefix + self.key_t(group_id)
  309. def _strip_prefix(self, key):
  310. """Takes bytes, emits string."""
  311. key = self.key_t(key)
  312. for prefix in self.task_keyprefix, self.group_keyprefix:
  313. if key.startswith(prefix):
  314. return bytes_to_str(key[len(prefix):])
  315. return bytes_to_str(key)
  316. def _mget_to_results(self, values, keys):
  317. if hasattr(values, 'items'):
  318. # client returns dict so mapping preserved.
  319. return dict((self._strip_prefix(k), self.decode(v))
  320. for k, v in items(values)
  321. if v is not None)
  322. else:
  323. # client returns list so need to recreate mapping.
  324. return dict((bytes_to_str(keys[i]), self.decode(value))
  325. for i, value in enumerate(values)
  326. if value is not None)
  327. def get_many(self, task_ids, timeout=None, interval=0.5,
  328. READY_STATES=states.READY_STATES):
  329. interval = 0.5 if interval is None else interval
  330. ids = task_ids if isinstance(task_ids, set) else set(task_ids)
  331. cached_ids = set()
  332. cache = self._cache
  333. for task_id in ids:
  334. try:
  335. cached = cache[task_id]
  336. except KeyError:
  337. pass
  338. else:
  339. if cached['status'] in READY_STATES:
  340. yield bytes_to_str(task_id), cached
  341. cached_ids.add(task_id)
  342. ids.difference_update(cached_ids)
  343. iterations = 0
  344. while ids:
  345. keys = list(ids)
  346. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  347. for k in keys]), keys)
  348. cache.update(r)
  349. ids.difference_update(set(bytes_to_str(v) for v in r))
  350. for key, value in items(r):
  351. yield bytes_to_str(key), value
  352. if timeout and iterations * interval >= timeout:
  353. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  354. time.sleep(interval) # don't busy loop.
  355. iterations += 1
  356. def _forget(self, task_id):
  357. self.delete(self.get_key_for_task(task_id))
  358. def _store_result(self, task_id, result, status,
  359. traceback=None, request=None, **kwargs):
  360. meta = {'status': status, 'result': result, 'traceback': traceback,
  361. 'children': self.current_task_children(request)}
  362. self.set(self.get_key_for_task(task_id), self.encode(meta))
  363. return result
  364. def _save_group(self, group_id, result):
  365. self.set(self.get_key_for_group(group_id),
  366. self.encode({'result': result.as_tuple()}))
  367. return result
  368. def _delete_group(self, group_id):
  369. self.delete(self.get_key_for_group(group_id))
  370. def _get_task_meta_for(self, task_id):
  371. """Get task metadata for a task by id."""
  372. meta = self.get(self.get_key_for_task(task_id))
  373. if not meta:
  374. return {'status': states.PENDING, 'result': None}
  375. return self.decode(meta)
  376. def _restore_group(self, group_id):
  377. """Get task metadata for a task by id."""
  378. meta = self.get(self.get_key_for_group(group_id))
  379. # previously this was always pickled, but later this
  380. # was extended to support other serializers, so the
  381. # structure is kind of weird.
  382. if meta:
  383. meta = self.decode(meta)
  384. result = meta['result']
  385. meta['result'] = result_from_tuple(result, self.app)
  386. return meta
  387. def _apply_chord_incr(self, header, partial_args, group_id, body,
  388. result=None, **options):
  389. self.save_group(group_id, self.app.GroupResult(group_id, result))
  390. return header(*partial_args, task_id=group_id)
  391. def on_chord_part_return(self, task, propagate=None):
  392. if not self.implements_incr:
  393. return
  394. app = self.app
  395. if propagate is None:
  396. propagate = self.app.conf.CELERY_CHORD_PROPAGATES
  397. gid = task.request.group
  398. if not gid:
  399. return
  400. key = self.get_key_for_chord(gid)
  401. try:
  402. deps = GroupResult.restore(gid, backend=task.backend)
  403. except Exception as exc:
  404. callback = maybe_signature(task.request.chord, app=self.app)
  405. return app._tasks[callback.task].backend.fail_from_current_stack(
  406. callback.id,
  407. exc=ChordError('Cannot restore group: {0!r}'.format(exc)),
  408. )
  409. if deps is None:
  410. try:
  411. raise ValueError(gid)
  412. except ValueError as exc:
  413. callback = maybe_signature(task.request.chord, app=self.app)
  414. task = app._tasks[callback.task]
  415. return task.backend.fail_from_current_stack(
  416. callback.id,
  417. exc=ChordError('GroupResult {0} no longer exists'.format(
  418. gid,
  419. ))
  420. )
  421. val = self.incr(key)
  422. if val >= len(deps):
  423. callback = maybe_signature(task.request.chord, app=self.app)
  424. j = deps.join_native if deps.supports_native_join else deps.join
  425. try:
  426. with allow_join_result():
  427. ret = j(timeout=3.0, propagate=propagate)
  428. except Exception as exc:
  429. try:
  430. culprit = next(deps._failed_join_report())
  431. reason = 'Dependency {0.id} raised {1!r}'.format(
  432. culprit, exc,
  433. )
  434. except StopIteration:
  435. reason = repr(exc)
  436. app._tasks[callback.task].backend.fail_from_current_stack(
  437. callback.id, exc=ChordError(reason),
  438. )
  439. else:
  440. try:
  441. callback.delay(ret)
  442. except Exception as exc:
  443. app._tasks[callback.task].backend.fail_from_current_stack(
  444. callback.id,
  445. exc=ChordError('Callback error: {0!r}'.format(exc)),
  446. )
  447. finally:
  448. deps.delete()
  449. self.client.delete(key)
  450. else:
  451. self.expire(key, 86400)
  452. class DisabledBackend(BaseBackend):
  453. _cache = {} # need this attribute to reset cache in tests.
  454. def store_result(self, *args, **kwargs):
  455. pass
  456. def _is_disabled(self, *args, **kwargs):
  457. raise NotImplementedError(
  458. 'No result backend configured. '
  459. 'Please see the documentation for more information.')
  460. wait_for = get_status = get_result = get_traceback = _is_disabled