Browse Source

Ensure threads/greenlets do not use the broker connection at the same time. Closes #2755

Ask Solem 9 years ago
parent
commit
7ec89a6bf0

+ 5 - 0
celery/tests/worker/test_control.py

@@ -48,6 +48,10 @@ class Consumer(consumer.Consumer):
         from celery.concurrency.base import BasePool
         self.pool = BasePool(10)
         self.task_buckets = defaultdict(lambda: None)
+        self.hub = None
+
+    def call_soon(self, p, *args, **kwargs):
+        return p(*args, **kwargs)
 
 
 class test_Pidbox(AppCase):
@@ -345,6 +349,7 @@ class test_ControlPanel(AppCase):
             queues = []
             cancelled = []
             consuming = False
+            hub = Mock(name='hub')
 
             def add_queue(self, queue):
                 self.queues.append(queue.name)

+ 28 - 5
celery/tests/worker/test_loops.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 import errno
 import socket
 
+from amqp import promise
 from kombu.async import Hub, READ, WRITE, ERR
 
 from celery.bootsteps import CLOSE, RUN
@@ -18,6 +19,22 @@ from celery.worker.loops import _quick_drain, asynloop, synloop
 from celery.tests.case import AppCase, Mock, task_message_from_sig
 
 
+class PromiseEqual(object):
+
+    def __init__(self, fun, *args, **kwargs):
+        self.fun = fun
+        self.args = args
+        self.kwargs = kwargs
+
+    def __eq__(self, other):
+        return (other.fun == self.fun and
+                other.args == self.args and
+                other.kwargs == self.kwargs)
+
+    def __repr__(self):
+        return '<promise: {0.fun!r} {0.args!r} {0.kwargs!r}>'.format(self)
+
+
 class X(object):
 
     def __init__(self, app, heartbeat=None, on_task_message=None,
@@ -61,7 +78,8 @@ class X(object):
         self.Hub = self.hub
         self.blueprint.state = RUN
         # need this for create_task_handler
-        _consumer = Consumer(Mock(), timer=Mock(), controller=Mock(), app=app)
+        self._consumer = _consumer = Consumer(
+            Mock(), timer=Mock(), controller=Mock(), app=app)
         _consumer.on_task_message = on_task_message or []
         self.obj.create_task_handler = _consumer.create_task_handler
         self.on_unknown_message = self.obj.on_unknown_message = Mock(
@@ -157,20 +175,25 @@ class test_asynloop(AppCase):
         return x, on_task, message, strategy
 
     def test_on_task_received(self):
-        _, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
+        x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
         on_task(msg)
         strategy.assert_called_with(
-            msg, None, msg.ack_log_error, msg.reject_log_error, [],
+            msg, None,
+            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
+            PromiseEqual(x._consumer.call_soon, msg.reject_log_error), [],
         )
 
     def test_on_task_received_executes_on_task_message(self):
         cbs = [Mock(), Mock(), Mock()]
-        _, on_task, msg, strategy = self.task_context(
+        x, on_task, msg, strategy = self.task_context(
             self.add.s(2, 2), on_task_message=cbs,
         )
         on_task(msg)
         strategy.assert_called_with(
-            msg, None, msg.ack_log_error, msg.reject_log_error, cbs,
+            msg, None,
+            PromiseEqual(x._consumer.call_soon, msg.ack_log_error),
+            PromiseEqual(x._consumer.call_soon, msg.reject_log_error),
+            cbs,
         )
 
     def test_on_task_message_missing_name(self):

+ 24 - 3
celery/worker/consumer.py

@@ -21,6 +21,7 @@ from heapq import heappush
 from operator import itemgetter
 from time import sleep
 
+from amqp.promise import ppartial, promise
 from billiard.common import restart_state
 from billiard.exceptions import RestartFreqExceeded
 from kombu.async.semaphore import DummyLock
@@ -213,12 +214,29 @@ class Consumer(object):
             # connect again.
             self.app.conf.broker_connection_timeout = None
 
+        self._pending_operations = []
+
         self.steps = []
         self.blueprint = self.Blueprint(
             app=self.app, on_close=self.on_close,
         )
         self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
 
+    def call_soon(self, p, *args, **kwargs):
+        p = ppartial(p, *args, **kwargs)
+        if self.hub:
+            return self.hub.call_soon(p)
+        self._pending_operations.append(p)
+        return p
+
+    def perform_pending_operations(self):
+        if not self.hub:
+            while self._pending_operations:
+                try:
+                    self._pending_operations.pop()()
+                except Exception as exc:
+                    error('Pending callback raised: %r', exc, exc_info=1)
+
     def bucket_for_task(self, type):
         limit = rate(getattr(type, 'rate_limit', None))
         return TokenBucket(limit, capacity=1) if limit else None
@@ -466,12 +484,13 @@ class Consumer(object):
             task.__trace__ = build_tracer(name, task, loader, self.hostname,
                                           app=self.app)
 
-    def create_task_handler(self):
+    def create_task_handler(self, promise=promise):
         strategies = self.strategies
         on_unknown_message = self.on_unknown_message
         on_unknown_task = self.on_unknown_task
         on_invalid_task = self.on_invalid_task
         callbacks = self.on_task_message
+        call_soon = self.call_soon
 
         def on_task_received(message):
             # payload will only be set for v1 protocol, since v2
@@ -497,8 +516,10 @@ class Consumer(object):
             else:
                 try:
                     strategy(
-                        message, payload, message.ack_log_error,
-                        message.reject_log_error, callbacks,
+                        message, payload,
+                        promise(call_soon, (message.ack_log_error,)),
+                        promise(call_soon, (message.reject_log_error,)),
+                        callbacks,
                     )
                 except InvalidTaskError as exc:
                     return on_invalid_task(payload, message, exc)

+ 7 - 3
celery/worker/control.py

@@ -345,14 +345,18 @@ def shutdown(state, msg='Got shutdown from remote', **kwargs):
 @Panel.register
 def add_consumer(state, queue, exchange=None, exchange_type=None,
                  routing_key=None, **options):
-    state.consumer.add_task_queue(queue, exchange, exchange_type,
-                                  routing_key, **options)
+    state.consumer.call_soon(
+        state.consumer.add_task_queue,
+        queue, exchange, exchange_type, routing_key, **options
+    )
     return {'ok': 'add consumer {0}'.format(queue)}
 
 
 @Panel.register
 def cancel_consumer(state, queue=None, **_):
-    state.consumer.cancel_task_queue(queue)
+    state.consumer.call_soon(
+        state.consumer.cancel_task_queue, queue,
+    )
     return {'ok': 'no longer consuming from {0}'.format(queue)}
 
 

+ 2 - 0
celery/worker/loops.py

@@ -104,6 +104,7 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
     """Fallback blocking event loop for transports that doesn't support AIO."""
 
     on_task_received = obj.create_task_handler()
+    perform_pending_operations = obj.perform_pending_operations
     consumer.on_message = on_task_received
     consumer.consume()
 
@@ -114,6 +115,7 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
         if qos.prev != qos.value:
             qos.update()
         try:
+            perform_pending_operations()
             connection.drain_events(timeout=2.0)
         except socket.timeout:
             pass