123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731 |
- # -*- coding: utf-8 -*-
- """Result backend base classes.
- - :class:`BaseBackend` defines the interface.
- - :class:`KeyValueStoreBackend` is a common base class
- using K/V semantics like _get and _put.
- """
- import sys
- import time
- from collections import namedtuple
- from datetime import timedelta
- from weakref import WeakValueDictionary
- from billiard.einfo import ExceptionInfo
- from kombu.serialization import (
- dumps, loads, prepare_accept_content,
- registry as serializer_registry,
- )
- from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
- from kombu.utils.url import maybe_sanitize_url
- from celery import states
- from celery import current_app, group, maybe_signature
- from celery.app import current_task
- from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
- from celery.result import (
- GroupResult, ResultBase, allow_join_result, result_from_tuple,
- )
- from celery.utils.collections import BufferMap
- from celery.utils.functional import LRUCache, arity_greater
- from celery.utils.log import get_logger
- from celery.utils.serialization import (
- get_pickled_exception,
- get_pickleable_exception,
- create_exception_cls,
- )
- __all__ = ['BaseBackend', 'KeyValueStoreBackend', 'DisabledBackend']
- EXCEPTION_ABLE_CODECS = frozenset({'pickle'})
- logger = get_logger(__name__)
- MESSAGE_BUFFER_MAX = 8192
- pending_results_t = namedtuple('pending_results_t', (
- 'concrete', 'weak',
- ))
- def unpickle_backend(cls, args, kwargs):
- """Return an unpickled backend."""
- return cls(*args, app=current_app._get_current_object(), **kwargs)
- class _nulldict(dict):
- def ignore(self, *a, **kw):
- pass
- __setitem__ = update = setdefault = ignore
- class Backend:
- READY_STATES = states.READY_STATES
- UNREADY_STATES = states.UNREADY_STATES
- EXCEPTION_STATES = states.EXCEPTION_STATES
- TimeoutError = TimeoutError
- #: Time to sleep between polling each individual item
- #: in `ResultSet.iterate`. as opposed to the `interval`
- #: argument which is for each pass.
- subpolling_interval = None
- #: If true the backend must implement :meth:`get_many`.
- supports_native_join = False
- #: If true the backend must automatically expire results.
- #: The daily backend_cleanup periodic task will not be triggered
- #: in this case.
- supports_autoexpire = False
- #: Set to true if the backend is peristent by default.
- persistent = True
- retry_policy = {
- 'max_retries': 20,
- 'interval_start': 0,
- 'interval_step': 1,
- 'interval_max': 1,
- }
- def __init__(self, app,
- serializer=None, max_cached_results=None, accept=None,
- expires=None, expires_type=None, url=None, **kwargs):
- self.app = app
- conf = self.app.conf
- self.serializer = serializer or conf.result_serializer
- (self.content_type,
- self.content_encoding,
- self.encoder) = serializer_registry._encoders[self.serializer]
- cmax = max_cached_results or conf.result_cache_max
- self._cache = _nulldict() if cmax == -1 else LRUCache(limit=cmax)
- self.expires = self.prepare_expires(expires, expires_type)
- self.accept = prepare_accept_content(
- conf.accept_content if accept is None else accept)
- self._pending_results = pending_results_t({}, WeakValueDictionary())
- self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
- self.url = url
- def as_uri(self, include_password=False):
- """Return the backend as an URI, sanitizing the password or not"""
- # when using maybe_sanitize_url(), "/" is added
- # we're stripping it for consistency
- if include_password:
- return self.url
- url = maybe_sanitize_url(self.url or '')
- return url[:-1] if url.endswith(':///') else url
- def mark_as_started(self, task_id, **meta):
- """Mark a task as started"""
- return self.store_result(task_id, meta, states.STARTED)
- def mark_as_done(self, task_id, result,
- request=None, store_result=True, state=states.SUCCESS):
- """Mark task as successfully executed."""
- if store_result:
- self.store_result(task_id, result, state, request=request)
- if request and request.chord:
- self.on_chord_part_return(request, state, result)
- def mark_as_failure(self, task_id, exc,
- traceback=None, request=None,
- store_result=True, call_errbacks=True,
- state=states.FAILURE):
- """Mark task as executed with failure. Stores the exception."""
- if store_result:
- self.store_result(task_id, exc, state,
- traceback=traceback, request=request)
- if request:
- if request.chord:
- self.on_chord_part_return(request, state, exc)
- if call_errbacks and request.errbacks:
- self._call_task_errbacks(request, exc, traceback)
- def _call_task_errbacks(self, request, exc, traceback):
- old_signature = []
- for errback in request.errbacks:
- errback = self.app.signature(errback)
- if arity_greater(errback.type.__header__, 1):
- errback(request, exc, traceback)
- else:
- old_signature.append(errback)
- if old_signature:
- # Previously errback was called as a task so we still
- # need to do so if the errback only takes a single task_id arg.
- task_id = request.id
- root_id = request.root_id or task_id
- group(old_signature, app=self.app).apply_async(
- (task_id,), parent_id=task_id, root_id=root_id
- )
- def mark_as_revoked(self, task_id, reason='',
- request=None, store_result=True, state=states.REVOKED):
- exc = TaskRevokedError(reason)
- if store_result:
- self.store_result(task_id, exc, state,
- traceback=None, request=request)
- if request and request.chord:
- self.on_chord_part_return(request, state, exc)
- def mark_as_retry(self, task_id, exc, traceback=None,
- request=None, store_result=True, state=states.RETRY):
- """Mark task as being retries. Stores the current
- exception (if any)."""
- return self.store_result(task_id, exc, state,
- traceback=traceback, request=request)
- def chord_error_from_stack(self, callback, exc=None):
- from celery import group
- app = self.app
- backend = app._tasks[callback.task].backend
- try:
- group(
- [app.signature(errback)
- for errback in callback.options.get('link_error') or []],
- app=app,
- ).apply_async((callback.id,))
- except Exception as eb_exc:
- return backend.fail_from_current_stack(callback.id, exc=eb_exc)
- else:
- return backend.fail_from_current_stack(callback.id, exc=exc)
- def fail_from_current_stack(self, task_id, exc=None):
- type_, real_exc, tb = sys.exc_info()
- try:
- exc = real_exc if exc is None else exc
- ei = ExceptionInfo((type_, exc, tb))
- self.mark_as_failure(task_id, exc, ei.traceback)
- return ei
- finally:
- del(tb)
- def prepare_exception(self, exc, serializer=None):
- """Prepare exception for serialization."""
- serializer = self.serializer if serializer is None else serializer
- if serializer in EXCEPTION_ABLE_CODECS:
- return get_pickleable_exception(exc)
- return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
- def exception_to_python(self, exc):
- """Convert serialized exception to Python exception."""
- if exc:
- if not isinstance(exc, BaseException):
- exc = create_exception_cls(
- from_utf8(exc['exc_type']), __name__)(exc['exc_message'])
- if self.serializer in EXCEPTION_ABLE_CODECS:
- exc = get_pickled_exception(exc)
- return exc
- def prepare_value(self, result):
- """Prepare value for storage."""
- if self.serializer != 'pickle' and isinstance(result, ResultBase):
- return result.as_tuple()
- return result
- def encode(self, data):
- _, _, payload = dumps(data, serializer=self.serializer)
- return payload
- def meta_from_decoded(self, meta):
- if meta['status'] in self.EXCEPTION_STATES:
- meta['result'] = self.exception_to_python(meta['result'])
- return meta
- def decode_result(self, payload):
- return self.meta_from_decoded(self.decode(payload))
- def decode(self, payload):
- return loads(payload,
- content_type=self.content_type,
- content_encoding=self.content_encoding,
- accept=self.accept)
- def prepare_expires(self, value, type=None):
- if value is None:
- value = self.app.conf.result_expires
- if isinstance(value, timedelta):
- value = value.total_seconds()
- if value is not None and type:
- return type(value)
- return value
- def prepare_persistent(self, enabled=None):
- if enabled is not None:
- return enabled
- p = self.app.conf.result_persistent
- return self.persistent if p is None else p
- def encode_result(self, result, state):
- if state in self.EXCEPTION_STATES and isinstance(result, Exception):
- return self.prepare_exception(result)
- else:
- return self.prepare_value(result)
- def is_cached(self, task_id):
- return task_id in self._cache
- def store_result(self, task_id, result, state,
- traceback=None, request=None, **kwargs):
- """Update task state and result."""
- result = self.encode_result(result, state)
- self._store_result(task_id, result, state, traceback,
- request=request, **kwargs)
- return result
- def forget(self, task_id):
- self._cache.pop(task_id, None)
- self._forget(task_id)
- def _forget(self, task_id):
- raise NotImplementedError('backend does not implement forget.')
- def get_state(self, task_id):
- """Get the state of a task."""
- return self.get_task_meta(task_id)['status']
- get_status = get_state # XXX compat
- def get_traceback(self, task_id):
- """Get the traceback for a failed task."""
- return self.get_task_meta(task_id).get('traceback')
- def get_result(self, task_id):
- """Get the result of a task."""
- return self.get_task_meta(task_id).get('result')
- def get_children(self, task_id):
- """Get the list of subtasks sent by a task."""
- try:
- return self.get_task_meta(task_id)['children']
- except KeyError:
- pass
- def _ensure_not_eager(self):
- if self.app.conf.task_always_eager:
- raise RuntimeError(
- "Cannot retrieve result with task_always_eager enabled")
- def get_task_meta(self, task_id, cache=True):
- self._ensure_not_eager()
- if cache:
- try:
- return self._cache[task_id]
- except KeyError:
- pass
- meta = self._get_task_meta_for(task_id)
- if cache and meta.get('status') == states.SUCCESS:
- self._cache[task_id] = meta
- return meta
- def reload_task_result(self, task_id):
- """Reload task result, even if it has been previously fetched."""
- self._cache[task_id] = self.get_task_meta(task_id, cache=False)
- def reload_group_result(self, group_id):
- """Reload group result, even if it has been previously fetched."""
- self._cache[group_id] = self.get_group_meta(group_id, cache=False)
- def get_group_meta(self, group_id, cache=True):
- self._ensure_not_eager()
- if cache:
- try:
- return self._cache[group_id]
- except KeyError:
- pass
- meta = self._restore_group(group_id)
- if cache and meta is not None:
- self._cache[group_id] = meta
- return meta
- def restore_group(self, group_id, cache=True):
- """Get the result for a group."""
- meta = self.get_group_meta(group_id, cache=cache)
- if meta:
- return meta['result']
- def save_group(self, group_id, result):
- """Store the result of an executed group."""
- return self._save_group(group_id, result)
- def delete_group(self, group_id):
- self._cache.pop(group_id, None)
- return self._delete_group(group_id)
- def cleanup(self):
- """Backend cleanup. Is run by
- :class:`celery.task.DeleteExpiredTaskMetaTask`."""
- pass
- def process_cleanup(self):
- """Cleanup actions to do at the end of a task worker process."""
- pass
- def on_task_call(self, producer, task_id):
- return {}
- def add_to_chord(self, chord_id, result):
- raise NotImplementedError('Backend does not support add_to_chord')
- def on_chord_part_return(self, request, state, result, **kwargs):
- pass
- def fallback_chord_unlock(self, group_id, body, result=None,
- countdown=1, **kwargs):
- kwargs['result'] = [r.as_tuple() for r in result]
- self.app.tasks['celery.chord_unlock'].apply_async(
- (group_id, body,), kwargs, countdown=countdown,
- )
- def apply_chord(self, header, partial_args, group_id, body,
- options={}, **kwargs):
- fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
- result = header(*partial_args, task_id=group_id, **fixed_options or {})
- self.fallback_chord_unlock(group_id, body, **kwargs)
- return result
- def current_task_children(self, request=None):
- request = request or getattr(current_task(), 'request', None)
- if request:
- return [r.as_tuple() for r in getattr(request, 'children', [])]
- def __reduce__(self, args=(), kwargs={}):
- return (unpickle_backend, (self.__class__, args, kwargs))
- class SyncBackendMixin:
- def iter_native(self, result, timeout=None, interval=0.5, no_ack=True,
- on_message=None, on_interval=None):
- self._ensure_not_eager()
- results = result.results
- if not results:
- return iter([])
- return self.get_many(
- {r.id for r in results},
- timeout=timeout, interval=interval, no_ack=no_ack,
- on_message=on_message, on_interval=on_interval,
- )
- def wait_for_pending(self, result, timeout=None, interval=0.5,
- no_ack=True, on_interval=None, callback=None,
- propagate=True):
- self._ensure_not_eager()
- meta = self.wait_for(
- result.id, timeout=timeout,
- interval=interval,
- on_interval=on_interval,
- no_ack=no_ack,
- )
- if meta:
- result._maybe_set_cache(meta)
- return result.maybe_throw(propagate=propagate, callback=callback)
- def wait_for(self, task_id,
- timeout=None, interval=0.5, no_ack=True, on_interval=None):
- """Wait for task and return its result.
- If the task raises an exception, this exception
- will be re-raised by :func:`wait_for`.
- Raises:
- celery.exceptions.TimeoutError:
- If `timeout` is not :const:`None`, and the operation
- takes longer than `timeout` seconds.
- """
- self._ensure_not_eager()
- time_elapsed = 0.0
- while 1:
- meta = self.get_task_meta(task_id)
- if meta['status'] in states.READY_STATES:
- return meta
- if on_interval:
- on_interval()
- # avoid hammering the CPU checking status.
- time.sleep(interval)
- time_elapsed += interval
- if timeout and time_elapsed >= timeout:
- raise TimeoutError('The operation timed out.')
- def add_pending_result(self, result, weak=False):
- return result
- def remove_pending_result(self, result):
- return result
- @property
- def is_async(self):
- return False
- class BaseBackend(Backend, SyncBackendMixin):
- pass
- BaseDictBackend = BaseBackend # XXX compat
- class BaseKeyValueStoreBackend(Backend):
- key_t = ensure_bytes
- task_keyprefix = 'celery-task-meta-'
- group_keyprefix = 'celery-taskset-meta-'
- chord_keyprefix = 'chord-unlock-'
- implements_incr = False
- def __init__(self, *args, **kwargs):
- if hasattr(self.key_t, '__func__'): # pragma: no cover
- self.key_t = self.key_t.__func__ # remove binding
- self._encode_prefixes()
- super(BaseKeyValueStoreBackend, self).__init__(*args, **kwargs)
- if self.implements_incr:
- self.apply_chord = self._apply_chord_incr
- def _encode_prefixes(self):
- self.task_keyprefix = self.key_t(self.task_keyprefix)
- self.group_keyprefix = self.key_t(self.group_keyprefix)
- self.chord_keyprefix = self.key_t(self.chord_keyprefix)
- def get(self, key):
- raise NotImplementedError('Must implement the get method.')
- def mget(self, keys):
- raise NotImplementedError('Does not support get_many')
- def set(self, key, value):
- raise NotImplementedError('Must implement the set method.')
- def delete(self, key):
- raise NotImplementedError('Must implement the delete method')
- def incr(self, key):
- raise NotImplementedError('Does not implement incr')
- def expire(self, key, value):
- pass
- def get_key_for_task(self, task_id, key=''):
- """Get the cache key for a task by id."""
- key_t = self.key_t
- return key_t('').join([
- self.task_keyprefix, key_t(task_id), key_t(key),
- ])
- def get_key_for_group(self, group_id, key=''):
- """Get the cache key for a group by id."""
- key_t = self.key_t
- return key_t('').join([
- self.group_keyprefix, key_t(group_id), key_t(key),
- ])
- def get_key_for_chord(self, group_id, key=''):
- """Get the cache key for the chord waiting on group with given id."""
- key_t = self.key_t
- return key_t('').join([
- self.chord_keyprefix, key_t(group_id), key_t(key),
- ])
- def _strip_prefix(self, key):
- """Takes bytes, emits string."""
- key = self.key_t(key)
- for prefix in self.task_keyprefix, self.group_keyprefix:
- if key.startswith(prefix):
- return bytes_to_str(key[len(prefix):])
- return bytes_to_str(key)
- def _filter_ready(self, values, READY_STATES=states.READY_STATES):
- for k, v in values:
- if v is not None:
- v = self.decode_result(v)
- if v['status'] in READY_STATES:
- yield k, v
- def _mget_to_results(self, values, keys):
- if hasattr(values, 'items'):
- # client returns dict so mapping preserved.
- return {
- self._strip_prefix(k): v
- for k, v in self._filter_ready(values.items())
- }
- else:
- # client returns list so need to recreate mapping.
- return {
- bytes_to_str(keys[i]): v
- for i, v in self._filter_ready(enumerate(values))
- }
- def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
- on_message=None, on_interval=None, max_iterations=None,
- READY_STATES=states.READY_STATES):
- interval = 0.5 if interval is None else interval
- ids = task_ids if isinstance(task_ids, set) else set(task_ids)
- cached_ids = set()
- cache = self._cache
- for task_id in ids:
- try:
- cached = cache[task_id]
- except KeyError:
- pass
- else:
- if cached['status'] in READY_STATES:
- yield bytes_to_str(task_id), cached
- cached_ids.add(task_id)
- ids.difference_update(cached_ids)
- iterations = 0
- while ids:
- keys = list(ids)
- r = self._mget_to_results(self.mget([self.get_key_for_task(k)
- for k in keys]), keys)
- cache.update(r)
- ids.difference_update({bytes_to_str(v) for v in r})
- for key, value in r.items():
- if on_message is not None:
- on_message(value)
- yield bytes_to_str(key), value
- if timeout and iterations * interval >= timeout:
- raise TimeoutError('Operation timed out ({0})'.format(timeout))
- if on_interval:
- on_interval()
- time.sleep(interval) # don't busy loop.
- iterations += 1
- if max_iterations and iterations >= max_iterations:
- break
- def _forget(self, task_id):
- self.delete(self.get_key_for_task(task_id))
- def _store_result(self, task_id, result, state,
- traceback=None, request=None, **kwargs):
- meta = {'status': state, 'result': result, 'traceback': traceback,
- 'children': self.current_task_children(request),
- 'task_id': bytes_to_str(task_id)}
- self.set(self.get_key_for_task(task_id), self.encode(meta))
- return result
- def _save_group(self, group_id, result):
- self.set(self.get_key_for_group(group_id),
- self.encode({'result': result.as_tuple()}))
- return result
- def _delete_group(self, group_id):
- self.delete(self.get_key_for_group(group_id))
- def _get_task_meta_for(self, task_id):
- """Get task meta-data for a task by id."""
- meta = self.get(self.get_key_for_task(task_id))
- if not meta:
- return {'status': states.PENDING, 'result': None}
- return self.decode_result(meta)
- def _restore_group(self, group_id):
- """Get task meta-data for a task by id."""
- meta = self.get(self.get_key_for_group(group_id))
- # previously this was always pickled, but later this
- # was extended to support other serializers, so the
- # structure is kind of weird.
- if meta:
- meta = self.decode(meta)
- result = meta['result']
- meta['result'] = result_from_tuple(result, self.app)
- return meta
- def _apply_chord_incr(self, header, partial_args, group_id, body,
- result=None, options={}, **kwargs):
- self.save_group(group_id, self.app.GroupResult(group_id, result))
- fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
- return header(*partial_args, task_id=group_id, **fixed_options or {})
- def on_chord_part_return(self, request, state, result, **kwargs):
- if not self.implements_incr:
- return
- app = self.app
- gid = request.group
- if not gid:
- return
- key = self.get_key_for_chord(gid)
- try:
- deps = GroupResult.restore(gid, backend=self)
- except Exception as exc:
- callback = maybe_signature(request.chord, app=app)
- logger.error('Chord %r raised: %r', gid, exc, exc_info=1)
- return self.chord_error_from_stack(
- callback,
- ChordError('Cannot restore group: {0!r}'.format(exc)),
- )
- if deps is None:
- try:
- raise ValueError(gid)
- except ValueError as exc:
- callback = maybe_signature(request.chord, app=app)
- logger.error('Chord callback %r raised: %r', gid, exc,
- exc_info=1)
- return self.chord_error_from_stack(
- callback,
- ChordError('GroupResult {0} no longer exists'.format(gid)),
- )
- val = self.incr(key)
- size = len(deps)
- if val > size: # pragma: no cover
- logger.warning('Chord counter incremented too many times for %r',
- gid)
- elif val == size:
- callback = maybe_signature(request.chord, app=app)
- j = deps.join_native if deps.supports_native_join else deps.join
- try:
- with allow_join_result():
- ret = j(timeout=3.0, propagate=True)
- except Exception as exc:
- try:
- culprit = next(deps._failed_join_report())
- reason = 'Dependency {0.id} raised {1!r}'.format(
- culprit, exc,
- )
- except StopIteration:
- reason = repr(exc)
- logger.error('Chord %r raised: %r', gid, reason, exc_info=1)
- self.chord_error_from_stack(callback, ChordError(reason))
- else:
- try:
- callback.delay(ret)
- except Exception as exc:
- logger.error('Chord %r raised: %r', gid, exc, exc_info=1)
- self.chord_error_from_stack(
- callback,
- ChordError('Callback error: {0!r}'.format(exc)),
- )
- finally:
- deps.delete()
- self.client.delete(key)
- else:
- self.expire(key, 86400)
- class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
- pass
- class DisabledBackend(BaseBackend):
- _cache = {} # need this attribute to reset cache in tests.
- def store_result(self, *args, **kwargs):
- pass
- def _is_disabled(self, *args, **kwargs):
- raise NotImplementedError(
- 'No result backend configured. '
- 'Please see the documentation for more information.')
- def as_uri(self, *args, **kwargs):
- return 'disabled://'
- get_state = get_status = get_result = get_traceback = _is_disabled
- wait_for = get_many = _is_disabled
|