Browse Source

Redis Result Consumer: unsubscribe on message success (#4666)

* Add manager assertion which checks AsyncResult state

* Redis Result Consumer: unsubscribe on message success

- Use on_after_fork in consumer to reset PubSub and connection pool
  internal states.
- Improve Canvas integration test.
George Psarakis 6 years ago
parent
commit
a035680a96

+ 14 - 0
celery/backends/redis.py

@@ -58,6 +58,20 @@ class ResultConsumer(async.BaseResultConsumer):
         self._decode_result = self.backend.decode_result
         self.subscribed_to = set()
 
+    def on_after_fork(self):
+        self.backend.client.connection_pool.reset()
+        if self._pubsub is not None:
+            self._pubsub.close()
+        super(ResultConsumer, self).on_after_fork()
+
+    def _maybe_cancel_ready_task(self, meta):
+        if meta['status'] in states.READY_STATES:
+            self.cancel_for(meta['task_id'])
+
+    def on_state_change(self, meta, message):
+        super(ResultConsumer, self).on_state_change(meta, message)
+        self._maybe_cancel_ready_task(meta)
+
     def start(self, initial_task_id, **kwargs):
         self._pubsub = self.backend.client.pubsub(
             ignore_subscribe_messages=True,

+ 26 - 0
celery/contrib/testing/manager.py

@@ -9,6 +9,7 @@ from itertools import count
 
 from kombu.utils.functional import retry_over_time
 
+from celery import states
 from celery.exceptions import TimeoutError
 from celery.five import items
 from celery.result import ResultSet
@@ -145,6 +146,31 @@ class ManagerMixin(object):
             self.is_accepted, ids, interval=interval, desc=desc, **policy
         )
 
+    def assert_result_tasks_in_progress_or_completed(
+        self,
+        async_results,
+        interval=0.5,
+        desc='waiting for tasks to be started or completed',
+        **policy
+    ):
+        return self.assert_task_state_from_result(
+            self.is_result_task_in_progress,
+            async_results,
+            interval=interval, desc=desc, **policy
+        )
+
+    def assert_task_state_from_result(self, fun, results,
+                                      interval=0.5, **policy):
+        return self.wait_for(
+            partial(self.true_or_raise, fun, results, timeout=interval),
+            (Sentinel,), **policy
+        )
+
+    @staticmethod
+    def is_result_task_in_progress(results, **kwargs):
+        possible_states = (states.STARTED, states.SUCCESS)
+        return all(result.state in possible_states for result in results)
+
     def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
         return self.wait_for(
             partial(self.true_or_raise, fun, ids, timeout=interval),

+ 42 - 12
t/integration/test_canvas.py

@@ -1,7 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
 from datetime import datetime, timedelta
-from time import sleep
 
 import pytest
 
@@ -257,23 +256,54 @@ def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
 
 class test_chord:
 
+    @staticmethod
+    def _get_active_redis_channels(client):
+        return client.execute_command('PUBSUB CHANNELS')
+
     @flaky
     def test_redis_subscribed_channels_leak(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):
             raise pytest.skip('Requires redis result backend.')
 
         redis_client = get_redis_connection()
-        async_result = chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
-        for _ in range(TIMEOUT):
-            if async_result.state == 'STARTED':
-                break
-            sleep(0.2)
-        channels_before = \
-            len(redis_client.execute_command('PUBSUB CHANNELS'))
-        assert async_result.get(timeout=TIMEOUT) == 24
-        channels_after = \
-            len(redis_client.execute_command('PUBSUB CHANNELS'))
-        assert channels_after < channels_before
+
+        manager.app.backend.result_consumer.on_after_fork()
+        initial_channels = self._get_active_redis_channels(redis_client)
+        initial_channels_count = len(initial_channels)
+
+        total_chords = 10
+        async_results = [
+            chord([add.s(5, 6), add.s(6, 7)])(delayed_sum.s())
+            for _ in range(total_chords)
+        ]
+
+        manager.assert_result_tasks_in_progress_or_completed(async_results)
+
+        channels_before = self._get_active_redis_channels(redis_client)
+        channels_before_count = len(channels_before)
+
+        assert set(channels_before) != set(initial_channels)
+        assert channels_before_count > initial_channels_count
+
+        # The total number of active Redis channels at this point
+        # is the number of chord header tasks multiplied by the
+        # total chord tasks, plus the initial channels
+        # (existing from previous tests).
+        chord_header_task_count = 2
+        assert channels_before_count == \
+            chord_header_task_count * total_chords + initial_channels_count
+
+        result_values = [
+            result.get(timeout=TIMEOUT)
+            for result in async_results
+        ]
+        assert result_values == [24] * total_chords
+
+        channels_after = self._get_active_redis_channels(redis_client)
+        channels_after_count = len(channels_after)
+
+        assert channels_after_count == initial_channels_count
+        assert set(channels_after) == set(initial_channels)
 
     @flaky
     def test_replaced_nested_chord(self, manager):

+ 48 - 0
t/unit/backends/test_redis.py

@@ -134,6 +134,54 @@ class sentinel(object):
     Sentinel = Sentinel
 
 
+class test_RedisResultConsumer:
+    def get_backend(self):
+        from celery.backends.redis import RedisBackend
+
+        class _RedisBackend(RedisBackend):
+            redis = redis
+
+        return _RedisBackend(app=self.app)
+
+    def get_consumer(self):
+        return self.get_backend().result_consumer
+
+    @patch('celery.backends.async.BaseResultConsumer.on_after_fork')
+    def test_on_after_fork(self, parent_method):
+        consumer = self.get_consumer()
+        consumer.start('none')
+        consumer.on_after_fork()
+        parent_method.assert_called_once()
+        consumer.backend.client.connection_pool.reset.assert_called_once()
+        consumer._pubsub.close.assert_called_once()
+        # PubSub instance not initialized - exception would be raised
+        # when calling .close()
+        consumer._pubsub = None
+        parent_method.reset_mock()
+        consumer.backend.client.connection_pool.reset.reset_mock()
+        consumer.on_after_fork()
+        parent_method.assert_called_once()
+        consumer.backend.client.connection_pool.reset.assert_called_once()
+
+    @patch('celery.backends.redis.ResultConsumer.cancel_for')
+    @patch('celery.backends.async.BaseResultConsumer.on_state_change')
+    def test_on_state_change(self, parent_method, cancel_for):
+        consumer = self.get_consumer()
+        meta = {'task_id': 'testing', 'status': states.SUCCESS}
+        message = 'hello'
+        consumer.on_state_change(meta, message)
+        parent_method.assert_called_once_with(meta, message)
+        cancel_for.assert_called_once_with(meta['task_id'])
+
+        # Does not call cancel_for for other states
+        meta = {'task_id': 'testing2', 'status': states.PENDING}
+        parent_method.reset_mock()
+        cancel_for.reset_mock()
+        consumer.on_state_change(meta, message)
+        parent_method.assert_called_once_with(meta, message)
+        cancel_for.assert_not_called()
+
+
 class test_RedisBackend:
     def get_backend(self):
         from celery.backends.redis import RedisBackend