|
@@ -14,6 +14,7 @@ from kombu.utils import cached_property, retry_over_time
|
|
|
from kombu.utils.url import _parse_url
|
|
|
|
|
|
from celery import states
|
|
|
+from celery._state import task_join_will_block
|
|
|
from celery.canvas import maybe_signature
|
|
|
from celery.exceptions import ChordError, ImproperlyConfigured
|
|
|
from celery.five import string_t
|
|
@@ -22,7 +23,8 @@ from celery.utils.functional import dictfilter
|
|
|
from celery.utils.log import get_logger
|
|
|
from celery.utils.timeutils import humanize_seconds
|
|
|
|
|
|
-from .base import KeyValueStoreBackend
|
|
|
+from . import async
|
|
|
+from . import base
|
|
|
|
|
|
try:
|
|
|
import redis
|
|
@@ -47,9 +49,54 @@ logger = get_logger(__name__)
|
|
|
error = logger.error
|
|
|
|
|
|
|
|
|
-class RedisBackend(KeyValueStoreBackend):
|
|
|
+class ResultConsumer(async.BaseResultConsumer):
|
|
|
+
|
|
|
+ _pubsub = None
|
|
|
+
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super(ResultConsumer, self).__init__(*args, **kwargs)
|
|
|
+ self._get_key_for_task = self.backend.get_key_for_task
|
|
|
+ self._decode_result = self.backend.decode_result
|
|
|
+ self.subscribed_to = set()
|
|
|
+
|
|
|
+ def start(self, initial_task_id):
|
|
|
+ self._pubsub = self.backend.client.pubsub(
|
|
|
+ ignore_subscribe_messages=True,
|
|
|
+ )
|
|
|
+ self._consume_from(initial_task_id)
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ if self._pubsub is not None:
|
|
|
+ self._pubsub.close()
|
|
|
+
|
|
|
+ def drain_events(self, timeout=None):
|
|
|
+ m = self._pubsub.get_message(timeout=timeout)
|
|
|
+ if m and m['type'] == 'message':
|
|
|
+ self.on_state_change(self._decode_result(m['data']), m)
|
|
|
+
|
|
|
+ def consume_from(self, task_id):
|
|
|
+ if self._pubsub is None:
|
|
|
+ return self.start(task_id)
|
|
|
+ self._consume_from(task_id)
|
|
|
+
|
|
|
+ def _consume_from(self, task_id):
|
|
|
+ key = self._get_key_for_task(task_id)
|
|
|
+ if key not in self.subscribed_to:
|
|
|
+ self.subscribed_to.add(key)
|
|
|
+ self._pubsub.subscribe(key)
|
|
|
+
|
|
|
+ def cancel_for(self, task_id):
|
|
|
+ if self._pubsub:
|
|
|
+ key = self._get_key_for_task(task_id)
|
|
|
+ self.subscribed_to.discard(key)
|
|
|
+ self._pubsub.unsubscribe(key)
|
|
|
+
|
|
|
+
|
|
|
+class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
|
|
|
"""Redis task result store."""
|
|
|
|
|
|
+ ResultConsumer = ResultConsumer
|
|
|
+
|
|
|
#: redis-py client module.
|
|
|
redis = redis
|
|
|
|
|
@@ -93,6 +140,8 @@ class RedisBackend(KeyValueStoreBackend):
|
|
|
self.connection_errors, self.channel_errors = (
|
|
|
get_redis_error_classes() if get_redis_error_classes
|
|
|
else ((), ()))
|
|
|
+ self.result_consumer = self.ResultConsumer(
|
|
|
+ self, self.app, self.accept, self._pending_results)
|
|
|
|
|
|
def _params_from_url(self, url, defaults):
|
|
|
scheme, host, port, user, password, path, query = _parse_url(url)
|
|
@@ -124,6 +173,10 @@ class RedisBackend(KeyValueStoreBackend):
|
|
|
connparams.update(query)
|
|
|
return connparams
|
|
|
|
|
|
+ def on_task_call(self, producer, task_id):
|
|
|
+ if not task_join_will_block():
|
|
|
+ self.result_consumer.consume_from(task_id)
|
|
|
+
|
|
|
def get(self, key):
|
|
|
return self.client.get(key)
|
|
|
|