Ask Solem 8 years ago
parent
commit
4fdb2c0aa1
3 changed files with 10 additions and 6 deletions
  1. 1 1
      celery/worker/consumer/consumer.py
  2. 8 5
      celery/worker/loops.py
  3. 1 0
      t/unit/worker/test_worker.py

+ 1 - 1
celery/worker/consumer/consumer.py

@@ -194,7 +194,7 @@ class Consumer(object):
         self.reset_rate_limits()
 
         self.hub = hub
-        if self.hub:
+        if self.hub or getattr(self.pool, 'is_green', False):
             self.amqheartbeat = amqheartbeat
             if self.amqheartbeat is None:
                 self.amqheartbeat = self.app.conf.broker_heartbeat

+ 8 - 5
celery/worker/loops.py

@@ -28,10 +28,11 @@ def _quick_drain(connection, timeout=0.1):
 
 
 def _enable_amqheartbeats(timer, connection, rate=2.0):
-    tick = connection.heartbeat_check
-    heartbeat = connection.get_heartbeat_interval()  # negotiated
-    if heartbeat and connection.supports_heartbeats:
-        timer.call_repeatedly(heartbeat / rate, tick, rate)
+    if connection:
+        tick = connection.heartbeat_check
+        heartbeat = connection.get_heartbeat_interval()  # negotiated
+        if heartbeat and connection.supports_heartbeats:
+            timer.call_repeatedly(heartbeat / rate, tick, (rate,))
 
 
 def asynloop(obj, connection, consumer, blueprint, hub, qos,
@@ -43,7 +44,7 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
 
     on_task_received = obj.create_task_handler()
 
-    _enable_amqheartbeats(hub, connection, rate=hbrate)
+    _enable_amqheartbeats(hub.timer, connection, rate=hbrate)
 
     consumer.on_message = on_task_received
     consumer.consume()
@@ -104,6 +105,8 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
     RUN = bootsteps.RUN
     on_task_received = obj.create_task_handler()
     perform_pending_operations = obj.perform_pending_operations
+    if getattr(obj.pool, 'is_green', False):
+        _enable_amqheartbeats(obj.timer, connection, rate=hbrate)
     consumer.on_message = on_task_received
     consumer.consume()
 

+ 1 - 0
t/unit/worker/test_worker.py

@@ -251,6 +251,7 @@ class test_Consumer(ConsumerCase):
         c.task_consumer = Mock()
         c.event_dispatcher = mock_event_dispatcher()
         c.connection = Mock(name='.connection')
+        c.connection.get_heartbeat_interval.return_value = 0
         c.connection.drain_events.side_effect = WorkerShutdown()
 
         with pytest.raises(WorkerShutdown):