|
@@ -1,5 +1,5 @@
|
|
|
import socket
|
|
|
-from celery.tests.utils import unittest
|
|
|
+import sys
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
from Queue import Empty
|
|
@@ -10,6 +10,7 @@ from celery.utils.timer2 import Timer
|
|
|
|
|
|
from celery.app import app_or_default
|
|
|
from celery.concurrency.base import BasePool
|
|
|
+from celery.exceptions import SystemTerminate
|
|
|
from celery.task import task as task_dec
|
|
|
from celery.task import periodic_task as periodic_task_dec
|
|
|
from celery.utils import gen_unique_id
|
|
@@ -21,7 +22,8 @@ from celery.worker.consumer import QoS, RUN
|
|
|
from celery.utils.serialization import pickle
|
|
|
|
|
|
from celery.tests.compat import catch_warnings
|
|
|
-from celery.tests.utils import execute_context
|
|
|
+from celery.tests.utils import unittest
|
|
|
+from celery.tests.utils import execute_context, skip
|
|
|
|
|
|
|
|
|
class MockConsumer(object):
|
|
@@ -125,12 +127,16 @@ class MockPool(BasePool):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
self.raise_regular = kwargs.get("raise_regular", False)
|
|
|
self.raise_base = kwargs.get("raise_base", False)
|
|
|
+ self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
|
|
|
+ False)
|
|
|
|
|
|
def apply_async(self, *args, **kwargs):
|
|
|
if self.raise_regular:
|
|
|
raise KeyError("some exception")
|
|
|
if self.raise_base:
|
|
|
raise KeyboardInterrupt("Ctrl+c")
|
|
|
+ if self.raise_SystemTerminate:
|
|
|
+ raise SystemTerminate()
|
|
|
|
|
|
def start(self):
|
|
|
pass
|
|
@@ -584,9 +590,65 @@ class test_Consumer(unittest.TestCase):
|
|
|
|
|
|
class test_WorkController(unittest.TestCase):
|
|
|
|
|
|
+ def create_worker(self, **kw):
|
|
|
+ worker = WorkController(concurrency=1, loglevel=0, **kw)
|
|
|
+ worker.logger = MockLogger()
|
|
|
+ return worker
|
|
|
+
|
|
|
def setUp(self):
|
|
|
- self.worker = WorkController(concurrency=1, loglevel=0)
|
|
|
- self.worker.logger = MockLogger()
|
|
|
+ self.worker = self.create_worker()
|
|
|
+
|
|
|
+ def test_process_initializer(self):
|
|
|
+ from celery import Celery
|
|
|
+ from celery import platforms
|
|
|
+ from celery import signals
|
|
|
+ from celery.app import _tls
|
|
|
+ from celery.worker import process_initializer
|
|
|
+ from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
|
|
|
+
|
|
|
+ ignored_signals = []
|
|
|
+ reset_signals = []
|
|
|
+ worker_init = [False]
|
|
|
+ default_app = app_or_default()
|
|
|
+ app = Celery(loader="default")
|
|
|
+
|
|
|
+ class Loader(object):
|
|
|
+
|
|
|
+ def init_worker(self):
|
|
|
+ worker_init[0] = True
|
|
|
+ app.loader = Loader()
|
|
|
+
|
|
|
+ def on_worker_process_init(**kwargs):
|
|
|
+ on_worker_process_init.called = True
|
|
|
+ on_worker_process_init.called = False
|
|
|
+ signals.worker_process_init.connect(on_worker_process_init)
|
|
|
+
|
|
|
+ def set_mp_process_title(title, hostname=None):
|
|
|
+ set_mp_process_title.called = (title, hostname)
|
|
|
+ set_mp_process_title.called = ()
|
|
|
+
|
|
|
+ pignore_signal = platforms.ignore_signal
|
|
|
+ preset_signal = platforms.reset_signal
|
|
|
+ psetproctitle = platforms.set_mp_process_title
|
|
|
+ platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
|
|
|
+ platforms.reset_signal = lambda sig: reset_signals.append(sig)
|
|
|
+ platforms.set_mp_process_title = set_mp_process_title
|
|
|
+ try:
|
|
|
+ process_initializer(app, "awesome.worker.com")
|
|
|
+ self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
|
|
|
+ self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
|
|
|
+ self.assertTrue(worker_init[0])
|
|
|
+ self.assertTrue(on_worker_process_init.called)
|
|
|
+ self.assertIs(_tls.current_app, app)
|
|
|
+ self.assertTupleEqual(set_mp_process_title.called,
|
|
|
+ ("celeryd", "awesome.worker.com"))
|
|
|
+ finally:
|
|
|
+ platforms.ignore_signal = pignore_signal
|
|
|
+ platforms.reset_signal = preset_signal
|
|
|
+ platforms.set_mp_process_title = psetproctitle
|
|
|
+ default_app.set_current()
|
|
|
+
|
|
|
+
|
|
|
|
|
|
def test_with_rate_limits_disabled(self):
|
|
|
worker = WorkController(concurrency=1, loglevel=0,
|
|
@@ -608,6 +670,50 @@ class test_WorkController(unittest.TestCase):
|
|
|
self.assertTrue(worker.beat)
|
|
|
self.assertIn(worker.beat, worker.components)
|
|
|
|
|
|
+ def test_with_autoscaler(self):
|
|
|
+ worker = self.create_worker(autoscale=[10, 3], send_events=False,
|
|
|
+ eta_scheduler_cls="celery.utils.timer2.Timer")
|
|
|
+ self.assertTrue(worker.autoscaler)
|
|
|
+
|
|
|
+ def test_dont_stop_or_terminate(self):
|
|
|
+ worker = WorkController(concurrency=1, loglevel=0)
|
|
|
+ worker.stop()
|
|
|
+ self.assertNotEqual(worker._state, worker.CLOSE)
|
|
|
+ worker.terminate()
|
|
|
+ self.assertNotEqual(worker._state, worker.CLOSE)
|
|
|
+
|
|
|
+ sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
|
|
|
+ try:
|
|
|
+ worker._state = worker.RUN
|
|
|
+ worker.stop(in_sighandler=True)
|
|
|
+ self.assertNotEqual(worker._state, worker.CLOSE)
|
|
|
+ worker.terminate(in_sighandler=True)
|
|
|
+ self.assertNotEqual(worker._state, worker.CLOSE)
|
|
|
+ finally:
|
|
|
+ worker.pool.signal_safe = sigsafe
|
|
|
+
|
|
|
+ def test_on_timer_error(self):
|
|
|
+ worker = WorkController(concurrency=1, loglevel=0)
|
|
|
+ worker.logger = MockLogger()
|
|
|
+
|
|
|
+ try:
|
|
|
+ raise KeyError("foo")
|
|
|
+ except KeyError:
|
|
|
+ exc_info = sys.exc_info()
|
|
|
+
|
|
|
+ worker.on_timer_error(exc_info)
|
|
|
+ logged = worker.logger.logged[0]
|
|
|
+ self.assertIn("foo", logged)
|
|
|
+
|
|
|
+ def test_on_timer_tick(self):
|
|
|
+ worker = WorkController(concurrency=1, loglevel=10)
|
|
|
+ worker.logger = MockLogger()
|
|
|
+ worker.timer_debug = worker.logger.debug
|
|
|
+
|
|
|
+ worker.on_timer_tick(30.0)
|
|
|
+ logged = worker.logger.logged[0]
|
|
|
+ self.assertIn("30.0", logged)
|
|
|
+
|
|
|
def test_process_task(self):
|
|
|
worker = self.worker
|
|
|
worker.pool = MockPool()
|
|
@@ -630,6 +736,18 @@ class test_WorkController(unittest.TestCase):
|
|
|
self.assertRaises(KeyboardInterrupt, worker.process_task, task)
|
|
|
self.assertEqual(worker._state, worker.TERMINATE)
|
|
|
|
|
|
+ def test_process_task_raise_SystemTerminate(self):
|
|
|
+ worker = self.worker
|
|
|
+ worker.pool = MockPool(raise_SystemTerminate=True)
|
|
|
+ backend = MockBackend()
|
|
|
+ m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
|
|
|
+ kwargs={})
|
|
|
+ task = TaskRequest.from_message(m, m.decode())
|
|
|
+ worker.components = []
|
|
|
+ worker._state = worker.RUN
|
|
|
+ self.assertRaises(SystemExit, worker.process_task, task)
|
|
|
+ self.assertEqual(worker._state, worker.TERMINATE)
|
|
|
+
|
|
|
def test_process_task_raise_regular(self):
|
|
|
worker = self.worker
|
|
|
worker.pool = MockPool(raise_regular=True)
|
|
@@ -640,6 +758,59 @@ class test_WorkController(unittest.TestCase):
|
|
|
worker.process_task(task)
|
|
|
worker.pool.stop()
|
|
|
|
|
|
+ def test_start_catches_base_exceptions(self):
|
|
|
+
|
|
|
+ class Component(object):
|
|
|
+ stopped = False
|
|
|
+ terminated = False
|
|
|
+
|
|
|
+ def __init__(self, exc):
|
|
|
+ self.exc = exc
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ raise self.exc
|
|
|
+
|
|
|
+ def terminate(self):
|
|
|
+ self.terminated = True
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ self.stopped = True
|
|
|
+
|
|
|
+ worker1 = self.create_worker()
|
|
|
+ worker1.components = [Component(SystemTerminate())]
|
|
|
+ self.assertRaises(SystemExit, worker1.start)
|
|
|
+ self.assertTrue(worker1.components[0].terminated)
|
|
|
+
|
|
|
+ worker2 = self.create_worker()
|
|
|
+ worker2.components = [Component(SystemExit())]
|
|
|
+ self.assertRaises(SystemExit, worker2.start)
|
|
|
+ self.assertTrue(worker2.components[0].stopped)
|
|
|
+
|
|
|
+ def test_state_db(self):
|
|
|
+ from celery.worker import state
|
|
|
+ Persistent = state.Persistent
|
|
|
+
|
|
|
+ class MockPersistent(Persistent):
|
|
|
+
|
|
|
+ def _load(self):
|
|
|
+ return {}
|
|
|
+
|
|
|
+ state.Persistent = MockPersistent
|
|
|
+ try:
|
|
|
+ worker = self.create_worker(db="statefilename")
|
|
|
+ self.assertTrue(worker._finalize_db)
|
|
|
+ worker._finalize_db.cancel()
|
|
|
+ finally:
|
|
|
+ state.Persistent = Persistent
|
|
|
+
|
|
|
+ @skip("Issue #264")
|
|
|
+ def test_disable_rate_limits(self):
|
|
|
+ from celery.worker.buckets import FastQueue
|
|
|
+ worker = self.create_worker(disable_rate_limits=True)
|
|
|
+ self.assertIsInstance(worker.ready_queue, FastQueue)
|
|
|
+ self.assertIsNone(worker.mediator)
|
|
|
+ self.assertEqual(worker.ready_queue.put, worker.process_task)
|
|
|
+
|
|
|
def test_start__stop(self):
|
|
|
worker = self.worker
|
|
|
w1 = {"started": False}
|