Browse Source

Tests passing

Ask Solem 11 years ago
parent
commit
e289fd8b78

+ 1 - 4
celery/concurrency/base.py

@@ -70,10 +70,7 @@ class BasePool(object):
     def on_stop(self):
         pass
 
-    def register_with_event_loop(self, worker, hub):
-        pass
-
-    def on_poll_start(self, hub):
+    def register_with_event_loop(self, loop):
         pass
 
     def on_apply(self, *args, **kwargs):

+ 4 - 6
celery/concurrency/processes.py

@@ -347,13 +347,12 @@ class AsynPool(_pool.Pool):
         def on_timeout_set(R, soft, hard):
             if soft:
                 trefs[R._job] = call_later(
-                    soft * 1000.0,
-                    self._on_soft_timeout, (R._job, soft, hard, hub),
+                    soft * 1000.0, self._on_soft_timeout,
+                    R._job, soft, hard, hub,
                 )
             elif hard:
                 trefs[R._job] = call_later(
-                    hard * 1000.0,
-                    self._on_hard_timeout, (R._job, )
+                    hard * 1000.0, self._on_hard_timeout, R._job,
                 )
         self.on_timeout_set = on_timeout_set
 
@@ -374,8 +373,7 @@ class AsynPool(_pool.Pool):
         # only used by async pool.
         if hard:
             self._tref_for_id[job] = hub.call_at(
-                now() + (hard - soft),
-                self._on_hard_timeout, (job, ),
+                now() + (hard - soft), self._on_hard_timeout, job,
             )
         try:
             result = self._cache[job]

+ 1 - 4
celery/tests/concurrency/test_concurrency.py

@@ -92,12 +92,9 @@ class test_BasePool(AppCase):
 
     def test_interface_register_with_event_loop(self):
         self.assertIsNone(
-            BasePool(10).register_with_event_loop(Mock(), Mock()),
+            BasePool(10).register_with_event_loop(Mock()),
         )
 
-    def test_interface_on_poll_start(self):
-        self.assertIsNone(BasePool(10).on_poll_start(Mock()))
-
     def test_interface_on_soft_timeout(self):
         self.assertIsNone(BasePool(10).on_soft_timeout(Mock()))
 

+ 13 - 4
celery/tests/concurrency/test_processes.py

@@ -101,6 +101,9 @@ class MockPool(object):
     def apply_async(self, *args, **kwargs):
         pass
 
+    def register_with_event_loop(self, loop):
+        pass
+
 
 class ExeMockPool(MockPool):
 
@@ -303,10 +306,16 @@ class test_TaskPool(PoolCase):
     def test_info(self):
         pool = TaskPool(10)
         procs = [Object(pid=i) for i in range(pool.limit)]
-        pool._pool = Object(_pool=procs,
-                            _maxtasksperchild=None,
-                            timeout=10,
-                            soft_timeout=5)
+
+        class _Pool(object):
+            _pool = procs
+            _maxtasksperchild = None
+            timeout = 10
+            soft_timeout = 5
+
+            def human_write_stats(self, *args, **kwargs):
+                return {}
+        pool._pool = _Pool()
         info = pool.info
         self.assertEqual(info['max-concurrency'], pool.limit)
         self.assertEqual(info['max-tasks-per-child'], 'N/A')

+ 1 - 7
celery/tests/worker/test_consumer.py

@@ -130,13 +130,7 @@ class test_Consumer(AppCase):
 
     def test_register_with_event_loop(self):
         c = self.get_consumer()
-        c.connection = Mock()
-        c.connection.eventmap = {1: 2}
-        hub = Mock()
-        c.register_with_event_loop(hub)
-
-        hub.add_reader.assert_called_with(1, 2)
-        c.connection.transport.register_with_event_loop.assert_called_with(hub)
+        c.register_with_event_loop(Mock(name='loop'))
 
     def test_on_close_clears_semaphore_timer_and_reqs(self):
         with patch('celery.worker.consumer.reserved_requests') as reserved:

+ 70 - 44
celery/tests/worker/test_loops.py

@@ -5,18 +5,21 @@ import socket
 from collections import defaultdict
 from mock import Mock
 
+from kombu.async import Hub, READ, WRITE, ERR
+
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.five import Empty
 from celery.worker import state
 from celery.worker.consumer import Consumer
-from celery.worker.loops import asynloop, synloop, CLOSE, READ, WRITE, ERR
+from celery.worker.loops import asynloop, synloop, CLOSE
 
 from celery.tests.case import AppCase, body_from_sig
 
 
 class X(object):
 
-    def __init__(self, app, heartbeat=None, on_task=None):
+    def __init__(self, app, heartbeat=None, on_task_message=None):
+        hub = Hub()
         (
             self.obj,
             self.connection,
@@ -30,7 +33,7 @@ class X(object):
                          Mock(name='connection'),
                          Mock(name='consumer'),
                          Mock(name='blueprint'),
-                         Mock(name='Hub'),
+                         hub,
                          Mock(name='qos'),
                          heartbeat,
                          Mock(name='clock')]
@@ -38,16 +41,19 @@ class X(object):
         self.consumer.callbacks = []
         self.obj.strategies = {}
         self.connection.connection_errors = (socket.error, )
-        #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
-        #self.Hub.__exit__ = Mock(name='Hub.__exit__')
-        #self.hub = hent.return_value = Mock(name='hub_context')
         self.hub.readers = {}
         self.hub.writers = {}
         self.hub.consolidate = set()
+        self.hub.timer = Mock(name='hub.timer')
+        self.hub.timer._queue = [Mock()]
+        self.hub.fire_timers = Mock(name='hub.fire_timers')
         self.hub.fire_timers.return_value = 1.7
+        self.hub.poller = Mock(name='hub.poller')
+        self.hub.close = Mock(name='hub.close()')  # asynloop calls hub.close
         self.Hub = self.hub
         # need this for create_task_handler
         _consumer = Consumer(Mock(), timer=Mock(), app=app)
+        _consumer.on_task_message = on_task_message or []
         self.obj.create_task_handler = _consumer.create_task_handler
         self.on_unknown_message = self.obj.on_unknown_message = Mock(
             name='on_unknown_message',
@@ -71,21 +77,27 @@ class X(object):
             raise socket.timeout()
         mock.side_effect = first
 
-    def close_then_error(self, mock, mod=0):
+    def close_then_error(self, mock=None, mod=0, exc=None):
+        mock = Mock() if mock is None else mock
 
         def first(*args, **kwargs):
             if not mod or mock.call_count > mod:
                 self.close()
                 self.connection.more_to_read = False
-                raise socket.error()
+                raise (socket.error() if exc is None else exc)
         mock.side_effect = first
+        return mock
 
     def close(self, *args, **kwargs):
         self.blueprint.state = CLOSE
 
-    def closer(self, mock=None):
+    def closer(self, mock=None, mod=0):
         mock = Mock() if mock is None else mock
-        mock.side_effect = self.close
+
+        def closing(*args, **kwargs):
+            if not mod or mock.call_count >= mod:
+                self.close()
+        mock.side_effect = closing
         return mock
 
 
@@ -107,12 +119,13 @@ class test_asynloop(AppCase):
 
     def test_setup_heartbeat(self):
         x = X(self.app, heartbeat=10)
+        x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
         x.blueprint.state = CLOSE
         asynloop(*x.args)
         x.consumer.consume.assert_called_with()
         x.obj.on_ready.assert_called_with()
         x.hub.call_repeatedly.assert_called_with(
-            10 / 2.0, x.connection.heartbeat_check, (2.0, ),
+            10 / 2.0, x.connection.heartbeat_check, 2.0,
         )
 
     def task_context(self, sig, **kwargs):
@@ -127,10 +140,10 @@ class test_asynloop(AppCase):
         on_task(body, msg)
         strategy.assert_called_with(msg, body, msg.ack_log_error)
 
-    def test_on_task_received_executes_hub_on_task(self):
+    def test_on_task_received_executes_on_task_message(self):
         cbs = [Mock(), Mock(), Mock()]
         _, on_task, body, msg, _ = self.task_context(
-            self.add.s(2, 2), on_task=cbs,
+            self.add.s(2, 2), on_task_message=cbs,
         )
         on_task(body, msg)
         [cb.assert_called_with() for cb in cbs]
@@ -187,20 +200,24 @@ class test_asynloop(AppCase):
         x = X(self.app)
         x.qos.prev = 3
         x.qos.value = 3
-        asynloop(*x.args, sleep=x.closer())
+        x.hub.on_tick.add(x.closer(mod=2))
+        x.hub.timer._queue = [1]
+        asynloop(*x.args)
         self.assertFalse(x.qos.update.called)
 
         x = X(self.app)
         x.qos.prev = 1
         x.qos.value = 6
-        asynloop(*x.args, sleep=x.closer())
+        x.hub.on_tick.add(x.closer(mod=2))
+        asynloop(*x.args)
         x.qos.update.assert_called_with()
         x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
 
     def test_poll_empty(self):
         x = X(self.app)
         x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        x.hub.timer._queue = [1]
+        x.close_then_error(x.hub.poller.poll)
         x.hub.fire_timers.return_value = 33.37
         x.hub.poller.poll.return_value = []
         with self.assertRaises(socket.error):
@@ -209,39 +226,43 @@ class test_asynloop(AppCase):
 
     def test_poll_readable(self):
         x = X(self.app)
-        x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait, mod=4)
+        reader = Mock(name='reader')
+        x.hub.add_reader(6, reader)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
         x.hub.poller.poll.return_value = [(6, READ)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        x.hub.readers[6].assert_called_with(6, READ)
+        reader.assert_called_with(6, READ)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_readable_raises_Empty(self):
         x = X(self.app)
-        x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        reader = Mock(name='reader')
+        x.hub.add_reader(6, reader)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, READ)]
-        x.hub.readers[6].side_effect = Empty()
+        reader.side_effect = Empty()
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        x.hub.readers[6].assert_called_with(6, READ)
+        reader.assert_called_with(6, READ)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable(self):
         x = X(self.app)
-        x.hub.writers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        writer = Mock(name='writer')
+        x.hub.add_writer(6, writer)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, WRITE)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        x.hub.writers[6].assert_called_with(6, WRITE)
+        writer.assert_called_with(6, WRITE)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_writable_none_registered(self):
         x = X(self.app)
-        x.hub.writers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        writer = Mock(name='writer')
+        x.hub.add_writer(6, writer)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(7, WRITE)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
@@ -249,8 +270,9 @@ class test_asynloop(AppCase):
 
     def test_poll_unknown_event(self):
         x = X(self.app)
-        x.hub.writers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        writer = Mock(name='reader')
+        x.hub.add_writer(6, writer)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, 0)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
@@ -275,24 +297,26 @@ class test_asynloop(AppCase):
 
     def test_poll_err_writable(self):
         x = X(self.app)
-        x.hub.writers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        writer = Mock(name='writer')
+        x.hub.add_writer(6, writer, 48)
+        x.hub.on_tick.add(x.close_then_error(Mock(), 2))
         x.hub.poller.poll.return_value = [(6, ERR)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        x.hub.writers[6].assert_called_with(6, ERR)
+        writer.assert_called_with(6, ERR, 48)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_write_generator(self):
         x = X(self.app)
+        x.hub.remove = Mock(name='hub.remove()')
 
         def Gen():
             yield 1
             yield 2
         gen = Gen()
 
-        x.hub.writers = {6: gen}
-        x.close_then_error(x.connection.drain_nowait)
+        x.hub.add_writer(6, gen)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, WRITE)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
@@ -306,9 +330,10 @@ class test_asynloop(AppCase):
             raise StopIteration()
             yield
         gen = Gen()
-        x.hub.writers = {6: gen}
-        x.close_then_error(x.connection.drain_nowait)
+        x.hub.add_writer(6, gen)
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, WRITE)]
+        x.hub.remove = Mock(name='hub.remove()')
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         self.assertIsNone(gen.gi_frame)
@@ -321,8 +346,9 @@ class test_asynloop(AppCase):
             raise ValueError('foo')
             yield
         gen = Gen()
-        x.hub.writers = {6: gen}
-        x.close_then_error(x.connection.drain_nowait)
+        x.hub.add_writer(6, gen)
+        x.hub.remove = Mock(name='hub.remove()')
+        x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
         x.hub.poller.poll.return_value = [(6, WRITE)]
         with self.assertRaises(ValueError):
             asynloop(*x.args)
@@ -331,19 +357,19 @@ class test_asynloop(AppCase):
 
     def test_poll_err_readable(self):
         x = X(self.app)
-        x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        reader = Mock(name='reader')
+        x.hub.add_reader(6, reader, 24)
+        x.hub.on_tick.add(x.close_then_error(Mock(), 2))
         x.hub.poller.poll.return_value = [(6, ERR)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        x.hub.readers[6].assert_called_with(6, ERR)
+        reader.assert_called_with(6, ERR, 24)
         self.assertTrue(x.hub.poller.poll.called)
 
     def test_poll_raises_ValueError(self):
         x = X(self.app)
         x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
-        x.hub.poller.poll.side_effect = ValueError()
+        x.close_then_error(x.hub.poller.poll, exc=ValueError)
         asynloop(*x.args)
         self.assertTrue(x.hub.poller.poll.called)
 

+ 1 - 26
celery/tests/worker/test_worker.py

@@ -1061,33 +1061,8 @@ class test_WorkController(AppCase):
         w.consumer.restart_count = -1
         pool = components.Pool(w)
         pool.create(w)
+        print(pool.register_with_event_loop)
         pool.register_with_event_loop(w, w.hub)
         self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
         P = w.pool
         P.start()
-
-        hub = Mock()
-
-        w = Mock()
-        poolimp.on_process_up(w)
-        hub.add_reader.assert_has_calls([
-            call(w.sentinel, P.maintain_pool),
-            call(w.outqR_fd, P.handle_result_event),
-        ])
-
-        poolimp.on_process_down(w)
-        hub.remove.assert_has_calls([
-            call(w.sentinel), call(w.outqR_fd),
-        ])
-
-        w.pool._tref_for_id = {}
-
-        result = Mock()
-        poolimp.on_timeout_cancel(result)
-        poolimp.on_timeout_cancel(result)  # no more tref
-
-        with self.assertRaises(WorkerLostError):
-            P._pool.did_start_ok = Mock()
-            P._pool.did_start_ok.return_value = False
-            w.consumer.restart_count = 0
-            P.register_with_event_loop(w, hub)

+ 2 - 2
celery/worker/loops.py

@@ -39,7 +39,7 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
     on_task_received = obj.create_task_handler()
 
     if heartbeat and connection.supports_heartbeats:
-        hub.call_repeatedly(heartbeat / hbrate, hbtick, (hbrate, ))
+        hub.call_repeatedly(heartbeat / hbrate, hbtick, hbrate)
 
     consumer.callbacks = [on_task_received]
     consumer.consume()
@@ -67,7 +67,7 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
             # control commands will be prioritized over task messages.
             if qos.prev != qos.value:
                 update_qos()
-            next(loop)
+            next(loop, None)
     finally:
         try:
             hub.close()