__init__.py 33 KB


  1. from __future__ import absolute_import
  2. from __future__ import with_statement
  3. import socket
  4. import sys
  5. from collections import deque
  6. from datetime import datetime, timedelta
  7. from Queue import Empty
  8. from kombu.transport.base import Message
  9. from kombu.connection import BrokerConnection
  10. from mock import Mock, patch
  11. from nose import SkipTest
  12. from celery import current_app
  13. from celery.app.defaults import DEFAULTS
  14. from celery.concurrency.base import BasePool
  15. from celery.datastructures import AttributeDict
  16. from celery.exceptions import SystemTerminate
  17. from celery.task import task as task_dec
  18. from celery.task import periodic_task as periodic_task_dec
  19. from celery.utils import uuid
  20. from celery.worker import WorkController
  21. from celery.worker.buckets import FastQueue
  22. from celery.worker.job import Request
  23. from celery.worker.consumer import Consumer as MainConsumer
  24. from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
  25. from celery.utils.serialization import pickle
  26. from celery.utils.timer2 import Timer
  27. from celery.tests.utils import AppCase, Case
  28. class PlaceHolder(object):
  29. pass
  30. class MyKombuConsumer(MainConsumer):
  31. broadcast_consumer = Mock()
  32. task_consumer = Mock()
  33. def __init__(self, *args, **kwargs):
  34. kwargs.setdefault("pool", BasePool(2))
  35. super(MyKombuConsumer, self).__init__(*args, **kwargs)
  36. def restart_heartbeat(self):
  37. self.heart = None
  38. class MockNode(object):
  39. commands = []
  40. def handle_message(self, body, message):
  41. self.commands.append(body.pop("command", None))
  42. class MockEventDispatcher(object):
  43. sent = []
  44. closed = False
  45. flushed = False
  46. _outbound_buffer = []
  47. def send(self, event, *args, **kwargs):
  48. self.sent.append(event)
  49. def close(self):
  50. self.closed = True
  51. def flush(self):
  52. self.flushed = True
  53. class MockHeart(object):
  54. closed = False
  55. def stop(self):
  56. self.closed = True
  57. @task_dec()
  58. def foo_task(x, y, z, **kwargs):
  59. return x * y * z
  60. @periodic_task_dec(run_every=60)
  61. def foo_periodic_task():
  62. return "foo"
  63. def create_message(channel, **data):
  64. data.setdefault("id", uuid())
  65. channel.no_ack_consumers = set()
  66. return Message(channel, body=pickle.dumps(dict(**data)),
  67. content_type="application/x-python-serialize",
  68. content_encoding="binary",
  69. delivery_info={"consumer_tag": "mock"})
  70. class test_QoS(Case):
  71. class _QoS(QoS):
  72. def __init__(self, value):
  73. self.value = value
  74. QoS.__init__(self, None, value, None)
  75. def set(self, value):
  76. return value
  77. def test_qos_increment_decrement(self):
  78. qos = self._QoS(10)
  79. self.assertEqual(qos.increment(), 11)
  80. self.assertEqual(qos.increment(3), 14)
  81. self.assertEqual(qos.increment(-30), 14)
  82. self.assertEqual(qos.decrement(7), 7)
  83. self.assertEqual(qos.decrement(), 6)
  84. with self.assertRaises(AssertionError):
  85. qos.decrement(10)
  86. def test_qos_disabled_increment_decrement(self):
  87. qos = self._QoS(0)
  88. self.assertEqual(qos.increment(), 0)
  89. self.assertEqual(qos.increment(3), 0)
  90. self.assertEqual(qos.increment(-30), 0)
  91. self.assertEqual(qos.decrement(7), 0)
  92. self.assertEqual(qos.decrement(), 0)
  93. self.assertEqual(qos.decrement(10), 0)
  94. def test_qos_thread_safe(self):
  95. qos = self._QoS(10)
  96. def add():
  97. for i in xrange(1000):
  98. qos.increment()
  99. def sub():
  100. for i in xrange(1000):
  101. qos.decrement_eventually()
  102. def threaded(funs):
  103. from threading import Thread
  104. threads = [Thread(target=fun) for fun in funs]
  105. for thread in threads:
  106. thread.start()
  107. for thread in threads:
  108. thread.join()
  109. threaded([add, add])
  110. self.assertEqual(qos.value, 2010)
  111. qos.value = 1000
  112. threaded([add, sub]) # n = 2
  113. self.assertEqual(qos.value, 1000)
  114. def test_exceeds_short(self):
  115. qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1,
  116. current_app.log.get_default_logger())
  117. qos.update()
  118. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  119. qos.increment()
  120. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  121. qos.increment()
  122. self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
  123. qos.decrement()
  124. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  125. qos.decrement()
  126. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  127. def test_consumer_increment_decrement(self):
  128. consumer = Mock()
  129. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  130. qos.update()
  131. self.assertEqual(qos.value, 10)
  132. self.assertIn({"prefetch_count": 10}, consumer.qos.call_args)
  133. qos.decrement()
  134. self.assertEqual(qos.value, 9)
  135. self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
  136. qos.decrement_eventually()
  137. self.assertEqual(qos.value, 8)
  138. self.assertIn({"prefetch_count": 9}, consumer.qos.call_args)
  139. # Does not decrement 0 value
  140. qos.value = 0
  141. qos.decrement()
  142. self.assertEqual(qos.value, 0)
  143. qos.increment()
  144. self.assertEqual(qos.value, 0)
  145. def test_consumer_decrement_eventually(self):
  146. consumer = Mock()
  147. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  148. qos.decrement_eventually()
  149. self.assertEqual(qos.value, 9)
  150. qos.value = 0
  151. qos.decrement_eventually()
  152. self.assertEqual(qos.value, 0)
  153. def test_set(self):
  154. consumer = Mock()
  155. qos = QoS(consumer, 10, current_app.log.get_default_logger())
  156. qos.set(12)
  157. self.assertEqual(qos.prev, 12)
  158. qos.set(qos.prev)
  159. class test_Consumer(Case):
  160. def setUp(self):
  161. self.ready_queue = FastQueue()
  162. self.eta_schedule = Timer()
  163. self.logger = current_app.log.get_default_logger()
  164. self.logger.setLevel(0)
  165. def tearDown(self):
  166. self.eta_schedule.stop()
  167. def test_info(self):
  168. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  169. send_events=False)
  170. l.qos = QoS(l.task_consumer, 10, l.logger)
  171. info = l.info
  172. self.assertEqual(info["prefetch_count"], 10)
  173. self.assertFalse(info["broker"])
  174. l.connection = current_app.broker_connection()
  175. info = l.info
  176. self.assertTrue(info["broker"])
  177. def test_start_when_closed(self):
  178. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  179. send_events=False)
  180. l._state = CLOSE
  181. l.start()
  182. def test_connection(self):
  183. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  184. send_events=False)
  185. l.reset_connection()
  186. self.assertIsInstance(l.connection, BrokerConnection)
  187. l._state = RUN
  188. l.event_dispatcher = None
  189. l.stop_consumers(close_connection=False)
  190. self.assertTrue(l.connection)
  191. l._state = RUN
  192. l.stop_consumers()
  193. self.assertIsNone(l.connection)
  194. self.assertIsNone(l.task_consumer)
  195. l.reset_connection()
  196. self.assertIsInstance(l.connection, BrokerConnection)
  197. l.stop_consumers()
  198. l.stop()
  199. l.close_connection()
  200. self.assertIsNone(l.connection)
  201. self.assertIsNone(l.task_consumer)
  202. def test_close_connection(self):
  203. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  204. send_events=False)
  205. l._state = RUN
  206. l.close_connection()
  207. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  208. send_events=False)
  209. eventer = l.event_dispatcher = Mock()
  210. eventer.enabled = True
  211. heart = l.heart = MockHeart()
  212. l._state = RUN
  213. l.stop_consumers()
  214. self.assertTrue(eventer.close.call_count)
  215. self.assertTrue(heart.closed)
  216. def test_receive_message_unknown(self):
  217. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  218. send_events=False)
  219. backend = Mock()
  220. m = create_message(backend, unknown={"baz": "!!!"})
  221. l.event_dispatcher = Mock()
  222. l.pidbox_node = MockNode()
  223. with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
  224. l.receive_message(m.decode(), m)
  225. @patch("celery.utils.timer2.to_timestamp")
  226. def test_receive_message_eta_OverflowError(self, to_timestamp):
  227. to_timestamp.side_effect = OverflowError()
  228. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  229. send_events=False)
  230. m = create_message(Mock(), task=foo_task.name,
  231. args=("2, 2"),
  232. kwargs={},
  233. eta=datetime.now().isoformat())
  234. l.event_dispatcher = Mock()
  235. l.pidbox_node = MockNode()
  236. l.update_strategies()
  237. l.receive_message(m.decode(), m)
  238. self.assertTrue(m.acknowledged)
  239. self.assertTrue(to_timestamp.call_count)
  240. def test_receive_message_InvalidTaskError(self):
  241. logger = Mock()
  242. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  243. send_events=False)
  244. m = create_message(Mock(), task=foo_task.name,
  245. args=(1, 2), kwargs="foobarbaz", id=1)
  246. l.update_strategies()
  247. l.event_dispatcher = Mock()
  248. l.pidbox_node = MockNode()
  249. l.receive_message(m.decode(), m)
  250. self.assertIn("Received invalid task message",
  251. logger.error.call_args[0][0])
  252. def test_on_decode_error(self):
  253. logger = Mock()
  254. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, logger,
  255. send_events=False)
  256. class MockMessage(Mock):
  257. content_type = "application/x-msgpack"
  258. content_encoding = "binary"
  259. body = "foobarbaz"
  260. message = MockMessage()
  261. l.on_decode_error(message, KeyError("foo"))
  262. self.assertTrue(message.ack.call_count)
  263. self.assertIn("Can't decode message body",
  264. logger.critical.call_args[0][0])
  265. def test_receieve_message(self):
  266. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  267. send_events=False)
  268. m = create_message(Mock(), task=foo_task.name,
  269. args=[2, 4, 8], kwargs={})
  270. l.update_strategies()
  271. l.event_dispatcher = Mock()
  272. l.receive_message(m.decode(), m)
  273. in_bucket = self.ready_queue.get_nowait()
  274. self.assertIsInstance(in_bucket, Request)
  275. self.assertEqual(in_bucket.name, foo_task.name)
  276. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  277. self.assertTrue(self.eta_schedule.empty())
  278. def test_start_connection_error(self):
  279. class MockConsumer(MainConsumer):
  280. iterations = 0
  281. def consume_messages(self):
  282. if not self.iterations:
  283. self.iterations = 1
  284. raise KeyError("foo")
  285. raise SyntaxError("bar")
  286. l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
  287. send_events=False, pool=BasePool())
  288. l.connection_errors = (KeyError, )
  289. with self.assertRaises(SyntaxError):
  290. l.start()
  291. l.heart.stop()
  292. l.priority_timer.stop()
  293. def test_start_channel_error(self):
  294. # Regression test for AMQPChannelExceptions that can occur within the
  295. # consumer. (i.e. 404 errors)
  296. class MockConsumer(MainConsumer):
  297. iterations = 0
  298. def consume_messages(self):
  299. if not self.iterations:
  300. self.iterations = 1
  301. raise KeyError("foo")
  302. raise SyntaxError("bar")
  303. l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
  304. send_events=False, pool=BasePool())
  305. l.channel_errors = (KeyError, )
  306. self.assertRaises(SyntaxError, l.start)
  307. l.heart.stop()
  308. l.priority_timer.stop()
  309. def test_consume_messages_ignores_socket_timeout(self):
  310. class Connection(current_app.broker_connection().__class__):
  311. obj = None
  312. def drain_events(self, **kwargs):
  313. self.obj.connection = None
  314. raise socket.timeout(10)
  315. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  316. send_events=False)
  317. l.connection = Connection()
  318. l.task_consumer = Mock()
  319. l.connection.obj = l
  320. l.qos = QoS(l.task_consumer, 10, l.logger)
  321. l.consume_messages()
  322. def test_consume_messages_when_socket_error(self):
  323. class Connection(current_app.broker_connection().__class__):
  324. obj = None
  325. def drain_events(self, **kwargs):
  326. self.obj.connection = None
  327. raise socket.error("foo")
  328. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  329. send_events=False)
  330. l._state = RUN
  331. c = l.connection = Connection()
  332. l.connection.obj = l
  333. l.task_consumer = Mock()
  334. l.qos = QoS(l.task_consumer, 10, l.logger)
  335. with self.assertRaises(socket.error):
  336. l.consume_messages()
  337. l._state = CLOSE
  338. l.connection = c
  339. l.consume_messages()
  340. def test_consume_messages(self):
  341. class Connection(current_app.broker_connection().__class__):
  342. obj = None
  343. def drain_events(self, **kwargs):
  344. self.obj.connection = None
  345. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  346. send_events=False)
  347. l.connection = Connection()
  348. l.connection.obj = l
  349. l.task_consumer = Mock()
  350. l.qos = QoS(l.task_consumer, 10, l.logger)
  351. l.consume_messages()
  352. l.consume_messages()
  353. self.assertTrue(l.task_consumer.consume.call_count)
  354. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  355. l.qos.decrement()
  356. l.consume_messages()
  357. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  358. def test_maybe_conn_error(self):
  359. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  360. send_events=False)
  361. l.connection_errors = (KeyError, )
  362. l.channel_errors = (SyntaxError, )
  363. l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
  364. l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
  365. l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
  366. with self.assertRaises(IndexError):
  367. l.maybe_conn_error(Mock(side_effect=IndexError("foo")))
  368. def test_apply_eta_task(self):
  369. from celery.worker import state
  370. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  371. send_events=False)
  372. l.qos = QoS(None, 10, l.logger)
  373. task = object()
  374. qos = l.qos.value
  375. l.apply_eta_task(task)
  376. self.assertIn(task, state.reserved_requests)
  377. self.assertEqual(l.qos.value, qos - 1)
  378. self.assertIs(self.ready_queue.get_nowait(), task)
  379. def test_receieve_message_eta_isoformat(self):
  380. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  381. send_events=False)
  382. m = create_message(Mock(), task=foo_task.name,
  383. eta=datetime.now().isoformat(),
  384. args=[2, 4, 8], kwargs={})
  385. l.task_consumer = Mock()
  386. l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
  387. l.event_dispatcher = Mock()
  388. l.enabled = False
  389. l.update_strategies()
  390. l.receive_message(m.decode(), m)
  391. l.eta_schedule.stop()
  392. items = [entry[2] for entry in self.eta_schedule.queue]
  393. found = 0
  394. for item in items:
  395. if item.args[0].name == foo_task.name:
  396. found = True
  397. self.assertTrue(found)
  398. self.assertTrue(l.task_consumer.qos.call_count)
  399. l.eta_schedule.stop()
  400. def test_on_control(self):
  401. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  402. send_events=False)
  403. l.pidbox_node = Mock()
  404. l.reset_pidbox_node = Mock()
  405. l.on_control("foo", "bar")
  406. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  407. l.pidbox_node = Mock()
  408. l.pidbox_node.handle_message.side_effect = KeyError("foo")
  409. l.on_control("foo", "bar")
  410. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  411. l.pidbox_node = Mock()
  412. l.pidbox_node.handle_message.side_effect = ValueError("foo")
  413. l.on_control("foo", "bar")
  414. l.pidbox_node.handle_message.assert_called_with("foo", "bar")
  415. l.reset_pidbox_node.assert_called_with()
  416. def test_revoke(self):
  417. ready_queue = FastQueue()
  418. l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
  419. send_events=False)
  420. backend = Mock()
  421. id = uuid()
  422. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  423. kwargs={}, id=id)
  424. from celery.worker.state import revoked
  425. revoked.add(id)
  426. l.receive_message(t.decode(), t)
  427. self.assertTrue(ready_queue.empty())
  428. def test_receieve_message_not_registered(self):
  429. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  430. send_events=False)
  431. backend = Mock()
  432. m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
  433. l.event_dispatcher = Mock()
  434. self.assertFalse(l.receive_message(m.decode(), m))
  435. with self.assertRaises(Empty):
  436. self.ready_queue.get_nowait()
  437. self.assertTrue(self.eta_schedule.empty())
  438. def test_receieve_message_ack_raises(self):
  439. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  440. send_events=False)
  441. backend = Mock()
  442. m = create_message(backend, args=[2, 4, 8], kwargs={})
  443. l.event_dispatcher = Mock()
  444. l.connection_errors = (socket.error, )
  445. l.logger = Mock()
  446. m.reject = Mock()
  447. m.reject.side_effect = socket.error("foo")
  448. with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
  449. self.assertFalse(l.receive_message(m.decode(), m))
  450. with self.assertRaises(Empty):
  451. self.ready_queue.get_nowait()
  452. self.assertTrue(self.eta_schedule.empty())
  453. m.reject.assert_called_with()
  454. self.assertTrue(l.logger.critical.call_count)
  455. def test_receieve_message_eta(self):
  456. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  457. send_events=False)
  458. l.event_dispatcher = Mock()
  459. l.event_dispatcher._outbound_buffer = deque()
  460. backend = Mock()
  461. m = create_message(backend, task=foo_task.name,
  462. args=[2, 4, 8], kwargs={},
  463. eta=(datetime.now() +
  464. timedelta(days=1)).isoformat())
  465. l.reset_connection()
  466. p = l.app.conf.BROKER_CONNECTION_RETRY
  467. l.app.conf.BROKER_CONNECTION_RETRY = False
  468. try:
  469. l.reset_connection()
  470. finally:
  471. l.app.conf.BROKER_CONNECTION_RETRY = p
  472. l.stop_consumers()
  473. l.event_dispatcher = Mock()
  474. l.receive_message(m.decode(), m)
  475. l.eta_schedule.stop()
  476. in_hold = self.eta_schedule.queue[0]
  477. self.assertEqual(len(in_hold), 3)
  478. eta, priority, entry = in_hold
  479. task = entry.args[0]
  480. self.assertIsInstance(task, Request)
  481. self.assertEqual(task.name, foo_task.name)
  482. self.assertEqual(task.execute(), 2 * 4 * 8)
  483. with self.assertRaises(Empty):
  484. self.ready_queue.get_nowait()
  485. def test_reset_pidbox_node(self):
  486. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  487. send_events=False)
  488. l.pidbox_node = Mock()
  489. chan = l.pidbox_node.channel = Mock()
  490. l.connection = Mock()
  491. chan.close.side_effect = socket.error("foo")
  492. l.connection_errors = (socket.error, )
  493. l.reset_pidbox_node()
  494. chan.close.assert_called_with()
  495. def test_reset_pidbox_node_green(self):
  496. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  497. send_events=False)
  498. l.pool = Mock()
  499. l.pool.is_green = True
  500. l.reset_pidbox_node()
  501. l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
  502. def test__green_pidbox_node(self):
  503. l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
  504. send_events=False)
  505. l.pidbox_node = Mock()
  506. class BConsumer(Mock):
  507. def __enter__(self):
  508. self.consume()
  509. return self
  510. def __exit__(self, *exc_info):
  511. self.cancel()
  512. l.pidbox_node.listen = BConsumer()
  513. connections = []
  514. class Connection(object):
  515. def __init__(self, obj):
  516. connections.append(self)
  517. self.obj = obj
  518. self.default_channel = self.channel()
  519. self.closed = False
  520. def __enter__(self):
  521. return self
  522. def __exit__(self, *exc_info):
  523. self.close()
  524. def channel(self):
  525. return Mock()
  526. def drain_events(self, **kwargs):
  527. self.obj.connection = None
  528. self.obj._pidbox_node_shutdown.set()
  529. def close(self):
  530. self.closed = True
  531. l.connection = Mock()
  532. l._open_connection = lambda: Connection(obj=l)
  533. l._green_pidbox_node()
  534. l.pidbox_node.listen.assert_called_with(callback=l.on_control)
  535. self.assertTrue(l.broadcast_consumer)
  536. l.broadcast_consumer.consume.assert_called_with()
  537. self.assertIsNone(l.connection)
  538. self.assertTrue(connections[0].closed)
  539. def test_start__consume_messages(self):
  540. class _QoS(object):
  541. prev = 3
  542. value = 4
  543. def update(self):
  544. self.prev = self.value
  545. class _Consumer(MyKombuConsumer):
  546. iterations = 0
  547. def reset_connection(self):
  548. if self.iterations >= 1:
  549. raise KeyError("foo")
  550. init_callback = Mock()
  551. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  552. send_events=False, init_callback=init_callback)
  553. l.task_consumer = Mock()
  554. l.broadcast_consumer = Mock()
  555. l.qos = _QoS()
  556. l.connection = BrokerConnection()
  557. l.iterations = 0
  558. def raises_KeyError(limit=None):
  559. l.iterations += 1
  560. if l.qos.prev != l.qos.value:
  561. l.qos.update()
  562. if l.iterations >= 2:
  563. raise KeyError("foo")
  564. l.consume_messages = raises_KeyError
  565. with self.assertRaises(KeyError):
  566. l.start()
  567. self.assertTrue(init_callback.call_count)
  568. self.assertEqual(l.iterations, 1)
  569. self.assertEqual(l.qos.prev, l.qos.value)
  570. init_callback.reset_mock()
  571. l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
  572. send_events=False, init_callback=init_callback)
  573. l.qos = _QoS()
  574. l.task_consumer = Mock()
  575. l.broadcast_consumer = Mock()
  576. l.connection = BrokerConnection()
  577. l.consume_messages = Mock(side_effect=socket.error("foo"))
  578. with self.assertRaises(socket.error):
  579. l.start()
  580. self.assertTrue(init_callback.call_count)
  581. self.assertTrue(l.consume_messages.call_count)
  582. def test_reset_connection_with_no_node(self):
  583. l = MainConsumer(self.ready_queue, self.eta_schedule, self.logger)
  584. self.assertEqual(None, l.pool)
  585. l.reset_connection()
  586. class test_WorkController(AppCase):
  587. def setup(self):
  588. self.worker = self.create_worker()
  589. def create_worker(self, **kw):
  590. worker = WorkController(concurrency=1, loglevel=0, **kw)
  591. worker._shutdown_complete.set()
  592. worker.logger = Mock()
  593. return worker
  594. @patch("celery.platforms.signals")
  595. @patch("celery.platforms.set_mp_process_title")
  596. def test_process_initializer(self, set_mp_process_title, _signals):
  597. from celery import Celery
  598. from celery import signals
  599. from celery.app.state import _tls
  600. from celery.concurrency.processes import process_initializer
  601. from celery.concurrency.processes import (WORKER_SIGRESET,
  602. WORKER_SIGIGNORE)
  603. def on_worker_process_init(**kwargs):
  604. on_worker_process_init.called = True
  605. on_worker_process_init.called = False
  606. signals.worker_process_init.connect(on_worker_process_init)
  607. loader = Mock()
  608. app = Celery(loader=loader, set_as_current=False)
  609. app.conf = AttributeDict(DEFAULTS)
  610. process_initializer(app, "awesome.worker.com")
  611. self.assertIn((tuple(WORKER_SIGIGNORE), {}),
  612. _signals.ignore.call_args_list)
  613. self.assertIn((tuple(WORKER_SIGRESET), {}),
  614. _signals.reset.call_args_list)
  615. self.assertTrue(app.loader.init_worker.call_count)
  616. self.assertTrue(on_worker_process_init.called)
  617. self.assertIs(_tls.current_app, app)
  618. set_mp_process_title.assert_called_with("celeryd",
  619. hostname="awesome.worker.com")
  620. def test_with_rate_limits_disabled(self):
  621. worker = WorkController(concurrency=1, loglevel=0,
  622. disable_rate_limits=True)
  623. self.assertTrue(hasattr(worker.ready_queue, "put"))
  624. def test_attrs(self):
  625. worker = self.worker
  626. self.assertIsInstance(worker.scheduler, Timer)
  627. self.assertTrue(worker.scheduler)
  628. self.assertTrue(worker.pool)
  629. self.assertTrue(worker.consumer)
  630. self.assertTrue(worker.mediator)
  631. self.assertTrue(worker.components)
  632. def test_with_embedded_celerybeat(self):
  633. worker = WorkController(concurrency=1, loglevel=0,
  634. embed_clockservice=True)
  635. self.assertTrue(worker.beat)
  636. self.assertIn(worker.beat, worker.components)
  637. def test_with_autoscaler(self):
  638. worker = self.create_worker(autoscale=[10, 3], send_events=False,
  639. eta_scheduler_cls="celery.utils.timer2.Timer")
  640. self.assertTrue(worker.autoscaler)
  641. def test_dont_stop_or_terminate(self):
  642. worker = WorkController(concurrency=1, loglevel=0)
  643. worker.stop()
  644. self.assertNotEqual(worker._state, worker.CLOSE)
  645. worker.terminate()
  646. self.assertNotEqual(worker._state, worker.CLOSE)
  647. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  648. try:
  649. worker._state = worker.RUN
  650. worker.stop(in_sighandler=True)
  651. self.assertNotEqual(worker._state, worker.CLOSE)
  652. worker.terminate(in_sighandler=True)
  653. self.assertNotEqual(worker._state, worker.CLOSE)
  654. finally:
  655. worker.pool.signal_safe = sigsafe
  656. def test_on_timer_error(self):
  657. worker = WorkController(concurrency=1, loglevel=0)
  658. worker.logger = Mock()
  659. try:
  660. raise KeyError("foo")
  661. except KeyError:
  662. exc_info = sys.exc_info()
  663. worker.on_timer_error(exc_info)
  664. msg, args = worker.logger.error.call_args[0]
  665. self.assertIn("KeyError", msg % args)
  666. def test_on_timer_tick(self):
  667. worker = WorkController(concurrency=1, loglevel=10)
  668. worker.logger = Mock()
  669. worker.on_timer_tick(30.0)
  670. xargs = worker.logger.debug.call_args[0]
  671. fmt, arg = xargs[0], xargs[1]
  672. self.assertEqual(30.0, arg)
  673. self.assertIn("Next eta %s secs", fmt)
  674. def test_process_task(self):
  675. worker = self.worker
  676. worker.pool = Mock()
  677. backend = Mock()
  678. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  679. kwargs={})
  680. task = Request.from_message(m, m.decode())
  681. worker.process_task(task)
  682. self.assertEqual(worker.pool.apply_async.call_count, 1)
  683. worker.pool.stop()
  684. def test_process_task_raise_base(self):
  685. worker = self.worker
  686. worker.pool = Mock()
  687. worker.pool.apply_async.side_effect = KeyboardInterrupt("Ctrl+C")
  688. backend = Mock()
  689. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  690. kwargs={})
  691. task = Request.from_message(m, m.decode())
  692. worker.components = []
  693. worker._state = worker.RUN
  694. with self.assertRaises(KeyboardInterrupt):
  695. worker.process_task(task)
  696. self.assertEqual(worker._state, worker.TERMINATE)
  697. def test_process_task_raise_SystemTerminate(self):
  698. worker = self.worker
  699. worker.pool = Mock()
  700. worker.pool.apply_async.side_effect = SystemTerminate()
  701. backend = Mock()
  702. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  703. kwargs={})
  704. task = Request.from_message(m, m.decode())
  705. worker.components = []
  706. worker._state = worker.RUN
  707. with self.assertRaises(SystemExit):
  708. worker.process_task(task)
  709. self.assertEqual(worker._state, worker.TERMINATE)
  710. def test_process_task_raise_regular(self):
  711. worker = self.worker
  712. worker.pool = Mock()
  713. worker.pool.apply_async.side_effect = KeyError("some exception")
  714. backend = Mock()
  715. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  716. kwargs={})
  717. task = Request.from_message(m, m.decode())
  718. worker.process_task(task)
  719. worker.pool.stop()
  720. def test_start_catches_base_exceptions(self):
  721. worker1 = self.create_worker()
  722. stc = Mock()
  723. stc.start.side_effect = SystemTerminate()
  724. worker1.components = [stc]
  725. worker1.start()
  726. self.assertTrue(stc.terminate.call_count)
  727. worker2 = self.create_worker()
  728. sec = Mock()
  729. sec.start.side_effect = SystemExit()
  730. sec.terminate = None
  731. worker2.components = [sec]
  732. worker2.start()
  733. self.assertTrue(sec.stop.call_count)
  734. def test_state_db(self):
  735. from celery.worker import state
  736. Persistent = state.Persistent
  737. state.Persistent = Mock()
  738. try:
  739. worker = self.create_worker(state_db="statefilename")
  740. self.assertTrue(worker._persistence)
  741. finally:
  742. state.Persistent = Persistent
  743. def test_disable_rate_limits_solo(self):
  744. worker = self.create_worker(disable_rate_limits=True,
  745. pool_cls="solo")
  746. self.assertIsInstance(worker.ready_queue, FastQueue)
  747. self.assertIsNone(worker.mediator)
  748. self.assertEqual(worker.ready_queue.put, worker.process_task)
  749. def test_disable_rate_limits_processes(self):
  750. try:
  751. worker = self.create_worker(disable_rate_limits=True,
  752. pool_cls="processes")
  753. except ImportError:
  754. raise SkipTest("multiprocessing not supported")
  755. self.assertIsInstance(worker.ready_queue, FastQueue)
  756. self.assertTrue(worker.mediator)
  757. self.assertNotEqual(worker.ready_queue.put, worker.process_task)
  758. def test_start__stop(self):
  759. worker = self.worker
  760. worker._shutdown_complete.set()
  761. worker.components = [Mock(), Mock(), Mock(), Mock()]
  762. worker.start()
  763. for w in worker.components:
  764. self.assertTrue(w.start.call_count)
  765. worker.stop()
  766. for component in worker.components:
  767. self.assertTrue(w.stop.call_count)
  768. def test_start__terminate(self):
  769. worker = self.worker
  770. worker._shutdown_complete.set()
  771. worker.components = [Mock(), Mock(), Mock(), Mock(), Mock()]
  772. for component in worker.components[:3]:
  773. component.terminate = None
  774. worker.start()
  775. for w in worker.components[:3]:
  776. self.assertTrue(w.start.call_count)
  777. self.assertTrue(worker._running, len(worker.components))
  778. self.assertEqual(worker._state, RUN)
  779. worker.terminate()
  780. for component in worker.components[:3]:
  781. self.assertTrue(component.stop.call_count)
  782. self.assertTrue(worker.components[4].terminate.call_count)