Browse Source

[result][redis] Use pubsub for consuming results, and use the new async backend interface

Incorporates ideas taken from Yaroslav Zhavoronkov's diff in #2511

Closes Issue #2511
Ask Solem 9 years ago
parent
commit
886faf6754
5 changed files with 92 additions and 19 deletions
  1. 11 10
      celery/backends/amqp.py
  2. 15 3
      celery/backends/async.py
  3. 8 3
      celery/backends/base.py
  4. 55 2
      celery/backends/redis.py
  5. 3 1
      celery/backends/rpc.py

+ 11 - 10
celery/backends/amqp.py

@@ -49,13 +49,16 @@ class NoCacheQueue(Queue):
 class ResultConsumer(BaseResultConsumer):
     Consumer = Consumer
 
+    _connection = None
+    _consumer = None
+
     def __init__(self, *args, **kwargs):
         super(ResultConsumer, self).__init__(*args, **kwargs)
-        self._connection = None
-        self._consumer = None
+        self._create_binding = self.backend._create_binding
 
-    def start(self, initial_queue, no_ack=True):
+    def start(self, initial_task_id, no_ack=True):
         self._connection = self.app.connection()
+        initial_queue = self._create_binding(initial_task_id)
         self._consumer = self.Consumer(
             self._connection.default_channel, [initial_queue],
             callbacks=[self.on_state_change], no_ack=no_ack,
@@ -77,16 +80,17 @@ class ResultConsumer(BaseResultConsumer):
             self._connection.collect()
             self._connection = None
 
-    def consume_from(self, queue):
+    def consume_from(self, task_id):
         if self._consumer is None:
-            return self.start(queue)
+            return self.start(task_id)
+        queue = self._create_binding(task_id)
         if not self._consumer.consuming_from(queue):
             self._consumer.add_queue(queue)
             self._consumer.consume()
 
-    def cancel_for(self, queue):
+    def cancel_for(self, task_id):
         if self._consumer:
-            self._consumer.cancel_by_queue(queue.name)
+            self._consumer.cancel_by_queue(self._create_binding(task_id).name)
 
 
 class AMQPBackend(base.Backend, AsyncBackendMixin):
@@ -138,9 +142,6 @@ class AMQPBackend(base.Backend, AsyncBackendMixin):
         self._pending_results.clear()
         self.result_consumer._after_fork()
 
-    def on_result_fulfilled(self, result):
-        self.result_consumer.cancel_for(self._create_binding(result.id))
-
     def _create_exchange(self, name, type='direct', delivery_mode=2):
         return self.Exchange(name=name,
                              type=type,

+ 15 - 3
celery/backends/async.py

@@ -135,7 +135,7 @@ class AsyncBackendMixin(object):
     def add_pending_result(self, result):
         if result.id not in self._pending_results:
             self._pending_results[result.id] = result
-            self.result_consumer.consume_from(self._create_binding(result.id))
+            self.result_consumer.consume_from(result.id)
         return result
 
     def remove_pending_result(self, result):
@@ -144,7 +144,7 @@ class AsyncBackendMixin(object):
         return result
 
     def on_result_fulfilled(self, result):
-        pass
+        self.result_consumer.cancel_for(result.id)
 
     def wait_for_pending(self, result,
                          callback=None, propagate=True, **kwargs):
@@ -177,8 +177,20 @@ class BaseResultConsumer(object):
         self.buckets = WeakKeyDictionary()
         self.drainer = drainers[detect_environment()](self)
 
+    def start(self):
+        raise NotImplementedError()
+
+    def stop(self):
+        pass
+
     def drain_events(self, timeout=None):
-        raise NotImplementedError('subclass responsibility')
+        raise NotImplementedError()
+
+    def consume_from(self, task_id):
+        raise NotImplementedError()
+
+    def cancel_for(self, task_id):
+        raise NotImplementedError()
 
     def _after_fork(self):
         self.bucket.clear()

+ 8 - 3
celery/backends/base.py

@@ -448,7 +448,7 @@ class BaseBackend(Backend, SyncBackendMixin):
 BaseDictBackend = BaseBackend  # XXX compat
 
 
-class KeyValueStoreBackend(BaseBackend):
+class BaseKeyValueStoreBackend(Backend):
     key_t = ensure_bytes
     task_keyprefix = 'celery-task-meta-'
     group_keyprefix = 'celery-taskset-meta-'
@@ -459,7 +459,7 @@ class KeyValueStoreBackend(BaseBackend):
         if hasattr(self.key_t, '__func__'):  # pragma: no cover
             self.key_t = self.key_t.__func__  # remove binding
         self._encode_prefixes()
-        super(KeyValueStoreBackend, self).__init__(*args, **kwargs)
+        super(BaseKeyValueStoreBackend, self).__init__(*args, **kwargs)
         if self.implements_incr:
             self.apply_chord = self._apply_chord_incr
 
@@ -578,7 +578,8 @@ class KeyValueStoreBackend(BaseBackend):
     def _store_result(self, task_id, result, state,
                       traceback=None, request=None, **kwargs):
         meta = {'status': state, 'result': result, 'traceback': traceback,
-                'children': self.current_task_children(request)}
+                'children': self.current_task_children(request),
+                'task_id': task_id}
         self.set(self.get_key_for_task(task_id), self.encode(meta))
         return result
 
@@ -683,6 +684,10 @@ class KeyValueStoreBackend(BaseBackend):
             self.expire(key, 86400)
 
 
+class KeyValueStoreBackend(BaseKeyValueStoreBackend, SyncBackendMixin):
+    pass
+
+
 class DisabledBackend(BaseBackend):
     _cache = {}   # need this attribute to reset cache in tests.
 

+ 55 - 2
celery/backends/redis.py

@@ -14,6 +14,7 @@ from kombu.utils import cached_property, retry_over_time
 from kombu.utils.url import _parse_url
 
 from celery import states
+from celery._state import task_join_will_block
 from celery.canvas import maybe_signature
 from celery.exceptions import ChordError, ImproperlyConfigured
 from celery.five import string_t
@@ -22,7 +23,8 @@ from celery.utils.functional import dictfilter
 from celery.utils.log import get_logger
 from celery.utils.timeutils import humanize_seconds
 
-from .base import KeyValueStoreBackend
+from . import async
+from . import base
 
 try:
     import redis
@@ -47,9 +49,54 @@ logger = get_logger(__name__)
 error = logger.error
 
 
-class RedisBackend(KeyValueStoreBackend):
+class ResultConsumer(async.BaseResultConsumer):
+
+    _pubsub = None
+
+    def __init__(self, *args, **kwargs):
+        super(ResultConsumer, self).__init__(*args, **kwargs)
+        self._get_key_for_task = self.backend.get_key_for_task
+        self._decode_result = self.backend.decode_result
+        self.subscribed_to = set()
+
+    def start(self, initial_task_id):
+        self._pubsub = self.backend.client.pubsub(
+            ignore_subscribe_messages=True,
+        )
+        self._consume_from(initial_task_id)
+
+    def stop(self):
+        if self._pubsub is not None:
+            self._pubsub.close()
+
+    def drain_events(self, timeout=None):
+        m = self._pubsub.get_message(timeout=timeout)
+        if m and m['type'] == 'message':
+            self.on_state_change(self._decode_result(m['data']), m)
+
+    def consume_from(self, task_id):
+        if self._pubsub is None:
+            return self.start(task_id)
+        self._consume_from(task_id)
+
+    def _consume_from(self, task_id):
+        key = self._get_key_for_task(task_id)
+        if key not in self.subscribed_to:
+            self.subscribed_to.add(key)
+            self._pubsub.subscribe(key)
+
+    def cancel_for(self, task_id):
+        if self._pubsub:
+            key = self._get_key_for_task(task_id)
+            self.subscribed_to.discard(key)
+            self._pubsub.unsubscribe(key)
+
+
+class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
     """Redis task result store."""
 
+    ResultConsumer = ResultConsumer
+
     #: redis-py client module.
     redis = redis
 
@@ -93,6 +140,8 @@ class RedisBackend(KeyValueStoreBackend):
         self.connection_errors, self.channel_errors = (
             get_redis_error_classes() if get_redis_error_classes
             else ((), ()))
+        self.result_consumer = self.ResultConsumer(
+            self, self.app, self.accept, self._pending_results)
 
     def _params_from_url(self, url, defaults):
         scheme, host, port, user, password, path, query = _parse_url(url)
@@ -124,6 +173,10 @@ class RedisBackend(KeyValueStoreBackend):
         connparams.update(query)
         return connparams
 
+    def on_task_call(self, producer, task_id):
+        if not task_join_will_block():
+            self.result_consumer.consume_from(task_id)
+
     def get(self, key):
         return self.client.get(key)
 

+ 3 - 1
celery/backends/rpc.py

@@ -13,6 +13,7 @@ from kombu.common import maybe_declare
 from kombu.utils import cached_property
 
 from celery import current_task
+from celery._state import task_join_will_block
 from celery.backends import amqp
 
 __all__ = ['RPCBackend']
@@ -29,7 +30,8 @@ class RPCBackend(amqp.AMQPBackend):
         return Exchange(None)
 
     def on_task_call(self, producer, task_id):
-        maybe_declare(self.binding(producer.channel), retry=True)
+        if not task_join_will_block():
+            maybe_declare(self.binding(producer.channel), retry=True)
 
     def _create_binding(self, task_id):
         return self.binding