test_worker.py 38 KB


  1. from __future__ import absolute_import
  2. import socket
  3. from collections import deque
  4. from datetime import datetime, timedelta
  5. from threading import Event
  6. from billiard.exceptions import WorkerLostError
  7. from kombu import Connection
  8. from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
  9. from kombu.exceptions import StdChannelError
  10. from kombu.transport.base import Message
  11. from mock import Mock, patch
  12. from nose import SkipTest
  13. from celery import current_app
  14. from celery.app.defaults import DEFAULTS
  15. from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
  16. from celery.concurrency.base import BasePool
  17. from celery.datastructures import AttributeDict
  18. from celery.exceptions import SystemTerminate
  19. from celery.five import Empty, range
  20. from celery.task import task as task_dec
  21. from celery.task import periodic_task as periodic_task_dec
  22. from celery.utils import uuid
  23. from celery.worker import WorkController
  24. from celery.worker import components
  25. from celery.worker.buckets import FastQueue
  26. from celery.worker.job import Request
  27. from celery.worker import consumer
  28. from celery.worker.consumer import Consumer as __Consumer
  29. from celery.utils.serialization import pickle
  30. from celery.utils.timer2 import Timer
  31. from celery.tests.utils import AppCase, Case
  32. def MockStep(step=None):
  33. step = Mock() if step is None else step
  34. step.namespace = Mock()
  35. step.namespace.name = 'MockNS'
  36. step.name = 'MockStep'
  37. return step
  38. class PlaceHolder(object):
  39. pass
  40. def find_step(obj, typ):
  41. return obj.namespace.steps[typ.name]
  42. class Consumer(__Consumer):
  43. def __init__(self, *args, **kwargs):
  44. kwargs.setdefault('enable_mingle', False) # disable Mingle step
  45. kwargs.setdefault('enable_gossip', False) # disable Gossip step
  46. kwargs.setdefault('enable_heartbeat', False) # disable Heart step
  47. super(Consumer, self).__init__(*args, **kwargs)
  48. class _MyKombuConsumer(Consumer):
  49. broadcast_consumer = Mock()
  50. task_consumer = Mock()
  51. def __init__(self, *args, **kwargs):
  52. kwargs.setdefault('pool', BasePool(2))
  53. super(_MyKombuConsumer, self).__init__(*args, **kwargs)
  54. def restart_heartbeat(self):
  55. self.heart = None
  56. class MyKombuConsumer(Consumer):
  57. def loop(self, *args, **kwargs):
  58. pass
  59. class MockNode(object):
  60. commands = []
  61. def handle_message(self, body, message):
  62. self.commands.append(body.pop('command', None))
  63. class MockEventDispatcher(object):
  64. sent = []
  65. closed = False
  66. flushed = False
  67. _outbound_buffer = []
  68. def send(self, event, *args, **kwargs):
  69. self.sent.append(event)
  70. def close(self):
  71. self.closed = True
  72. def flush(self):
  73. self.flushed = True
  74. class MockHeart(object):
  75. closed = False
  76. def stop(self):
  77. self.closed = True
  78. @task_dec()
  79. def foo_task(x, y, z, **kwargs):
  80. return x * y * z
  81. @periodic_task_dec(run_every=60)
  82. def foo_periodic_task():
  83. return 'foo'
  84. def create_message(channel, **data):
  85. data.setdefault('id', uuid())
  86. channel.no_ack_consumers = set()
  87. return Message(channel, body=pickle.dumps(dict(**data)),
  88. content_type='application/x-python-serialize',
  89. content_encoding='binary',
  90. delivery_info={'consumer_tag': 'mock'})
  91. class test_QoS(Case):
  92. class _QoS(QoS):
  93. def __init__(self, value):
  94. self.value = value
  95. QoS.__init__(self, None, value)
  96. def set(self, value):
  97. return value
  98. def test_qos_increment_decrement(self):
  99. qos = self._QoS(10)
  100. self.assertEqual(qos.increment_eventually(), 11)
  101. self.assertEqual(qos.increment_eventually(3), 14)
  102. self.assertEqual(qos.increment_eventually(-30), 14)
  103. self.assertEqual(qos.decrement_eventually(7), 7)
  104. self.assertEqual(qos.decrement_eventually(), 6)
  105. def test_qos_disabled_increment_decrement(self):
  106. qos = self._QoS(0)
  107. self.assertEqual(qos.increment_eventually(), 0)
  108. self.assertEqual(qos.increment_eventually(3), 0)
  109. self.assertEqual(qos.increment_eventually(-30), 0)
  110. self.assertEqual(qos.decrement_eventually(7), 0)
  111. self.assertEqual(qos.decrement_eventually(), 0)
  112. self.assertEqual(qos.decrement_eventually(10), 0)
  113. def test_qos_thread_safe(self):
  114. qos = self._QoS(10)
  115. def add():
  116. for i in range(1000):
  117. qos.increment_eventually()
  118. def sub():
  119. for i in range(1000):
  120. qos.decrement_eventually()
  121. def threaded(funs):
  122. from threading import Thread
  123. threads = [Thread(target=fun) for fun in funs]
  124. for thread in threads:
  125. thread.start()
  126. for thread in threads:
  127. thread.join()
  128. threaded([add, add])
  129. self.assertEqual(qos.value, 2010)
  130. qos.value = 1000
  131. threaded([add, sub]) # n = 2
  132. self.assertEqual(qos.value, 1000)
  133. def test_exceeds_short(self):
  134. qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
  135. qos.update()
  136. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  137. qos.increment_eventually()
  138. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  139. qos.increment_eventually()
  140. self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
  141. qos.decrement_eventually()
  142. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  143. qos.decrement_eventually()
  144. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  145. def test_consumer_increment_decrement(self):
  146. mconsumer = Mock()
  147. qos = QoS(mconsumer.qos, 10)
  148. qos.update()
  149. self.assertEqual(qos.value, 10)
  150. mconsumer.qos.assert_called_with(prefetch_count=10)
  151. qos.decrement_eventually()
  152. qos.update()
  153. self.assertEqual(qos.value, 9)
  154. mconsumer.qos.assert_called_with(prefetch_count=9)
  155. qos.decrement_eventually()
  156. self.assertEqual(qos.value, 8)
  157. mconsumer.qos.assert_called_with(prefetch_count=9)
  158. self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
  159. # Does not decrement 0 value
  160. qos.value = 0
  161. qos.decrement_eventually()
  162. self.assertEqual(qos.value, 0)
  163. qos.increment_eventually()
  164. self.assertEqual(qos.value, 0)
  165. def test_consumer_decrement_eventually(self):
  166. mconsumer = Mock()
  167. qos = QoS(mconsumer.qos, 10)
  168. qos.decrement_eventually()
  169. self.assertEqual(qos.value, 9)
  170. qos.value = 0
  171. qos.decrement_eventually()
  172. self.assertEqual(qos.value, 0)
  173. def test_set(self):
  174. mconsumer = Mock()
  175. qos = QoS(mconsumer.qos, 10)
  176. qos.set(12)
  177. self.assertEqual(qos.prev, 12)
  178. qos.set(qos.prev)
  179. class test_Consumer(Case):
  180. def setUp(self):
  181. self.ready_queue = FastQueue()
  182. self.timer = Timer()
  183. def tearDown(self):
  184. self.timer.stop()
  185. def test_info(self):
  186. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  187. l.task_consumer = Mock()
  188. l.qos = QoS(l.task_consumer.qos, 10)
  189. info = l.info
  190. self.assertEqual(info['prefetch_count'], 10)
  191. self.assertFalse(info['broker'])
  192. l.connection = current_app.connection()
  193. info = l.info
  194. self.assertTrue(info['broker'])
  195. def test_start_when_closed(self):
  196. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  197. l.namespace.state = CLOSE
  198. l.start()
  199. def test_connection(self):
  200. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  201. l.namespace.start(l)
  202. self.assertIsInstance(l.connection, Connection)
  203. l.namespace.state = RUN
  204. l.event_dispatcher = None
  205. l.namespace.restart(l)
  206. self.assertTrue(l.connection)
  207. l.namespace.state = RUN
  208. l.shutdown()
  209. self.assertIsNone(l.connection)
  210. self.assertIsNone(l.task_consumer)
  211. l.namespace.start(l)
  212. self.assertIsInstance(l.connection, Connection)
  213. l.namespace.restart(l)
  214. l.stop()
  215. l.shutdown()
  216. self.assertIsNone(l.connection)
  217. self.assertIsNone(l.task_consumer)
  218. def test_close_connection(self):
  219. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  220. l.namespace.state = RUN
  221. step = find_step(l, consumer.Connection)
  222. conn = l.connection = Mock()
  223. step.shutdown(l)
  224. self.assertTrue(conn.close.called)
  225. self.assertIsNone(l.connection)
  226. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  227. eventer = l.event_dispatcher = Mock()
  228. eventer.enabled = True
  229. heart = l.heart = MockHeart()
  230. l.namespace.state = RUN
  231. Events = find_step(l, consumer.Events)
  232. Events.shutdown(l)
  233. Heart = find_step(l, consumer.Heart)
  234. Heart.shutdown(l)
  235. self.assertTrue(eventer.close.call_count)
  236. self.assertTrue(heart.closed)
  237. @patch('celery.worker.consumer.warn')
  238. def test_receive_message_unknown(self, warn):
  239. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  240. l.steps.pop()
  241. backend = Mock()
  242. m = create_message(backend, unknown={'baz': '!!!'})
  243. l.event_dispatcher = Mock()
  244. l.node = MockNode()
  245. callback = self._get_on_message(l)
  246. callback(m.decode(), m)
  247. self.assertTrue(warn.call_count)
  248. @patch('celery.worker.consumer.to_timestamp')
  249. def test_receive_message_eta_OverflowError(self, to_timestamp):
  250. to_timestamp.side_effect = OverflowError()
  251. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  252. l.steps.pop()
  253. m = create_message(Mock(), task=foo_task.name,
  254. args=('2, 2'),
  255. kwargs={},
  256. eta=datetime.now().isoformat())
  257. l.event_dispatcher = Mock()
  258. l.node = MockNode()
  259. l.update_strategies()
  260. callback = self._get_on_message(l)
  261. callback(m.decode(), m)
  262. self.assertTrue(m.acknowledged)
  263. self.assertTrue(to_timestamp.call_count)
  264. @patch('celery.worker.consumer.error')
  265. def test_receive_message_InvalidTaskError(self, error):
  266. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  267. l.steps.pop()
  268. m = create_message(Mock(), task=foo_task.name,
  269. args=(1, 2), kwargs='foobarbaz', id=1)
  270. l.update_strategies()
  271. l.event_dispatcher = Mock()
  272. callback = self._get_on_message(l)
  273. callback(m.decode(), m)
  274. self.assertIn('Received invalid task message', error.call_args[0][0])
  275. @patch('celery.worker.consumer.crit')
  276. def test_on_decode_error(self, crit):
  277. l = Consumer(self.ready_queue, timer=self.timer)
  278. class MockMessage(Mock):
  279. content_type = 'application/x-msgpack'
  280. content_encoding = 'binary'
  281. body = 'foobarbaz'
  282. message = MockMessage()
  283. l.on_decode_error(message, KeyError('foo'))
  284. self.assertTrue(message.ack.call_count)
  285. self.assertIn("Can't decode message body", crit.call_args[0][0])
  286. def _get_on_message(self, l):
  287. if l.qos is None:
  288. l.qos = Mock()
  289. l.event_dispatcher = Mock()
  290. l.task_consumer = Mock()
  291. l.connection = Mock()
  292. l.connection.drain_events.side_effect = SystemExit()
  293. with self.assertRaises(SystemExit):
  294. l.loop(*l.loop_args())
  295. self.assertTrue(l.task_consumer.register_callback.called)
  296. return l.task_consumer.register_callback.call_args[0][0]
  297. def test_receieve_message(self):
  298. l = Consumer(self.ready_queue, timer=self.timer)
  299. m = create_message(Mock(), task=foo_task.name,
  300. args=[2, 4, 8], kwargs={})
  301. l.update_strategies()
  302. callback = self._get_on_message(l)
  303. callback(m.decode(), m)
  304. in_bucket = self.ready_queue.get_nowait()
  305. self.assertIsInstance(in_bucket, Request)
  306. self.assertEqual(in_bucket.name, foo_task.name)
  307. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  308. self.assertTrue(self.timer.empty())
  309. def test_start_channel_error(self):
  310. class MockConsumer(Consumer):
  311. iterations = 0
  312. def loop(self, *args, **kwargs):
  313. if not self.iterations:
  314. self.iterations = 1
  315. raise KeyError('foo')
  316. raise SyntaxError('bar')
  317. l = MockConsumer(self.ready_queue, timer=self.timer,
  318. send_events=False, pool=BasePool())
  319. l.channel_errors = (KeyError, )
  320. with self.assertRaises(KeyError):
  321. l.start()
  322. l.timer.stop()
  323. def test_start_connection_error(self):
  324. class MockConsumer(Consumer):
  325. iterations = 0
  326. def loop(self, *args, **kwargs):
  327. if not self.iterations:
  328. self.iterations = 1
  329. raise KeyError('foo')
  330. raise SyntaxError('bar')
  331. l = MockConsumer(self.ready_queue, timer=self.timer,
  332. send_events=False, pool=BasePool())
  333. l.connection_errors = (KeyError, )
  334. self.assertRaises(SyntaxError, l.start)
  335. l.timer.stop()
  336. def test_loop_ignores_socket_timeout(self):
  337. class Connection(current_app.connection().__class__):
  338. obj = None
  339. def drain_events(self, **kwargs):
  340. self.obj.connection = None
  341. raise socket.timeout(10)
  342. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  343. l.connection = Connection()
  344. l.task_consumer = Mock()
  345. l.connection.obj = l
  346. l.qos = QoS(l.task_consumer.qos, 10)
  347. l.loop(*l.loop_args())
  348. def test_loop_when_socket_error(self):
  349. class Connection(current_app.connection().__class__):
  350. obj = None
  351. def drain_events(self, **kwargs):
  352. self.obj.connection = None
  353. raise socket.error('foo')
  354. l = Consumer(self.ready_queue, timer=self.timer)
  355. l.namespace.state = RUN
  356. c = l.connection = Connection()
  357. l.connection.obj = l
  358. l.task_consumer = Mock()
  359. l.qos = QoS(l.task_consumer.qos, 10)
  360. with self.assertRaises(socket.error):
  361. l.loop(*l.loop_args())
  362. l.namespace.state = CLOSE
  363. l.connection = c
  364. l.loop(*l.loop_args())
  365. def test_loop(self):
  366. class Connection(current_app.connection().__class__):
  367. obj = None
  368. def drain_events(self, **kwargs):
  369. self.obj.connection = None
  370. l = Consumer(self.ready_queue, timer=self.timer)
  371. l.connection = Connection()
  372. l.connection.obj = l
  373. l.task_consumer = Mock()
  374. l.qos = QoS(l.task_consumer.qos, 10)
  375. l.loop(*l.loop_args())
  376. l.loop(*l.loop_args())
  377. self.assertTrue(l.task_consumer.consume.call_count)
  378. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  379. self.assertEqual(l.qos.value, 10)
  380. l.qos.decrement_eventually()
  381. self.assertEqual(l.qos.value, 9)
  382. l.qos.update()
  383. self.assertEqual(l.qos.value, 9)
  384. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  385. def test_ignore_errors(self):
  386. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  387. l.connection_errors = (AttributeError, KeyError, )
  388. l.channel_errors = (SyntaxError, )
  389. ignore_errors(l, Mock(side_effect=AttributeError('foo')))
  390. ignore_errors(l, Mock(side_effect=KeyError('foo')))
  391. ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
  392. with self.assertRaises(IndexError):
  393. ignore_errors(l, Mock(side_effect=IndexError('foo')))
  394. def test_apply_eta_task(self):
  395. from celery.worker import state
  396. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  397. l.qos = QoS(None, 10)
  398. task = object()
  399. qos = l.qos.value
  400. l.apply_eta_task(task)
  401. self.assertIn(task, state.reserved_requests)
  402. self.assertEqual(l.qos.value, qos - 1)
  403. self.assertIs(self.ready_queue.get_nowait(), task)
  404. def test_receieve_message_eta_isoformat(self):
  405. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  406. l.steps.pop()
  407. m = create_message(Mock(), task=foo_task.name,
  408. eta=datetime.now().isoformat(),
  409. args=[2, 4, 8], kwargs={})
  410. l.task_consumer = Mock()
  411. qos = l.qos = QoS(l.task_consumer.qos, 1)
  412. current_pcount = l.qos.value
  413. l.event_dispatcher = Mock()
  414. l.enabled = False
  415. l.update_strategies()
  416. callback = self._get_on_message(l)
  417. callback(m.decode(), m)
  418. l.timer.stop()
  419. l.timer.join(1)
  420. items = [entry[2] for entry in self.timer.queue]
  421. found = 0
  422. for item in items:
  423. if item.args[0].name == foo_task.name:
  424. found = True
  425. self.assertTrue(found)
  426. self.assertGreater(l.qos.value, current_pcount)
  427. l.timer.stop()
  428. def test_pidbox_callback(self):
  429. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  430. con = find_step(l, consumer.Control).box
  431. con.node = Mock()
  432. con.reset = Mock()
  433. con.on_message('foo', 'bar')
  434. con.node.handle_message.assert_called_with('foo', 'bar')
  435. con.node = Mock()
  436. con.node.handle_message.side_effect = KeyError('foo')
  437. con.on_message('foo', 'bar')
  438. con.node.handle_message.assert_called_with('foo', 'bar')
  439. con.node = Mock()
  440. con.node.handle_message.side_effect = ValueError('foo')
  441. con.on_message('foo', 'bar')
  442. con.node.handle_message.assert_called_with('foo', 'bar')
  443. self.assertTrue(con.reset.called)
  444. def test_revoke(self):
  445. ready_queue = FastQueue()
  446. l = _MyKombuConsumer(ready_queue, timer=self.timer)
  447. l.steps.pop()
  448. backend = Mock()
  449. id = uuid()
  450. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  451. kwargs={}, id=id)
  452. from celery.worker.state import revoked
  453. revoked.add(id)
  454. callback = self._get_on_message(l)
  455. callback(t.decode(), t)
  456. self.assertTrue(ready_queue.empty())
  457. def test_receieve_message_not_registered(self):
  458. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  459. l.steps.pop()
  460. backend = Mock()
  461. m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
  462. l.event_dispatcher = Mock()
  463. callback = self._get_on_message(l)
  464. self.assertFalse(callback(m.decode(), m))
  465. with self.assertRaises(Empty):
  466. self.ready_queue.get_nowait()
  467. self.assertTrue(self.timer.empty())
  468. @patch('celery.worker.consumer.warn')
  469. @patch('celery.worker.consumer.logger')
  470. def test_receieve_message_ack_raises(self, logger, warn):
  471. l = Consumer(self.ready_queue, timer=self.timer)
  472. backend = Mock()
  473. m = create_message(backend, args=[2, 4, 8], kwargs={})
  474. l.event_dispatcher = Mock()
  475. l.connection_errors = (socket.error, )
  476. m.reject = Mock()
  477. m.reject.side_effect = socket.error('foo')
  478. callback = self._get_on_message(l)
  479. self.assertFalse(callback(m.decode(), m))
  480. self.assertTrue(warn.call_count)
  481. with self.assertRaises(Empty):
  482. self.ready_queue.get_nowait()
  483. self.assertTrue(self.timer.empty())
  484. m.reject.assert_called_with()
  485. self.assertTrue(logger.critical.call_count)
  486. def test_receive_message_eta(self):
  487. l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
  488. l.steps.pop()
  489. l.event_dispatcher = Mock()
  490. l.event_dispatcher._outbound_buffer = deque()
  491. backend = Mock()
  492. m = create_message(backend, task=foo_task.name,
  493. args=[2, 4, 8], kwargs={},
  494. eta=(datetime.now() +
  495. timedelta(days=1)).isoformat())
  496. l.namespace.start(l)
  497. p = l.app.conf.BROKER_CONNECTION_RETRY
  498. l.app.conf.BROKER_CONNECTION_RETRY = False
  499. try:
  500. l.namespace.start(l)
  501. finally:
  502. l.app.conf.BROKER_CONNECTION_RETRY = p
  503. l.namespace.restart(l)
  504. l.event_dispatcher = Mock()
  505. callback = self._get_on_message(l)
  506. callback(m.decode(), m)
  507. l.timer.stop()
  508. in_hold = l.timer.queue[0]
  509. self.assertEqual(len(in_hold), 3)
  510. eta, priority, entry = in_hold
  511. task = entry.args[0]
  512. self.assertIsInstance(task, Request)
  513. self.assertEqual(task.name, foo_task.name)
  514. self.assertEqual(task.execute(), 2 * 4 * 8)
  515. with self.assertRaises(Empty):
  516. self.ready_queue.get_nowait()
  517. def test_reset_pidbox_node(self):
  518. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  519. con = find_step(l, consumer.Control).box
  520. con.node = Mock()
  521. chan = con.node.channel = Mock()
  522. l.connection = Mock()
  523. chan.close.side_effect = socket.error('foo')
  524. l.connection_errors = (socket.error, )
  525. con.reset()
  526. chan.close.assert_called_with()
  527. def test_reset_pidbox_node_green(self):
  528. from celery.worker.pidbox import gPidbox
  529. pool = Mock()
  530. pool.is_green = True
  531. l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
  532. con = find_step(l, consumer.Control)
  533. self.assertIsInstance(con.box, gPidbox)
  534. con.start(l)
  535. l.pool.spawn_n.assert_called_with(
  536. con.box.loop, l,
  537. )
  538. def test__green_pidbox_node(self):
  539. pool = Mock()
  540. pool.is_green = True
  541. l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
  542. l.node = Mock()
  543. controller = find_step(l, consumer.Control)
  544. class BConsumer(Mock):
  545. def __enter__(self):
  546. self.consume()
  547. return self
  548. def __exit__(self, *exc_info):
  549. self.cancel()
  550. controller.box.node.listen = BConsumer()
  551. connections = []
  552. class Connection(object):
  553. calls = 0
  554. def __init__(self, obj):
  555. connections.append(self)
  556. self.obj = obj
  557. self.default_channel = self.channel()
  558. self.closed = False
  559. def __enter__(self):
  560. return self
  561. def __exit__(self, *exc_info):
  562. self.close()
  563. def channel(self):
  564. return Mock()
  565. def as_uri(self):
  566. return 'dummy://'
  567. def drain_events(self, **kwargs):
  568. if not self.calls:
  569. self.calls += 1
  570. raise socket.timeout()
  571. self.obj.connection = None
  572. controller.box._node_shutdown.set()
  573. def close(self):
  574. self.closed = True
  575. l.connection = Mock()
  576. l.connect = lambda: Connection(obj=l)
  577. controller = find_step(l, consumer.Control)
  578. controller.box.loop(l)
  579. self.assertTrue(controller.box.node.listen.called)
  580. self.assertTrue(controller.box.consumer)
  581. controller.box.consumer.consume.assert_called_with()
  582. self.assertIsNone(l.connection)
  583. self.assertTrue(connections[0].closed)
  584. @patch('kombu.connection.Connection._establish_connection')
  585. @patch('kombu.utils.sleep')
  586. def test_connect_errback(self, sleep, connect):
  587. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  588. from kombu.transport.memory import Transport
  589. Transport.connection_errors = (StdChannelError, )
  590. def effect():
  591. if connect.call_count > 1:
  592. return
  593. raise StdChannelError()
  594. connect.side_effect = effect
  595. l.connect()
  596. connect.assert_called_with()
  597. def test_stop_pidbox_node(self):
  598. l = MyKombuConsumer(self.ready_queue, timer=self.timer)
  599. cont = find_step(l, consumer.Control)
  600. cont._node_stopped = Event()
  601. cont._node_shutdown = Event()
  602. cont._node_stopped.set()
  603. cont.stop(l)
  604. def test_start__loop(self):
  605. class _QoS(object):
  606. prev = 3
  607. value = 4
  608. def update(self):
  609. self.prev = self.value
  610. class _Consumer(MyKombuConsumer):
  611. iterations = 0
  612. def reset_connection(self):
  613. if self.iterations >= 1:
  614. raise KeyError('foo')
  615. init_callback = Mock()
  616. l = _Consumer(self.ready_queue, timer=self.timer,
  617. init_callback=init_callback)
  618. l.task_consumer = Mock()
  619. l.broadcast_consumer = Mock()
  620. l.qos = _QoS()
  621. l.connection = Connection()
  622. l.iterations = 0
  623. def raises_KeyError(*args, **kwargs):
  624. l.iterations += 1
  625. if l.qos.prev != l.qos.value:
  626. l.qos.update()
  627. if l.iterations >= 2:
  628. raise KeyError('foo')
  629. l.loop = raises_KeyError
  630. with self.assertRaises(KeyError):
  631. l.start()
  632. self.assertEqual(l.iterations, 2)
  633. self.assertEqual(l.qos.prev, l.qos.value)
  634. init_callback.reset_mock()
  635. l = _Consumer(self.ready_queue, timer=self.timer,
  636. send_events=False, init_callback=init_callback)
  637. l.qos = _QoS()
  638. l.task_consumer = Mock()
  639. l.broadcast_consumer = Mock()
  640. l.connection = Connection()
  641. l.loop = Mock(side_effect=socket.error('foo'))
  642. with self.assertRaises(socket.error):
  643. l.start()
  644. self.assertTrue(l.loop.call_count)
  645. def test_reset_connection_with_no_node(self):
  646. l = Consumer(self.ready_queue, timer=self.timer)
  647. l.steps.pop()
  648. self.assertEqual(None, l.pool)
  649. l.namespace.start(l)
  650. def test_on_task_revoked(self):
  651. l = Consumer(self.ready_queue, timer=self.timer)
  652. task = Mock()
  653. task.revoked.return_value = True
  654. l.on_task(task)
  655. def test_on_task_no_events(self):
  656. l = Consumer(self.ready_queue, timer=self.timer)
  657. task = Mock()
  658. task.revoked.return_value = False
  659. l.event_dispatcher = Mock()
  660. l.event_dispatcher.enabled = False
  661. task.eta = None
  662. l._does_info = False
  663. l.on_task(task)
  664. class test_WorkController(AppCase):
  665. def setup(self):
  666. self.worker = self.create_worker()
  667. from celery import worker
  668. self._logger = worker.logger
  669. self._comp_logger = components.logger
  670. self.logger = worker.logger = Mock()
  671. self.comp_logger = components.logger = Mock()
  672. def teardown(self):
  673. from celery import worker
  674. worker.logger = self._logger
  675. components.logger = self._comp_logger
  676. def create_worker(self, **kw):
  677. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  678. worker.namespace.shutdown_complete.set()
  679. return worker
  680. @patch('celery.platforms.create_pidlock')
  681. def test_use_pidfile(self, create_pidlock):
  682. create_pidlock.return_value = Mock()
  683. worker = self.create_worker(pidfile='pidfilelockfilepid')
  684. worker.steps = []
  685. worker.start()
  686. self.assertTrue(create_pidlock.called)
  687. worker.stop()
  688. self.assertTrue(worker.pidlock.release.called)
  689. @patch('celery.platforms.signals')
  690. @patch('celery.platforms.set_mp_process_title')
  691. def test_process_initializer(self, set_mp_process_title, _signals):
  692. from celery import Celery
  693. from celery import signals
  694. from celery._state import _tls
  695. from celery.concurrency.processes import process_initializer
  696. from celery.concurrency.processes import (WORKER_SIGRESET,
  697. WORKER_SIGIGNORE)
  698. def on_worker_process_init(**kwargs):
  699. on_worker_process_init.called = True
  700. on_worker_process_init.called = False
  701. signals.worker_process_init.connect(on_worker_process_init)
  702. loader = Mock()
  703. loader.override_backends = {}
  704. app = Celery(loader=loader, set_as_current=False)
  705. app.loader = loader
  706. app.conf = AttributeDict(DEFAULTS)
  707. process_initializer(app, 'awesome.worker.com')
  708. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  709. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  710. self.assertTrue(app.loader.init_worker.call_count)
  711. self.assertTrue(on_worker_process_init.called)
  712. self.assertIs(_tls.current_app, app)
  713. set_mp_process_title.assert_called_with('celeryd',
  714. hostname='awesome.worker.com')
  715. def test_with_rate_limits_disabled(self):
  716. worker = WorkController(concurrency=1, loglevel=0,
  717. disable_rate_limits=True)
  718. self.assertTrue(hasattr(worker.ready_queue, 'put'))
  719. def test_attrs(self):
  720. worker = self.worker
  721. self.assertIsInstance(worker.timer, Timer)
  722. self.assertTrue(worker.timer)
  723. self.assertTrue(worker.pool)
  724. self.assertTrue(worker.consumer)
  725. self.assertTrue(worker.mediator)
  726. self.assertTrue(worker.steps)
  727. def test_with_embedded_celerybeat(self):
  728. worker = WorkController(concurrency=1, loglevel=0, beat=True)
  729. self.assertTrue(worker.beat)
  730. self.assertIn(worker.beat, [w.obj for w in worker.steps])
  731. def test_with_autoscaler(self):
  732. worker = self.create_worker(autoscale=[10, 3], send_events=False,
  733. timer_cls='celery.utils.timer2.Timer')
  734. self.assertTrue(worker.autoscaler)
  735. def test_dont_stop_or_terminate(self):
  736. worker = WorkController(concurrency=1, loglevel=0)
  737. worker.stop()
  738. self.assertNotEqual(worker.namespace.state, CLOSE)
  739. worker.terminate()
  740. self.assertNotEqual(worker.namespace.state, CLOSE)
  741. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  742. try:
  743. worker.namespace.state = RUN
  744. worker.stop(in_sighandler=True)
  745. self.assertNotEqual(worker.namespace.state, CLOSE)
  746. worker.terminate(in_sighandler=True)
  747. self.assertNotEqual(worker.namespace.state, CLOSE)
  748. finally:
  749. worker.pool.signal_safe = sigsafe
  750. def test_on_timer_error(self):
  751. worker = WorkController(concurrency=1, loglevel=0)
  752. try:
  753. raise KeyError('foo')
  754. except KeyError as exc:
  755. components.Timer(worker).on_timer_error(exc)
  756. msg, args = self.comp_logger.error.call_args[0]
  757. self.assertIn('KeyError', msg % args)
  758. def test_on_timer_tick(self):
  759. worker = WorkController(concurrency=1, loglevel=10)
  760. components.Timer(worker).on_timer_tick(30.0)
  761. xargs = self.comp_logger.debug.call_args[0]
  762. fmt, arg = xargs[0], xargs[1]
  763. self.assertEqual(30.0, arg)
  764. self.assertIn('Next eta %s secs', fmt)
  765. def test_process_task(self):
  766. worker = self.worker
  767. worker.pool = Mock()
  768. backend = Mock()
  769. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  770. kwargs={})
  771. task = Request.from_message(m, m.decode())
  772. worker.process_task(task)
  773. self.assertEqual(worker.pool.apply_async.call_count, 1)
  774. worker.pool.stop()
  775. def test_process_task_raise_base(self):
  776. worker = self.worker
  777. worker.pool = Mock()
  778. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  779. backend = Mock()
  780. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  781. kwargs={})
  782. task = Request.from_message(m, m.decode())
  783. worker.steps = []
  784. worker.namespace.state = RUN
  785. with self.assertRaises(KeyboardInterrupt):
  786. worker.process_task(task)
  787. self.assertEqual(worker.namespace.state, TERMINATE)
  788. def test_process_task_raise_SystemTerminate(self):
  789. worker = self.worker
  790. worker.pool = Mock()
  791. worker.pool.apply_async.side_effect = SystemTerminate()
  792. backend = Mock()
  793. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  794. kwargs={})
  795. task = Request.from_message(m, m.decode())
  796. worker.steps = []
  797. worker.namespace.state = RUN
  798. with self.assertRaises(SystemExit):
  799. worker.process_task(task)
  800. self.assertEqual(worker.namespace.state, TERMINATE)
  801. def test_process_task_raise_regular(self):
  802. worker = self.worker
  803. worker.pool = Mock()
  804. worker.pool.apply_async.side_effect = KeyError('some exception')
  805. backend = Mock()
  806. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  807. kwargs={})
  808. task = Request.from_message(m, m.decode())
  809. worker.process_task(task)
  810. worker.pool.stop()
  811. def test_start_catches_base_exceptions(self):
  812. worker1 = self.create_worker()
  813. stc = MockStep()
  814. stc.start.side_effect = SystemTerminate()
  815. worker1.steps = [stc]
  816. worker1.start()
  817. stc.start.assert_called_with(worker1)
  818. self.assertTrue(stc.terminate.call_count)
  819. worker2 = self.create_worker()
  820. sec = MockStep()
  821. sec.start.side_effect = SystemExit()
  822. sec.terminate = None
  823. worker2.steps = [sec]
  824. worker2.start()
  825. self.assertTrue(sec.stop.call_count)
  826. def test_state_db(self):
  827. from celery.worker import state
  828. Persistent = state.Persistent
  829. state.Persistent = Mock()
  830. try:
  831. worker = self.create_worker(state_db='statefilename')
  832. self.assertTrue(worker._persistence)
  833. finally:
  834. state.Persistent = Persistent
  835. def test_disable_rate_limits_solo(self):
  836. worker = self.create_worker(disable_rate_limits=True,
  837. pool_cls='solo')
  838. self.assertIsInstance(worker.ready_queue, FastQueue)
  839. self.assertIsNone(worker.mediator)
  840. self.assertEqual(worker.ready_queue.put, worker.process_task)
  841. def test_disable_rate_limits_processes(self):
  842. try:
  843. worker = self.create_worker(disable_rate_limits=True,
  844. pool_cls='processes')
  845. except ImportError:
  846. raise SkipTest('multiprocessing not supported')
  847. self.assertIsInstance(worker.ready_queue, FastQueue)
  848. self.assertTrue(worker.mediator)
  849. self.assertNotEqual(worker.ready_queue.put, worker.process_task)
  850. def test_process_task_sem(self):
  851. worker = self.worker
  852. worker._quick_acquire = Mock()
  853. req = Mock()
  854. worker.process_task_sem(req)
  855. worker._quick_acquire.assert_called_with(worker.process_task, req)
  856. def test_signal_consumer_close(self):
  857. worker = self.worker
  858. worker.consumer = Mock()
  859. worker.signal_consumer_close()
  860. worker.consumer.close.assert_called_with()
  861. worker.consumer.close.side_effect = AttributeError()
  862. worker.signal_consumer_close()
  863. def test_start__stop(self):
  864. worker = self.worker
  865. worker.namespace.shutdown_complete.set()
  866. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  867. worker.namespace.state = RUN
  868. worker.namespace.started = 4
  869. for w in worker.steps:
  870. w.start = Mock()
  871. w.stop = Mock()
  872. worker.start()
  873. for w in worker.steps:
  874. self.assertTrue(w.start.call_count)
  875. worker.stop()
  876. for w in worker.steps:
  877. self.assertTrue(w.stop.call_count)
  878. # Doesn't close pool if no pool.
  879. worker.start()
  880. worker.pool = None
  881. worker.stop()
  882. # test that stop of None is not attempted
  883. worker.steps[-1] = None
  884. worker.start()
  885. worker.stop()
  886. def test_step_raises(self):
  887. worker = self.worker
  888. step = Mock()
  889. worker.steps = [step]
  890. step.start.side_effect = TypeError()
  891. worker.stop = Mock()
  892. worker.start()
  893. worker.stop.assert_called_with()
  894. def test_state(self):
  895. self.assertTrue(self.worker.state)
  896. def test_start__terminate(self):
  897. worker = self.worker
  898. worker.namespace.shutdown_complete.set()
  899. worker.namespace.started = 5
  900. worker.namespace.state = RUN
  901. worker.steps = [MockStep() for _ in range(5)]
  902. worker.start()
  903. for w in worker.steps[:3]:
  904. self.assertTrue(w.start.call_count)
  905. self.assertTrue(worker.namespace.started, len(worker.steps))
  906. self.assertEqual(worker.namespace.state, RUN)
  907. worker.terminate()
  908. for step in worker.steps:
  909. self.assertTrue(step.terminate.call_count)
  910. def test_Queues_pool_not_rlimit_safe(self):
  911. w = Mock()
  912. w.pool_cls.rlimit_safe = False
  913. components.Queues(w).create(w)
  914. self.assertTrue(w.disable_rate_limits)
  915. def test_Queues_pool_no_sem(self):
  916. w = Mock()
  917. w.pool_cls.uses_semaphore = False
  918. components.Queues(w).create(w)
  919. self.assertIs(w.ready_queue.put, w.process_task)
  920. def test_Hub_crate(self):
  921. w = Mock()
  922. x = components.Hub(w)
  923. hub = x.create(w)
  924. self.assertTrue(w.timer.max_interval)
  925. self.assertIs(w.hub, hub)
  926. def test_Pool_crate_threaded(self):
  927. w = Mock()
  928. w.pool_cls = Mock()
  929. w.use_eventloop = False
  930. pool = components.Pool(w)
  931. pool.create(w)
  932. def test_Pool_create(self):
  933. from celery.worker.hub import BoundedSemaphore
  934. w = Mock()
  935. w.hub = Mock()
  936. w.hub.on_init = []
  937. w.pool_cls = Mock()
  938. P = w.pool_cls.return_value = Mock()
  939. P.timers = {Mock(): 30}
  940. w.use_eventloop = True
  941. pool = components.Pool(w)
  942. pool.create(w)
  943. self.assertIsInstance(w.semaphore, BoundedSemaphore)
  944. self.assertTrue(w.hub.on_init)
  945. hub = Mock()
  946. w.hub.on_init[0](hub)
  947. cbs = w.pool.init_callbacks.call_args[1]
  948. w = Mock()
  949. cbs['on_process_up'](w)
  950. hub.add_reader.assert_called_with(w.sentinel, P.maintain_pool)
  951. cbs['on_process_down'](w)
  952. hub.remove.assert_called_with(w.sentinel)
  953. result = Mock()
  954. tref = result._tref
  955. cbs['on_timeout_cancel'](result)
  956. tref.cancel.assert_called_with()
  957. cbs['on_timeout_cancel'](result) # no more tref
  958. cbs['on_timeout_set'](result, 10, 20)
  959. tsoft, callback = hub.timer.apply_after.call_args[0]
  960. callback()
  961. cbs['on_timeout_set'](result, 10, None)
  962. tsoft, callback = hub.timer.apply_after.call_args[0]
  963. callback()
  964. cbs['on_timeout_set'](result, None, 10)
  965. cbs['on_timeout_set'](result, None, None)
  966. P.did_start_ok.return_value = False
  967. with self.assertRaises(WorkerLostError):
  968. pool.on_poll_init(P, hub)