Browse Source

Event loop fixes

Ask Solem 11 years ago
parent
commit
ff42a7ff2d

+ 3 - 3
celery/concurrency/processes.py

@@ -219,7 +219,7 @@ class ResultHandler(_pool.ResultHandler):
                     raise CoroStop()
                 on_state_change(task)
 
-    def handle_event(self, fileno=None, event=None):
+    def handle_event(self, fileno):
         if self._state == RUN:
             it = self._it
             if it is None:
@@ -328,7 +328,7 @@ class AsynPool(_pool.Pool):
          for fd in self.process_sentinels]
         # Handle_result_event is called whenever one of the
         # result queues are readable.
-        [hub.add_reader(fd, self.handle_result_event)
+        [hub.add_reader(fd, self.handle_result_event, fd)
          for fd in self._fileno_to_outq]
 
         # Timers include calling maintain_pool at a regular interval
@@ -428,7 +428,7 @@ class AsynPool(_pool.Pool):
             add_reader(proc.sentinel, maintain_pool)
             # handle_result_event is called when the processes outqueue is
             # readable.
-            add_reader(proc.outqR_fd, handle_result_event)
+            add_reader(proc.outqR_fd, handle_result_event, proc.outqR_fd)
         self.on_process_up = on_process_up
 
         def on_process_down(proc):

+ 0 - 1
celery/tests/worker/test_autoscale.py

@@ -2,7 +2,6 @@ from __future__ import absolute_import
 
 import sys
 
-from collections import defaultdict
 from time import time
 
 from mock import Mock, patch

+ 2 - 2
celery/tests/worker/test_bootsteps.py

@@ -187,10 +187,10 @@ class test_StartStopStep(AppCase):
 
     def test_terminate(self):
         x = self.Def(self)
-        x.terminable = False
         x.create = Mock()
 
         x.include(self)
+        delattr(x.obj, 'terminate')
         x.terminate(self)
         x.obj.stop.assert_called_with()
 
@@ -230,7 +230,7 @@ class test_Blueprint(AppCase):
         blueprint.send_all = Mock()
         blueprint.close(1)
         blueprint.send_all.assert_called_with(
-            1, 'close', 'Closing', reverse=False,
+            1, 'close', 'closing', reverse=False,
         )
 
     def test_send_all_with_None_steps(self):

+ 4 - 9
celery/tests/worker/test_hub.py

@@ -246,16 +246,16 @@ class test_Hub(Case):
 
         read_A = Mock()
         read_B = Mock()
-        hub.add_reader(10, read_A)
-        hub.add_reader(File(11), read_B)
+        hub.add_reader(10, read_A, 10)
+        hub.add_reader(File(11), read_B, 11)
 
         P.register.assert_has_calls([
             call(10, hub.READ | hub.ERR),
             call(File(11), hub.READ | hub.ERR),
         ], any_order=True)
 
-        self.assertEqual(hub.readers[10], (read_A, ()))
-        self.assertEqual(hub.readers[11], (read_B, ()))
+        self.assertEqual(hub.readers[10], (read_A, (10, )))
+        self.assertEqual(hub.readers[11], (read_B, (11, )))
 
         hub.remove(10)
         self.assertNotIn(10, hub.readers)
@@ -306,15 +306,10 @@ class test_Hub(Case):
     def test_enter__exit(self):
         hub = Hub()
         P = hub.poller = Mock()
-        hub.init = Mock()
-
         on_close = Mock()
         hub.on_close.add(on_close)
 
-        hub.init()
         try:
-            hub.init.assert_called_with()
-
             read_A = Mock()
             read_B = Mock()
             hub.add_reader(10, read_A)

+ 15 - 17
celery/tests/worker/test_loops.py

@@ -2,16 +2,16 @@ from __future__ import absolute_import
 
 import socket
 
-from collections import defaultdict
 from mock import Mock
 
 from kombu.async import Hub, READ, WRITE, ERR
 
+from celery.bootsteps import CLOSE, RUN
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.five import Empty
 from celery.worker import state
 from celery.worker.consumer import Consumer
-from celery.worker.loops import asynloop, synloop, CLOSE
+from celery.worker.loops import asynloop, synloop
 
 from celery.tests.case import AppCase, body_from_sig
 
@@ -51,6 +51,7 @@ class X(object):
         self.hub.poller = Mock(name='hub.poller')
         self.hub.close = Mock(name='hub.close()')  # asynloop calls hub.close
         self.Hub = self.hub
+        self.blueprint.state = RUN
         # need this for create_task_handler
         _consumer = Consumer(Mock(), timer=Mock(), app=app)
         _consumer.on_task_message = on_task_message or []
@@ -227,41 +228,41 @@ class test_asynloop(AppCase):
     def test_poll_readable(self):
         x = X(self.app)
         reader = Mock(name='reader')
-        x.hub.add_reader(6, reader)
+        x.hub.add_reader(6, reader, 6)
         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
         x.hub.poller.poll.return_value = [(6, READ)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        reader.assert_called_with(6, READ)
+        reader.assert_called_with(6)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_readable_raises_Empty(self):
         x = X(self.app)
         reader = Mock(name='reader')
-        x.hub.add_reader(6, reader)
+        x.hub.add_reader(6, reader, 6)
         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, READ)]
         reader.side_effect = Empty()
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        reader.assert_called_with(6, READ)
+        reader.assert_called_with(6)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable(self):
         x = X(self.app)
         writer = Mock(name='writer')
-        x.hub.add_writer(6, writer)
+        x.hub.add_writer(6, writer, 6)
         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, WRITE)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        writer.assert_called_with(6, WRITE)
+        writer.assert_called_with(6)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable_none_registered(self):
         x = X(self.app)
         writer = Mock(name='writer')
-        x.hub.add_writer(6, writer)
+        x.hub.add_writer(6, writer, 6)
         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(7, WRITE)]
         with self.assertRaises(socket.error):
@@ -271,7 +272,7 @@ class test_asynloop(AppCase):
     def test_poll_unknown_event(self):
         x = X(self.app)
         writer = Mock(name='reader')
-        x.hub.add_writer(6, writer)
+        x.hub.add_writer(6, writer, 6)
         x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, 0)]
         with self.assertRaises(socket.error):
@@ -287,23 +288,20 @@ class test_asynloop(AppCase):
             poll.side_effect = socket.error()
         poll.side_effect = se
 
-        x.connection.transport.nb_keep_draining = False
-        x.close_then_error(x.connection.drain_nowait)
         x.hub.poller.poll.return_value = [(6, 0)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         self.assertTrue(x.hub.poller.poll.called)
-        self.assertFalse(x.connection.drain_nowait.called)
 
     def test_poll_err_writable(self):
         x = X(self.app)
         writer = Mock(name='writer')
-        x.hub.add_writer(6, writer, 48)
+        x.hub.add_writer(6, writer, 6, 48)
         x.hub.on_tick.add(x.close_then_error(Mock(), 2))
         x.hub.poller.poll.return_value = [(6, ERR)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        writer.assert_called_with(6, ERR, 48)
+        writer.assert_called_with(6, 48)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_write_generator(self):
@@ -358,12 +356,12 @@ class test_asynloop(AppCase):
     def test_poll_err_readable(self):
         x = X(self.app)
         reader = Mock(name='reader')
-        x.hub.add_reader(6, reader, 24)
+        x.hub.add_reader(6, reader, 6, 24)
         x.hub.on_tick.add(x.close_then_error(Mock(), 2))
         x.hub.poller.poll.return_value = [(6, ERR)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        reader.assert_called_with(6, ERR, 24)
+        reader.assert_called_with(6, 24)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_raises_ValueError(self):

+ 14 - 7
celery/tests/worker/test_worker.py

@@ -8,15 +8,13 @@ from datetime import datetime, timedelta
 from threading import Event
 
 from amqp import ChannelError
-from billiard.exceptions import WorkerLostError
 from kombu import Connection
-from kombu.async import READ, ERR
 from kombu.common import QoS, ignore_errors
 from kombu.transport.base import Message
-from mock import call, Mock, patch
+from mock import Mock, patch
 
 from celery.app.defaults import DEFAULTS
-from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
+from celery.bootsteps import RUN, CLOSE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate, TaskRevokedError
@@ -199,6 +197,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
@@ -213,6 +212,7 @@ class test_Consumer(AppCase):
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.steps.pop()
         m = create_message(Mock(), task=self.foo_task.name,
                            args=('2, 2'),
@@ -230,6 +230,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.event_dispatcher = Mock()
         l.steps.pop()
         m = create_message(Mock(), task=self.foo_task.name,
@@ -270,6 +271,7 @@ class test_Consumer(AppCase):
 
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.event_dispatcher = Mock()
         m = create_message(Mock(), task=self.foo_task.name,
                            args=[2, 4, 8], kwargs={})
@@ -366,6 +368,7 @@ class test_Consumer(AppCase):
                 self.obj.connection = None
 
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
@@ -406,6 +409,7 @@ class test_Consumer(AppCase):
 
     def test_receieve_message_eta_isoformat(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.steps.pop()
         m = create_message(
             Mock(), task=self.foo_task.name,
@@ -455,6 +459,7 @@ class test_Consumer(AppCase):
 
     def test_revoke(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.steps.pop()
         backend = Mock()
         id = uuid()
@@ -469,6 +474,7 @@ class test_Consumer(AppCase):
 
     def test_receieve_message_not_registered(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
@@ -484,6 +490,7 @@ class test_Consumer(AppCase):
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
+        l.blueprint.state = RUN
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
@@ -884,7 +891,6 @@ class test_WorkController(AppCase):
         worker.blueprint.state = RUN
         with self.assertRaises(KeyboardInterrupt):
             worker._process_task(task)
-        self.assertEqual(worker.blueprint.state, TERMINATE)
 
     def test_process_task_raise_SystemTerminate(self):
         worker = self.worker
@@ -898,7 +904,6 @@ class test_WorkController(AppCase):
         worker.blueprint.state = RUN
         with self.assertRaises(SystemExit):
             worker._process_task(task)
-        self.assertEqual(worker.blueprint.state, TERMINATE)
 
     def test_process_task_raise_regular(self):
         worker = self.worker
@@ -913,6 +918,7 @@ class test_WorkController(AppCase):
 
     def test_start_catches_base_exceptions(self):
         worker1 = self.create_worker()
+        worker1.blueprint.state = RUN
         stc = MockStep()
         stc.start.side_effect = SystemTerminate()
         worker1.steps = [stc]
@@ -921,6 +927,7 @@ class test_WorkController(AppCase):
         self.assertTrue(stc.terminate.call_count)
 
         worker2 = self.create_worker()
+        worker2.blueprint.state = RUN
         sec = MockStep()
         sec.start.side_effect = SystemExit()
         sec.terminate = None
@@ -1023,7 +1030,7 @@ class test_WorkController(AppCase):
     def test_Hub_crate(self):
         w = Mock()
         x = components.Hub(w)
-        hub = x.create(w)
+        x.create(w)
         self.assertTrue(w.timer.max_interval)
 
     def test_Pool_crate_threaded(self):

+ 0 - 2
celery/worker/autoscale.py

@@ -16,7 +16,6 @@ from __future__ import absolute_import
 import os
 import threading
 
-from functools import partial
 from time import sleep, time
 
 from kombu.async.semaphore import DummyLock
@@ -51,7 +50,6 @@ class WorkerComponent(bootsteps.StartStopStep):
             w.pool, w.max_concurrency, w.min_concurrency,
             mutex=DummyLock() if w.use_eventloop else None,
         )
-        print('HELLO')
         return scaler if not w.use_eventloop else None
 
     def register_with_event_loop(self, w, hub):

+ 1 - 1
celery/worker/components.py

@@ -71,7 +71,7 @@ class Hub(bootsteps.StartStopStep):
         pass
 
     def stop(self, w):
-        w.hub.close()
+        pass
 
     def _patch_thread_primitives(self, w):
         # make clock use dummy lock

+ 1 - 1
celery/worker/consumer.py

@@ -704,7 +704,7 @@ class Mingle(bootsteps.StartStopStep):
     def start(self, c):
         info('mingle: searching for neighbors')
         I = c.app.control.inspect(timeout=1.0, connection=c.connection)
-        replies = I.hello(c.hostname, revoked._data)
+        replies = I.hello(c.hostname, revoked._data) or {}
         replies.pop(c.hostname, None)
         if replies:
             info('mingle: hello %s! sync with me',

+ 10 - 10
celery/worker/loops.py

@@ -9,11 +9,8 @@ from __future__ import absolute_import
 
 import socket
 
-from time import sleep
-
-from celery.bootsteps import CLOSE
+from celery.bootsteps import RUN
 from celery.exceptions import SystemTerminate, WorkerLostError
-from celery.five import Empty
 from celery.utils.log import get_logger
 
 from . import state
@@ -25,8 +22,7 @@ error = logger.error
 
 
 def asynloop(obj, connection, consumer, blueprint, hub, qos,
-             heartbeat, clock, hbrate=2.0,
-             sleep=sleep, min=min, Empty=Empty):
+             heartbeat, clock, hbrate=2.0, RUN=RUN):
     """Non-blocking event loop consuming messages until connection is lost,
     or shutdown is requested."""
 
@@ -55,7 +51,7 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
     loop = hub._loop(propagate=errors)
 
     try:
-        while blueprint.state != CLOSE and obj.connection:
+        while blueprint.state == RUN and obj.connection:
             # shutdown if signal handlers told us to.
             if state.should_stop:
                 raise SystemExit()
@@ -67,7 +63,11 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
             # control commands will be prioritized over task messages.
             if qos.prev != qos.value:
                 update_qos()
-            next(loop, None)
+
+            try:
+                next(loop)
+            except StopIteration:
+                loop = hub._loop(propagate=errors)
     finally:
         try:
             hub.close()
@@ -87,7 +87,7 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
 
     obj.on_ready()
 
-    while blueprint.state != CLOSE and obj.connection:
+    while blueprint.state == RUN and obj.connection:
         state.maybe_shutdown()
         if qos.prev != qos.value:
             qos.update()
@@ -96,5 +96,5 @@ def synloop(obj, connection, consumer, blueprint, hub, qos,
         except socket.timeout:
             pass
         except socket.error:
-            if blueprint.state != CLOSE:
+            if blueprint.state == RUN:
                 raise

+ 7 - 22
docs/userguide/extending.rst

@@ -641,11 +641,10 @@ will take some time so other transports still use a threading-based solution.
 
 .. method:: hub.add(fd, callback, flags)
 
-    Add callback for fd with custom flags, which can be any combination of
-    :data:`~kombu.utils.eventio.READ`, :data:`~kombu.utils.eventio.WRITE`,
-    and :data:`~kombu.utils.eventio.ERR`, the callback will then be called
-    whenever the condition specified in flags is true (readable,
-    writeable, or error).
+
+.. method:: hub.add_reader(fd, callback, \*args)
+
+    Add callback to be called when ``fd`` is readable.
 
     The callback will stay registered until explictly removed using
     :meth:`hub.remove(fd) <hub.remove>`, or the fd is automatically discarded
@@ -655,32 +654,18 @@ will take some time so other transports still use a threading-based solution.
     so calling ``add`` a second time will remove any callback that
     was previously registered for that fd.
 
-    ``fd`` may also be a list of file descriptors, in this case the
-    callback will be registered for all of the fds in this list.
-
     A file descriptor is any file-like object that supports the ``fileno``
     method, or it can be the file descriptor number (int).
 
-.. method:: hub.add_reader(fd, callback)
-
-    Shortcut to ``hub.add(fd, callback, READ | ERR)``.
-
-.. method:: hub.add_writer(fd, callback)
+.. method:: hub.add_writer(fd, callback, \*args)
 
-    Shortcut to ``hub.add(fd, callback, WRITE)``.
+    Add callback to be called when ``fd`` is writable.
+    See also notes for :meth:`hub.add_reader` above.
 
 .. method:: hub.remove(fd)
 
     Remove all callbacks for ``fd`` from the loop.
 
-.. method:: hub.update_readers(fd, mapping)
-
-    Shortcut to add callbacks from a map of ``{fd: callback}`` items.
-
-.. method:: hub.update_writers(fd, mapping)
-
-    Shortcut to add callbacks from a map of ``{fd: callback}`` items.
-
 Timer - Scheduling events
 -------------------------