浏览代码

Tests passing

Ask Solem 12 年之前
父节点
当前提交
d47ae1389c

+ 10 - 4
celery/app/task.py

@@ -59,10 +59,16 @@ class Context(object):
     _protected = 0
 
     def __init__(self, *args, **kwargs):
-        self.get = self.__dict__.get
-        self.clear = self.__dict__.clear
-        update = self.update = self.__dict__.update
-        update(*args, **kwargs)
+        self.update(*args, **kwargs)
+
+    def update(self, *args, **kwargs):
+        return self.__dict__.update(*args, **kwargs)
+
+    def clear(self):
+        return self.__dict__.clear()
+
+    def get(self, key, default=None):
+        return getattr(self, key, default)
 
     def __repr__(self):
         return '<Context: {0!r}>'.format(vars(self))

+ 1 - 1
celery/bootsteps.py

@@ -155,7 +155,7 @@ class Namespace(object):
             return
         self.close(parent)
         self.state = CLOSE
-        self.restart(parent, what, 'terminate' if terminate else 'stop')
+        self.restart(parent, 'terminate' if terminate else 'stop', what)
 
         if self.on_stopped:
             self.on_stopped()

+ 4 - 2
celery/concurrency/processes.py

@@ -324,7 +324,8 @@ class AsynPool(_pool.Pool):
 
 class TaskPool(BasePool):
     """Multiprocessing Pool implementation."""
-    Pool = _pool.Pool
+    Pool = AsynPool
+    BlockingPool = _pool.Pool
 
     uses_semaphore = True
 
@@ -343,7 +344,8 @@ class TaskPool(BasePool):
                 warning(MAXTASKS_NO_BILLIARD)
 
         forking_enable(self.forking_enable)
-        Pool = self.Pool if self.options.get('threads', True) else AsynPool
+        Pool = (self.BlockingPool if self.options.get('threads', True)
+                else self.Pool)
         P = self._pool = Pool(processes=self.limit,
                               initializer=process_initializer,
                               **self.options)

+ 8 - 0
celery/datastructures.py

@@ -529,12 +529,20 @@ class LimitedSet(object):
         self.expires = expires
         self._data = {} if data is None else data
         self._heap = [] if heap is None else heap
+        # make shortcuts
+        self.__iter__ = self._data.__iter__
         self.__len__ = self._data.__len__
         self.__contains__ = self._data.__contains__
 
     def __iter__(self):
         return iter(self._data)
 
+    def __len__(self):
+        return len(self._data)
+
+    def __contains__(self, key):
+        return key in self._data
+
     def add(self, value):
         """Add a new member."""
         self.purge(1)

+ 13 - 15
celery/tests/concurrency/test_processes.py

@@ -67,7 +67,8 @@ class MockPool(object):
         self.maintain_pool = Mock()
         self._state = mp.RUN
         self._processes = kwargs.get('processes')
-        self._pool = [Object(pid=i) for i in range(self._processes)]
+        self._pool = [Object(pid=i, inqW_fd=1, outqR_fd=2)
+                      for i in range(self._processes)]
         self._current_proc = cycle(range(self._processes))
 
     def close(self):
@@ -80,6 +81,15 @@ class MockPool(object):
     def terminate(self):
         self.terminated = True
 
+    def terminate_job(self, *args, **kwargs):
+        pass
+
+    def restart(self, *args, **kwargs):
+        pass
+
+    def handle_result_event(self, *args, **kwargs):
+        pass
+
     def grow(self, n=1):
         self._processes += n
 
@@ -100,11 +110,11 @@ class ExeMockPool(MockPool):
 
 
 class TaskPool(mp.TaskPool):
-    Pool = MockPool
+    Pool = BlockingPool = MockPool
 
 
 class ExeMockTaskPool(mp.TaskPool):
-    Pool = ExeMockPool
+    Pool = BlockingPool = ExeMockPool
 
 
 class test_TaskPool(Case):
@@ -138,12 +148,6 @@ class test_TaskPool(Case):
         pool.start()
         pool.apply_async(lambda x: x, (2, ), {})
 
-    def test_terminate_job(self):
-        pool = TaskPool(10)
-        pool._pool = Mock()
-        pool.terminate_job(1341)
-        pool._pool.terminate_job.assert_called_with(1341, None)
-
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool.start()
@@ -170,12 +174,6 @@ class test_TaskPool(Case):
         pool.start()
         self.assertEqual(pool.num_processes, 7)
 
-    def test_restart_pool(self):
-        pool = TaskPool()
-        pool._pool = Mock()
-        pool.restart()
-        pool._pool.restart.assert_called_with()
-
     def test_restart(self):
         raise SkipTest('functional test')
 

+ 2 - 2
celery/tests/worker/test_hub.py

@@ -198,7 +198,7 @@ class test_Hub(Case):
         hub.remove(File(11))
         self.assertNotIn(11, hub.readers)
         P.unregister.assert_has_calls([
-            call(10), call(File(11)),
+            call(10), call(11),
         ])
 
     def test_can_remove_unknown_fds(self):
@@ -235,7 +235,7 @@ class test_Hub(Case):
         hub.remove(File(21))
         self.assertNotIn(21, hub.writers)
         P.unregister.assert_has_calls([
-            call(20), call(File(21)),
+            call(20), call(21),
         ])
 
     def test_enter__exit(self):

+ 47 - 21
celery/tests/worker/test_worker.py

@@ -11,7 +11,7 @@ from kombu import Connection
 from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
-from mock import Mock, patch
+from mock import call, Mock, patch
 
 from celery import current_app
 from celery.app.defaults import DEFAULTS
@@ -25,9 +25,10 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import components
-from celery.worker.job import Request
 from celery.worker import consumer
 from celery.worker.consumer import Consumer as __Consumer
+from celery.worker.hub import READ, ERR
+from celery.worker.job import Request
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
@@ -38,7 +39,7 @@ def MockStep(step=None):
     step = Mock() if step is None else step
     step.namespace = Mock()
     step.namespace.name = 'MockNS'
-    step.name = 'MockStep'
+    step.name = 'MockStep(%s)' % (id(step), )
     return step
 
 
@@ -1026,14 +1027,17 @@ class test_WorkController(AppCase):
         worker.namespace.started = 4
         for w in worker.steps:
             w.start = Mock()
+            w.close = Mock()
             w.stop = Mock()
 
         worker.start()
         for w in worker.steps:
             self.assertTrue(w.start.call_count)
+        worker.consumer = Mock()
         worker.stop()
-        for w in worker.steps:
-            self.assertTrue(w.stop.call_count)
+        for stopstep in worker.steps:
+            self.assertTrue(stopstep.close.call_count)
+            self.assertTrue(stopstep.stop.call_count)
 
         # Doesn't close pool if no pool.
         worker.start()
@@ -1099,45 +1103,67 @@ class test_WorkController(AppCase):
         w._conninfo.connection_errors = w._conninfo.channel_errors = ()
         w.hub = Mock()
         w.hub.on_init = []
-        w.pool_cls = Mock()
-        P = w.pool_cls.return_value = Mock()
-        P.timers = {Mock(): 30}
+
+        PoolImp = Mock()
+        poolimp = PoolImp.return_value = Mock()
+        poolimp._pool = [Mock(), Mock()]
+        poolimp._cache = {}
+        poolimp._fileno_to_inq = {}
+        poolimp._fileno_to_outq = {}
+
+        from celery.concurrency.processes import TaskPool as _TaskPool
+
+        class TaskPool(_TaskPool):
+            Pool = PoolImp
+
+            @property
+            def timers(self):
+                return {Mock(): 30}
+
+        w.pool_cls = TaskPool
         w.use_eventloop = True
         w.consumer.restart_count = -1
         pool = components.Pool(w)
         pool.create(w)
         self.assertIsInstance(w.semaphore, BoundedSemaphore)
         self.assertTrue(w.hub.on_init)
+        P = w.pool
+        P.start()
 
         hub = Mock()
         w.hub.on_init[0](hub)
 
-        cbs = w.pool.init_callbacks.call_args[1]
         w = Mock()
-        cbs['on_process_up'](w)
-        hub.add_reader.assert_called_with(w.sentinel, P.maintain_pool)
+        poolimp.on_process_up(w)
+        hub.add.assert_has_calls([
+            call(w.sentinel, P.maintain_pool, READ | ERR),
+            call(w.outqR_fd, P.handle_result_event, READ | ERR),
+        ])
 
-        cbs['on_process_down'](w)
-        hub.remove.assert_called_with(w.sentinel)
+        poolimp.on_process_down(w)
+        hub.remove.assert_has_calls([
+            call(w.sentinel), call(w.outqR_fd),
+        ])
 
         result = Mock()
         tref = result._tref
 
-        cbs['on_timeout_cancel'](result)
+        poolimp.on_timeout_cancel(result)
         tref.cancel.assert_called_with()
-        cbs['on_timeout_cancel'](result)  # no more tref
+        poolimp.on_timeout_cancel(result)  # no more tref
 
-        cbs['on_timeout_set'](result, 10, 20)
+        poolimp.on_timeout_set(result, 10, 20)
         tsoft, callback = hub.timer.apply_after.call_args[0]
         callback()
 
-        cbs['on_timeout_set'](result, 10, None)
+        poolimp.on_timeout_set(result, 10, None)
         tsoft, callback = hub.timer.apply_after.call_args[0]
         callback()
-        cbs['on_timeout_set'](result, None, 10)
-        cbs['on_timeout_set'](result, None, None)
+        poolimp.on_timeout_set(result, None, 10)
+        poolimp.on_timeout_set(result, None, None)
 
         with self.assertRaises(WorkerLostError):
-            P.did_start_ok.return_value = False
+            P._pool.did_start_ok = Mock()
+            P._pool.did_start_ok.return_value = False
             w.consumer.restart_count = 0
-            pool.on_poll_init(P, w, hub)
+            P.on_poll_init(w, hub)

+ 5 - 3
celery/worker/consumer.py

@@ -277,11 +277,13 @@ class Consumer(object):
         # Clear internal queues to get rid of old messages.
         # They can't be acked anyway, as a delivery tag is specific
         # to the current channel.
-        if self.controller.semaphore:
+        if self.controller and self.controller.semaphore:
             self.controller.semaphore.clear()
-        self.timer.clear()
+        if self.timer:
+            self.timer.clear()
         reserved_requests.clear()
-        self.pool.flush()
+        if self.pool:
+            self.pool.flush()
 
     def connect(self):
         """Establish the broker connection.