|
@@ -9,8 +9,12 @@
|
|
|
import sys
|
|
|
import time
|
|
|
|
|
|
-from collections import namedtuple
|
|
|
from datetime import timedelta
|
|
|
+from numbers import Number
|
|
|
+from typing import (
|
|
|
+ Any, AnyStr, Callable, Dict, Iterable, Iterator,
|
|
|
+ Mapping, MutableMapping, NamedTuple, Optional, Set, Sequence, Tuple,
|
|
|
+)
|
|
|
from weakref import WeakValueDictionary
|
|
|
|
|
|
from billiard.einfo import ExceptionInfo
|
|
@@ -18,6 +22,7 @@ from kombu.serialization import (
|
|
|
dumps, loads, prepare_accept_content,
|
|
|
registry as serializer_registry,
|
|
|
)
|
|
|
+from kombu.types import ProducerT
|
|
|
from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
|
|
|
from kombu.utils.url import maybe_sanitize_url
|
|
|
|
|
@@ -30,6 +35,7 @@ from celery.exceptions import (
|
|
|
from celery.result import (
|
|
|
GroupResult, ResultBase, allow_join_result, result_from_tuple,
|
|
|
)
|
|
|
+from celery.types import AppT, BackendT, ResultT, RequestT, SignatureT
|
|
|
from celery.utils.collections import BufferMap
|
|
|
from celery.utils.functional import LRUCache, arity_greater
|
|
|
from celery.utils.log import get_logger
|
|
@@ -47,10 +53,6 @@ logger = get_logger(__name__)
|
|
|
|
|
|
MESSAGE_BUFFER_MAX = 8192
|
|
|
|
|
|
-pending_results_t = namedtuple('pending_results_t', (
|
|
|
- 'concrete', 'weak',
|
|
|
-))
|
|
|
-
|
|
|
E_NO_BACKEND = """
|
|
|
No result backend is configured.
|
|
|
Please see the documentation for more information.
|
|
@@ -66,15 +68,22 @@ Result backends that supports chords: Redis, Database, Memcached, and more.
|
|
|
"""
|
|
|
|
|
|
|
|
|
-def unpickle_backend(cls, args, kwargs):
|
|
|
+class pending_results_t(NamedTuple):
|
|
|
+ """Tuple of concrete and weak references to pending results."""
|
|
|
+
|
|
|
+ concrete: Mapping
|
|
|
+ weak: WeakValueDictionary
|
|
|
+
|
|
|
+
|
|
|
+def unpickle_backend(cls: type, args: Tuple, kwargs: Dict) -> BackendT:
|
|
|
"""Return an unpickled backend."""
|
|
|
return cls(*args, app=current_app._get_current_object(), **kwargs)
|
|
|
|
|
|
|
|
|
class _nulldict(dict):
|
|
|
|
|
|
- def ignore(self, *a, **kw):
|
|
|
- pass
|
|
|
+ def ignore(self, *a, **kw) -> None:
|
|
|
+ ...
|
|
|
__setitem__ = update = setdefault = ignore
|
|
|
|
|
|
|
|
@@ -89,7 +98,7 @@ class Backend:
|
|
|
#: Time to sleep between polling each individual item
|
|
|
#: in `ResultSet.iterate`. as opposed to the `interval`
|
|
|
#: argument which is for each pass.
|
|
|
- subpolling_interval = None
|
|
|
+ subpolling_interval: float = None
|
|
|
|
|
|
#: If true the backend must implement :meth:`get_many`.
|
|
|
supports_native_join = False
|
|
@@ -109,9 +118,15 @@ class Backend:
|
|
|
'interval_max': 1,
|
|
|
}
|
|
|
|
|
|
- def __init__(self, app,
|
|
|
- serializer=None, max_cached_results=None, accept=None,
|
|
|
- expires=None, expires_type=None, url=None, **kwargs):
|
|
|
+ def __init__(self, app: AppT,
|
|
|
+ *,
|
|
|
+ serializer: str = None,
|
|
|
+ max_cached_results: int = None,
|
|
|
+ accept: Set[str] = None,
|
|
|
+ expires: float = None,
|
|
|
+ expires_type: Callable = None,
|
|
|
+ url: str = None,
|
|
|
+ **kwargs) -> None:
|
|
|
self.app = app
|
|
|
conf = self.app.conf
|
|
|
self.serializer = serializer or conf.result_serializer
|
|
@@ -128,7 +143,7 @@ class Backend:
|
|
|
self._pending_messages = BufferMap(MESSAGE_BUFFER_MAX)
|
|
|
self.url = url
|
|
|
|
|
|
- def as_uri(self, include_password=False):
|
|
|
+ def as_uri(self, include_password: bool = False) -> str:
|
|
|
"""Return the backend as an URI, sanitizing the password or not."""
|
|
|
# when using maybe_sanitize_url(), "/" is added
|
|
|
# we're stripping it for consistency
|
|
@@ -137,33 +152,42 @@ class Backend:
|
|
|
url = maybe_sanitize_url(self.url or '')
|
|
|
return url[:-1] if url.endswith(':///') else url
|
|
|
|
|
|
- def mark_as_started(self, task_id, **meta):
|
|
|
+ async def mark_as_started(self, task_id: str, **meta) -> Any:
|
|
|
"""Mark a task as started."""
|
|
|
- return self.store_result(task_id, meta, states.STARTED)
|
|
|
+ return await self.store_result(task_id, meta, states.STARTED)
|
|
|
|
|
|
- def mark_as_done(self, task_id, result,
|
|
|
- request=None, store_result=True, state=states.SUCCESS):
|
|
|
+ async def mark_as_done(self, task_id: str, result: Any,
|
|
|
+ *,
|
|
|
+ request: RequestT = None,
|
|
|
+ store_result: bool = True,
|
|
|
+ state: str = states.SUCCESS) -> None:
|
|
|
"""Mark task as successfully executed."""
|
|
|
if store_result:
|
|
|
- self.store_result(task_id, result, state, request=request)
|
|
|
+ await self.store_result(task_id, result, state, request=request)
|
|
|
if request and request.chord:
|
|
|
- self.on_chord_part_return(request, state, result)
|
|
|
-
|
|
|
- def mark_as_failure(self, task_id, exc,
|
|
|
- traceback=None, request=None,
|
|
|
- store_result=True, call_errbacks=True,
|
|
|
- state=states.FAILURE):
|
|
|
+ await self.on_chord_part_return(request, state, result)
|
|
|
+
|
|
|
+ async def mark_as_failure(self, task_id: str, exc: Exception,
|
|
|
+ *,
|
|
|
+ traceback: str = None,
|
|
|
+ request: RequestT = None,
|
|
|
+ store_result: bool = True,
|
|
|
+ call_errbacks: bool = True,
|
|
|
+ state: str = states.FAILURE) -> None:
|
|
|
"""Mark task as executed with failure."""
|
|
|
if store_result:
|
|
|
- self.store_result(task_id, exc, state,
|
|
|
- traceback=traceback, request=request)
|
|
|
+ await self.store_result(task_id, exc, state,
|
|
|
+ traceback=traceback, request=request)
|
|
|
if request:
|
|
|
if request.chord:
|
|
|
- self.on_chord_part_return(request, state, exc)
|
|
|
+ await self.on_chord_part_return(request, state, exc)
|
|
|
if call_errbacks and request.errbacks:
|
|
|
- self._call_task_errbacks(request, exc, traceback)
|
|
|
+ await self._call_task_errbacks(request, exc, traceback)
|
|
|
|
|
|
- def _call_task_errbacks(self, request, exc, traceback):
|
|
|
+ async def _call_task_errbacks(self,
|
|
|
+ request: RequestT,
|
|
|
+ exc: Exception,
|
|
|
+ traceback: str) -> None:
|
|
|
old_signature = []
|
|
|
for errback in request.errbacks:
|
|
|
errback = self.app.signature(errback)
|
|
@@ -176,30 +200,38 @@ class Backend:
|
|
|
# need to do so if the errback only takes a single task_id arg.
|
|
|
task_id = request.id
|
|
|
root_id = request.root_id or task_id
|
|
|
- group(old_signature, app=self.app).apply_async(
|
|
|
+ await group(old_signature, app=self.app).apply_async(
|
|
|
(task_id,), parent_id=task_id, root_id=root_id
|
|
|
)
|
|
|
|
|
|
- def mark_as_revoked(self, task_id, reason='',
|
|
|
- request=None, store_result=True, state=states.REVOKED):
|
|
|
+ async def mark_as_revoked(self, task_id: str, reason: str = '',
|
|
|
+ *,
|
|
|
+ request: RequestT = None,
|
|
|
+ store_result: bool = True,
|
|
|
+ state: str = states.REVOKED) -> None:
|
|
|
exc = TaskRevokedError(reason)
|
|
|
if store_result:
|
|
|
- self.store_result(task_id, exc, state,
|
|
|
- traceback=None, request=request)
|
|
|
+ await self.store_result(task_id, exc, state,
|
|
|
+ traceback=None, request=request)
|
|
|
if request and request.chord:
|
|
|
- self.on_chord_part_return(request, state, exc)
|
|
|
-
|
|
|
- def mark_as_retry(self, task_id, exc, traceback=None,
|
|
|
- request=None, store_result=True, state=states.RETRY):
|
|
|
+ await self.on_chord_part_return(request, state, exc)
|
|
|
+
|
|
|
+ async def mark_as_retry(self, task_id: str, exc: Exception,
|
|
|
+ *,
|
|
|
+ traceback: str = None,
|
|
|
+ request: RequestT = None,
|
|
|
+ store_result: bool = True,
|
|
|
+ state: str = states.RETRY) -> None:
|
|
|
"""Mark task as being retries.
|
|
|
|
|
|
Note:
|
|
|
Stores the current exception (if any).
|
|
|
"""
|
|
|
- return self.store_result(task_id, exc, state,
|
|
|
- traceback=traceback, request=request)
|
|
|
+ return await self.store_result(task_id, exc, state,
|
|
|
+ traceback=traceback, request=request)
|
|
|
|
|
|
- def chord_error_from_stack(self, callback, exc=None):
|
|
|
+ async def chord_error_from_stack(self, callback: Callable,
|
|
|
+ exc: Exception = None) -> ExceptionInfo:
|
|
|
# need below import for test for some crazy reason
|
|
|
from celery import group # pylint: disable
|
|
|
app = self.app
|
|
@@ -208,34 +240,37 @@ class Backend:
|
|
|
except KeyError:
|
|
|
backend = self
|
|
|
try:
|
|
|
- group(
|
|
|
+ await group(
|
|
|
[app.signature(errback)
|
|
|
for errback in callback.options.get('link_error') or []],
|
|
|
app=app,
|
|
|
).apply_async((callback.id,))
|
|
|
except Exception as eb_exc: # pylint: disable=broad-except
|
|
|
- return backend.fail_from_current_stack(callback.id, exc=eb_exc)
|
|
|
+ return await backend.fail_from_current_stack(
|
|
|
+ callback.id, exc=eb_exc)
|
|
|
else:
|
|
|
- return backend.fail_from_current_stack(callback.id, exc=exc)
|
|
|
+ return await backend.fail_from_current_stack(
|
|
|
+ callback.id, exc=exc)
|
|
|
|
|
|
- def fail_from_current_stack(self, task_id, exc=None):
|
|
|
+ async def fail_from_current_stack(self, task_id: str,
|
|
|
+ exc: Exception = None) -> ExceptionInfo:
|
|
|
type_, real_exc, tb = sys.exc_info()
|
|
|
try:
|
|
|
exc = real_exc if exc is None else exc
|
|
|
ei = ExceptionInfo((type_, exc, tb))
|
|
|
- self.mark_as_failure(task_id, exc, ei.traceback)
|
|
|
+ await self.mark_as_failure(task_id, exc, ei.traceback)
|
|
|
return ei
|
|
|
finally:
|
|
|
del tb
|
|
|
|
|
|
- def prepare_exception(self, exc, serializer=None):
|
|
|
+ def prepare_exception(self, exc: Exception, serializer: str = None) -> Any:
|
|
|
"""Prepare exception for serialization."""
|
|
|
serializer = self.serializer if serializer is None else serializer
|
|
|
if serializer in EXCEPTION_ABLE_CODECS:
|
|
|
return get_pickleable_exception(exc)
|
|
|
return {'exc_type': type(exc).__name__, 'exc_message': str(exc)}
|
|
|
|
|
|
- def exception_to_python(self, exc):
|
|
|
+ def exception_to_python(self, exc: Any) -> Exception:
|
|
|
"""Convert serialized exception to Python exception."""
|
|
|
if exc:
|
|
|
if not isinstance(exc, BaseException):
|
|
@@ -245,34 +280,35 @@ class Backend:
|
|
|
exc = get_pickled_exception(exc)
|
|
|
return exc
|
|
|
|
|
|
- def prepare_value(self, result):
|
|
|
+ def prepare_value(self, result: Any) -> Any:
|
|
|
"""Prepare value for storage."""
|
|
|
if self.serializer != 'pickle' and isinstance(result, ResultBase):
|
|
|
return result.as_tuple()
|
|
|
return result
|
|
|
|
|
|
- def encode(self, data):
|
|
|
+ def encode(self, data: Any) -> AnyStr:
|
|
|
_, _, payload = self._encode(data)
|
|
|
return payload
|
|
|
|
|
|
- def _encode(self, data):
|
|
|
+ def _encode(self, data: Any) -> AnyStr:
|
|
|
return dumps(data, serializer=self.serializer)
|
|
|
|
|
|
- def meta_from_decoded(self, meta):
|
|
|
+ def meta_from_decoded(self, meta: MutableMapping) -> MutableMapping:
|
|
|
if meta['status'] in self.EXCEPTION_STATES:
|
|
|
meta['result'] = self.exception_to_python(meta['result'])
|
|
|
return meta
|
|
|
|
|
|
- def decode_result(self, payload):
|
|
|
+ def decode_result(self, payload: AnyStr) -> Mapping:
|
|
|
return self.meta_from_decoded(self.decode(payload))
|
|
|
|
|
|
- def decode(self, payload):
|
|
|
+ def decode(self, payload: AnyStr) -> Mapping:
|
|
|
return loads(payload,
|
|
|
content_type=self.content_type,
|
|
|
content_encoding=self.content_encoding,
|
|
|
accept=self.accept)
|
|
|
|
|
|
- def prepare_expires(self, value, type=None):
|
|
|
+ def prepare_expires(self, value: Optional[Number],
|
|
|
+ type: Callable = None) -> Optional[Number]:
|
|
|
if value is None:
|
|
|
value = self.app.conf.result_expires
|
|
|
if isinstance(value, timedelta):
|
|
@@ -281,61 +317,63 @@ class Backend:
|
|
|
return type(value)
|
|
|
return value
|
|
|
|
|
|
- def prepare_persistent(self, enabled=None):
|
|
|
+ def prepare_persistent(self, enabled: bool = None) -> bool:
|
|
|
if enabled is not None:
|
|
|
return enabled
|
|
|
p = self.app.conf.result_persistent
|
|
|
return self.persistent if p is None else p
|
|
|
|
|
|
- def encode_result(self, result, state):
|
|
|
+ def encode_result(self, result: Any, state: str) -> Any:
|
|
|
if state in self.EXCEPTION_STATES and isinstance(result, Exception):
|
|
|
return self.prepare_exception(result)
|
|
|
else:
|
|
|
return self.prepare_value(result)
|
|
|
|
|
|
- def is_cached(self, task_id):
|
|
|
+ def is_cached(self, task_id: str) -> bool:
|
|
|
return task_id in self._cache
|
|
|
|
|
|
- def store_result(self, task_id, result, state,
|
|
|
- traceback=None, request=None, **kwargs):
|
|
|
+ async def store_result(self, task_id: str, result: Any, state: str,
|
|
|
+ *,
|
|
|
+ traceback: str = None,
|
|
|
+ request: RequestT = None, **kwargs) -> Any:
|
|
|
"""Update task state and result."""
|
|
|
result = self.encode_result(result, state)
|
|
|
- self._store_result(task_id, result, state, traceback,
|
|
|
- request=request, **kwargs)
|
|
|
+ await self._store_result(task_id, result, state, traceback,
|
|
|
+ request=request, **kwargs)
|
|
|
return result
|
|
|
|
|
|
- def forget(self, task_id):
|
|
|
+ async def forget(self, task_id: str) -> None:
|
|
|
self._cache.pop(task_id, None)
|
|
|
- self._forget(task_id)
|
|
|
+ await self._forget(task_id)
|
|
|
|
|
|
- def _forget(self, task_id):
|
|
|
+ async def _forget(self, task_id: str) -> None:
|
|
|
raise NotImplementedError('backend does not implement forget.')
|
|
|
|
|
|
- def get_state(self, task_id):
|
|
|
+ def get_state(self, task_id: str) -> str:
|
|
|
"""Get the state of a task."""
|
|
|
return self.get_task_meta(task_id)['status']
|
|
|
|
|
|
- def get_traceback(self, task_id):
|
|
|
+ def get_traceback(self, task_id: str) -> str:
|
|
|
"""Get the traceback for a failed task."""
|
|
|
return self.get_task_meta(task_id).get('traceback')
|
|
|
|
|
|
- def get_result(self, task_id):
|
|
|
+ def get_result(self, task_id: str) -> Any:
|
|
|
"""Get the result of a task."""
|
|
|
return self.get_task_meta(task_id).get('result')
|
|
|
|
|
|
- def get_children(self, task_id):
|
|
|
+ def get_children(self, task_id: str) -> Sequence[str]:
|
|
|
"""Get the list of subtasks sent by a task."""
|
|
|
try:
|
|
|
return self.get_task_meta(task_id)['children']
|
|
|
except KeyError:
|
|
|
pass
|
|
|
|
|
|
- def _ensure_not_eager(self):
|
|
|
+ def _ensure_not_eager(self) -> None:
|
|
|
if self.app.conf.task_always_eager:
|
|
|
raise RuntimeError(
|
|
|
"Cannot retrieve result with task_always_eager enabled")
|
|
|
|
|
|
- def get_task_meta(self, task_id, cache=True):
|
|
|
+ def get_task_meta(self, task_id: str, cache: bool = True) -> Mapping:
|
|
|
self._ensure_not_eager()
|
|
|
if cache:
|
|
|
try:
|
|
@@ -348,15 +386,15 @@ class Backend:
|
|
|
self._cache[task_id] = meta
|
|
|
return meta
|
|
|
|
|
|
- def reload_task_result(self, task_id):
|
|
|
+ def reload_task_result(self, task_id: str) -> None:
|
|
|
"""Reload task result, even if it has been previously fetched."""
|
|
|
self._cache[task_id] = self.get_task_meta(task_id, cache=False)
|
|
|
|
|
|
- def reload_group_result(self, group_id):
|
|
|
+ def reload_group_result(self, group_id: str) -> None:
|
|
|
"""Reload group result, even if it has been previously fetched."""
|
|
|
self._cache[group_id] = self.get_group_meta(group_id, cache=False)
|
|
|
|
|
|
- def get_group_meta(self, group_id, cache=True):
|
|
|
+ def get_group_meta(self, group_id: str, cache: bool = True) -> Mapping:
|
|
|
self._ensure_not_eager()
|
|
|
if cache:
|
|
|
try:
|
|
@@ -369,91 +407,118 @@ class Backend:
|
|
|
self._cache[group_id] = meta
|
|
|
return meta
|
|
|
|
|
|
- def restore_group(self, group_id, cache=True):
|
|
|
+ def restore_group(self, group_id: str, cache: bool = True) -> GroupResult:
|
|
|
"""Get the result for a group."""
|
|
|
meta = self.get_group_meta(group_id, cache=cache)
|
|
|
if meta:
|
|
|
return meta['result']
|
|
|
|
|
|
- def save_group(self, group_id, result):
|
|
|
+ def save_group(self, group_id: str, result: GroupResult) -> GroupResult:
|
|
|
"""Store the result of an executed group."""
|
|
|
return self._save_group(group_id, result)
|
|
|
|
|
|
- def delete_group(self, group_id):
|
|
|
+ def _save_group(self, group_id: str, result: GroupResult) -> GroupResult:
|
|
|
+ raise NotImplementedError()
|
|
|
+
|
|
|
+ def delete_group(self, group_id: str) -> None:
|
|
|
self._cache.pop(group_id, None)
|
|
|
- return self._delete_group(group_id)
|
|
|
+ self._delete_group(group_id)
|
|
|
|
|
|
- def cleanup(self):
|
|
|
+ def _delete_group(self, group_id: str) -> None:
|
|
|
+ raise NotImplementedError()
|
|
|
+
|
|
|
+ def cleanup(self) -> None:
|
|
|
"""Backend cleanup.
|
|
|
|
|
|
Note:
|
|
|
This is run by :class:`celery.task.DeleteExpiredTaskMetaTask`.
|
|
|
"""
|
|
|
- pass
|
|
|
+ ...
|
|
|
|
|
|
- def process_cleanup(self):
|
|
|
+ def process_cleanup(self) -> None:
|
|
|
"""Cleanup actions to do at the end of a task worker process."""
|
|
|
- pass
|
|
|
+ ...
|
|
|
|
|
|
- def on_task_call(self, producer, task_id):
|
|
|
+ async def on_task_call(self, producer: ProducerT, task_id: str) -> Mapping:
|
|
|
return {}
|
|
|
|
|
|
- def add_to_chord(self, chord_id, result):
|
|
|
+ async def add_to_chord(self, chord_id: str, result: ResultT) -> None:
|
|
|
raise NotImplementedError('Backend does not support add_to_chord')
|
|
|
|
|
|
- def on_chord_part_return(self, request, state, result, **kwargs):
|
|
|
- pass
|
|
|
+ async def on_chord_part_return(self, request: RequestT,
|
|
|
+ state: str,
|
|
|
+ result: Any,
|
|
|
+ **kwargs) -> None:
|
|
|
+ ...
|
|
|
|
|
|
- def fallback_chord_unlock(self, group_id, body, result=None,
|
|
|
- countdown=1, **kwargs):
|
|
|
+ async def fallback_chord_unlock(self, group_id: str, body: SignatureT,
|
|
|
+ result: ResultT = None,
|
|
|
+ countdown: float = 1,
|
|
|
+ **kwargs) -> None:
|
|
|
kwargs['result'] = [r.as_tuple() for r in result]
|
|
|
- self.app.tasks['celery.chord_unlock'].apply_async(
|
|
|
+ await self.app.tasks['celery.chord_unlock'].apply_async(
|
|
|
(group_id, body,), kwargs, countdown=countdown,
|
|
|
)
|
|
|
|
|
|
- def ensure_chords_allowed(self):
|
|
|
- pass
|
|
|
+ def ensure_chords_allowed(self) -> None:
|
|
|
+ ...
|
|
|
|
|
|
- def apply_chord(self, header, partial_args, group_id, body,
|
|
|
- options={}, **kwargs):
|
|
|
+ async def apply_chord(self, header: SignatureT, partial_args: Sequence,
|
|
|
+ *,
|
|
|
+ group_id: str, body: SignatureT,
|
|
|
+ options: Mapping = {}, **kwargs) -> ResultT:
|
|
|
self.ensure_chords_allowed()
|
|
|
fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
|
|
|
- result = header(*partial_args, task_id=group_id, **fixed_options or {})
|
|
|
- self.fallback_chord_unlock(group_id, body, **kwargs)
|
|
|
+ result = await header(
|
|
|
+ *partial_args, task_id=group_id, **fixed_options or {})
|
|
|
+ await self.fallback_chord_unlock(group_id, body, **kwargs)
|
|
|
return result
|
|
|
|
|
|
- def current_task_children(self, request=None):
|
|
|
+ def current_task_children(self,
|
|
|
+ request: RequestT = None) -> Sequence[Tuple]:
|
|
|
request = request or getattr(get_current_task(), 'request', None)
|
|
|
if request:
|
|
|
return [r.as_tuple() for r in getattr(request, 'children', [])]
|
|
|
|
|
|
- def __reduce__(self, args=(), kwargs={}):
|
|
|
+ def __reduce__(self, args: Tuple = (), kwargs: Dict = {}) -> Tuple:
|
|
|
return (unpickle_backend, (self.__class__, args, kwargs))
|
|
|
|
|
|
|
|
|
class SyncBackendMixin:
|
|
|
|
|
|
- def iter_native(self, result, timeout=None, interval=0.5, no_ack=True,
|
|
|
- on_message=None, on_interval=None):
|
|
|
+ async def iter_native(
|
|
|
+ self, result: ResultT,
|
|
|
+ *,
|
|
|
+ timeout: float = None,
|
|
|
+ interval: float = 0.5,
|
|
|
+ no_ack: bool = True,
|
|
|
+ on_message: Callable = None,
|
|
|
+ on_interval: Callable = None) -> Iterable[Tuple[str, Mapping]]:
|
|
|
self._ensure_not_eager()
|
|
|
results = result.results
|
|
|
if not results:
|
|
|
return iter([])
|
|
|
- return self.get_many(
|
|
|
+ return await self.get_many(
|
|
|
{r.id for r in results},
|
|
|
timeout=timeout, interval=interval, no_ack=no_ack,
|
|
|
on_message=on_message, on_interval=on_interval,
|
|
|
)
|
|
|
|
|
|
- def wait_for_pending(self, result, timeout=None, interval=0.5,
|
|
|
- no_ack=True, on_message=None, on_interval=None,
|
|
|
- callback=None, propagate=True):
|
|
|
+ async def wait_for_pending(self, result: ResultT,
|
|
|
+ *,
|
|
|
+ timeout: float = None,
|
|
|
+ interval: float = 0.5,
|
|
|
+ no_ack: bool = True,
|
|
|
+ on_message: Callable = None,
|
|
|
+ on_interval: Callable = None,
|
|
|
+ callback: Callable = None,
|
|
|
+ propagate: bool = True) -> Any:
|
|
|
self._ensure_not_eager()
|
|
|
if on_message is not None:
|
|
|
raise ImproperlyConfigured(
|
|
|
'Backend does not support on_message callback')
|
|
|
|
|
|
- meta = self.wait_for(
|
|
|
+ meta = await self.wait_for(
|
|
|
result.id, timeout=timeout,
|
|
|
interval=interval,
|
|
|
on_interval=on_interval,
|
|
@@ -463,8 +528,12 @@ class SyncBackendMixin:
|
|
|
result._maybe_set_cache(meta)
|
|
|
return result.maybe_throw(propagate=propagate, callback=callback)
|
|
|
|
|
|
- def wait_for(self, task_id,
|
|
|
- timeout=None, interval=0.5, no_ack=True, on_interval=None):
|
|
|
+ async def wait_for(self, task_id: str,
|
|
|
+ *,
|
|
|
+ timeout: float = None,
|
|
|
+ interval: float = 0.5,
|
|
|
+ no_ack: bool = True,
|
|
|
+ on_interval: Callable = None) -> Mapping:
|
|
|
"""Wait for task and return its result.
|
|
|
|
|
|
If the task raises an exception, this exception
|
|
@@ -484,21 +553,23 @@ class SyncBackendMixin:
|
|
|
if meta['status'] in states.READY_STATES:
|
|
|
return meta
|
|
|
if on_interval:
|
|
|
- on_interval()
|
|
|
+ await on_interval()
|
|
|
# avoid hammering the CPU checking status.
|
|
|
time.sleep(interval)
|
|
|
time_elapsed += interval
|
|
|
if timeout and time_elapsed >= timeout:
|
|
|
raise TimeoutError('The operation timed out.')
|
|
|
|
|
|
- def add_pending_result(self, result, weak=False):
|
|
|
+ def add_pending_result(self, result: ResultT,
|
|
|
+ *,
|
|
|
+ weak: bool = False) -> ResultT:
|
|
|
return result
|
|
|
|
|
|
- def remove_pending_result(self, result):
|
|
|
+ def remove_pending_result(self, result: ResultT) -> ResultT:
|
|
|
return result
|
|
|
|
|
|
@property
|
|
|
- def is_async(self):
|
|
|
+ def is_async(self) -> bool:
|
|
|
return False
|
|
|
|
|
|
|
|
@@ -514,7 +585,7 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
chord_keyprefix = 'chord-unlock-'
|
|
|
implements_incr = False
|
|
|
|
|
|
- def __init__(self, *args, **kwargs):
|
|
|
+ def __init__(self, *args, **kwargs) -> None:
|
|
|
if hasattr(self.key_t, '__func__'): # pragma: no cover
|
|
|
self.key_t = self.key_t.__func__ # remove binding
|
|
|
self._encode_prefixes()
|
|
@@ -522,51 +593,51 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
if self.implements_incr:
|
|
|
self.apply_chord = self._apply_chord_incr
|
|
|
|
|
|
- def _encode_prefixes(self):
|
|
|
+ def _encode_prefixes(self) -> None:
|
|
|
self.task_keyprefix = self.key_t(self.task_keyprefix)
|
|
|
self.group_keyprefix = self.key_t(self.group_keyprefix)
|
|
|
self.chord_keyprefix = self.key_t(self.chord_keyprefix)
|
|
|
|
|
|
- def get(self, key):
|
|
|
+ async def get(self, key: AnyStr) -> AnyStr:
|
|
|
raise NotImplementedError('Must implement the get method.')
|
|
|
|
|
|
- def mget(self, keys):
|
|
|
+ async def mget(self, keys: Sequence[AnyStr]) -> Sequence[AnyStr]:
|
|
|
raise NotImplementedError('Does not support get_many')
|
|
|
|
|
|
- def set(self, key, value):
|
|
|
+ async def set(self, key: AnyStr, value: AnyStr) -> None:
|
|
|
raise NotImplementedError('Must implement the set method.')
|
|
|
|
|
|
- def delete(self, key):
|
|
|
+ async def delete(self, key: AnyStr) -> None:
|
|
|
raise NotImplementedError('Must implement the delete method')
|
|
|
|
|
|
- def incr(self, key):
|
|
|
+ async def incr(self, key: AnyStr) -> None:
|
|
|
raise NotImplementedError('Does not implement incr')
|
|
|
|
|
|
- def expire(self, key, value):
|
|
|
- pass
|
|
|
+ async def expire(self, key: AnyStr, value: AnyStr) -> None:
|
|
|
+ ...
|
|
|
|
|
|
- def get_key_for_task(self, task_id, key=''):
|
|
|
+ def get_key_for_task(self, task_id: str, key: AnyStr = '') -> AnyStr:
|
|
|
"""Get the cache key for a task by id."""
|
|
|
key_t = self.key_t
|
|
|
return key_t('').join([
|
|
|
self.task_keyprefix, key_t(task_id), key_t(key),
|
|
|
])
|
|
|
|
|
|
- def get_key_for_group(self, group_id, key=''):
|
|
|
+ def get_key_for_group(self, group_id: str, key: AnyStr = '') -> AnyStr:
|
|
|
"""Get the cache key for a group by id."""
|
|
|
key_t = self.key_t
|
|
|
return key_t('').join([
|
|
|
self.group_keyprefix, key_t(group_id), key_t(key),
|
|
|
])
|
|
|
|
|
|
- def get_key_for_chord(self, group_id, key=''):
|
|
|
+ def get_key_for_chord(self, group_id: str, key: AnyStr = '') -> AnyStr:
|
|
|
"""Get the cache key for the chord waiting on group with given id."""
|
|
|
key_t = self.key_t
|
|
|
return key_t('').join([
|
|
|
self.chord_keyprefix, key_t(group_id), key_t(key),
|
|
|
])
|
|
|
|
|
|
- def _strip_prefix(self, key):
|
|
|
+ def _strip_prefix(self, key: AnyStr) -> AnyStr:
|
|
|
"""Take bytes: emit string."""
|
|
|
key = self.key_t(key)
|
|
|
for prefix in self.task_keyprefix, self.group_keyprefix:
|
|
@@ -574,14 +645,18 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
return bytes_to_str(key[len(prefix):])
|
|
|
return bytes_to_str(key)
|
|
|
|
|
|
- def _filter_ready(self, values, READY_STATES=states.READY_STATES):
|
|
|
+ def _filter_ready(
|
|
|
+ self,
|
|
|
+ values: Sequence[Tuple[Any, Any]],
|
|
|
+ *,
|
|
|
+ READY_STATES: Set[str] = states.READY_STATES) -> Iterable[Tuple]:
|
|
|
for k, v in values:
|
|
|
if v is not None:
|
|
|
v = self.decode_result(v)
|
|
|
if v['status'] in READY_STATES:
|
|
|
yield k, v
|
|
|
|
|
|
- def _mget_to_results(self, values, keys):
|
|
|
+ def _mget_to_results(self, values: Any, keys: Sequence[AnyStr]) -> Mapping:
|
|
|
if hasattr(values, 'items'):
|
|
|
# client returns dict so mapping preserved.
|
|
|
return {
|
|
@@ -595,9 +670,16 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
for i, v in self._filter_ready(enumerate(values))
|
|
|
}
|
|
|
|
|
|
- def get_many(self, task_ids, timeout=None, interval=0.5, no_ack=True,
|
|
|
- on_message=None, on_interval=None, max_iterations=None,
|
|
|
- READY_STATES=states.READY_STATES):
|
|
|
+ async def get_many(
|
|
|
+ self, task_ids: Sequence[str],
|
|
|
+ *,
|
|
|
+ timeout: float = None,
|
|
|
+ interval: float = 0.5,
|
|
|
+ no_ack: bool = True,
|
|
|
+ on_message: Callable = None,
|
|
|
+ on_interval: Callable = None,
|
|
|
+ max_iterations: int = None,
|
|
|
+ READY_STATES=states.READY_STATES) -> Iterator[str, Mapping]:
|
|
|
interval = 0.5 if interval is None else interval
|
|
|
ids = task_ids if isinstance(task_ids, set) else set(task_ids)
|
|
|
cached_ids = set()
|
|
@@ -616,54 +698,59 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
iterations = 0
|
|
|
while ids:
|
|
|
keys = list(ids)
|
|
|
- r = self._mget_to_results(self.mget([self.get_key_for_task(k)
|
|
|
- for k in keys]), keys)
|
|
|
+ payloads = await self.mget([
|
|
|
+ self.get_key_for_task(k) for k in keys
|
|
|
+ ])
|
|
|
+ r = self._mget_to_results(payloads, keys)
|
|
|
cache.update(r)
|
|
|
ids.difference_update({bytes_to_str(v) for v in r})
|
|
|
for key, value in r.items():
|
|
|
if on_message is not None:
|
|
|
- on_message(value)
|
|
|
+ await on_message(value)
|
|
|
yield bytes_to_str(key), value
|
|
|
if timeout and iterations * interval >= timeout:
|
|
|
raise TimeoutError('Operation timed out ({0})'.format(timeout))
|
|
|
if on_interval:
|
|
|
- on_interval()
|
|
|
+ await on_interval()
|
|
|
time.sleep(interval) # don't busy loop.
|
|
|
iterations += 1
|
|
|
if max_iterations and iterations >= max_iterations:
|
|
|
break
|
|
|
|
|
|
- def _forget(self, task_id):
|
|
|
- self.delete(self.get_key_for_task(task_id))
|
|
|
+ async def _forget(self, task_id: str) -> None:
|
|
|
+ await self.delete(self.get_key_for_task(task_id))
|
|
|
|
|
|
- def _store_result(self, task_id, result, state,
|
|
|
- traceback=None, request=None, **kwargs):
|
|
|
+ async def _store_result(self, task_id: str, result: Any, state: str,
|
|
|
+ traceback: str = None,
|
|
|
+ request: RequestT = None,
|
|
|
+ **kwargs) -> Any:
|
|
|
meta = {
|
|
|
'status': state, 'result': result, 'traceback': traceback,
|
|
|
'children': self.current_task_children(request),
|
|
|
'task_id': bytes_to_str(task_id),
|
|
|
}
|
|
|
- self.set(self.get_key_for_task(task_id), self.encode(meta))
|
|
|
+ await self.set(self.get_key_for_task(task_id), self.encode(meta))
|
|
|
return result
|
|
|
|
|
|
- def _save_group(self, group_id, result):
|
|
|
- self.set(self.get_key_for_group(group_id),
|
|
|
- self.encode({'result': result.as_tuple()}))
|
|
|
+ async def _save_group(self,
|
|
|
+ group_id: str, result: GroupResult) -> GroupResult:
|
|
|
+ await self.set(self.get_key_for_group(group_id),
|
|
|
+ self.encode({'result': result.as_tuple()}))
|
|
|
return result
|
|
|
|
|
|
- def _delete_group(self, group_id):
|
|
|
- self.delete(self.get_key_for_group(group_id))
|
|
|
+ async def _delete_group(self, group_id: str) -> None:
|
|
|
+ await self.delete(self.get_key_for_group(group_id))
|
|
|
|
|
|
- def _get_task_meta_for(self, task_id):
|
|
|
+ async def _get_task_meta_for(self, task_id: str) -> Mapping:
|
|
|
"""Get task meta-data for a task by id."""
|
|
|
- meta = self.get(self.get_key_for_task(task_id))
|
|
|
+ meta = await self.get(self.get_key_for_task(task_id))
|
|
|
if not meta:
|
|
|
return {'status': states.PENDING, 'result': None}
|
|
|
return self.decode_result(meta)
|
|
|
|
|
|
- def _restore_group(self, group_id):
|
|
|
+ async def _restore_group(self, group_id: str) -> GroupResult:
|
|
|
"""Get task meta-data for a task by id."""
|
|
|
- meta = self.get(self.get_key_for_group(group_id))
|
|
|
+ meta = await self.get(self.get_key_for_group(group_id))
|
|
|
# previously this was always pickled, but later this
|
|
|
# was extended to support other serializers, so the
|
|
|
# structure is kind of weird.
|
|
@@ -673,16 +760,25 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
meta['result'] = result_from_tuple(result, self.app)
|
|
|
return meta
|
|
|
|
|
|
- def _apply_chord_incr(self, header, partial_args, group_id, body,
|
|
|
- result=None, options={}, **kwargs):
|
|
|
+ async def _apply_chord_incr(self, header: SignatureT,
|
|
|
+ partial_args: Sequence,
|
|
|
+ group_id: str,
|
|
|
+ body: SignatureT,
|
|
|
+ result: ResultT = None,
|
|
|
+ options: Mapping = {},
|
|
|
+ **kwargs) -> ResultT:
|
|
|
self.ensure_chords_allowed()
|
|
|
- self.save_group(group_id, self.app.GroupResult(group_id, result))
|
|
|
+ await self.save_group(group_id, self.app.GroupResult(group_id, result))
|
|
|
|
|
|
fixed_options = {k: v for k, v in options.items() if k != 'task_id'}
|
|
|
|
|
|
- return header(*partial_args, task_id=group_id, **fixed_options or {})
|
|
|
+ return await header(
|
|
|
+ *partial_args, task_id=group_id, **fixed_options or {})
|
|
|
|
|
|
- def on_chord_part_return(self, request, state, result, **kwargs):
|
|
|
+ async def on_chord_part_return(
|
|
|
+ self,
|
|
|
+ request: RequestT, state: str, result: Any,
|
|
|
+ **kwargs) -> None:
|
|
|
if not self.implements_incr:
|
|
|
return
|
|
|
app = self.app
|
|
@@ -691,25 +787,27 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
return
|
|
|
key = self.get_key_for_chord(gid)
|
|
|
try:
|
|
|
- deps = GroupResult.restore(gid, backend=self)
|
|
|
+ deps = await GroupResult.restore(gid, backend=self)
|
|
|
except Exception as exc: # pylint: disable=broad-except
|
|
|
callback = maybe_signature(request.chord, app=app)
|
|
|
logger.exception('Chord %r raised: %r', gid, exc)
|
|
|
- return self.chord_error_from_stack(
|
|
|
+ await self.chord_error_from_stack(
|
|
|
callback,
|
|
|
ChordError('Cannot restore group: {0!r}'.format(exc)),
|
|
|
)
|
|
|
+ return
|
|
|
if deps is None:
|
|
|
try:
|
|
|
raise ValueError(gid)
|
|
|
except ValueError as exc:
|
|
|
callback = maybe_signature(request.chord, app=app)
|
|
|
logger.exception('Chord callback %r raised: %r', gid, exc)
|
|
|
- return self.chord_error_from_stack(
|
|
|
+ await self.chord_error_from_stack(
|
|
|
callback,
|
|
|
ChordError('GroupResult {0} no longer exists'.format(gid)),
|
|
|
)
|
|
|
- val = self.incr(key)
|
|
|
+ return
|
|
|
+ val = await self.incr(key)
|
|
|
size = len(deps)
|
|
|
if val > size: # pragma: no cover
|
|
|
logger.warning('Chord counter incremented too many times for %r',
|
|
@@ -719,7 +817,7 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
j = deps.join_native if deps.supports_native_join else deps.join
|
|
|
try:
|
|
|
with allow_join_result():
|
|
|
- ret = j(timeout=3.0, propagate=True)
|
|
|
+ ret = await j(timeout=3.0, propagate=True)
|
|
|
except Exception as exc: # pylint: disable=broad-except
|
|
|
try:
|
|
|
culprit = next(deps._failed_join_report())
|
|
@@ -730,21 +828,21 @@ class BaseKeyValueStoreBackend(Backend):
|
|
|
reason = repr(exc)
|
|
|
|
|
|
logger.exception('Chord %r raised: %r', gid, reason)
|
|
|
- self.chord_error_from_stack(callback, ChordError(reason))
|
|
|
+ await self.chord_error_from_stack(callback, ChordError(reason))
|
|
|
else:
|
|
|
try:
|
|
|
- callback.delay(ret)
|
|
|
+ await callback.delay(ret)
|
|
|
except Exception as exc: # pylint: disable=broad-except
|
|
|
logger.exception('Chord %r raised: %r', gid, exc)
|
|
|
- self.chord_error_from_stack(
|
|
|
+ await self.chord_error_from_stack(
|
|
|
callback,
|
|
|
ChordError('Callback error: {0!r}'.format(exc)),
|
|
|
)
|
|
|
finally:
|
|
|
deps.delete()
|
|
|
- self.client.delete(key)
|
|
|
+ await self.client.delete(key)
|
|
|
else:
|
|
|
- self.expire(key, self.expires)
|
|
|
+ await self.expire(key, self.expires)
|
|
|
|
|
|
|
|
|
class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
|
|
@@ -756,16 +854,16 @@ class DisabledBackend(BaseBackend):
|
|
|
|
|
|
_cache = {} # need this attribute to reset cache in tests.
|
|
|
|
|
|
- def store_result(self, *args, **kwargs):
|
|
|
- pass
|
|
|
+ async def store_result(self, *args, **kwargs) -> Any:
|
|
|
+ ...
|
|
|
|
|
|
- def ensure_chords_allowed(self):
|
|
|
+ def ensure_chords_allowed(self) -> None:
|
|
|
raise NotImplementedError(E_CHORD_NO_BACKEND.strip())
|
|
|
|
|
|
- def _is_disabled(self, *args, **kwargs):
|
|
|
+ def _is_disabled(self, *args, **kwargs) -> bool:
|
|
|
raise NotImplementedError(E_NO_BACKEND.strip())
|
|
|
|
|
|
- def as_uri(self, *args, **kwargs):
|
|
|
+ def as_uri(self, *args, **kwargs) -> str:
|
|
|
return 'disabled://'
|
|
|
|
|
|
get_state = get_result = get_traceback = _is_disabled
|