base.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`BaseDictBackend` assumes the fields are stored in a dict.
  8. - :class:`KeyValueStoreBackend` is a common base class
  9. using K/V semantics like _get and _put.
  10. """
  11. from __future__ import absolute_import
  12. import time
  13. import sys
  14. from datetime import timedelta
  15. from billiard.einfo import ExceptionInfo
  16. from kombu import serialization
  17. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  18. from celery import states
  19. from celery.app import current_task
  20. from celery.datastructures import LRUCache
  21. from celery.exceptions import ChordError, TaskRevokedError, TimeoutError
  22. from celery.result import from_serializable, GroupResult
  23. from celery.utils import timeutils
  24. from celery.utils.serialization import (
  25. get_pickled_exception,
  26. get_pickleable_exception,
  27. create_exception_cls,
  28. )
  29. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  30. is_py3k = sys.version_info >= (3, 0)
  31. def unpickle_backend(cls, args, kwargs):
  32. """Returns an unpickled backend."""
  33. return cls(*args, **kwargs)
  34. class BaseBackend(object):
  35. """Base backend class."""
  36. READY_STATES = states.READY_STATES
  37. UNREADY_STATES = states.UNREADY_STATES
  38. EXCEPTION_STATES = states.EXCEPTION_STATES
  39. TimeoutError = TimeoutError
  40. #: Time to sleep between polling each individual item
  41. #: in `ResultSet.iterate`. as opposed to the `interval`
  42. #: argument which is for each pass.
  43. subpolling_interval = None
  44. #: If true the backend must implement :meth:`get_many`.
  45. supports_native_join = False
  46. def __init__(self, *args, **kwargs):
  47. from celery.app import app_or_default
  48. self.app = app_or_default(kwargs.get('app'))
  49. self.serializer = kwargs.get('serializer',
  50. self.app.conf.CELERY_RESULT_SERIALIZER)
  51. (self.content_type,
  52. self.content_encoding,
  53. self.encoder) = serialization.registry._encoders[self.serializer]
  54. def encode(self, data):
  55. _, _, payload = serialization.encode(data, serializer=self.serializer)
  56. return payload
  57. def decode(self, payload):
  58. payload = is_py3k and payload or str(payload)
  59. return serialization.decode(payload,
  60. content_type=self.content_type,
  61. content_encoding=self.content_encoding)
  62. def prepare_expires(self, value, type=None):
  63. if value is None:
  64. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  65. if isinstance(value, timedelta):
  66. value = timeutils.timedelta_seconds(value)
  67. if value is not None and type:
  68. return type(value)
  69. return value
  70. def encode_result(self, result, status):
  71. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  72. return self.prepare_exception(result)
  73. else:
  74. return self.prepare_value(result)
  75. def store_result(self, task_id, result, status, traceback=None):
  76. """Store the result and status of a task."""
  77. raise NotImplementedError(
  78. 'store_result is not supported by this backend.')
  79. def mark_as_started(self, task_id, **meta):
  80. """Mark a task as started"""
  81. return self.store_result(task_id, meta, status=states.STARTED)
  82. def mark_as_done(self, task_id, result):
  83. """Mark task as successfully executed."""
  84. return self.store_result(task_id, result, status=states.SUCCESS)
  85. def mark_as_failure(self, task_id, exc, traceback=None):
  86. """Mark task as executed with failure. Stores the execption."""
  87. return self.store_result(task_id, exc, status=states.FAILURE,
  88. traceback=traceback)
  89. def fail_from_current_stack(self, task_id, exc=None):
  90. type_, real_exc, tb = sys.exc_info()
  91. try:
  92. exc = real_exc if exc is None else exc
  93. ei = ExceptionInfo((type_, exc, tb))
  94. self.mark_as_failure(task_id, exc, ei.traceback)
  95. return ei
  96. finally:
  97. del(tb)
  98. def mark_as_retry(self, task_id, exc, traceback=None):
  99. """Mark task as being retries. Stores the current
  100. exception (if any)."""
  101. return self.store_result(task_id, exc, status=states.RETRY,
  102. traceback=traceback)
  103. def mark_as_revoked(self, task_id, reason=''):
  104. return self.store_result(task_id, TaskRevokedError(reason),
  105. status=states.REVOKED, traceback=None)
  106. def prepare_exception(self, exc):
  107. """Prepare exception for serialization."""
  108. if self.serializer in EXCEPTION_ABLE_CODECS:
  109. return get_pickleable_exception(exc)
  110. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  111. def exception_to_python(self, exc):
  112. """Convert serialized exception to Python exception."""
  113. if self.serializer in EXCEPTION_ABLE_CODECS:
  114. return get_pickled_exception(exc)
  115. return create_exception_cls(from_utf8(exc['exc_type']),
  116. sys.modules[__name__])(exc['exc_message'])
  117. def prepare_value(self, result):
  118. """Prepare value for storage."""
  119. if isinstance(result, GroupResult):
  120. return result.serializable()
  121. return result
  122. def forget(self, task_id):
  123. raise NotImplementedError('%s does not implement forget.' % (
  124. self.__class__))
  125. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  126. """Wait for task and return its result.
  127. If the task raises an exception, this exception
  128. will be re-raised by :func:`wait_for`.
  129. If `timeout` is not :const:`None`, this raises the
  130. :class:`celery.exceptions.TimeoutError` exception if the operation
  131. takes longer than `timeout` seconds.
  132. """
  133. time_elapsed = 0.0
  134. while 1:
  135. status = self.get_status(task_id)
  136. if status == states.SUCCESS:
  137. return self.get_result(task_id)
  138. elif status in states.PROPAGATE_STATES:
  139. result = self.get_result(task_id)
  140. if propagate:
  141. raise result
  142. return result
  143. # avoid hammering the CPU checking status.
  144. time.sleep(interval)
  145. time_elapsed += interval
  146. if timeout and time_elapsed >= timeout:
  147. raise TimeoutError('The operation timed out.')
  148. def cleanup(self):
  149. """Backend cleanup. Is run by
  150. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  151. pass
  152. def process_cleanup(self):
  153. """Cleanup actions to do at the end of a task worker process."""
  154. pass
  155. def get_status(self, task_id):
  156. """Get the status of a task."""
  157. raise NotImplementedError(
  158. 'get_status is not supported by this backend.')
  159. def get_result(self, task_id):
  160. """Get the result of a task."""
  161. raise NotImplementedError(
  162. 'get_result is not supported by this backend.')
  163. def get_children(self, task_id):
  164. raise NotImplementedError(
  165. 'get_children is not supported by this backend.')
  166. def get_traceback(self, task_id):
  167. """Get the traceback for a failed task."""
  168. raise NotImplementedError(
  169. 'get_traceback is not supported by this backend.')
  170. def save_group(self, group_id, result):
  171. """Store the result and status of a task."""
  172. raise NotImplementedError(
  173. 'save_group is not supported by this backend.')
  174. def restore_group(self, group_id, cache=True):
  175. """Get the result of a group."""
  176. raise NotImplementedError(
  177. 'restore_group is not supported by this backend.')
  178. def delete_group(self, group_id):
  179. raise NotImplementedError(
  180. 'delete_group is not supported by this backend.')
  181. def reload_task_result(self, task_id):
  182. """Reload task result, even if it has been previously fetched."""
  183. raise NotImplementedError(
  184. 'reload_task_result is not supported by this backend.')
  185. def reload_group_result(self, task_id):
  186. """Reload group result, even if it has been previously fetched."""
  187. raise NotImplementedError(
  188. 'reload_group_result is not supported by this backend.')
  189. def on_chord_part_return(self, task, propagate=True):
  190. pass
  191. def fallback_chord_unlock(self, group_id, body, result=None,
  192. countdown=1, **kwargs):
  193. kwargs['result'] = [r.id for r in result]
  194. self.app.tasks['celery.chord_unlock'].apply_async(
  195. (group_id, body, ), kwargs, countdown=countdown,
  196. )
  197. on_chord_apply = fallback_chord_unlock
  198. def current_task_children(self):
  199. current = current_task()
  200. if current:
  201. return [r.serializable() for r in current.request.children]
  202. def __reduce__(self, args=(), kwargs={}):
  203. return (unpickle_backend, (self.__class__, args, kwargs))
  204. def is_cached(self, task_id):
  205. return False
  206. class BaseDictBackend(BaseBackend):
  207. def __init__(self, *args, **kwargs):
  208. super(BaseDictBackend, self).__init__(*args, **kwargs)
  209. self._cache = LRUCache(limit=kwargs.get('max_cached_results') or
  210. self.app.conf.CELERY_MAX_CACHED_RESULTS)
  211. def is_cached(self, task_id):
  212. return task_id in self._cache
  213. def store_result(self, task_id, result, status, traceback=None, **kwargs):
  214. """Store task result and status."""
  215. result = self.encode_result(result, status)
  216. self._store_result(task_id, result, status, traceback, **kwargs)
  217. return result
  218. def forget(self, task_id):
  219. self._cache.pop(task_id, None)
  220. self._forget(task_id)
  221. def _forget(self, task_id):
  222. raise NotImplementedError('%s does not implement forget.' % (
  223. self.__class__))
  224. def get_status(self, task_id):
  225. """Get the status of a task."""
  226. return self.get_task_meta(task_id)['status']
  227. def get_traceback(self, task_id):
  228. """Get the traceback for a failed task."""
  229. return self.get_task_meta(task_id).get('traceback')
  230. def get_result(self, task_id):
  231. """Get the result of a task."""
  232. meta = self.get_task_meta(task_id)
  233. if meta['status'] in self.EXCEPTION_STATES:
  234. return self.exception_to_python(meta['result'])
  235. else:
  236. return meta['result']
  237. def get_children(self, task_id):
  238. """Get the list of subtasks sent by a task."""
  239. try:
  240. return self.get_task_meta(task_id)['children']
  241. except KeyError:
  242. pass
  243. def get_task_meta(self, task_id, cache=True):
  244. if cache:
  245. try:
  246. return self._cache[task_id]
  247. except KeyError:
  248. pass
  249. meta = self._get_task_meta_for(task_id)
  250. if cache and meta.get('status') == states.SUCCESS:
  251. self._cache[task_id] = meta
  252. return meta
  253. def reload_task_result(self, task_id):
  254. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  255. def reload_group_result(self, group_id):
  256. self._cache[group_id] = self.get_group_meta(group_id,
  257. cache=False)
  258. def get_group_meta(self, group_id, cache=True):
  259. if cache:
  260. try:
  261. return self._cache[group_id]
  262. except KeyError:
  263. pass
  264. meta = self._restore_group(group_id)
  265. if cache and meta is not None:
  266. self._cache[group_id] = meta
  267. return meta
  268. def restore_group(self, group_id, cache=True):
  269. """Get the result for a group."""
  270. meta = self.get_group_meta(group_id, cache=cache)
  271. if meta:
  272. return meta['result']
  273. def save_group(self, group_id, result):
  274. """Store the result of an executed group."""
  275. return self._save_group(group_id, result)
  276. def delete_group(self, group_id):
  277. self._cache.pop(group_id, None)
  278. return self._delete_group(group_id)
  279. class KeyValueStoreBackend(BaseDictBackend):
  280. task_keyprefix = ensure_bytes('celery-task-meta-')
  281. group_keyprefix = ensure_bytes('celery-taskset-meta-')
  282. chord_keyprefix = ensure_bytes('chord-unlock-')
  283. implements_incr = False
  284. def get(self, key):
  285. raise NotImplementedError('Must implement the get method.')
  286. def mget(self, keys):
  287. raise NotImplementedError('Does not support get_many')
  288. def set(self, key, value):
  289. raise NotImplementedError('Must implement the set method.')
  290. def delete(self, key):
  291. raise NotImplementedError('Must implement the delete method')
  292. def incr(self, key):
  293. raise NotImplementedError('Does not implement incr')
  294. def expire(self, key, value):
  295. pass
  296. def get_key_for_task(self, task_id):
  297. """Get the cache key for a task by id."""
  298. return self.task_keyprefix + ensure_bytes(task_id)
  299. def get_key_for_group(self, group_id):
  300. """Get the cache key for a group by id."""
  301. return self.group_keyprefix + ensure_bytes(group_id)
  302. def get_key_for_chord(self, group_id):
  303. """Get the cache key for the chord waiting on group with given id."""
  304. return self.chord_keyprefix + ensure_bytes(group_id)
  305. def _strip_prefix(self, key):
  306. """Takes bytes, emits string."""
  307. key = ensure_bytes(key)
  308. for prefix in self.task_keyprefix, self.group_keyprefix:
  309. if key.startswith(prefix):
  310. return bytes_to_str(key[len(prefix):])
  311. return bytes_to_str(key)
  312. def _mget_to_results(self, values, keys):
  313. if hasattr(values, 'items'):
  314. # client returns dict so mapping preserved.
  315. return dict((self._strip_prefix(k), self.decode(v))
  316. for k, v in values.iteritems()
  317. if v is not None)
  318. else:
  319. # client returns list so need to recreate mapping.
  320. return dict((bytes_to_str(keys[i]), self.decode(value))
  321. for i, value in enumerate(values)
  322. if value is not None)
  323. def get_many(self, task_ids, timeout=None, interval=0.5):
  324. ids = set(task_ids)
  325. cached_ids = set()
  326. for task_id in ids:
  327. try:
  328. cached = self._cache[task_id]
  329. except KeyError:
  330. pass
  331. else:
  332. if cached['status'] in states.READY_STATES:
  333. yield bytes_to_str(task_id), cached
  334. cached_ids.add(task_id)
  335. ids ^= cached_ids
  336. iterations = 0
  337. while ids:
  338. keys = list(ids)
  339. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  340. for k in keys]), keys)
  341. self._cache.update(r)
  342. ids ^= set(map(bytes_to_str, r))
  343. for key, value in r.iteritems():
  344. yield bytes_to_str(key), value
  345. if timeout and iterations * interval >= timeout:
  346. raise TimeoutError('Operation timed out (%s)' % (timeout, ))
  347. time.sleep(interval) # don't busy loop.
  348. iterations += 1
  349. def _forget(self, task_id):
  350. self.delete(self.get_key_for_task(task_id))
  351. def _store_result(self, task_id, result, status, traceback=None):
  352. meta = {'status': status, 'result': result, 'traceback': traceback,
  353. 'children': self.current_task_children()}
  354. self.set(self.get_key_for_task(task_id), self.encode(meta))
  355. return result
  356. def _save_group(self, group_id, result):
  357. self.set(self.get_key_for_group(group_id),
  358. self.encode({'result': result.serializable()}))
  359. return result
  360. def _delete_group(self, group_id):
  361. self.delete(self.get_key_for_group(group_id))
  362. def _get_task_meta_for(self, task_id):
  363. """Get task metadata for a task by id."""
  364. meta = self.get(self.get_key_for_task(task_id))
  365. if not meta:
  366. return {'status': states.PENDING, 'result': None}
  367. return self.decode(meta)
  368. def _restore_group(self, group_id):
  369. """Get task metadata for a task by id."""
  370. meta = self.get(self.get_key_for_group(group_id))
  371. # previously this was always pickled, but later this
  372. # was extended to support other serializers, so the
  373. # structure is kind of weird.
  374. if meta:
  375. meta = self.decode(meta)
  376. result = meta['result']
  377. if isinstance(result, (list, tuple)):
  378. return {'result': from_serializable(result)}
  379. return meta
  380. def on_chord_apply(self, group_id, body, result=None, **kwargs):
  381. if self.implements_incr:
  382. self.app.GroupResult(group_id, result).save()
  383. else:
  384. self.fallback_chord_unlock(group_id, body, result, **kwargs)
  385. def on_chord_part_return(self, task, propagate=None):
  386. if not self.implements_incr:
  387. return
  388. from celery import subtask
  389. from celery.result import GroupResult
  390. app = self.app
  391. if propagate is None:
  392. propagate = self.app.conf.CELERY_CHORD_PROPAGATES
  393. gid = task.request.group
  394. if not gid:
  395. return
  396. key = self.get_key_for_chord(gid)
  397. deps = GroupResult.restore(gid, backend=task.backend)
  398. val = self.incr(key)
  399. if val >= len(deps):
  400. j = deps.join_native if deps.supports_native_join else deps.join
  401. callback = subtask(task.request.chord)
  402. try:
  403. ret = j(propagate=propagate)
  404. except Exception, exc:
  405. try:
  406. culprit = deps._failed_join_report().next()
  407. reason = 'Dependency %s raised %r' % (culprit.id, exc)
  408. except StopIteration:
  409. reason = repr(exc)
  410. app._tasks[callback.task].backend.fail_from_current_stack(
  411. callback.id, exc=ChordError(reason),
  412. )
  413. else:
  414. try:
  415. callback.delay(ret)
  416. except Exception, exc:
  417. app._tasks[callback.task].backend.fail_from_current_stack(
  418. callback.id,
  419. exc=ChordError('Callback error: %r' % (exc, )),
  420. )
  421. finally:
  422. deps.delete()
  423. self.client.delete(key)
  424. else:
  425. self.expire(key, 86400)
  426. class DisabledBackend(BaseBackend):
  427. _cache = {} # need this attribute to reset cache in tests.
  428. def store_result(self, *args, **kwargs):
  429. pass
  430. def _is_disabled(self, *args, **kwargs):
  431. raise NotImplementedError(
  432. 'No result backend configured. '
  433. 'Please see the documentation for more information.')
  434. wait_for = get_status = get_result = get_traceback = _is_disabled