Explorar el Código

Tests passing

Ask Solem hace 12 años
padre
commit
609e402949

+ 5 - 16
celery/tests/worker/test_bootsteps.py

@@ -87,12 +87,12 @@ class test_StartStopComponent(Case):
         # it to the parent.components list.
         x.include(self)
         self.assertTrue(self.components)
-        self.assertIs(self.components[0], x.obj)
+        self.assertIs(self.components[0], x)
 
-        x.start()
+        x.start(self)
         x.obj.start.assert_called_with()
 
-        x.stop()
+        x.stop(self)
         x.obj.stop.assert_called_with()
 
     def test_include_when_disabled(self):
@@ -101,25 +101,14 @@ class test_StartStopComponent(Case):
         x.include(self)
         self.assertFalse(self.components)
 
-    def test_terminate_when_terminable(self):
-        x = self.Def(self)
-        x.terminable = True
-        x.create = Mock()
-
-        x.include(self)
-        x.terminate()
-        x.obj.terminate.assert_called_with()
-        self.assertFalse(x.obj.stop.call_count)
-
-    def test_terminate_calls_stop_when_not_terminable(self):
+    def test_terminate(self):
         x = self.Def(self)
         x.terminable = False
         x.create = Mock()
 
         x.include(self)
-        x.terminate()
+        x.terminate(self)
         x.obj.stop.assert_called_with()
-        self.assertFalse(x.obj.terminate.call_count)
 
 
 class test_Namespace(AppCase):

+ 14 - 10
celery/tests/worker/test_worker.py

@@ -9,6 +9,7 @@ from Queue import Empty
 
 from billiard.exceptions import WorkerLostError
 from kombu import Connection
+from kombu.common import QoS, PREFETCH_COUNT_MAX
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from mock import Mock, patch
@@ -24,11 +25,10 @@ from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker.components import Queues, Timers, EvLoop, Pool
-from celery.worker.bootsteps import RUN, CLOSE, TERMINATE
+from celery.worker.bootsteps import RUN, CLOSE, TERMINATE, StartStopComponent
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker.consumer import BlockingConsumer
-from celery.worker.consumer import QoS, PREFETCH_COUNT_MAX
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
@@ -830,7 +830,7 @@ class test_WorkController(AppCase):
     def test_with_embedded_celerybeat(self):
         worker = WorkController(concurrency=1, loglevel=0, beat=True)
         self.assertTrue(worker.beat)
-        self.assertIn(worker.beat, worker.components)
+        self.assertIn(worker.beat, [w.obj for w in worker.components])
 
     def test_with_autoscaler(self):
         worker = self.create_worker(autoscale=[10, 3], send_events=False,
@@ -988,13 +988,18 @@ class test_WorkController(AppCase):
     def test_start__stop(self):
         worker = self.worker
         worker.namespace.shutdown_complete.set()
-        worker.components = [Mock(), Mock(), Mock(), Mock()]
+        worker.components = [StartStopComponent(self) for _ in range(4)]
+        worker.namespace.state = RUN
+        worker.namespace.started = 4
+        for w in worker.components:
+            w.start = Mock()
+            w.stop = Mock()
 
         worker.start()
         for w in worker.components:
             self.assertTrue(w.start.call_count)
         worker.stop()
-        for component in worker.components:
+        for w in worker.components:
             self.assertTrue(w.stop.call_count)
 
         # Doesn't close pool if no pool.
@@ -1022,9 +1027,9 @@ class test_WorkController(AppCase):
     def test_start__terminate(self):
         worker = self.worker
         worker.namespace.shutdown_complete.set()
+        worker.namespace.started = 5
+        worker.namespace.state = RUN
         worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
-        for component in worker.components[:3]:
-            component.terminate = None
 
         worker.start()
         for w in worker.components[:3]:
@@ -1032,9 +1037,8 @@ class test_WorkController(AppCase):
         self.assertTrue(worker.namespace.started, len(worker.components))
         self.assertEqual(worker.namespace.state, RUN)
         worker.terminate()
-        for component in worker.components[:3]:
-            self.assertTrue(component.stop.call_count)
-        self.assertTrue(worker.components[4].terminate.call_count)
+        for component in worker.components:
+            self.assertTrue(component.terminate.call_count)
 
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()

+ 0 - 1
celery/worker/__init__.py

@@ -129,7 +129,6 @@ class WorkController(configurated):
                                    on_stopped=self.on_stopped)
         self.namespace.apply(self, **kwargs)
 
-
     def on_before_init(self, **kwargs):
         pass
 

+ 6 - 12
celery/worker/bootsteps.py

@@ -85,7 +85,6 @@ class Namespace(object):
                 pass
             else:
                 close(parent)
-        self.state = CLOSE
 
     def stop(self, parent, terminate=False):
         what = 'Terminating' if terminate else 'Stopping'
@@ -95,21 +94,19 @@ class Namespace(object):
         if self.state in (CLOSE, TERMINATE):
             return
 
-        self.close()
+        self.close(parent)
 
         if self.state != RUN or self.started != len(parent.components):
             # Not fully started, can safely exit.
             self.state = TERMINATE
             self.shutdown_complete.set()
             return
+        self.state = CLOSE
 
         for component in reversed(parent.components):
             if component:
                 logger.debug('%s %s...', what, qualname(component))
-                stop = component.stop
-                if terminate:
-                    stop = getattr(component, 'terminate', None) or stop
-                stop(parent)
+                (component.terminate if terminate else component.stop)(parent)
 
         if self.on_stopped:
             self.on_stopped()
@@ -277,7 +274,6 @@ class Component(object):
 
 class StartStopComponent(Component):
     abstract = True
-    terminable = False
 
     def start(self, parent):
         return self.obj.start()
@@ -288,11 +284,9 @@ class StartStopComponent(Component):
     def close(self, parent):
         pass
 
-    def terminate(self):
-        if self.terminable:
-            return self.obj.terminate()
-        return self.obj.stop()
+    def terminate(self, parent):
+        self.stop(parent)
 
     def include(self, parent):
         if super(StartStopComponent, self).include(parent):
-            parent.components.append(self.obj)
+            parent.components.append(self)

+ 4 - 0
celery/worker/components.py

@@ -58,6 +58,10 @@ class Pool(bootsteps.StartStopComponent):
         if w.pool:
             w.pool.close()
 
+    def terminate(self, w):
+        if w.pool:
+            w.pool.terminate()
+
     def on_poll_init(self, pool, hub):
         apply_after = hub.timer.apply_after
         apply_at = hub.timer.apply_at