Browse Source

Do not subscribe to Redis channels when results are ignored (#4709)

* Add ignored property to AsyncResult

* Set AsyncResult as ignored on task creation & bypass backend callback

* Cancel subscription when calling AsyncResult.forget

* Cancel pending result operations on AsyncResult destruction

* Integration tests for Redis channel unsubscriptions (#4707)

* Handle KeyError when resetting Redis client objects

https://travis-ci.org/celery/celery/jobs/375637202
George Psarakis 6 years ago
parent
commit
2636251a12

+ 10 - 1
celery/app/base.py

@@ -711,6 +711,8 @@ class Celery(object):
             warnings.warn(AlwaysEagerIgnored(
                 'task_always_eager has no effect on send_task',
             ), stacklevel=2)
+
+        ignored_result = options.pop('ignore_result', False)
         options = router.route(
             options, route_name or name, args, kwargs, task_type)
 
@@ -735,11 +737,18 @@ class Celery(object):
 
         if connection:
             producer = amqp.Producer(connection, auto_declare=False)
+
         with self.producer_or_acquire(producer) as P:
             with P.connection._reraise_as_library_errors():
-                self.backend.on_task_call(P, task_id)
+                if not ignored_result:
+                    self.backend.on_task_call(P, task_id)
                 amqp.send_task_message(P, name, message, **options)
         result = (result_cls or self.AsyncResult)(task_id)
+        # We avoid using the constructor since a custom result class
+        # can be used, in which case the constructor may still use
+        # the old signature.
+        result.ignored = ignored_result
+
         if add_to_parent:
             if not have_parent:
                 parent, have_parent = self.current_worker_task, True

+ 3 - 0
celery/app/task.py

@@ -525,6 +525,9 @@ class Task(object):
 
         preopts = self._get_exec_options()
         options = dict(preopts, **options) if options else preopts
+
+        options.setdefault('ignore_result', self.ignore_result)
+
         return app.send_task(
             self.name, args, kwargs, task_id=task_id, producer=producer,
             link=link, link_error=link_error, result_cls=self.AsyncResult,

+ 11 - 4
celery/backends/redis.py

@@ -13,7 +13,7 @@ from celery import states
 from celery._state import task_join_will_block
 from celery.canvas import maybe_signature
 from celery.exceptions import ChordError, ImproperlyConfigured
-from celery.five import string_t
+from celery.five import string_t, text_t
 from celery.utils import deprecated
 from celery.utils.functional import dictfilter
 from celery.utils.log import get_logger
@@ -83,9 +83,12 @@ class ResultConsumer(async.BaseResultConsumer):
         self.subscribed_to = set()
 
     def on_after_fork(self):
-        self.backend.client.connection_pool.reset()
-        if self._pubsub is not None:
-            self._pubsub.close()
+        try:
+            self.backend.client.connection_pool.reset()
+            if self._pubsub is not None:
+                self._pubsub.close()
+        except KeyError as e:
+            logger.warn(text_t(e))
         super(ResultConsumer, self).on_after_fork()
 
     def _maybe_cancel_ready_task(self, meta):
@@ -287,6 +290,10 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             pipe.publish(key, value)
             pipe.execute()
 
+    def forget(self, task_id):
+        super(RedisBackend, self).forget(task_id)
+        self.result_consumer.cancel_for(task_id)
+
     def delete(self, key):
         self.client.delete(key)
 

+ 21 - 0
celery/result.py

@@ -101,6 +101,19 @@ class AsyncResult(ResultBase):
         self.parent = parent
         self.on_ready = promise(self._on_fulfilled)
         self._cache = None
+        self._ignored = False
+
+    @property
+    def ignored(self):
+        """"If True, task result retrieval is disabled."""
+        if hasattr(self, '_ignored'):
+            return self._ignored
+        return False
+
+    @ignored.setter
+    def ignored(self, value):
+        """Enable/disable task result retrieval."""
+        self._ignored = value
 
     def then(self, callback, on_error=None, weak=False):
         self.backend.add_pending_result(self, weak=weak)
@@ -183,6 +196,9 @@ class AsyncResult(ResultBase):
             Exception: If the remote call raised an exception then that
                 exception will be re-raised in the caller process.
         """
+        if self.ignored:
+            return
+
         if disable_sync_subtasks:
             assert_will_not_block()
         _on_interval = promise()
@@ -363,6 +379,11 @@ class AsyncResult(ResultBase):
     def __reduce_args__(self):
         return self.id, self.backend, None, None, self.parent
 
+    def __del__(self):
+        """Cancel pending operations when the instance is destroyed."""
+        if self.backend is not None:
+            self.backend.remove_pending_result(self)
+
     @cached_property
     def graph(self):
         return self.build_graph()

+ 4 - 0
t/integration/conftest.py

@@ -29,6 +29,10 @@ def get_redis_connection():
     return StrictRedis(host=os.environ.get('REDIS_HOST'))
 
 
+def get_active_redis_channels():
+    return get_redis_connection().execute_command('PUBSUB CHANNELS')
+
+
 @pytest.fixture(scope='session')
 def celery_config():
     return {

+ 6 - 0
t/integration/tasks.py

@@ -23,6 +23,12 @@ def add(x, y):
     return x + y
 
 
+@shared_task(ignore_result=True)
+def add_ignore_result(x, y):
+    """Add two numbers."""
+    return x + y
+
+
 @shared_task
 def chain_add(x, y):
     (

+ 5 - 11
t/integration/test_canvas.py

@@ -8,7 +8,7 @@ from celery import chain, chord, group
 from celery.exceptions import TimeoutError
 from celery.result import AsyncResult, GroupResult, ResultSet
 
-from .conftest import flaky, get_redis_connection
+from .conftest import flaky, get_active_redis_channels, get_redis_connection
 from .tasks import (add, add_chord_to_chord, add_replaced, add_to_all,
                     add_to_all_to_chord, collect_ids, delayed_sum,
                     delayed_sum_with_soft_guard, identity, ids, print_unicode,
@@ -255,19 +255,13 @@ def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
 
 class test_chord:
 
-    @staticmethod
-    def _get_active_redis_channels(client):
-        return client.execute_command('PUBSUB CHANNELS')
-
     @flaky
     def test_redis_subscribed_channels_leak(self, manager):
         if not manager.app.conf.result_backend.startswith('redis'):
             raise pytest.skip('Requires redis result backend.')
 
-        redis_client = get_redis_connection()
-
         manager.app.backend.result_consumer.on_after_fork()
-        initial_channels = self._get_active_redis_channels(redis_client)
+        initial_channels = get_active_redis_channels()
         initial_channels_count = len(initial_channels)
 
         total_chords = 10
@@ -278,7 +272,7 @@ class test_chord:
 
         manager.assert_result_tasks_in_progress_or_completed(async_results)
 
-        channels_before = self._get_active_redis_channels(redis_client)
+        channels_before = get_active_redis_channels()
         channels_before_count = len(channels_before)
 
         assert set(channels_before) != set(initial_channels)
@@ -289,7 +283,7 @@ class test_chord:
         # total chord tasks, plus the initial channels
         # (existing from previous tests).
         chord_header_task_count = 2
-        assert channels_before_count == \
+        assert channels_before_count <= \
             chord_header_task_count * total_chords + initial_channels_count
 
         result_values = [
@@ -298,7 +292,7 @@ class test_chord:
         ]
         assert result_values == [24] * total_chords
 
-        channels_after = self._get_active_redis_channels(redis_client)
+        channels_after = get_active_redis_channels()
         channels_after_count = len(channels_after)
 
         assert channels_after_count == initial_channels_count

+ 32 - 2
t/integration/test_tasks.py

@@ -1,9 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from celery import group
 
-from .conftest import flaky
-from .tasks import print_unicode, retry_once, sleeping
+from .conftest import flaky, get_active_redis_channels
+from .tasks import add, add_ignore_result, print_unicode, retry_once, sleeping
 
 
 class test_tasks:
@@ -25,3 +27,31 @@ class test_tasks:
             group(print_unicode.s() for _ in range(5))(),
             timeout=10, propagate=True,
         )
+
+
+class tests_task_redis_result_backend:
+    def setup(self, manager):
+        if not manager.app.conf.result_backend.startswith('redis'):
+            raise pytest.skip('Requires redis result backend.')
+
+    def test_ignoring_result_no_subscriptions(self):
+        assert get_active_redis_channels() == []
+        result = add_ignore_result.delay(1, 2)
+        assert result.ignored is True
+        assert get_active_redis_channels() == []
+
+    def test_asyncresult_forget_cancels_subscription(self):
+        result = add.delay(1, 2)
+        assert get_active_redis_channels() == [
+            "celery-task-meta-{}".format(result.id)
+        ]
+        result.forget()
+        assert get_active_redis_channels() == []
+
+    def test_asyncresult_get_cancels_subscription(self):
+        result = add.delay(1, 2)
+        assert get_active_redis_channels() == [
+            "celery-task-meta-{}".format(result.id)
+        ]
+        assert result.get(timeout=3) == 3
+        assert get_active_redis_channels() == []

+ 8 - 0
t/unit/backends/test_redis.py

@@ -163,6 +163,14 @@ class test_RedisResultConsumer:
         parent_method.assert_called_once()
         consumer.backend.client.connection_pool.reset.assert_called_once()
 
+        # Continues on KeyError
+        consumer._pubsub = Mock()
+        consumer._pubsub.close = Mock(side_effect=KeyError)
+        parent_method.reset_mock()
+        consumer.backend.client.connection_pool.reset.reset_mock()
+        consumer.on_after_fork()
+        parent_method.assert_called_once()
+
     @patch('celery.backends.redis.ResultConsumer.cancel_for')
     @patch('celery.backends.async.BaseResultConsumer.on_state_change')
     def test_on_state_change(self, parent_method, cancel_for):

+ 26 - 0
t/unit/tasks/test_result.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
+import copy
 import traceback
 from contextlib import contextmanager
 
@@ -82,6 +83,12 @@ class test_AsyncResult:
             pass
         self.mytask = mytask
 
+    def test_ignored_getter(self):
+        result = self.app.AsyncResult(uuid())
+        assert result.ignored is False
+        result.__delattr__('_ignored')
+        assert result.ignored is False
+
     @patch('celery.result.task_join_will_block')
     def test_assert_will_not_block(self, task_join_will_block):
         task_join_will_block.return_value = True
@@ -324,6 +331,12 @@ class test_AsyncResult:
         assert isinstance(nok2_res.result, KeyError)
         assert ok_res.info == 'the'
 
+    def test_get_when_ignored(self):
+        result = self.app.AsyncResult(uuid())
+        result.ignored = True
+        # Does not block
+        assert result.get() is None
+
     def test_eq_ne(self):
         r1 = self.app.AsyncResult(self.task1['id'])
         r2 = self.app.AsyncResult(self.task1['id'])
@@ -366,6 +379,19 @@ class test_AsyncResult:
 
         assert not self.app.AsyncResult(uuid()).ready()
 
+    def test_del(self):
+        with patch('celery.result.AsyncResult.backend') as backend:
+            result = self.app.AsyncResult(self.task1['id'])
+            result_clone = copy.copy(result)
+            del result
+            assert backend.remove_pending_result.called_once_with(
+                result_clone
+            )
+
+        result = self.app.AsyncResult(self.task1['id'])
+        result.backend = None
+        del result
+
 
 class test_ResultSet:
 

+ 66 - 3
t/unit/tasks/test_tasks.py

@@ -153,7 +153,13 @@ class TasksCase:
 
         self.task_check_request_context = task_check_request_context
 
-        # memove all messages from memory-transport
+        @self.app.task(ignore_result=True)
+        def task_with_ignored_result():
+            pass
+
+        self.task_with_ignored_result = task_with_ignored_result
+
+        # Remove all messages from memory-transport
         from kombu.transport.memory import Channel
         Channel.queues.clear()
 
@@ -391,7 +397,8 @@ class test_tasks(TasksCase):
                                                    task_id=ANY,
                                                    task_type=ANY,
                                                    time_limit=ANY,
-                                                   shadow='fooxyz')
+                                                   shadow='fooxyz',
+                                                   ignore_result=False)
 
         self.app.send_task = old_send_task
 
@@ -427,7 +434,8 @@ class test_tasks(TasksCase):
                                                    task_id=ANY,
                                                    task_type=ANY,
                                                    time_limit=ANY,
-                                                   shadow='fooxyz')
+                                                   shadow='fooxyz',
+                                                   ignore_result=False)
 
         self.app.send_task = old_send_task
 
@@ -789,3 +797,58 @@ class test_apply_task(TasksCase):
         assert f.traceback
         with pytest.raises(KeyError):
             f.get()
+
+
+class test_apply_async(TasksCase):
+    def common_send_task_arguments(self):
+        return (ANY, ANY, ANY), dict(
+            compression=ANY,
+            delivery_mode=ANY,
+            exchange=ANY,
+            expires=ANY,
+            immediate=ANY,
+            link=ANY,
+            link_error=ANY,
+            mandatory=ANY,
+            priority=ANY,
+            producer=ANY,
+            queue=ANY,
+            result_cls=ANY,
+            routing_key=ANY,
+            serializer=ANY,
+            soft_time_limit=ANY,
+            task_id=ANY,
+            task_type=ANY,
+            time_limit=ANY,
+            shadow=None,
+            ignore_result=False
+        )
+
+    def test_task_with_ignored_result(self):
+        with patch.object(self.app, 'send_task') as send_task:
+            self.task_with_ignored_result.apply_async()
+            expected_args, expected_kwargs = self.common_send_task_arguments()
+            expected_kwargs['ignore_result'] = True
+            send_task.assert_called_once_with(
+                *expected_args,
+                **expected_kwargs
+            )
+
+    def test_task_with_result(self):
+        with patch.object(self.app, 'send_task') as send_task:
+            self.mytask.apply_async()
+            expected_args, expected_kwargs = self.common_send_task_arguments()
+            send_task.assert_called_once_with(
+                *expected_args,
+                **expected_kwargs
+            )
+
+    def test_task_with_result_ignoring_on_call(self):
+        with patch.object(self.app, 'send_task') as send_task:
+            self.mytask.apply_async(ignore_result=True)
+            expected_args, expected_kwargs = self.common_send_task_arguments()
+            expected_kwargs['ignore_result'] = True
+            send_task.assert_called_once_with(
+                *expected_args,
+                **expected_kwargs
+            )