|  | @@ -1,5 +1,5 @@
 | 
	
		
			
				|  |  |  import socket
 | 
	
		
			
				|  |  | -from celery.tests.utils import unittest
 | 
	
		
			
				|  |  | +import sys
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from datetime import datetime, timedelta
 | 
	
		
			
				|  |  |  from Queue import Empty
 | 
	
	
		
			
				|  | @@ -10,6 +10,7 @@ from celery.utils.timer2 import Timer
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery.app import app_or_default
 | 
	
		
			
				|  |  |  from celery.concurrency.base import BasePool
 | 
	
		
			
				|  |  | +from celery.exceptions import SystemTerminate
 | 
	
		
			
				|  |  |  from celery.task import task as task_dec
 | 
	
		
			
				|  |  |  from celery.task import periodic_task as periodic_task_dec
 | 
	
		
			
				|  |  |  from celery.utils import gen_unique_id
 | 
	
	
		
			
				|  | @@ -21,7 +22,8 @@ from celery.worker.consumer import QoS, RUN
 | 
	
		
			
				|  |  |  from celery.utils.serialization import pickle
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery.tests.compat import catch_warnings
 | 
	
		
			
				|  |  | -from celery.tests.utils import execute_context
 | 
	
		
			
				|  |  | +from celery.tests.utils import unittest
 | 
	
		
			
				|  |  | +from celery.tests.utils import execute_context, skip
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class MockConsumer(object):
 | 
	
	
		
			
				|  | @@ -125,12 +127,16 @@ class MockPool(BasePool):
 | 
	
		
			
				|  |  |      def __init__(self, *args, **kwargs):
 | 
	
		
			
				|  |  |          self.raise_regular = kwargs.get("raise_regular", False)
 | 
	
		
			
				|  |  |          self.raise_base = kwargs.get("raise_base", False)
 | 
	
		
			
				|  |  | +        self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
 | 
	
		
			
				|  |  | +                                                False)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def apply_async(self, *args, **kwargs):
 | 
	
		
			
				|  |  |          if self.raise_regular:
 | 
	
		
			
				|  |  |              raise KeyError("some exception")
 | 
	
		
			
				|  |  |          if self.raise_base:
 | 
	
		
			
				|  |  |              raise KeyboardInterrupt("Ctrl+c")
 | 
	
		
			
				|  |  | +        if self.raise_SystemTerminate:
 | 
	
		
			
				|  |  | +            raise SystemTerminate()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def start(self):
 | 
	
		
			
				|  |  |          pass
 | 
	
	
		
			
				|  | @@ -584,9 +590,65 @@ class test_Consumer(unittest.TestCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class test_WorkController(unittest.TestCase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def create_worker(self, **kw):
 | 
	
		
			
				|  |  | +        worker = WorkController(concurrency=1, loglevel=0, **kw)
 | 
	
		
			
				|  |  | +        worker.logger = MockLogger()
 | 
	
		
			
				|  |  | +        return worker
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def setUp(self):
 | 
	
		
			
				|  |  | -        self.worker = WorkController(concurrency=1, loglevel=0)
 | 
	
		
			
				|  |  | -        self.worker.logger = MockLogger()
 | 
	
		
			
				|  |  | +        self.worker = self.create_worker()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_process_initializer(self):
 | 
	
		
			
				|  |  | +        from celery import Celery
 | 
	
		
			
				|  |  | +        from celery import platforms
 | 
	
		
			
				|  |  | +        from celery import signals
 | 
	
		
			
				|  |  | +        from celery.app import _tls
 | 
	
		
			
				|  |  | +        from celery.worker import process_initializer
 | 
	
		
			
				|  |  | +        from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ignored_signals = []
 | 
	
		
			
				|  |  | +        reset_signals = []
 | 
	
		
			
				|  |  | +        worker_init = [False]
 | 
	
		
			
				|  |  | +        default_app = app_or_default()
 | 
	
		
			
				|  |  | +        app = Celery(loader="default")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Loader(object):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def init_worker(self):
 | 
	
		
			
				|  |  | +                worker_init[0] = True
 | 
	
		
			
				|  |  | +        app.loader = Loader()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def on_worker_process_init(**kwargs):
 | 
	
		
			
				|  |  | +            on_worker_process_init.called = True
 | 
	
		
			
				|  |  | +        on_worker_process_init.called = False
 | 
	
		
			
				|  |  | +        signals.worker_process_init.connect(on_worker_process_init)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def set_mp_process_title(title, hostname=None):
 | 
	
		
			
				|  |  | +            set_mp_process_title.called = (title, hostname)
 | 
	
		
			
				|  |  | +        set_mp_process_title.called = ()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        pignore_signal = platforms.ignore_signal
 | 
	
		
			
				|  |  | +        preset_signal = platforms.reset_signal
 | 
	
		
			
				|  |  | +        psetproctitle = platforms.set_mp_process_title
 | 
	
		
			
				|  |  | +        platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
 | 
	
		
			
				|  |  | +        platforms.reset_signal = lambda sig: reset_signals.append(sig)
 | 
	
		
			
				|  |  | +        platforms.set_mp_process_title = set_mp_process_title
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            process_initializer(app, "awesome.worker.com")
 | 
	
		
			
				|  |  | +            self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
 | 
	
		
			
				|  |  | +            self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
 | 
	
		
			
				|  |  | +            self.assertTrue(worker_init[0])
 | 
	
		
			
				|  |  | +            self.assertTrue(on_worker_process_init.called)
 | 
	
		
			
				|  |  | +            self.assertIs(_tls.current_app, app)
 | 
	
		
			
				|  |  | +            self.assertTupleEqual(set_mp_process_title.called,
 | 
	
		
			
				|  |  | +                                  ("celeryd", "awesome.worker.com"))
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            platforms.ignore_signal = pignore_signal
 | 
	
		
			
				|  |  | +            platforms.reset_signal = preset_signal
 | 
	
		
			
				|  |  | +            platforms.set_mp_process_title = psetproctitle
 | 
	
		
			
				|  |  | +            default_app.set_current()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def test_with_rate_limits_disabled(self):
 | 
	
		
			
				|  |  |          worker = WorkController(concurrency=1, loglevel=0,
 | 
	
	
		
			
				|  | @@ -608,6 +670,50 @@ class test_WorkController(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertTrue(worker.beat)
 | 
	
		
			
				|  |  |          self.assertIn(worker.beat, worker.components)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_with_autoscaler(self):
 | 
	
		
			
				|  |  | +        worker = self.create_worker(autoscale=[10, 3], send_events=False,
 | 
	
		
			
				|  |  | +                                eta_scheduler_cls="celery.utils.timer2.Timer")
 | 
	
		
			
				|  |  | +        self.assertTrue(worker.autoscaler)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_dont_stop_or_terminate(self):
 | 
	
		
			
				|  |  | +        worker = WorkController(concurrency=1, loglevel=0)
 | 
	
		
			
				|  |  | +        worker.stop()
 | 
	
		
			
				|  |  | +        self.assertNotEqual(worker._state, worker.CLOSE)
 | 
	
		
			
				|  |  | +        worker.terminate()
 | 
	
		
			
				|  |  | +        self.assertNotEqual(worker._state, worker.CLOSE)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            worker._state = worker.RUN
 | 
	
		
			
				|  |  | +            worker.stop(in_sighandler=True)
 | 
	
		
			
				|  |  | +            self.assertNotEqual(worker._state, worker.CLOSE)
 | 
	
		
			
				|  |  | +            worker.terminate(in_sighandler=True)
 | 
	
		
			
				|  |  | +            self.assertNotEqual(worker._state, worker.CLOSE)
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            worker.pool.signal_safe = sigsafe
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_on_timer_error(self):
 | 
	
		
			
				|  |  | +        worker = WorkController(concurrency=1, loglevel=0)
 | 
	
		
			
				|  |  | +        worker.logger = MockLogger()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            raise KeyError("foo")
 | 
	
		
			
				|  |  | +        except KeyError:
 | 
	
		
			
				|  |  | +            exc_info = sys.exc_info()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        worker.on_timer_error(exc_info)
 | 
	
		
			
				|  |  | +        logged = worker.logger.logged[0]
 | 
	
		
			
				|  |  | +        self.assertIn("foo", logged)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_on_timer_tick(self):
 | 
	
		
			
				|  |  | +        worker = WorkController(concurrency=1, loglevel=10)
 | 
	
		
			
				|  |  | +        worker.logger = MockLogger()
 | 
	
		
			
				|  |  | +        worker.timer_debug = worker.logger.debug
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        worker.on_timer_tick(30.0)
 | 
	
		
			
				|  |  | +        logged = worker.logger.logged[0]
 | 
	
		
			
				|  |  | +        self.assertIn("30.0", logged)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_process_task(self):
 | 
	
		
			
				|  |  |          worker = self.worker
 | 
	
		
			
				|  |  |          worker.pool = MockPool()
 | 
	
	
		
			
				|  | @@ -630,6 +736,18 @@ class test_WorkController(unittest.TestCase):
 | 
	
		
			
				|  |  |          self.assertRaises(KeyboardInterrupt, worker.process_task, task)
 | 
	
		
			
				|  |  |          self.assertEqual(worker._state, worker.TERMINATE)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_process_task_raise_SystemTerminate(self):
 | 
	
		
			
				|  |  | +        worker = self.worker
 | 
	
		
			
				|  |  | +        worker.pool = MockPool(raise_SystemTerminate=True)
 | 
	
		
			
				|  |  | +        backend = MockBackend()
 | 
	
		
			
				|  |  | +        m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
 | 
	
		
			
				|  |  | +                           kwargs={})
 | 
	
		
			
				|  |  | +        task = TaskRequest.from_message(m, m.decode())
 | 
	
		
			
				|  |  | +        worker.components = []
 | 
	
		
			
				|  |  | +        worker._state = worker.RUN
 | 
	
		
			
				|  |  | +        self.assertRaises(SystemExit, worker.process_task, task)
 | 
	
		
			
				|  |  | +        self.assertEqual(worker._state, worker.TERMINATE)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_process_task_raise_regular(self):
 | 
	
		
			
				|  |  |          worker = self.worker
 | 
	
		
			
				|  |  |          worker.pool = MockPool(raise_regular=True)
 | 
	
	
		
			
				|  | @@ -640,6 +758,59 @@ class test_WorkController(unittest.TestCase):
 | 
	
		
			
				|  |  |          worker.process_task(task)
 | 
	
		
			
				|  |  |          worker.pool.stop()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def test_start_catches_base_exceptions(self):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class Component(object):
 | 
	
		
			
				|  |  | +            stopped = False
 | 
	
		
			
				|  |  | +            terminated = False
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def __init__(self, exc):
 | 
	
		
			
				|  |  | +                self.exc = exc
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def start(self):
 | 
	
		
			
				|  |  | +                raise self.exc
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def terminate(self):
 | 
	
		
			
				|  |  | +                self.terminated = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def stop(self):
 | 
	
		
			
				|  |  | +                self.stopped = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        worker1 = self.create_worker()
 | 
	
		
			
				|  |  | +        worker1.components = [Component(SystemTerminate())]
 | 
	
		
			
				|  |  | +        self.assertRaises(SystemExit, worker1.start)
 | 
	
		
			
				|  |  | +        self.assertTrue(worker1.components[0].terminated)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        worker2 = self.create_worker()
 | 
	
		
			
				|  |  | +        worker2.components = [Component(SystemExit())]
 | 
	
		
			
				|  |  | +        self.assertRaises(SystemExit, worker2.start)
 | 
	
		
			
				|  |  | +        self.assertTrue(worker2.components[0].stopped)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_state_db(self):
 | 
	
		
			
				|  |  | +        from celery.worker import state
 | 
	
		
			
				|  |  | +        Persistent = state.Persistent
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        class MockPersistent(Persistent):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def _load(self):
 | 
	
		
			
				|  |  | +                return {}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        state.Persistent = MockPersistent
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            worker = self.create_worker(db="statefilename")
 | 
	
		
			
				|  |  | +            self.assertTrue(worker._finalize_db)
 | 
	
		
			
				|  |  | +            worker._finalize_db.cancel()
 | 
	
		
			
				|  |  | +        finally:
 | 
	
		
			
				|  |  | +            state.Persistent = Persistent
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @skip("Issue #264")
 | 
	
		
			
				|  |  | +    def test_disable_rate_limits(self):
 | 
	
		
			
				|  |  | +        from celery.worker.buckets import FastQueue
 | 
	
		
			
				|  |  | +        worker = self.create_worker(disable_rate_limits=True)
 | 
	
		
			
				|  |  | +        self.assertIsInstance(worker.ready_queue, FastQueue)
 | 
	
		
			
				|  |  | +        self.assertIsNone(worker.mediator)
 | 
	
		
			
				|  |  | +        self.assertEqual(worker.ready_queue.put, worker.process_task)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def test_start__stop(self):
 | 
	
		
			
				|  |  |          worker = self.worker
 | 
	
		
			
				|  |  |          w1 = {"started": False}
 |