Browse Source

98% Coverage for celery.worker

Ask Solem 14 years ago
parent
commit
96f8bb51a5
2 changed files with 178 additions and 5 deletions
  1. 175 4
      celery/tests/test_worker.py
  2. 3 1
      celery/worker/__init__.py

+ 175 - 4
celery/tests/test_worker.py

@@ -1,5 +1,5 @@
 import socket
-from celery.tests.utils import unittest
+import sys
 
 from datetime import datetime, timedelta
 from Queue import Empty
@@ -10,6 +10,7 @@ from celery.utils.timer2 import Timer
 
 from celery.app import app_or_default
 from celery.concurrency.base import BasePool
+from celery.exceptions import SystemTerminate
 from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import gen_unique_id
@@ -21,7 +22,8 @@ from celery.worker.consumer import QoS, RUN
 from celery.utils.serialization import pickle
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import execute_context
+from celery.tests.utils import unittest
+from celery.tests.utils import execute_context, skip
 
 
 class MockConsumer(object):
@@ -125,12 +127,16 @@ class MockPool(BasePool):
     def __init__(self, *args, **kwargs):
         self.raise_regular = kwargs.get("raise_regular", False)
         self.raise_base = kwargs.get("raise_base", False)
+        self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
+                                                False)
 
     def apply_async(self, *args, **kwargs):
         if self.raise_regular:
             raise KeyError("some exception")
         if self.raise_base:
             raise KeyboardInterrupt("Ctrl+c")
+        if self.raise_SystemTerminate:
+            raise SystemTerminate()
 
     def start(self):
         pass
@@ -584,9 +590,65 @@ class test_Consumer(unittest.TestCase):
 
 class test_WorkController(unittest.TestCase):
 
+    def create_worker(self, **kw):
+        worker = WorkController(concurrency=1, loglevel=0, **kw)
+        worker.logger = MockLogger()
+        return worker
+
     def setUp(self):
-        self.worker = WorkController(concurrency=1, loglevel=0)
-        self.worker.logger = MockLogger()
+        self.worker = self.create_worker()
+
+    def test_process_initializer(self):
+        from celery import Celery
+        from celery import platforms
+        from celery import signals
+        from celery.app import _tls
+        from celery.worker import process_initializer
+        from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
+
+        ignored_signals = []
+        reset_signals = []
+        worker_init = [False]
+        default_app = app_or_default()
+        app = Celery(loader="default")
+
+        class Loader(object):
+
+            def init_worker(self):
+                worker_init[0] = True
+        app.loader = Loader()
+
+        def on_worker_process_init(**kwargs):
+            on_worker_process_init.called = True
+        on_worker_process_init.called = False
+        signals.worker_process_init.connect(on_worker_process_init)
+
+        def set_mp_process_title(title, hostname=None):
+            set_mp_process_title.called = (title, hostname)
+        set_mp_process_title.called = ()
+
+        pignore_signal = platforms.ignore_signal
+        preset_signal = platforms.reset_signal
+        psetproctitle = platforms.set_mp_process_title
+        platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
+        platforms.reset_signal = lambda sig: reset_signals.append(sig)
+        platforms.set_mp_process_title = set_mp_process_title
+        try:
+            process_initializer(app, "awesome.worker.com")
+            self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
+            self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
+            self.assertTrue(worker_init[0])
+            self.assertTrue(on_worker_process_init.called)
+            self.assertIs(_tls.current_app, app)
+            self.assertTupleEqual(set_mp_process_title.called,
+                                  ("celeryd", "awesome.worker.com"))
+        finally:
+            platforms.ignore_signal = pignore_signal
+            platforms.reset_signal = preset_signal
+            platforms.set_mp_process_title = psetproctitle
+            default_app.set_current()
+
+
 
     def test_with_rate_limits_disabled(self):
         worker = WorkController(concurrency=1, loglevel=0,
@@ -608,6 +670,50 @@ class test_WorkController(unittest.TestCase):
         self.assertTrue(worker.beat)
         self.assertIn(worker.beat, worker.components)
 
+    def test_with_autoscaler(self):
+        worker = self.create_worker(autoscale=[10, 3], send_events=False,
+                                eta_scheduler_cls="celery.utils.timer2.Timer")
+        self.assertTrue(worker.autoscaler)
+
+    def test_dont_stop_or_terminate(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+        worker.stop()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+        worker.terminate()
+        self.assertNotEqual(worker._state, worker.CLOSE)
+
+        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
+        try:
+            worker._state = worker.RUN
+            worker.stop(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+            worker.terminate(in_sighandler=True)
+            self.assertNotEqual(worker._state, worker.CLOSE)
+        finally:
+            worker.pool.signal_safe = sigsafe
+
+    def test_on_timer_error(self):
+        worker = WorkController(concurrency=1, loglevel=0)
+        worker.logger = MockLogger()
+
+        try:
+            raise KeyError("foo")
+        except KeyError:
+            exc_info = sys.exc_info()
+
+        worker.on_timer_error(exc_info)
+        logged = worker.logger.logged[0]
+        self.assertIn("foo", logged)
+
+    def test_on_timer_tick(self):
+        worker = WorkController(concurrency=1, loglevel=10)
+        worker.logger = MockLogger()
+        worker.timer_debug = worker.logger.debug
+
+        worker.on_timer_tick(30.0)
+        logged = worker.logger.logged[0]
+        self.assertIn("30.0", logged)
+
     def test_process_task(self):
         worker = self.worker
         worker.pool = MockPool()
@@ -630,6 +736,18 @@ class test_WorkController(unittest.TestCase):
         self.assertRaises(KeyboardInterrupt, worker.process_task, task)
         self.assertEqual(worker._state, worker.TERMINATE)
 
+    def test_process_task_raise_SystemTerminate(self):
+        worker = self.worker
+        worker.pool = MockPool(raise_SystemTerminate=True)
+        backend = MockBackend()
+        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
+                           kwargs={})
+        task = TaskRequest.from_message(m, m.decode())
+        worker.components = []
+        worker._state = worker.RUN
+        self.assertRaises(SystemExit, worker.process_task, task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
     def test_process_task_raise_regular(self):
         worker = self.worker
         worker.pool = MockPool(raise_regular=True)
@@ -640,6 +758,59 @@ class test_WorkController(unittest.TestCase):
         worker.process_task(task)
         worker.pool.stop()
 
+    def test_start_catches_base_exceptions(self):
+
+        class Component(object):
+            stopped = False
+            terminated = False
+
+            def __init__(self, exc):
+                self.exc = exc
+
+            def start(self):
+                raise self.exc
+
+            def terminate(self):
+                self.terminated = True
+
+            def stop(self):
+                self.stopped = True
+
+        worker1 = self.create_worker()
+        worker1.components = [Component(SystemTerminate())]
+        self.assertRaises(SystemExit, worker1.start)
+        self.assertTrue(worker1.components[0].terminated)
+
+        worker2 = self.create_worker()
+        worker2.components = [Component(SystemExit())]
+        self.assertRaises(SystemExit, worker2.start)
+        self.assertTrue(worker2.components[0].stopped)
+
+    def test_state_db(self):
+        from celery.worker import state
+        Persistent = state.Persistent
+
+        class MockPersistent(Persistent):
+
+            def _load(self):
+                return {}
+
+        state.Persistent = MockPersistent
+        try:
+            worker = self.create_worker(db="statefilename")
+            self.assertTrue(worker._finalize_db)
+            worker._finalize_db.cancel()
+        finally:
+            state.Persistent = Persistent
+
+    @skip("Issue #264")
+    def test_disable_rate_limits(self):
+        from celery.worker.buckets import FastQueue
+        worker = self.create_worker(disable_rate_limits=True)
+        self.assertIsInstance(worker.ready_queue, FastQueue)
+        self.assertIsNone(worker.mediator)
+        self.assertEqual(worker.ready_queue.put, worker.process_task)
+
     def test_start__stop(self):
         worker = self.worker
         w1 = {"started": False}

+ 3 - 1
celery/worker/__init__.py

@@ -158,10 +158,12 @@ class WorkController(object):
         self.queues = queues
 
         self._finalize = Finalize(self, self.stop, exitpriority=1)
+        self._finalize_db = None
 
         if self.db:
             persistence = state.Persistent(self.db)
-            Finalize(persistence, persistence.save, exitpriority=5)
+            self._finalize_db = Finalize(persistence, persistence.save,
+                                         exitpriority=5)
 
         # Queues
         if self.disable_rate_limits: