test_worker.py 27 KB


  1. import socket
  2. import sys
  3. from datetime import datetime, timedelta
  4. from Queue import Empty
  5. from kombu.transport.base import Message
  6. from kombu.connection import BrokerConnection
  7. from celery.utils.timer2 import Timer
  8. from celery.app import app_or_default
  9. from celery.concurrency.base import BasePool
  10. from celery.exceptions import SystemTerminate
  11. from celery.task import task as task_dec
  12. from celery.task import periodic_task as periodic_task_dec
  13. from celery.utils import gen_unique_id
  14. from celery.worker import WorkController
  15. from celery.worker.buckets import FastQueue
  16. from celery.worker.job import TaskRequest
  17. from celery.worker.consumer import Consumer as MainConsumer
  18. from celery.worker.consumer import QoS, RUN
  19. from celery.utils.serialization import pickle
  20. from celery.tests.compat import catch_warnings
  21. from celery.tests.utils import unittest
  22. from celery.tests.utils import execute_context, skip
  23. class MockConsumer(object):
  24. class Channel(object):
  25. def close(self):
  26. pass
  27. def register_callback(self, cb):
  28. pass
  29. def consume(self):
  30. pass
  31. @property
  32. def channel(self):
  33. return self.Channel()
  34. class PlaceHolder(object):
  35. pass
  36. class MyKombuConsumer(MainConsumer):
  37. broadcast_consumer = MockConsumer()
  38. task_consumer = MockConsumer()
  39. def restart_heartbeat(self):
  40. self.heart = None
  41. class MockNode(object):
  42. commands = []
  43. def handle_message(self, body, message):
  44. self.commands.append(body.pop("command", None))
  45. class MockEventDispatcher(object):
  46. sent = []
  47. closed = False
  48. flushed = False
  49. def send(self, event, *args, **kwargs):
  50. self.sent.append(event)
  51. def close(self):
  52. self.closed = True
  53. def flush(self):
  54. self.flushed = True
  55. class MockHeart(object):
  56. closed = False
  57. def stop(self):
  58. self.closed = True
  59. @task_dec()
  60. def foo_task(x, y, z, **kwargs):
  61. return x * y * z
  62. @periodic_task_dec(run_every=60)
  63. def foo_periodic_task():
  64. return "foo"
  65. class MockLogger(object):
  66. def __init__(self):
  67. self.logged = []
  68. def critical(self, msg, *args, **kwargs):
  69. self.logged.append(msg)
  70. def info(self, msg, *args, **kwargs):
  71. self.logged.append(msg)
  72. def error(self, msg, *args, **kwargs):
  73. self.logged.append(msg)
  74. def debug(self, msg, *args, **kwargs):
  75. self.logged.append(msg)
  76. class MockBackend(object):
  77. _acked = False
  78. def basic_ack(self, delivery_tag):
  79. self._acked = True
  80. class MockPool(BasePool):
  81. _terminated = False
  82. _stopped = False
  83. def __init__(self, *args, **kwargs):
  84. self.raise_regular = kwargs.get("raise_regular", False)
  85. self.raise_base = kwargs.get("raise_base", False)
  86. self.raise_SystemTerminate = kwargs.get("raise_SystemTerminate",
  87. False)
  88. def apply_async(self, *args, **kwargs):
  89. if self.raise_regular:
  90. raise KeyError("some exception")
  91. if self.raise_base:
  92. raise KeyboardInterrupt("Ctrl+c")
  93. if self.raise_SystemTerminate:
  94. raise SystemTerminate()
  95. def start(self):
  96. pass
  97. def stop(self):
  98. self._stopped = True
  99. return True
  100. def terminate(self):
  101. self._terminated = True
  102. self.stop()
  103. class MockController(object):
  104. def __init__(self, w, *args, **kwargs):
  105. self._w = w
  106. self._stopped = False
  107. def start(self):
  108. self._w["started"] = True
  109. self._stopped = False
  110. def stop(self):
  111. self._stopped = True
  112. def create_message(backend, **data):
  113. data.setdefault("id", gen_unique_id())
  114. return Message(backend, body=pickle.dumps(dict(**data)),
  115. content_type="application/x-python-serialize",
  116. content_encoding="binary")
  117. class test_QoS(unittest.TestCase):
  118. class MockConsumer(object):
  119. prefetch_count = 0
  120. def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
  121. self.prefetch_count = prefetch_count
  122. def test_increment_decrement(self):
  123. consumer = self.MockConsumer()
  124. qos = QoS(consumer, 10, app_or_default().log.get_default_logger())
  125. qos.update()
  126. self.assertEqual(int(qos.value), 10)
  127. self.assertEqual(consumer.prefetch_count, 10)
  128. qos.decrement()
  129. self.assertEqual(int(qos.value), 9)
  130. self.assertEqual(consumer.prefetch_count, 9)
  131. qos.decrement_eventually()
  132. self.assertEqual(int(qos.value), 8)
  133. self.assertEqual(consumer.prefetch_count, 9)
  134. # Does not decrement 0 value
  135. qos.value._value = 0
  136. qos.decrement()
  137. self.assertEqual(int(qos.value), 0)
  138. qos.increment()
  139. self.assertEqual(int(qos.value), 0)
  140. class test_Consumer(unittest.TestCase):
  141. def setUp(self):
  142. self.ready_queue = FastQueue()
  143. self.eta_schedule = Timer()
  144. self.logger = app_or_default().log.get_default_logger()
  145. self.logger.setLevel(0)
  146. def tearDown(self):
  147. self.eta_schedule.stop()
  148. def test_info(self):
  149. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  150. send_events=False)
  151. l.qos = QoS(l.task_consumer, 10, l.logger)
  152. info = l.info
  153. self.assertEqual(info["prefetch_count"], 10)
  154. self.assertFalse(info["broker"])
  155. l.connection = app_or_default().broker_connection()
  156. info = l.info
  157. self.assertTrue(info["broker"])
  158. def test_connection(self):
  159. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  160. send_events=False)
  161. l.reset_connection()
  162. self.assertIsInstance(l.connection, BrokerConnection)
  163. l._state = RUN
  164. l.event_dispatcher = None
  165. l.stop_consumers(close=False)
  166. self.assertTrue(l.connection)
  167. l._state = RUN
  168. l.stop_consumers()
  169. self.assertIsNone(l.connection)
  170. self.assertIsNone(l.task_consumer)
  171. l.reset_connection()
  172. self.assertIsInstance(l.connection, BrokerConnection)
  173. l.stop_consumers()
  174. l.stop()
  175. l.close_connection()
  176. self.assertIsNone(l.connection)
  177. self.assertIsNone(l.task_consumer)
  178. def test_close_connection(self):
  179. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  180. send_events=False)
  181. l._state = RUN
  182. l.close_connection()
  183. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  184. send_events=False)
  185. eventer = l.event_dispatcher = MockEventDispatcher()
  186. heart = l.heart = MockHeart()
  187. l._state = RUN
  188. l.stop_consumers()
  189. self.assertTrue(eventer.closed)
  190. self.assertTrue(heart.closed)
  191. def test_receive_message_unknown(self):
  192. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  193. send_events=False)
  194. backend = MockBackend()
  195. m = create_message(backend, unknown={"baz": "!!!"})
  196. l.event_dispatcher = MockEventDispatcher()
  197. l.pidbox_node = MockNode()
  198. def with_catch_warnings(log):
  199. l.receive_message(m.decode(), m)
  200. self.assertTrue(log)
  201. self.assertIn("unknown message", log[0].message.args[0])
  202. context = catch_warnings(record=True)
  203. execute_context(context, with_catch_warnings)
  204. def test_receive_message_eta_OverflowError(self):
  205. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  206. send_events=False)
  207. backend = MockBackend()
  208. called = [False]
  209. def to_timestamp(d):
  210. called[0] = True
  211. raise OverflowError()
  212. m = create_message(backend, task=foo_task.name,
  213. args=("2, 2"),
  214. kwargs={},
  215. eta=datetime.now().isoformat())
  216. l.event_dispatcher = MockEventDispatcher()
  217. l.pidbox_node = MockNode()
  218. from celery.worker import consumer
  219. prev, consumer.to_timestamp = consumer.to_timestamp, to_timestamp
  220. try:
  221. l.receive_message(m.decode(), m)
  222. self.assertTrue(m.acknowledged)
  223. self.assertTrue(called[0])
  224. finally:
  225. consumer.to_timestamp = prev
  226. def test_receive_message_InvalidTaskError(self):
  227. logger = MockLogger()
  228. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  229. send_events=False)
  230. backend = MockBackend()
  231. m = create_message(backend, task=foo_task.name,
  232. args=(1, 2), kwargs="foobarbaz", id=1)
  233. l.event_dispatcher = MockEventDispatcher()
  234. l.pidbox_node = MockNode()
  235. l.receive_message(m.decode(), m)
  236. self.assertIn("Invalid task ignored", logger.logged[0])
  237. def test_on_decode_error(self):
  238. logger = MockLogger()
  239. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  240. send_events=False)
  241. class MockMessage(object):
  242. content_type = "application/x-msgpack"
  243. content_encoding = "binary"
  244. body = "foobarbaz"
  245. acked = False
  246. def ack(self):
  247. self.acked = True
  248. message = MockMessage()
  249. l.on_decode_error(message, KeyError("foo"))
  250. self.assertTrue(message.acked)
  251. self.assertIn("Can't decode message body", logger.logged[0])
  252. def test_receieve_message(self):
  253. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  254. send_events=False)
  255. backend = MockBackend()
  256. m = create_message(backend, task=foo_task.name,
  257. args=[2, 4, 8], kwargs={})
  258. l.event_dispatcher = MockEventDispatcher()
  259. l.receive_message(m.decode(), m)
  260. in_bucket = self.ready_queue.get_nowait()
  261. self.assertIsInstance(in_bucket, TaskRequest)
  262. self.assertEqual(in_bucket.task_name, foo_task.name)
  263. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  264. self.assertTrue(self.eta_schedule.empty())
  265. def test_start_connection_error(self):
  266. class MockConsumer(MainConsumer):
  267. iterations = 0
  268. def consume_messages(self):
  269. if not self.iterations:
  270. self.iterations = 1
  271. raise KeyError("foo")
  272. raise SyntaxError("bar")
  273. l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
  274. send_events=False)
  275. l.connection_errors = (KeyError, )
  276. self.assertRaises(SyntaxError, l.start)
  277. l.heart.stop()
  278. def test_consume_messages(self):
  279. app = app_or_default()
  280. class Connection(app.broker_connection().__class__):
  281. obj = None
  282. def drain_events(self, **kwargs):
  283. self.obj.connection = None
  284. class Consumer(object):
  285. consuming = False
  286. prefetch_count = 0
  287. def consume(self):
  288. self.consuming = True
  289. def qos(self, prefetch_size=0, prefetch_count=0,
  290. apply_global=False):
  291. self.prefetch_count = prefetch_count
  292. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  293. send_events=False)
  294. l.connection = Connection()
  295. l.connection.obj = l
  296. l.task_consumer = Consumer()
  297. l.broadcast_consumer = Consumer()
  298. l.qos = QoS(l.task_consumer, 10, l.logger)
  299. l.consume_messages()
  300. l.consume_messages()
  301. self.assertTrue(l.task_consumer.consuming)
  302. self.assertTrue(l.broadcast_consumer.consuming)
  303. self.assertEqual(l.task_consumer.prefetch_count, 10)
  304. l.qos.decrement()
  305. l.consume_messages()
  306. self.assertEqual(l.task_consumer.prefetch_count, 9)
  307. def test_maybe_conn_error(self):
  308. def raises(error):
  309. def fun():
  310. raise error
  311. return fun
  312. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  313. send_events=False)
  314. l.connection_errors = (KeyError, )
  315. l.channel_errors = (SyntaxError, )
  316. l.maybe_conn_error(raises(AttributeError("foo")))
  317. l.maybe_conn_error(raises(KeyError("foo")))
  318. l.maybe_conn_error(raises(SyntaxError("foo")))
  319. self.assertRaises(IndexError, l.maybe_conn_error,
  320. raises(IndexError("foo")))
  321. def test_apply_eta_task(self):
  322. from celery.worker import state
  323. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  324. send_events=False)
  325. l.qos = QoS(None, 10, l.logger)
  326. task = object()
  327. qos = l.qos.next
  328. l.apply_eta_task(task)
  329. self.assertIn(task, state.reserved_requests)
  330. self.assertEqual(l.qos.next, qos - 1)
  331. self.assertIs(self.ready_queue.get_nowait(), task)
  332. def test_receieve_message_eta_isoformat(self):
  333. class MockConsumer(object):
  334. prefetch_count_incremented = False
  335. def qos(self, **kwargs):
  336. self.prefetch_count_incremented = True
  337. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  338. send_events=False)
  339. backend = MockBackend()
  340. m = create_message(backend, task=foo_task.name,
  341. eta=datetime.now().isoformat(),
  342. args=[2, 4, 8], kwargs={})
  343. l.task_consumer = MockConsumer()
  344. l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
  345. l.event_dispatcher = MockEventDispatcher()
  346. l.receive_message(m.decode(), m)
  347. l.eta_schedule.stop()
  348. items = [entry[2] for entry in self.eta_schedule.queue]
  349. found = 0
  350. for item in items:
  351. if item.args[0].task_name == foo_task.name:
  352. found = True
  353. self.assertTrue(found)
  354. self.assertTrue(l.task_consumer.prefetch_count_incremented)
  355. l.eta_schedule.stop()
  356. def test_revoke(self):
  357. ready_queue = FastQueue()
  358. l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
  359. send_events=False)
  360. backend = MockBackend()
  361. id = gen_unique_id()
  362. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  363. kwargs={}, id=id)
  364. from celery.worker.state import revoked
  365. revoked.add(id)
  366. l.receive_message(t.decode(), t)
  367. self.assertTrue(ready_queue.empty())
  368. def test_receieve_message_not_registered(self):
  369. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  370. send_events=False)
  371. backend = MockBackend()
  372. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  373. l.event_dispatcher = MockEventDispatcher()
  374. self.assertFalse(l.receive_message(m.decode(), m))
  375. self.assertRaises(Empty, self.ready_queue.get_nowait)
  376. self.assertTrue(self.eta_schedule.empty())
  377. def test_receieve_message_eta(self):
  378. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  379. send_events=False)
  380. dispatcher = l.event_dispatcher = MockEventDispatcher()
  381. backend = MockBackend()
  382. m = create_message(backend, task=foo_task.name,
  383. args=[2, 4, 8], kwargs={},
  384. eta=(datetime.now() +
  385. timedelta(days=1)).isoformat())
  386. l.reset_connection()
  387. p = l.app.conf.BROKER_CONNECTION_RETRY
  388. l.app.conf.BROKER_CONNECTION_RETRY = False
  389. try:
  390. l.reset_connection()
  391. finally:
  392. l.app.conf.BROKER_CONNECTION_RETRY = p
  393. l.stop_consumers()
  394. self.assertTrue(dispatcher.flushed)
  395. l.event_dispatcher = MockEventDispatcher()
  396. l.receive_message(m.decode(), m)
  397. l.eta_schedule.stop()
  398. in_hold = self.eta_schedule.queue[0]
  399. self.assertEqual(len(in_hold), 3)
  400. eta, priority, entry = in_hold
  401. task = entry.args[0]
  402. self.assertIsInstance(task, TaskRequest)
  403. self.assertEqual(task.task_name, foo_task.name)
  404. self.assertEqual(task.execute(), 2 * 4 * 8)
  405. self.assertRaises(Empty, self.ready_queue.get_nowait)
  406. def test_start__consume_messages(self):
  407. class _QoS(object):
  408. prev = 3
  409. next = 4
  410. def update(self):
  411. self.prev = self.next
  412. class _Consumer(MyKombuConsumer):
  413. iterations = 0
  414. wait_method = None
  415. def reset_connection(self):
  416. if self.iterations >= 1:
  417. raise KeyError("foo")
  418. called_back = [False]
  419. def init_callback(consumer):
  420. called_back[0] = True
  421. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  422. send_events=False, init_callback=init_callback)
  423. l.task_consumer = MockConsumer()
  424. l.broadcast_consumer = MockConsumer()
  425. l.qos = _QoS()
  426. l.connection = BrokerConnection()
  427. l.iterations = 0
  428. def raises_KeyError(limit=None):
  429. l.iterations += 1
  430. if l.qos.prev != l.qos.next:
  431. l.qos.update()
  432. if l.iterations >= 2:
  433. raise KeyError("foo")
  434. l.consume_messages = raises_KeyError
  435. self.assertRaises(KeyError, l.start)
  436. self.assertTrue(called_back[0])
  437. self.assertEqual(l.iterations, 1)
  438. self.assertEqual(l.qos.prev, l.qos.next)
  439. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  440. send_events=False, init_callback=init_callback)
  441. l.qos = _QoS()
  442. l.task_consumer = MockConsumer()
  443. l.broadcast_consumer = MockConsumer()
  444. l.connection = BrokerConnection()
  445. def raises_socket_error(limit=None):
  446. l.iterations = 1
  447. raise socket.error("foo")
  448. l.consume_messages = raises_socket_error
  449. self.assertRaises(socket.error, l.start)
  450. self.assertTrue(called_back[0])
  451. self.assertEqual(l.iterations, 1)
  452. class test_WorkController(unittest.TestCase):
  453. def create_worker(self, **kw):
  454. worker = WorkController(concurrency=1, loglevel=0, **kw)
  455. worker.logger = MockLogger()
  456. return worker
  457. def setUp(self):
  458. self.worker = self.create_worker()
  459. def test_process_initializer(self):
  460. from celery import Celery
  461. from celery import platforms
  462. from celery import signals
  463. from celery.app import _tls
  464. from celery.worker import process_initializer
  465. from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
  466. ignored_signals = []
  467. reset_signals = []
  468. worker_init = [False]
  469. default_app = app_or_default()
  470. app = Celery(loader="default")
  471. class Loader(object):
  472. def init_worker(self):
  473. worker_init[0] = True
  474. app.loader = Loader()
  475. def on_worker_process_init(**kwargs):
  476. on_worker_process_init.called = True
  477. on_worker_process_init.called = False
  478. signals.worker_process_init.connect(on_worker_process_init)
  479. def set_mp_process_title(title, hostname=None):
  480. set_mp_process_title.called = (title, hostname)
  481. set_mp_process_title.called = ()
  482. pignore_signal = platforms.ignore_signal
  483. preset_signal = platforms.reset_signal
  484. psetproctitle = platforms.set_mp_process_title
  485. platforms.ignore_signal = lambda sig: ignored_signals.append(sig)
  486. platforms.reset_signal = lambda sig: reset_signals.append(sig)
  487. platforms.set_mp_process_title = set_mp_process_title
  488. try:
  489. process_initializer(app, "awesome.worker.com")
  490. self.assertItemsEqual(ignored_signals, WORKER_SIGIGNORE)
  491. self.assertItemsEqual(reset_signals, WORKER_SIGRESET)
  492. self.assertTrue(worker_init[0])
  493. self.assertTrue(on_worker_process_init.called)
  494. self.assertIs(_tls.current_app, app)
  495. self.assertTupleEqual(set_mp_process_title.called,
  496. ("celeryd", "awesome.worker.com"))
  497. finally:
  498. platforms.ignore_signal = pignore_signal
  499. platforms.reset_signal = preset_signal
  500. platforms.set_mp_process_title = psetproctitle
  501. default_app.set_current()
  502. def test_with_rate_limits_disabled(self):
  503. worker = WorkController(concurrency=1, loglevel=0,
  504. disable_rate_limits=True)
  505. self.assertTrue(hasattr(worker.ready_queue, "put"))
  506. def test_attrs(self):
  507. worker = self.worker
  508. self.assertIsInstance(worker.scheduler, Timer)
  509. self.assertTrue(worker.scheduler)
  510. self.assertTrue(worker.pool)
  511. self.assertTrue(worker.consumer)
  512. self.assertTrue(worker.mediator)
  513. self.assertTrue(worker.components)
  514. def test_with_embedded_celerybeat(self):
  515. worker = WorkController(concurrency=1, loglevel=0,
  516. embed_clockservice=True)
  517. self.assertTrue(worker.beat)
  518. self.assertIn(worker.beat, worker.components)
  519. def test_with_autoscaler(self):
  520. worker = self.create_worker(autoscale=[10, 3], send_events=False,
  521. eta_scheduler_cls="celery.utils.timer2.Timer")
  522. self.assertTrue(worker.autoscaler)
  523. def test_dont_stop_or_terminate(self):
  524. worker = WorkController(concurrency=1, loglevel=0)
  525. worker.stop()
  526. self.assertNotEqual(worker._state, worker.CLOSE)
  527. worker.terminate()
  528. self.assertNotEqual(worker._state, worker.CLOSE)
  529. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  530. try:
  531. worker._state = worker.RUN
  532. worker.stop(in_sighandler=True)
  533. self.assertNotEqual(worker._state, worker.CLOSE)
  534. worker.terminate(in_sighandler=True)
  535. self.assertNotEqual(worker._state, worker.CLOSE)
  536. finally:
  537. worker.pool.signal_safe = sigsafe
  538. def test_on_timer_error(self):
  539. worker = WorkController(concurrency=1, loglevel=0)
  540. worker.logger = MockLogger()
  541. try:
  542. raise KeyError("foo")
  543. except KeyError:
  544. exc_info = sys.exc_info()
  545. worker.on_timer_error(exc_info)
  546. logged = worker.logger.logged[0]
  547. self.assertIn("KeyError", logged)
  548. def test_on_timer_tick(self):
  549. worker = WorkController(concurrency=1, loglevel=10)
  550. worker.logger = MockLogger()
  551. worker.timer_debug = worker.logger.debug
  552. worker.on_timer_tick(30.0)
  553. logged = worker.logger.logged[0]
  554. self.assertIn("30.0", logged)
  555. def test_process_task(self):
  556. worker = self.worker
  557. worker.pool = MockPool()
  558. backend = MockBackend()
  559. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  560. kwargs={})
  561. task = TaskRequest.from_message(m, m.decode())
  562. worker.process_task(task)
  563. worker.pool.stop()
  564. def test_process_task_raise_base(self):
  565. worker = self.worker
  566. worker.pool = MockPool(raise_base=True)
  567. backend = MockBackend()
  568. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  569. kwargs={})
  570. task = TaskRequest.from_message(m, m.decode())
  571. worker.components = []
  572. worker._state = worker.RUN
  573. self.assertRaises(KeyboardInterrupt, worker.process_task, task)
  574. self.assertEqual(worker._state, worker.TERMINATE)
  575. def test_process_task_raise_SystemTerminate(self):
  576. worker = self.worker
  577. worker.pool = MockPool(raise_SystemTerminate=True)
  578. backend = MockBackend()
  579. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  580. kwargs={})
  581. task = TaskRequest.from_message(m, m.decode())
  582. worker.components = []
  583. worker._state = worker.RUN
  584. self.assertRaises(SystemExit, worker.process_task, task)
  585. self.assertEqual(worker._state, worker.TERMINATE)
  586. def test_process_task_raise_regular(self):
  587. worker = self.worker
  588. worker.pool = MockPool(raise_regular=True)
  589. backend = MockBackend()
  590. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  591. kwargs={})
  592. task = TaskRequest.from_message(m, m.decode())
  593. worker.process_task(task)
  594. worker.pool.stop()
  595. def test_start_catches_base_exceptions(self):
  596. class Component(object):
  597. stopped = False
  598. terminated = False
  599. def __init__(self, exc):
  600. self.exc = exc
  601. def start(self):
  602. raise self.exc
  603. def terminate(self):
  604. self.terminated = True
  605. def stop(self):
  606. self.stopped = True
  607. worker1 = self.create_worker()
  608. worker1.components = [Component(SystemTerminate())]
  609. self.assertRaises(SystemExit, worker1.start)
  610. self.assertTrue(worker1.components[0].terminated)
  611. worker2 = self.create_worker()
  612. worker2.components = [Component(SystemExit())]
  613. self.assertRaises(SystemExit, worker2.start)
  614. self.assertTrue(worker2.components[0].stopped)
  615. def test_state_db(self):
  616. from celery.worker import state
  617. Persistent = state.Persistent
  618. class MockPersistent(Persistent):
  619. def _load(self):
  620. return {}
  621. state.Persistent = MockPersistent
  622. try:
  623. worker = self.create_worker(db="statefilename")
  624. self.assertTrue(worker._finalize_db)
  625. worker._finalize_db.cancel()
  626. finally:
  627. state.Persistent = Persistent
  628. @skip("Issue #264")
  629. def test_disable_rate_limits(self):
  630. from celery.worker.buckets import FastQueue
  631. worker = self.create_worker(disable_rate_limits=True)
  632. self.assertIsInstance(worker.ready_queue, FastQueue)
  633. self.assertIsNone(worker.mediator)
  634. self.assertEqual(worker.ready_queue.put, worker.process_task)
  635. def test_start__stop(self):
  636. worker = self.worker
  637. w1 = {"started": False}
  638. w2 = {"started": False}
  639. w3 = {"started": False}
  640. w4 = {"started": False}
  641. worker.components = [MockController(w1), MockController(w2),
  642. MockController(w3), MockController(w4)]
  643. worker.start()
  644. for w in (w1, w2, w3, w4):
  645. self.assertTrue(w["started"])
  646. self.assertTrue(worker._running, len(worker.components))
  647. worker.stop()
  648. for component in worker.components:
  649. self.assertTrue(component._stopped)
  650. def test_start__terminate(self):
  651. worker = self.worker
  652. w1 = {"started": False}
  653. w2 = {"started": False}
  654. w3 = {"started": False}
  655. w4 = {"started": False}
  656. worker.components = [MockController(w1), MockController(w2),
  657. MockController(w3), MockController(w4),
  658. MockPool()]
  659. worker.start()
  660. for w in (w1, w2, w3, w4):
  661. self.assertTrue(w["started"])
  662. self.assertTrue(worker._running, len(worker.components))
  663. self.assertEqual(worker._state, RUN)
  664. worker.terminate()
  665. for component in worker.components:
  666. self.assertTrue(component._stopped)
  667. self.assertTrue(worker.components[4]._terminated)