test_worker.py 31 KB


  1. from __future__ import with_statement
  2. import socket
  3. import sys
  4. from collections import deque
  5. from datetime import datetime, timedelta
  6. from Queue import Empty
  7. from kombu.transport.base import Message
  8. from kombu.connection import BrokerConnection
  9. from mock import Mock, patch
  10. from celery import current_app
  11. from celery.concurrency.base import BasePool
  12. from celery.exceptions import SystemTerminate
  13. from celery.task import task as task_dec
  14. from celery.task import periodic_task as periodic_task_dec
  15. from celery.utils import gen_unique_id
  16. from celery.worker import WorkController
  17. from celery.worker.buckets import FastQueue
  18. from celery.worker.job import TaskRequest
  19. from celery.worker.consumer import Consumer as MainConsumer
  20. from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
  21. from celery.utils.serialization import pickle
  22. from celery.utils.timer2 import Timer
  23. from celery.tests.compat import catch_warnings
  24. from celery.tests.utils import unittest
  25. from celery.tests.utils import AppCase, skip
  26. class PlaceHolder(object):
  27. pass
  28. class MyKombuConsumer(MainConsumer):
  29. broadcast_consumer = Mock()
  30. task_consumer = Mock()
  31. def __init__(self, *args, **kwargs):
  32. kwargs.setdefault("pool", BasePool(2))
  33. super(MyKombuConsumer, self).__init__(*args, **kwargs)
  34. def restart_heartbeat(self):
  35. self.heart = None
  36. class MockNode(object):
  37. commands = []
  38. def handle_message(self, body, message):
  39. self.commands.append(body.pop("command", None))
  40. class MockEventDispatcher(object):
  41. sent = []
  42. closed = False
  43. flushed = False
  44. _outbound_buffer = []
  45. def send(self, event, *args, **kwargs):
  46. self.sent.append(event)
  47. def close(self):
  48. self.closed = True
  49. def flush(self):
  50. self.flushed = True
  51. class MockHeart(object):
  52. closed = False
  53. def stop(self):
  54. self.closed = True
  55. @task_dec()
  56. def foo_task(x, y, z, **kwargs):
  57. return x * y * z
  58. @periodic_task_dec(run_every=60)
  59. def foo_periodic_task():
  60. return "foo"
  61. def create_message(channel, **data):
  62. data.setdefault("id", gen_unique_id())
  63. channel.no_ack_consumers = set()
  64. return Message(channel, body=pickle.dumps(dict(**data)),
  65. content_type="application/x-python-serialize",
  66. content_encoding="binary",
  67. delivery_info={"consumer_tag": "mock"})
  68. class test_QoS(unittest.TestCase):
  69. class _QoS(QoS):
  70. def __init__(self, value):
  71. self.value = value
  72. QoS.__init__(self, None, value, None)
  73. def set(self, value):
  74. return value
  75. def test_qos_increment_decrement(self):
  76. qos = self._QoS(10)
  77. self.assertEqual(qos.increment(), 11)
  78. self.assertEqual(qos.increment(3), 14)
  79. self.assertEqual(qos.increment(-30), 14)
  80. self.assertEqual(qos.decrement(7), 7)
  81. self.assertEqual(qos.decrement(), 6)
  82. self.assertRaises(AssertionError, qos.decrement, 10)
  83. def test_qos_disabled_increment_decrement(self):
  84. qos = self._QoS(0)
  85. self.assertEqual(qos.increment(), 0)
  86. self.assertEqual(qos.increment(3), 0)
  87. self.assertEqual(qos.increment(-30), 0)
  88. self.assertEqual(qos.decrement(7), 0)
  89. self.assertEqual(qos.decrement(), 0)
  90. self.assertEqual(qos.decrement(10), 0)
  91. def test_qos_thread_safe(self):
  92. qos = self._QoS(10)
  93. def add():
  94. for i in xrange(1000):
  95. qos.increment()
  96. def sub():
  97. for i in xrange(1000):
  98. qos.decrement_eventually()
  99. def threaded(funs):
  100. from threading import Thread
  101. threads = [Thread(target=fun) for fun in funs]
  102. for thread in threads:
  103. thread.start()
  104. for thread in threads:
  105. thread.join()
  106. threaded([add, add])
  107. self.assertEqual(qos.value, 2010)
  108. qos.value = 1000
  109. threaded([add, sub]) # n = 2
  110. self.assertEqual(qos.value, 1000)
  111. def test_exceeds_short(self):
  112. qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
  113. current_app.log.get_default_logger())
  114. qos.update()
  115. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  116. qos.increment()
  117. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  118. qos.increment()
  119. self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
  120. qos.decrement()
  121. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  122. qos.decrement()
  123. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  124. def test_consumer_increment_decrement(self):
  125. consumer = Mock()
  126. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  127. qos.update()
  128. self.assertEqual(qos.value, 10)
  129. self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
  130. qos.decrement()
  131. self.assertEqual(qos.value, 9)
  132. self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
  133. qos.decrement_eventually()
  134. self.assertEqual(qos.value, 8)
  135. self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
  136. # Does not decrement 0 value
  137. qos.value = 0
  138. qos.decrement()
  139. self.assertEqual(qos.value, 0)
  140. qos.increment()
  141. self.assertEqual(qos.value, 0)
  142. def test_consumer_decrement_eventually(self):
  143. consumer = Mock()
  144. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  145. qos.decrement_eventually()
  146. self.assertEqual(qos.value, 9)
  147. qos.value = 0
  148. qos.decrement_eventually()
  149. self.assertEqual(qos.value, 0)
  150. def test_set(self):
  151. consumer = Mock()
  152. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  153. qos.set(12)
  154. self.assertEqual(qos.prev, 12)
  155. qos.set(qos.prev)
  156. class test_Consumer(unittest.TestCase):
  157. def setUp(self):
  158. self.ready_queue = FastQueue()
  159. self.eta_schedule = Timer()
  160. self.logger = current_app.log.get_default_logger()
  161. self.logger.setLevel(0)
  162. def tearDown(self):
  163. self.eta_schedule.stop()
  164. def test_info(self):
  165. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  166. send_events=False)
  167. l.qos = QoS(l.task_consumer, 10, l.logger)
  168. info = l.info
  169. self.assertEqual(info["prefetch_count"], 10)
  170. self.assertFalse(info["broker"])
  171. l.connection = current_app.broker_connection()
  172. info = l.info
  173. self.assertTrue(info["broker"])
  174. def test_start_when_closed(self):
  175. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  176. send_events=False)
  177. l._state = CLOSE
  178. l.start()
  179. def test_connection(self):
  180. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  181. send_events=False)
  182. l.reset_connection()
  183. self.assertIsInstance(l.connection, BrokerConnection)
  184. l._state = RUN
  185. l.event_dispatcher = None
  186. l.stop_consumers(close_connection=False)
  187. self.assertTrue(l.connection)
  188. l._state = RUN
  189. l.stop_consumers()
  190. self.assertIsNone(l.connection)
  191. self.assertIsNone(l.task_consumer)
  192. l.reset_connection()
  193. self.assertIsInstance(l.connection, BrokerConnection)
  194. l.stop_consumers()
  195. l.stop()
  196. l.close_connection()
  197. self.assertIsNone(l.connection)
  198. self.assertIsNone(l.task_consumer)
  199. def test_close_connection(self):
  200. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  201. send_events=False)
  202. l._state = RUN
  203. l.close_connection()
  204. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  205. send_events=False)
  206. eventer = l.event_dispatcher = Mock()
  207. eventer.enabled = True
  208. heart = l.heart = MockHeart()
  209. l._state = RUN
  210. l.stop_consumers()
  211. self.assertTrue(eventer.close.call_count)
  212. self.assertTrue(heart.closed)
  213. def test_receive_message_unknown(self):
  214. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  215. send_events=False)
  216. backend = Mock()
  217. m = create_message(backend, unknown={"baz": "!!!"})
  218. l.event_dispatcher = Mock()
  219. l.pidbox_node = MockNode()
  220. with catch_warnings(record=True) as log:
  221. l.receive_message(m.decode(), m)
  222. self.assertTrue(log)
  223. self.assertIn("unknown message", log[0].message.args[0])
  224. @patch("celery.utils.timer2.to_timestamp")
  225. def test_receive_message_eta_OverflowError(self, to_timestamp):
  226. to_timestamp.side_effect = OverflowError()
  227. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  228. send_events=False)
  229. m = create_message(Mock(), task=foo_task.name,
  230. args=("2, 2"),
  231. kwargs={},
  232. eta=datetime.now().isoformat())
  233. l.event_dispatcher = Mock()
  234. l.pidbox_node = MockNode()
  235. l.receive_message(m.decode(), m)
  236. self.assertTrue(m.acknowledged)
  237. self.assertTrue(to_timestamp.call_count)
  238. def test_receive_message_InvalidTaskError(self):
  239. logger = Mock()
  240. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  241. send_events=False)
  242. m = create_message(Mock(), task=foo_task.name,
  243. args=(1, 2), kwargs="foobarbaz", id=1)
  244. l.event_dispatcher = Mock()
  245. l.pidbox_node = MockNode()
  246. l.receive_message(m.decode(), m)
  247. self.assertIn("Received invalid task message",
  248. logger.error.call_args[0][0])
  249. def test_on_decode_error(self):
  250. logger = Mock()
  251. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  252. send_events=False)
  253. class MockMessage(Mock):
  254. content_type = "application/x-msgpack"
  255. content_encoding = "binary"
  256. body = "foobarbaz"
  257. message = MockMessage()
  258. l.on_decode_error(message, KeyError("foo"))
  259. self.assertTrue(message.ack.call_count)
  260. self.assertIn("Can't decode message body",
  261. logger.critical.call_args[0][0])
  262. def test_receieve_message(self):
  263. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  264. send_events=False)
  265. m = create_message(Mock(), task=foo_task.name,
  266. args=[2, 4, 8], kwargs={})
  267. l.event_dispatcher = Mock()
  268. l.receive_message(m.decode(), m)
  269. in_bucket = self.ready_queue.get_nowait()
  270. self.assertIsInstance(in_bucket, TaskRequest)
  271. self.assertEqual(in_bucket.task_name, foo_task.name)
  272. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  273. self.assertTrue(self.eta_schedule.empty())
  274. def test_start_connection_error(self):
  275. class MockConsumer(MainConsumer):
  276. iterations = 0
  277. def consume_messages(self):
  278. if not self.iterations:
  279. self.iterations = 1
  280. raise KeyError("foo")
  281. raise SyntaxError("bar")
  282. l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
  283. send_events=False, pool=BasePool())
  284. l.connection_errors = (KeyError, )
  285. self.assertRaises(SyntaxError, l.start)
  286. l.heart.stop()
  287. l.priority_timer.stop()
  288. def test_consume_messages_ignores_socket_timeout(self):
  289. class Connection(current_app.broker_connection().__class__):
  290. obj = None
  291. def drain_events(self, **kwargs):
  292. self.obj.connection = None
  293. raise socket.timeout(10)
  294. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  295. send_events=False)
  296. l.connection = Connection()
  297. l.task_consumer = Mock()
  298. l.connection.obj = l
  299. l.qos = QoS(l.task_consumer, 10, l.logger)
  300. l.consume_messages()
  301. def test_consume_messages_when_socket_error(self):
  302. class Connection(current_app.broker_connection().__class__):
  303. obj = None
  304. def drain_events(self, **kwargs):
  305. self.obj.connection = None
  306. raise socket.error("foo")
  307. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  308. send_events=False)
  309. l._state = RUN
  310. c = l.connection = Connection()
  311. l.connection.obj = l
  312. l.task_consumer = Mock()
  313. l.qos = QoS(l.task_consumer, 10, l.logger)
  314. with self.assertRaises(socket.error):
  315. l.consume_messages()
  316. l._state = CLOSE
  317. l.connection = c
  318. l.consume_messages()
  319. def test_consume_messages(self):
  320. class Connection(current_app.broker_connection().__class__):
  321. obj = None
  322. def drain_events(self, **kwargs):
  323. self.obj.connection = None
  324. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  325. send_events=False)
  326. l.connection = Connection()
  327. l.connection.obj = l
  328. l.task_consumer = Mock()
  329. l.qos = QoS(l.task_consumer, 10, l.logger)
  330. l.consume_messages()
  331. l.consume_messages()
  332. self.assertTrue(l.task_consumer.consume.call_count)
  333. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  334. l.qos.decrement()
  335. l.consume_messages()
  336. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  337. def test_maybe_conn_error(self):
  338. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  339. send_events=False)
  340. l.connection_errors = (KeyError, )
  341. l.channel_errors = (SyntaxError, )
  342. l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
  343. l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
  344. l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
  345. self.assertRaises(IndexError, l.maybe_conn_error,
  346. Mock(side_effect=IndexError("foo")))
  347. def test_apply_eta_task(self):
  348. from celery.worker import state
  349. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  350. send_events=False)
  351. l.qos = QoS(None, 10, l.logger)
  352. task = object()
  353. qos = l.qos.value
  354. l.apply_eta_task(task)
  355. self.assertIn(task, state.reserved_requests)
  356. self.assertEqual(l.qos.value, qos - 1)
  357. self.assertIs(self.ready_queue.get_nowait(), task)
  358. def test_receieve_message_eta_isoformat(self):
  359. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  360. send_events=False)
  361. m = create_message(Mock(), task=foo_task.name,
  362. eta=datetime.now().isoformat(),
  363. args=[2, 4, 8], kwargs={})
  364. l.task_consumer = Mock()
  365. l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
  366. l.event_dispatcher = Mock()
  367. l.enabled = False
  368. l.receive_message(m.decode(), m)
  369. l.eta_schedule.stop()
  370. items = [entry[2] for entry in self.eta_schedule.queue]
  371. found = 0
  372. for item in items:
  373. if item.args[0].task_name == foo_task.name:
  374. found = True
  375. self.assertTrue(found)
  376. self.assertTrue(l.task_consumer.qos.call_count)
  377. l.eta_schedule.stop()
  378. def test_on_control(self):
  379. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  380. send_events=False)
  381. l.pidbox_node = Mock()
  382. l.reset_pidbox_node = Mock()
  383. l.on_control("foo", "bar")
  384. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  385. l.pidbox_node = Mock()
  386. l.pidbox_node.handle_message.side_effect = KeyError("foo")
  387. l.on_control("foo", "bar")
  388. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  389. l.pidbox_node = Mock()
  390. l.pidbox_node.handle_message.side_effect = ValueError("foo")
  391. l.on_control("foo", "bar")
  392. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  393. l.reset_pidbox_node.assert_called_with()
  394. def test_revoke(self):
  395. ready_queue = FastQueue()
  396. l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
  397. send_events=False)
  398. backend = Mock()
  399. id = gen_unique_id()
  400. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  401. kwargs={}, id=id)
  402. from celery.worker.state import revoked
  403. revoked.add(id)
  404. l.receive_message(t.decode(), t)
  405. self.assertTrue(ready_queue.empty())
  406. def test_receieve_message_not_registered(self):
  407. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  408. send_events=False)
  409. backend = Mock()
  410. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  411. l.event_dispatcher = Mock()
  412. self.assertFalse(l.receive_message(m.decode(), m))
  413. self.assertRaises(Empty, self.ready_queue.get_nowait)
  414. self.assertTrue(self.eta_schedule.empty())
  415. def test_receieve_message_ack_raises(self):
  416. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  417. send_events=False)
  418. backend = Mock()
  419. m = create_message(backend, args=[2, 4, 8], kwargs={})
  420. l.event_dispatcher = Mock()
  421. l.connection_errors = (socket.error, )
  422. l.logger = Mock()
  423. m.ack = Mock()
  424. m.ack.side_effect = socket.error("foo")
  425. with catch_warnings(record=True) as log:
  426. self.assertFalse(l.receive_message(m.decode(), m))
  427. self.assertTrue(log)
  428. self.assertIn("unknown message", log[0].message.args[0])
  429. self.assertRaises(Empty, self.ready_queue.get_nowait)
  430. self.assertTrue(self.eta_schedule.empty())
  431. m.ack.assert_called_with()
  432. self.assertTrue(l.logger.critical.call_count)
  433. def test_receieve_message_eta(self):
  434. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  435. send_events=False)
  436. l.event_dispatcher = Mock()
  437. l.event_dispatcher._outbound_buffer = deque()
  438. backend = Mock()
  439. m = create_message(backend, task=foo_task.name,
  440. args=[2, 4, 8], kwargs={},
  441. eta=(datetime.now() +
  442. timedelta(days=1)).isoformat())
  443. l.reset_connection()
  444. p = l.app.conf.BROKER_CONNECTION_RETRY
  445. l.app.conf.BROKER_CONNECTION_RETRY = False
  446. try:
  447. l.reset_connection()
  448. finally:
  449. l.app.conf.BROKER_CONNECTION_RETRY = p
  450. l.stop_consumers()
  451. l.event_dispatcher = Mock()
  452. l.receive_message(m.decode(), m)
  453. l.eta_schedule.stop()
  454. in_hold = self.eta_schedule.queue[0]
  455. self.assertEqual(len(in_hold), 3)
  456. eta, priority, entry = in_hold
  457. task = entry.args[0]
  458. self.assertIsInstance(task, TaskRequest)
  459. self.assertEqual(task.task_name, foo_task.name)
  460. self.assertEqual(task.execute(), 2 * 4 * 8)
  461. self.assertRaises(Empty, self.ready_queue.get_nowait)
  462. def test_reset_pidbox_node(self):
  463. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  464. send_events=False)
  465. l.pidbox_node = Mock()
  466. chan = l.pidbox_node.channel = Mock()
  467. l.connection = Mock()
  468. chan.close.side_effect = socket.error("foo")
  469. l.connection_errors = (socket.error, )
  470. l.reset_pidbox_node()
  471. chan.close.assert_called_with()
  472. def test_reset_pidbox_node_green(self):
  473. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  474. send_events=False)
  475. l.pool = Mock()
  476. l.pool.is_green = True
  477. l.reset_pidbox_node()
  478. l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
  479. def test__green_pidbox_node(self):
  480. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  481. send_events=False)
  482. l.pidbox_node = Mock()
  483. connections = []
  484. class Connection(object):
  485. def __init__(self, obj):
  486. connections.append(self)
  487. self.obj = obj
  488. self.closed = False
  489. def channel(self):
  490. return Mock()
  491. def drain_events(self):
  492. self.obj.connection = None
  493. def close(self):
  494. self.closed = True
  495. l.connection = Mock()
  496. l._open_connection = lambda: Connection(obj=l)
  497. l._green_pidbox_node()
  498. l.pidbox_node.listen.assert_called_with(callback=l.on_control)
  499. self.assertTrue(l.broadcast_consumer)
  500. l.broadcast_consumer.consume.assert_called_with()
  501. self.assertIsNone(l.connection)
  502. self.assertTrue(connections[0].closed)
  503. def test_start__consume_messages(self):
  504. class _QoS(object):
  505. prev = 3
  506. value = 4
  507. def update(self):
  508. self.prev = self.value
  509. class _Consumer(MyKombuConsumer):
  510. iterations = 0
  511. def reset_connection(self):
  512. if self.iterations >= 1:
  513. raise KeyError("foo")
  514. init_callback = Mock()
  515. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  516. send_events=False, init_callback=init_callback)
  517. l.task_consumer = Mock()
  518. l.broadcast_consumer = Mock()
  519. l.qos = _QoS()
  520. l.connection = BrokerConnection()
  521. l.iterations = 0
  522. def raises_KeyError(limit=None):
  523. l.iterations += 1
  524. if l.qos.prev != l.qos.value:
  525. l.qos.update()
  526. if l.iterations >= 2:
  527. raise KeyError("foo")
  528. l.consume_messages = raises_KeyError
  529. self.assertRaises(KeyError, l.start)
  530. self.assertTrue(init_callback.call_count)
  531. self.assertEqual(l.iterations, 1)
  532. self.assertEqual(l.qos.prev, l.qos.value)
  533. init_callback.reset_mock()
  534. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  535. send_events=False, init_callback=init_callback)
  536. l.qos = _QoS()
  537. l.task_consumer = Mock()
  538. l.broadcast_consumer = Mock()
  539. l.connection = BrokerConnection()
  540. l.consume_messages = Mock(side_effect=socket.error("foo"))
  541. self.assertRaises(socket.error, l.start)
  542. self.assertTrue(init_callback.call_count)
  543. self.assertTrue(l.consume_messages.call_count)
  544. def test_reset_connection_with_no_node(self):
  545. l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
  546. self.assertEqual(None, l.pool)
  547. l.reset_connection()
  548. class test_WorkController(AppCase):
  549. def setup(self):
  550. self.worker = self.create_worker()
  551. def create_worker(self, **kw):
  552. worker = WorkController(concurrency=1, loglevel=0, **kw)
  553. worker.logger = Mock()
  554. return worker
  555. @patch("celery.platforms.signals")
  556. @patch("celery.platforms.set_mp_process_title")
  557. def test_process_initializer(self, set_mp_process_title, _signals):
  558. from celery import Celery
  559. from celery import signals
  560. from celery.app import _tls
  561. from celery.worker import process_initializer
  562. from celery.worker import WORKER_SIGRESET, WORKER_SIGIGNORE
  563. def on_worker_process_init(**kwargs):
  564. on_worker_process_init.called = True
  565. on_worker_process_init.called = False
  566. signals.worker_process_init.connect(on_worker_process_init)
  567. app = Celery(loader=Mock(), set_as_current=False)
  568. process_initializer(app, "awesome.worker.com")
  569. self.assertIn((tuple(WORKER_SIGIGNORE), {}),
  570. _signals.ignore.call_args_list)
  571. self.assertIn((tuple(WORKER_SIGRESET), {}),
  572. _signals.reset.call_args_list)
  573. self.assertTrue(app.loader.init_worker.call_count)
  574. self.assertTrue(on_worker_process_init.called)
  575. self.assertIs(_tls.current_app, app)
  576. set_mp_process_title.assert_called_with("celeryd",
  577. hostname="awesome.worker.com")
  578. def test_with_rate_limits_disabled(self):
  579. worker = WorkController(concurrency=1, loglevel=0,
  580. disable_rate_limits=True)
  581. self.assertTrue(hasattr(worker.ready_queue, "put"))
  582. def test_attrs(self):
  583. worker = self.worker
  584. self.assertIsInstance(worker.scheduler, Timer)
  585. self.assertTrue(worker.scheduler)
  586. self.assertTrue(worker.pool)
  587. self.assertTrue(worker.consumer)
  588. self.assertTrue(worker.mediator)
  589. self.assertTrue(worker.components)
  590. def test_with_embedded_celerybeat(self):
  591. worker = WorkController(concurrency=1, loglevel=0,
  592. embed_clockservice=True)
  593. self.assertTrue(worker.beat)
  594. self.assertIn(worker.beat, worker.components)
  595. def test_with_autoscaler(self):
  596. worker = self.create_worker(autoscale=[10, 3], send_events=False,
  597. eta_scheduler_cls="celery.utils.timer2.Timer")
  598. self.assertTrue(worker.autoscaler)
  599. def test_dont_stop_or_terminate(self):
  600. worker = WorkController(concurrency=1, loglevel=0)
  601. worker.stop()
  602. self.assertNotEqual(worker._state, worker.CLOSE)
  603. worker.terminate()
  604. self.assertNotEqual(worker._state, worker.CLOSE)
  605. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  606. try:
  607. worker._state = worker.RUN
  608. worker.stop(in_sighandler=True)
  609. self.assertNotEqual(worker._state, worker.CLOSE)
  610. worker.terminate(in_sighandler=True)
  611. self.assertNotEqual(worker._state, worker.CLOSE)
  612. finally:
  613. worker.pool.signal_safe = sigsafe
  614. def test_on_timer_error(self):
  615. worker = WorkController(concurrency=1, loglevel=0)
  616. worker.logger = Mock()
  617. try:
  618. raise KeyError("foo")
  619. except KeyError:
  620. exc_info = sys.exc_info()
  621. worker.on_timer_error(exc_info)
  622. logged = worker.logger.error.call_args[0][0]
  623. self.assertIn("KeyError", logged)
  624. def test_on_timer_tick(self):
  625. worker = WorkController(concurrency=1, loglevel=10)
  626. worker.logger = Mock()
  627. worker.timer_debug = worker.logger.debug
  628. worker.on_timer_tick(30.0)
  629. logged = worker.logger.debug.call_args[0][0]
  630. self.assertIn("30.0", logged)
  631. def test_process_task(self):
  632. worker = self.worker
  633. worker.pool = Mock()
  634. backend = Mock()
  635. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  636. kwargs={})
  637. task = TaskRequest.from_message(m, m.decode())
  638. worker.process_task(task)
  639. self.assertEqual(worker.pool.apply_async.call_count, 1)
  640. worker.pool.stop()
  641. def test_process_task_raise_base(self):
  642. worker = self.worker
  643. worker.pool = Mock()
  644. worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
  645. backend = Mock()
  646. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  647. kwargs={})
  648. task = TaskRequest.from_message(m, m.decode())
  649. worker.components = []
  650. worker._state = worker.RUN
  651. self.assertRaises(KeyboardInterrupt, worker.process_task, task)
  652. self.assertEqual(worker._state, worker.TERMINATE)
  653. def test_process_task_raise_SystemTerminate(self):
  654. worker = self.worker
  655. worker.pool = Mock()
  656. worker.pool.apply_async.side_effect = SystemTerminate()
  657. backend = Mock()
  658. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  659. kwargs={})
  660. task = TaskRequest.from_message(m, m.decode())
  661. worker.components = []
  662. worker._state = worker.RUN
  663. self.assertRaises(SystemExit, worker.process_task, task)
  664. self.assertEqual(worker._state, worker.TERMINATE)
  665. def test_process_task_raise_regular(self):
  666. worker = self.worker
  667. worker.pool = Mock()
  668. worker.pool.apply_async.side_effect = KeyError("some exception")
  669. backend = Mock()
  670. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  671. kwargs={})
  672. task = TaskRequest.from_message(m, m.decode())
  673. worker.process_task(task)
  674. worker.pool.stop()
  675. def test_start_catches_base_exceptions(self):
  676. worker1 = self.create_worker()
  677. stc = Mock()
  678. stc.start.side_effect = SystemTerminate()
  679. worker1.components = [stc]
  680. self.assertRaises(SystemExit, worker1.start)
  681. self.assertTrue(stc.terminate.call_count)
  682. worker2 = self.create_worker()
  683. sec = Mock()
  684. sec.start.side_effect = SystemExit()
  685. sec.terminate = None
  686. worker2.components = [sec]
  687. self.assertRaises(SystemExit, worker2.start)
  688. self.assertTrue(sec.stop.call_count)
  689. def test_state_db(self):
  690. from celery.worker import state
  691. Persistent = state.Persistent
  692. state.Persistent = Mock()
  693. try:
  694. worker = self.create_worker(db="statefilename")
  695. self.assertTrue(worker._finalize_db)
  696. worker._finalize_db.cancel()
  697. finally:
  698. state.Persistent = Persistent
  699. @skip("Issue #264")
  700. def test_disable_rate_limits(self):
  701. from celery.worker.buckets import FastQueue
  702. worker = self.create_worker(disable_rate_limits=True)
  703. self.assertIsInstance(worker.ready_queue, FastQueue)
  704. self.assertIsNone(worker.mediator)
  705. self.assertEqual(worker.ready_queue.put, worker.process_task)
  706. def test_start__stop(self):
  707. worker = self.worker
  708. worker.components = [Mock(), Mock(), Mock(), Mock()]
  709. worker.start()
  710. for w in worker.components:
  711. self.assertTrue(w.start.call_count)
  712. worker.stop()
  713. for component in worker.components:
  714. self.assertTrue(w.stop.call_count)
  715. def test_start__terminate(self):
  716. worker = self.worker
  717. worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
  718. for component in worker.components[:3]:
  719. component.terminate = None
  720. worker.start()
  721. for w in worker.components[:3]:
  722. self.assertTrue(w.start.call_count)
  723. self.assertTrue(worker._running, len(worker.components))
  724. self.assertEqual(worker._state, RUN)
  725. worker.terminate()
  726. for component in worker.components[:3]:
  727. self.assertTrue(component.stop.call_count)
  728. self.assertTrue(worker.components[4].terminate.call_count)