base.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.backends.base
  4. ~~~~~~~~~~~~~~~~~~~~
  5. Result backend base classes.
  6. - :class:`BaseBackend` defines the interface.
  7. - :class:`KeyValueStoreBackend` is a common base class
  8. using K/V semantics like _get and _put.
  9. """
  10. from __future__ import absolute_import
  11. import time
  12. import sys
  13. from datetime import timedelta
  14. from itertools import imap
  15. from kombu import serialization
  16. from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
  17. from celery import states
  18. from celery.app import current_task
  19. from celery.datastructures import LRUCache
  20. from celery.exceptions import TimeoutError, TaskRevokedError
  21. from celery.result import from_serializable, GroupResult
  22. from celery.utils import timeutils
  23. from celery.utils.serialization import (
  24. get_pickled_exception,
  25. get_pickleable_exception,
  26. create_exception_cls,
  27. )
  28. EXCEPTION_ABLE_CODECS = frozenset(['pickle', 'yaml'])
  29. is_py3k = sys.version_info >= (3, 0)
  30. def unpickle_backend(cls, args, kwargs):
  31. """Returns an unpickled backend."""
  32. return cls(*args, **kwargs)
  33. class BaseBackend(object):
  34. READY_STATES = states.READY_STATES
  35. UNREADY_STATES = states.UNREADY_STATES
  36. EXCEPTION_STATES = states.EXCEPTION_STATES
  37. TimeoutError = TimeoutError
  38. #: Time to sleep between polling each individual item
  39. #: in `ResultSet.iterate`. as opposed to the `interval`
  40. #: argument which is for each pass.
  41. subpolling_interval = None
  42. #: If true the backend must implement :meth:`get_many`.
  43. supports_native_join = False
  44. def __init__(self, app=None, serializer=None, max_cached_results=None,
  45. **kwargs):
  46. from celery.app import app_or_default
  47. self.app = app_or_default(app)
  48. self.serializer = serializer or self.app.conf.CELERY_RESULT_SERIALIZER
  49. (self.content_type,
  50. self.content_encoding,
  51. self.encoder) = serialization.registry._encoders[self.serializer]
  52. self._cache = LRUCache(limit=max_cached_results or
  53. self.app.conf.CELERY_MAX_CACHED_RESULTS)
  54. def mark_as_started(self, task_id, **meta):
  55. """Mark a task as started"""
  56. return self.store_result(task_id, meta, status=states.STARTED)
  57. def mark_as_done(self, task_id, result):
  58. """Mark task as successfully executed."""
  59. return self.store_result(task_id, result, status=states.SUCCESS)
  60. def mark_as_failure(self, task_id, exc, traceback=None):
  61. """Mark task as executed with failure. Stores the execption."""
  62. return self.store_result(task_id, exc, status=states.FAILURE,
  63. traceback=traceback)
  64. def mark_as_retry(self, task_id, exc, traceback=None):
  65. """Mark task as being retries. Stores the current
  66. exception (if any)."""
  67. return self.store_result(task_id, exc, status=states.RETRY,
  68. traceback=traceback)
  69. def mark_as_revoked(self, task_id, reason=''):
  70. return self.store_result(task_id, TaskRevokedError(reason),
  71. status=states.REVOKED, traceback=None)
  72. def prepare_exception(self, exc):
  73. """Prepare exception for serialization."""
  74. if self.serializer in EXCEPTION_ABLE_CODECS:
  75. return get_pickleable_exception(exc)
  76. return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
  77. def exception_to_python(self, exc):
  78. """Convert serialized exception to Python exception."""
  79. if self.serializer in EXCEPTION_ABLE_CODECS:
  80. return get_pickled_exception(exc)
  81. return create_exception_cls(from_utf8(exc['exc_type']),
  82. sys.modules[__name__])(exc['exc_message'])
  83. def prepare_value(self, result):
  84. """Prepare value for storage."""
  85. if isinstance(result, GroupResult):
  86. return result.serializable()
  87. return result
  88. def encode(self, data):
  89. _, _, payload = serialization.encode(data, serializer=self.serializer)
  90. return payload
  91. def decode(self, payload):
  92. payload = is_py3k and payload or str(payload)
  93. return serialization.decode(payload,
  94. content_type=self.content_type,
  95. content_encoding=self.content_encoding)
  96. def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
  97. """Wait for task and return its result.
  98. If the task raises an exception, this exception
  99. will be re-raised by :func:`wait_for`.
  100. If `timeout` is not :const:`None`, this raises the
  101. :class:`celery.exceptions.TimeoutError` exception if the operation
  102. takes longer than `timeout` seconds.
  103. """
  104. time_elapsed = 0.0
  105. while 1:
  106. status = self.get_status(task_id)
  107. if status == states.SUCCESS:
  108. return self.get_result(task_id)
  109. elif status in states.PROPAGATE_STATES:
  110. result = self.get_result(task_id)
  111. if propagate:
  112. raise result
  113. return result
  114. # avoid hammering the CPU checking status.
  115. time.sleep(interval)
  116. time_elapsed += interval
  117. if timeout and time_elapsed >= timeout:
  118. raise TimeoutError('The operation timed out.')
  119. def prepare_expires(self, value, type=None):
  120. if value is None:
  121. value = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  122. if isinstance(value, timedelta):
  123. value = timeutils.timedelta_seconds(value)
  124. if value is not None and type:
  125. return type(value)
  126. return value
  127. def encode_result(self, result, status):
  128. if status in self.EXCEPTION_STATES and isinstance(result, Exception):
  129. return self.prepare_exception(result)
  130. else:
  131. return self.prepare_value(result)
  132. def store_result(self, task_id, result, status, traceback=None, **kwargs):
  133. """Update task state and result."""
  134. result = self.encode_result(result, status)
  135. self._store_result(task_id, result, status, traceback, **kwargs)
  136. return result
  137. def forget(self, task_id):
  138. self._cache.pop(task_id, None)
  139. self._forget(task_id)
  140. def _forget(self, task_id):
  141. raise NotImplementedError('backend does not implement forget.')
  142. def get_status(self, task_id):
  143. """Get the status of a task."""
  144. return self.get_task_meta(task_id)['status']
  145. def get_traceback(self, task_id):
  146. """Get the traceback for a failed task."""
  147. return self.get_task_meta(task_id).get('traceback')
  148. def get_result(self, task_id):
  149. """Get the result of a task."""
  150. meta = self.get_task_meta(task_id)
  151. if meta['status'] in self.EXCEPTION_STATES:
  152. return self.exception_to_python(meta['result'])
  153. else:
  154. return meta['result']
  155. def get_children(self, task_id):
  156. """Get the list of subtasks sent by a task."""
  157. try:
  158. return self.get_task_meta(task_id)['children']
  159. except KeyError:
  160. pass
  161. def get_task_meta(self, task_id, cache=True):
  162. if cache:
  163. try:
  164. return self._cache[task_id]
  165. except KeyError:
  166. pass
  167. meta = self._get_task_meta_for(task_id)
  168. if cache and meta.get('status') == states.SUCCESS:
  169. self._cache[task_id] = meta
  170. return meta
  171. def reload_task_result(self, task_id):
  172. """Reload task result, even if it has been previously fetched."""
  173. self._cache[task_id] = self.get_task_meta(task_id, cache=False)
  174. def reload_group_result(self, group_id):
  175. """Reload group result, even if it has been previously fetched."""
  176. self._cache[group_id] = self.get_group_meta(group_id, cache=False)
  177. def get_group_meta(self, group_id, cache=True):
  178. if cache:
  179. try:
  180. return self._cache[group_id]
  181. except KeyError:
  182. pass
  183. meta = self._restore_group(group_id)
  184. if cache and meta is not None:
  185. self._cache[group_id] = meta
  186. return meta
  187. def restore_group(self, group_id, cache=True):
  188. """Get the result for a group."""
  189. meta = self.get_group_meta(group_id, cache=cache)
  190. if meta:
  191. return meta['result']
  192. def save_group(self, group_id, result):
  193. """Store the result of an executed group."""
  194. return self._save_group(group_id, result)
  195. def delete_group(self, group_id):
  196. self._cache.pop(group_id, None)
  197. return self._delete_group(group_id)
  198. def cleanup(self):
  199. """Backend cleanup. Is run by
  200. :class:`celery.task.DeleteExpiredTaskMetaTask`."""
  201. pass
  202. def process_cleanup(self):
  203. """Cleanup actions to do at the end of a task worker process."""
  204. pass
  205. def on_chord_part_return(self, task, propagate=False):
  206. pass
  207. def fallback_chord_unlock(self, group_id, body, result=None, **kwargs):
  208. kwargs['result'] = [r.id for r in result]
  209. self.app.tasks['celery.chord_unlock'].apply_async((group_id, body, ),
  210. kwargs, countdown=1)
  211. on_chord_apply = fallback_chord_unlock
  212. def current_task_children(self):
  213. current = current_task()
  214. if current:
  215. return [r.serializable() for r in current.request.children]
  216. def __reduce__(self, args=(), kwargs={}):
  217. return (unpickle_backend, (self.__class__, args, kwargs))
  218. BaseDictBackend = BaseBackend # XXX compat
  219. class KeyValueStoreBackend(BaseBackend):
  220. task_keyprefix = ensure_bytes('celery-task-meta-')
  221. group_keyprefix = ensure_bytes('celery-taskset-meta-')
  222. chord_keyprefix = ensure_bytes('chord-unlock-')
  223. implements_incr = False
  224. def get(self, key):
  225. raise NotImplementedError('Must implement the get method.')
  226. def mget(self, keys):
  227. raise NotImplementedError('Does not support get_many')
  228. def set(self, key, value):
  229. raise NotImplementedError('Must implement the set method.')
  230. def delete(self, key):
  231. raise NotImplementedError('Must implement the delete method')
  232. def incr(self, key):
  233. raise NotImplementedError('Does not implement incr')
  234. def expire(self, key, value):
  235. pass
  236. def get_key_for_task(self, task_id):
  237. """Get the cache key for a task by id."""
  238. return self.task_keyprefix + ensure_bytes(task_id)
  239. def get_key_for_group(self, group_id):
  240. """Get the cache key for a group by id."""
  241. return self.group_keyprefix + ensure_bytes(group_id)
  242. def get_key_for_chord(self, group_id):
  243. """Get the cache key for the chord waiting on group with given id."""
  244. return self.chord_keyprefix + ensure_bytes(group_id)
  245. def _strip_prefix(self, key):
  246. """Takes bytes, emits string."""
  247. key = ensure_bytes(key)
  248. for prefix in self.task_keyprefix, self.group_keyprefix:
  249. if key.startswith(prefix):
  250. return bytes_to_str(key[len(prefix):])
  251. return bytes_to_str(key)
  252. def _mget_to_results(self, values, keys):
  253. if hasattr(values, 'items'):
  254. # client returns dict so mapping preserved.
  255. return dict((self._strip_prefix(k), self.decode(v))
  256. for k, v in values.iteritems()
  257. if v is not None)
  258. else:
  259. # client returns list so need to recreate mapping.
  260. return dict((bytes_to_str(keys[i]), self.decode(value))
  261. for i, value in enumerate(values)
  262. if value is not None)
  263. def get_many(self, task_ids, timeout=None, interval=0.5):
  264. ids = set(task_ids)
  265. cached_ids = set()
  266. for task_id in ids:
  267. try:
  268. cached = self._cache[task_id]
  269. except KeyError:
  270. pass
  271. else:
  272. if cached['status'] in states.READY_STATES:
  273. yield bytes_to_str(task_id), cached
  274. cached_ids.add(task_id)
  275. ids ^= cached_ids
  276. iterations = 0
  277. while ids:
  278. keys = list(ids)
  279. r = self._mget_to_results(self.mget([self.get_key_for_task(k)
  280. for k in keys]), keys)
  281. self._cache.update(r)
  282. ids ^= set(imap(bytes_to_str, r))
  283. for key, value in r.iteritems():
  284. yield bytes_to_str(key), value
  285. if timeout and iterations * interval >= timeout:
  286. raise TimeoutError('Operation timed out ({0})'.format(timeout))
  287. time.sleep(interval) # don't busy loop.
  288. iterations += 1
  289. def _forget(self, task_id):
  290. self.delete(self.get_key_for_task(task_id))
  291. def _store_result(self, task_id, result, status, traceback=None):
  292. meta = {'status': status, 'result': result, 'traceback': traceback,
  293. 'children': self.current_task_children()}
  294. self.set(self.get_key_for_task(task_id), self.encode(meta))
  295. return result
  296. def _save_group(self, group_id, result):
  297. self.set(self.get_key_for_group(group_id),
  298. self.encode({'result': result.serializable()}))
  299. return result
  300. def _delete_group(self, group_id):
  301. self.delete(self.get_key_for_group(group_id))
  302. def _get_task_meta_for(self, task_id):
  303. """Get task metadata for a task by id."""
  304. meta = self.get(self.get_key_for_task(task_id))
  305. if not meta:
  306. return {'status': states.PENDING, 'result': None}
  307. return self.decode(meta)
  308. def _restore_group(self, group_id):
  309. """Get task metadata for a task by id."""
  310. meta = self.get(self.get_key_for_group(group_id))
  311. # previously this was always pickled, but later this
  312. # was extended to support other serializers, so the
  313. # structure is kind of weird.
  314. if meta:
  315. meta = self.decode(meta)
  316. result = meta['result']
  317. if isinstance(result, (list, tuple)):
  318. return {'result': from_serializable(result)}
  319. return meta
  320. def on_chord_apply(self, group_id, body, result=None, **kwargs):
  321. if self.implements_incr:
  322. self.app.GroupResult(group_id, result).save()
  323. else:
  324. self.fallback_chord_unlock(group_id, body, result, **kwargs)
  325. def on_chord_part_return(self, task, propagate=False):
  326. if not self.implements_incr:
  327. return
  328. from celery import subtask
  329. from celery.result import GroupResult
  330. gid = task.request.group
  331. if not gid:
  332. return
  333. key = self.get_key_for_chord(gid)
  334. deps = GroupResult.restore(gid, backend=task.backend)
  335. val = self.incr(key)
  336. if val >= len(deps):
  337. subtask(task.request.chord).delay(deps.join(propagate=propagate))
  338. deps.delete()
  339. self.client.delete(key)
  340. else:
  341. self.expire(key, 86400)
  342. class DisabledBackend(BaseBackend):
  343. _cache = {} # need this attribute to reset cache in tests.
  344. def store_result(self, *args, **kwargs):
  345. pass
  346. def _is_disabled(self, *args, **kwargs):
  347. raise NotImplementedError('No result backend configured. '
  348. 'Please see the documentation for more information.')
  349. wait_for = get_status = get_result = get_traceback = _is_disabled