浏览代码

Tests passing

Ask Solem 13 年之前
父节点
当前提交
609e402949

+ 5 - 16
celery/tests/worker/test_bootsteps.py

@@ -87,12 +87,12 @@ class test_StartStopComponent(Case):
         # it to the parent.components list.
         # it to the parent.components list.
         x.include(self)
         x.include(self)
         self.assertTrue(self.components)
         self.assertTrue(self.components)
-        self.assertIs(self.components[0], x.obj)
+        self.assertIs(self.components[0], x)
 
 
-        x.start()
+        x.start(self)
         x.obj.start.assert_called_with()
         x.obj.start.assert_called_with()
 
 
-        x.stop()
+        x.stop(self)
         x.obj.stop.assert_called_with()
         x.obj.stop.assert_called_with()
 
 
     def test_include_when_disabled(self):
     def test_include_when_disabled(self):
@@ -101,25 +101,14 @@ class test_StartStopComponent(Case):
         x.include(self)
         x.include(self)
         self.assertFalse(self.components)
         self.assertFalse(self.components)
 
 
-    def test_terminate_when_terminable(self):
-        x = self.Def(self)
-        x.terminable = True
-        x.create = Mock()
-
-        x.include(self)
-        x.terminate()
-        x.obj.terminate.assert_called_with()
-        self.assertFalse(x.obj.stop.call_count)
-
-    def test_terminate_calls_stop_when_not_terminable(self):
+    def test_terminate(self):
         x = self.Def(self)
         x = self.Def(self)
         x.terminable = False
         x.terminable = False
         x.create = Mock()
         x.create = Mock()
 
 
         x.include(self)
         x.include(self)
-        x.terminate()
+        x.terminate(self)
         x.obj.stop.assert_called_with()
         x.obj.stop.assert_called_with()
-        self.assertFalse(x.obj.terminate.call_count)
 
 
 
 
 class test_Namespace(AppCase):
 class test_Namespace(AppCase):

+ 14 - 10
celery/tests/worker/test_worker.py

@@ -9,6 +9,7 @@ from Queue import Empty
 
 
 from billiard.exceptions import WorkerLostError
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
 from kombu import Connection
+from kombu.common import QoS, PREFETCH_COUNT_MAX
 from kombu.exceptions import StdChannelError
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from kombu.transport.base import Message
 from mock import Mock, patch
 from mock import Mock, patch
@@ -24,11 +25,10 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import WorkController
 from celery.worker.components import Queues, Timers, EvLoop, Pool
 from celery.worker.components import Queues, Timers, EvLoop, Pool
-from celery.worker.bootsteps import RUN, CLOSE, TERMINATE
+from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.job import Request
 from celery.worker.consumer import BlockingConsumer
 from celery.worker.consumer import BlockingConsumer
-from celery.worker.consumer import QoS, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
@@ -830,7 +830,7 @@ class test_WorkController(AppCase):
     def test_with_embedded_celerybeat(self):
     def test_with_embedded_celerybeat(self):
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
         self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, worker.components)
+        self.assertIn(worker.beat, [w.obj for w in worker.components])
 
 
     def test_with_autoscaler(self):
     def test_with_autoscaler(self):
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
@@ -988,13 +988,18 @@ class test_WorkController(AppCase):
     def test_start__stop(self):
     def test_start__stop(self):
         worker = self.worker
         worker = self.worker
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
-        worker.components = [Mock(), Mock(), Mock(), Mock()]
+        worker.components = [StartStopComponent(self) for _ in range(4)]
+        worker.namespace.state = RUN
+        worker.namespace.started = 4
+        for w in worker.components:
+            w.start = Mock()
+            w.stop = Mock()
 
 
         worker.start()
         worker.start()
         for w in worker.components:
         for w in worker.components:
             self.assertTrue(w.start.call_count)
             self.assertTrue(w.start.call_count)
         worker.stop()
         worker.stop()
-        for component in worker.components:
+        for w in worker.components:
             self.assertTrue(w.stop.call_count)
             self.assertTrue(w.stop.call_count)
 
 
         # Doesn't close pool if no pool.
         # Doesn't close pool if no pool.
@@ -1022,9 +1027,9 @@ class test_WorkController(AppCase):
     def test_start__terminate(self):
     def test_start__terminate(self):
         worker = self.worker
         worker = self.worker
         worker.namespace.shutdown_complete.set()
         worker.namespace.shutdown_complete.set()
+        worker.namespace.started = 5
+        worker.namespace.state = RUN
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-        for component in worker.components[:3]:
-            component.terminate = None
 
 
         worker.start()
         worker.start()
         for w in worker.components[:3]:
         for w in worker.components[:3]:
@@ -1032,9 +1037,8 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.namespace.started, len(worker.components))
         self.assertTrue(worker.namespace.started, len(worker.components))
         self.assertEqual(worker.namespace.state, RUN)
         self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
         worker.terminate()
-        for component in worker.components[:3]:
-            self.assertTrue(component.stop.call_count)
-        self.assertTrue(worker.components[4].terminate.call_count)
+        for component in worker.components:
+            self.assertTrue(component.terminate.call_count)
 
 
     def test_Queues_pool_not_rlimit_safe(self):
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()
         w = Mock()

+ 0 - 1
celery/worker/__init__.py

@@ -129,7 +129,6 @@ class WorkController(configurated):
                                    on_stopped=self.on_stopped)
                                    on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
         self.namespace.apply(self, **kwargs)
 
 
-
     def on_before_init(self, **kwargs):
     def on_before_init(self, **kwargs):
         pass
         pass
 
 

+ 6 - 12
celery/worker/bootsteps.py

@@ -85,7 +85,6 @@ class Namespace(object):
                 pass
                 pass
             else:
             else:
                 close(parent)
                 close(parent)
-        self.state = CLOSE
 
 
     def stop(self, parent, terminate=False):
     def stop(self, parent, terminate=False):
         what = 'Terminating' if terminate else 'Stopping'
         what = 'Terminating' if terminate else 'Stopping'
@@ -95,21 +94,19 @@ class Namespace(object):
         if self.state in (CLOSE, TERMINATE):
         if self.state in (CLOSE, TERMINATE):
             return
             return
 
 
-        self.close()
+        self.close(parent)
 
 
         if self.state != RUN or self.started != len(parent.components):
         if self.state != RUN or self.started != len(parent.components):
             # Not fully started, can safely exit.
             # Not fully started, can safely exit.
             self.state = TERMINATE
             self.state = TERMINATE
             self.shutdown_complete.set()
             self.shutdown_complete.set()
             return
             return
+        self.state = CLOSE
 
 
         for component in reversed(parent.components):
         for component in reversed(parent.components):
             if component:
             if component:
                 logger.debug('%s %s...', what, qualname(component))
                 logger.debug('%s %s...', what, qualname(component))
-                stop = component.stop
-                if terminate:
-                    stop = getattr(component, 'terminate', None) or stop
-                stop(parent)
+                (component.terminate if terminate else component.stop)(parent)
 
 
         if self.on_stopped:
         if self.on_stopped:
             self.on_stopped()
             self.on_stopped()
@@ -277,7 +274,6 @@ class Component(object):
 
 
 class StartStopComponent(Component):
 class StartStopComponent(Component):
     abstract = True
     abstract = True
-    terminable = False
 
 
     def start(self, parent):
     def start(self, parent):
         return self.obj.start()
         return self.obj.start()
@@ -288,11 +284,9 @@ class StartStopComponent(Component):
     def close(self, parent):
     def close(self, parent):
         pass
         pass
 
 
-    def terminate(self):
-        if self.terminable:
-            return self.obj.terminate()
-        return self.obj.stop()
+    def terminate(self, parent):
+        self.stop(parent)
 
 
     def include(self, parent):
     def include(self, parent):
         if super(StartStopComponent, self).include(parent):
         if super(StartStopComponent, self).include(parent):
-            parent.components.append(self.obj)
+            parent.components.append(self)

+ 4 - 0
celery/worker/components.py

@@ -58,6 +58,10 @@ class Pool(bootsteps.StartStopComponent):
         if w.pool:
         if w.pool:
             w.pool.close()
             w.pool.close()
 
 
+    def terminate(self, w):
+        if w.pool:
+            w.pool.terminate()
+
     def on_poll_init(self, pool, hub):
     def on_poll_init(self, pool, hub):
         apply_after = hub.timer.apply_after
         apply_after = hub.timer.apply_after
         apply_at = hub.timer.apply_at
         apply_at = hub.timer.apply_at