test_worker.py 38 KB


  1. from __future__ import absolute_import
  2. import os
  3. import socket
  4. from collections import deque
  5. from datetime import datetime, timedelta
  6. from threading import Event
  7. from billiard.exceptions import WorkerLostError
  8. from kombu import Connection
  9. from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
  10. from kombu.exceptions import StdChannelError
  11. from kombu.transport.base import Message
  12. from mock import call, Mock, patch
  13. from celery.app.defaults import DEFAULTS
  14. from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
  15. from celery.concurrency.base import BasePool
  16. from celery.datastructures import AttributeDict
  17. from celery.exceptions import SystemTerminate
  18. from celery.five import Empty, range, Queue as FastQueue
  19. from celery.task import task as task_dec
  20. from celery.task import periodic_task as periodic_task_dec
  21. from celery.utils import uuid
  22. from celery.worker import WorkController
  23. from celery.worker import components
  24. from celery.worker import consumer
  25. from celery.worker.consumer import Consumer as __Consumer
  26. from celery.worker.hub import READ, ERR
  27. from celery.worker.job import Request
  28. from celery.utils.serialization import pickle
  29. from celery.utils.timer2 import Timer
  30. from celery.tests.utils import AppCase, Case
  31. def MockStep(step=None):
  32. step = Mock() if step is None else step
  33. step.blueprint = Mock()
  34. step.blueprint.name = 'MockNS'
  35. step.name = 'MockStep(%s)' % (id(step), )
  36. return step
  37. class PlaceHolder(object):
  38. pass
  39. def find_step(obj, typ):
  40. return obj.blueprint.steps[typ.name]
  41. class Consumer(__Consumer):
  42. def __init__(self, *args, **kwargs):
  43. kwargs.setdefault('enable_mingle', False) # disable Mingle step
  44. kwargs.setdefault('enable_gossip', False) # disable Gossip step
  45. kwargs.setdefault('enable_heartbeat', False) # disable Heart step
  46. super(Consumer, self).__init__(*args, **kwargs)
  47. class _MyKombuConsumer(Consumer):
  48. broadcast_consumer = Mock()
  49. task_consumer = Mock()
  50. def __init__(self, *args, **kwargs):
  51. kwargs.setdefault('pool', BasePool(2))
  52. super(_MyKombuConsumer, self).__init__(*args, **kwargs)
  53. def restart_heartbeat(self):
  54. self.heart = None
  55. class MyKombuConsumer(Consumer):
  56. def loop(self, *args, **kwargs):
  57. pass
  58. class MockNode(object):
  59. commands = []
  60. def handle_message(self, body, message):
  61. self.commands.append(body.pop('command', None))
  62. class MockEventDispatcher(object):
  63. sent = []
  64. closed = False
  65. flushed = False
  66. _outbound_buffer = []
  67. def send(self, event, *args, **kwargs):
  68. self.sent.append(event)
  69. def close(self):
  70. self.closed = True
  71. def flush(self):
  72. self.flushed = True
  73. class MockHeart(object):
  74. closed = False
  75. def stop(self):
  76. self.closed = True
  77. @task_dec()
  78. def foo_task(x, y, z, **kwargs):
  79. return x * y * z
  80. @periodic_task_dec(run_every=60)
  81. def foo_periodic_task():
  82. return 'foo'
  83. def create_message(channel, **data):
  84. data.setdefault('id', uuid())
  85. channel.no_ack_consumers = set()
  86. return Message(channel, body=pickle.dumps(dict(**data)),
  87. content_type='application/x-python-serialize',
  88. content_encoding='binary',
  89. delivery_info={'consumer_tag': 'mock'})
  90. class test_QoS(Case):
  91. class _QoS(QoS):
  92. def __init__(self, value):
  93. self.value = value
  94. QoS.__init__(self, None, value)
  95. def set(self, value):
  96. return value
  97. def test_qos_increment_decrement(self):
  98. qos = self._QoS(10)
  99. self.assertEqual(qos.increment_eventually(), 11)
  100. self.assertEqual(qos.increment_eventually(3), 14)
  101. self.assertEqual(qos.increment_eventually(-30), 14)
  102. self.assertEqual(qos.decrement_eventually(7), 7)
  103. self.assertEqual(qos.decrement_eventually(), 6)
  104. def test_qos_disabled_increment_decrement(self):
  105. qos = self._QoS(0)
  106. self.assertEqual(qos.increment_eventually(), 0)
  107. self.assertEqual(qos.increment_eventually(3), 0)
  108. self.assertEqual(qos.increment_eventually(-30), 0)
  109. self.assertEqual(qos.decrement_eventually(7), 0)
  110. self.assertEqual(qos.decrement_eventually(), 0)
  111. self.assertEqual(qos.decrement_eventually(10), 0)
  112. def test_qos_thread_safe(self):
  113. qos = self._QoS(10)
  114. def add():
  115. for i in range(1000):
  116. qos.increment_eventually()
  117. def sub():
  118. for i in range(1000):
  119. qos.decrement_eventually()
  120. def threaded(funs):
  121. from threading import Thread
  122. threads = [Thread(target=fun) for fun in funs]
  123. for thread in threads:
  124. thread.start()
  125. for thread in threads:
  126. thread.join()
  127. threaded([add, add])
  128. self.assertEqual(qos.value, 2010)
  129. qos.value = 1000
  130. threaded([add, sub]) # n = 2
  131. self.assertEqual(qos.value, 1000)
  132. def test_exceeds_short(self):
  133. qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
  134. qos.update()
  135. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  136. qos.increment_eventually()
  137. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  138. qos.increment_eventually()
  139. self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
  140. qos.decrement_eventually()
  141. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  142. qos.decrement_eventually()
  143. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  144. def test_consumer_increment_decrement(self):
  145. mconsumer = Mock()
  146. qos = QoS(mconsumer.qos, 10)
  147. qos.update()
  148. self.assertEqual(qos.value, 10)
  149. mconsumer.qos.assert_called_with(prefetch_count=10)
  150. qos.decrement_eventually()
  151. qos.update()
  152. self.assertEqual(qos.value, 9)
  153. mconsumer.qos.assert_called_with(prefetch_count=9)
  154. qos.decrement_eventually()
  155. self.assertEqual(qos.value, 8)
  156. mconsumer.qos.assert_called_with(prefetch_count=9)
  157. self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
  158. # Does not decrement 0 value
  159. qos.value = 0
  160. qos.decrement_eventually()
  161. self.assertEqual(qos.value, 0)
  162. qos.increment_eventually()
  163. self.assertEqual(qos.value, 0)
  164. def test_consumer_decrement_eventually(self):
  165. mconsumer = Mock()
  166. qos = QoS(mconsumer.qos, 10)
  167. qos.decrement_eventually()
  168. self.assertEqual(qos.value, 9)
  169. qos.value = 0
  170. qos.decrement_eventually()
  171. self.assertEqual(qos.value, 0)
  172. def test_set(self):
  173. mconsumer = Mock()
  174. qos = QoS(mconsumer.qos, 10)
  175. qos.set(12)
  176. self.assertEqual(qos.prev, 12)
  177. qos.set(qos.prev)
  178. class test_Consumer(AppCase):
  179. def setup(self):
  180. self.buffer = FastQueue()
  181. self.timer = Timer()
  182. def teardown(self):
  183. self.timer.stop()
  184. def test_info(self):
  185. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  186. l.task_consumer = Mock()
  187. l.qos = QoS(l.task_consumer.qos, 10)
  188. l.connection = Mock()
  189. l.connection.info.return_value = {'foo': 'bar'}
  190. l.controller = l.app.WorkController()
  191. l.controller.pool = Mock()
  192. l.controller.pool.info.return_value = [Mock(), Mock()]
  193. l.controller.consumer = l
  194. info = l.controller.stats()
  195. self.assertEqual(info['prefetch_count'], 10)
  196. self.assertTrue(info['broker'])
  197. def test_start_when_closed(self):
  198. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  199. l.blueprint.state = CLOSE
  200. l.start()
  201. def test_connection(self):
  202. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  203. l.blueprint.start(l)
  204. self.assertIsInstance(l.connection, Connection)
  205. l.blueprint.state = RUN
  206. l.event_dispatcher = None
  207. l.blueprint.restart(l)
  208. self.assertTrue(l.connection)
  209. l.blueprint.state = RUN
  210. l.shutdown()
  211. self.assertIsNone(l.connection)
  212. self.assertIsNone(l.task_consumer)
  213. l.blueprint.start(l)
  214. self.assertIsInstance(l.connection, Connection)
  215. l.blueprint.restart(l)
  216. l.stop()
  217. l.shutdown()
  218. self.assertIsNone(l.connection)
  219. self.assertIsNone(l.task_consumer)
  220. def test_close_connection(self):
  221. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  222. l.blueprint.state = RUN
  223. step = find_step(l, consumer.Connection)
  224. conn = l.connection = Mock()
  225. step.shutdown(l)
  226. self.assertTrue(conn.close.called)
  227. self.assertIsNone(l.connection)
  228. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  229. eventer = l.event_dispatcher = Mock()
  230. eventer.enabled = True
  231. heart = l.heart = MockHeart()
  232. l.blueprint.state = RUN
  233. Events = find_step(l, consumer.Events)
  234. Events.shutdown(l)
  235. Heart = find_step(l, consumer.Heart)
  236. Heart.shutdown(l)
  237. self.assertTrue(eventer.close.call_count)
  238. self.assertTrue(heart.closed)
  239. @patch('celery.worker.consumer.warn')
  240. def test_receive_message_unknown(self, warn):
  241. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  242. l.steps.pop()
  243. backend = Mock()
  244. m = create_message(backend, unknown={'baz': '!!!'})
  245. l.event_dispatcher = Mock()
  246. l.node = MockNode()
  247. callback = self._get_on_message(l)
  248. callback(m.decode(), m)
  249. self.assertTrue(warn.call_count)
  250. @patch('celery.worker.strategy.to_timestamp')
  251. def test_receive_message_eta_OverflowError(self, to_timestamp):
  252. to_timestamp.side_effect = OverflowError()
  253. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  254. l.steps.pop()
  255. m = create_message(Mock(), task=foo_task.name,
  256. args=('2, 2'),
  257. kwargs={},
  258. eta=datetime.now().isoformat())
  259. l.event_dispatcher = Mock()
  260. l.node = MockNode()
  261. l.update_strategies()
  262. l.qos = Mock()
  263. callback = self._get_on_message(l)
  264. callback(m.decode(), m)
  265. self.assertTrue(m.acknowledged)
  266. @patch('celery.worker.consumer.error')
  267. def test_receive_message_InvalidTaskError(self, error):
  268. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  269. l.event_dispatcher = Mock()
  270. l.steps.pop()
  271. m = create_message(Mock(), task=foo_task.name,
  272. args=(1, 2), kwargs='foobarbaz', id=1)
  273. l.update_strategies()
  274. l.event_dispatcher = Mock()
  275. callback = self._get_on_message(l)
  276. callback(m.decode(), m)
  277. self.assertIn('Received invalid task message', error.call_args[0][0])
  278. @patch('celery.worker.consumer.crit')
  279. def test_on_decode_error(self, crit):
  280. l = Consumer(self.buffer.put, timer=self.timer)
  281. class MockMessage(Mock):
  282. content_type = 'application/x-msgpack'
  283. content_encoding = 'binary'
  284. body = 'foobarbaz'
  285. message = MockMessage()
  286. l.on_decode_error(message, KeyError('foo'))
  287. self.assertTrue(message.ack.call_count)
  288. self.assertIn("Can't decode message body", crit.call_args[0][0])
  289. def _get_on_message(self, l):
  290. if l.qos is None:
  291. l.qos = Mock()
  292. l.event_dispatcher = Mock()
  293. l.task_consumer = Mock()
  294. l.connection = Mock()
  295. l.connection.drain_events.side_effect = SystemExit()
  296. with self.assertRaises(SystemExit):
  297. l.loop(*l.loop_args())
  298. self.assertTrue(l.task_consumer.register_callback.called)
  299. return l.task_consumer.register_callback.call_args[0][0]
  300. def test_receieve_message(self):
  301. l = Consumer(self.buffer.put, timer=self.timer)
  302. l.event_dispatcher = Mock()
  303. m = create_message(Mock(), task=foo_task.name,
  304. args=[2, 4, 8], kwargs={})
  305. l.update_strategies()
  306. callback = self._get_on_message(l)
  307. callback(m.decode(), m)
  308. in_bucket = self.buffer.get_nowait()
  309. self.assertIsInstance(in_bucket, Request)
  310. self.assertEqual(in_bucket.name, foo_task.name)
  311. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  312. self.assertTrue(self.timer.empty())
  313. def test_start_channel_error(self):
  314. class MockConsumer(Consumer):
  315. iterations = 0
  316. def loop(self, *args, **kwargs):
  317. if not self.iterations:
  318. self.iterations = 1
  319. raise KeyError('foo')
  320. raise SyntaxError('bar')
  321. l = MockConsumer(self.buffer.put, timer=self.timer,
  322. send_events=False, pool=BasePool())
  323. l.channel_errors = (KeyError, )
  324. with self.assertRaises(KeyError):
  325. l.start()
  326. l.timer.stop()
  327. def test_start_connection_error(self):
  328. class MockConsumer(Consumer):
  329. iterations = 0
  330. def loop(self, *args, **kwargs):
  331. if not self.iterations:
  332. self.iterations = 1
  333. raise KeyError('foo')
  334. raise SyntaxError('bar')
  335. l = MockConsumer(self.buffer.put, timer=self.timer,
  336. send_events=False, pool=BasePool())
  337. l.connection_errors = (KeyError, )
  338. self.assertRaises(SyntaxError, l.start)
  339. l.timer.stop()
  340. def test_loop_ignores_socket_timeout(self):
  341. class Connection(self.app.connection().__class__):
  342. obj = None
  343. def drain_events(self, **kwargs):
  344. self.obj.connection = None
  345. raise socket.timeout(10)
  346. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  347. l.connection = Connection()
  348. l.task_consumer = Mock()
  349. l.connection.obj = l
  350. l.qos = QoS(l.task_consumer.qos, 10)
  351. l.loop(*l.loop_args())
  352. def test_loop_when_socket_error(self):
  353. class Connection(self.app.connection().__class__):
  354. obj = None
  355. def drain_events(self, **kwargs):
  356. self.obj.connection = None
  357. raise socket.error('foo')
  358. l = Consumer(self.buffer.put, timer=self.timer)
  359. l.blueprint.state = RUN
  360. c = l.connection = Connection()
  361. l.connection.obj = l
  362. l.task_consumer = Mock()
  363. l.qos = QoS(l.task_consumer.qos, 10)
  364. with self.assertRaises(socket.error):
  365. l.loop(*l.loop_args())
  366. l.blueprint.state = CLOSE
  367. l.connection = c
  368. l.loop(*l.loop_args())
  369. def test_loop(self):
  370. class Connection(self.app.connection().__class__):
  371. obj = None
  372. def drain_events(self, **kwargs):
  373. self.obj.connection = None
  374. l = Consumer(self.buffer.put, timer=self.timer)
  375. l.connection = Connection()
  376. l.connection.obj = l
  377. l.task_consumer = Mock()
  378. l.qos = QoS(l.task_consumer.qos, 10)
  379. l.loop(*l.loop_args())
  380. l.loop(*l.loop_args())
  381. self.assertTrue(l.task_consumer.consume.call_count)
  382. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  383. self.assertEqual(l.qos.value, 10)
  384. l.qos.decrement_eventually()
  385. self.assertEqual(l.qos.value, 9)
  386. l.qos.update()
  387. self.assertEqual(l.qos.value, 9)
  388. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  389. def test_ignore_errors(self):
  390. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  391. l.connection_errors = (AttributeError, KeyError, )
  392. l.channel_errors = (SyntaxError, )
  393. ignore_errors(l, Mock(side_effect=AttributeError('foo')))
  394. ignore_errors(l, Mock(side_effect=KeyError('foo')))
  395. ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
  396. with self.assertRaises(IndexError):
  397. ignore_errors(l, Mock(side_effect=IndexError('foo')))
  398. def test_apply_eta_task(self):
  399. from celery.worker import state
  400. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  401. l.qos = QoS(None, 10)
  402. task = object()
  403. qos = l.qos.value
  404. l.apply_eta_task(task)
  405. self.assertIn(task, state.reserved_requests)
  406. self.assertEqual(l.qos.value, qos - 1)
  407. self.assertIs(self.buffer.get_nowait(), task)
  408. def test_receieve_message_eta_isoformat(self):
  409. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  410. l.steps.pop()
  411. m = create_message(
  412. Mock(), task=foo_task.name,
  413. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  414. args=[2, 4, 8], kwargs={},
  415. )
  416. l.task_consumer = Mock()
  417. l.qos = QoS(l.task_consumer.qos, 1)
  418. current_pcount = l.qos.value
  419. l.event_dispatcher = Mock()
  420. l.enabled = False
  421. l.update_strategies()
  422. callback = self._get_on_message(l)
  423. callback(m.decode(), m)
  424. l.timer.stop()
  425. l.timer.join(1)
  426. items = [entry[2] for entry in self.timer.queue]
  427. found = 0
  428. for item in items:
  429. if item.args[0].name == foo_task.name:
  430. found = True
  431. self.assertTrue(found)
  432. self.assertGreater(l.qos.value, current_pcount)
  433. l.timer.stop()
  434. def test_pidbox_callback(self):
  435. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  436. con = find_step(l, consumer.Control).box
  437. con.node = Mock()
  438. con.reset = Mock()
  439. con.on_message('foo', 'bar')
  440. con.node.handle_message.assert_called_with('foo', 'bar')
  441. con.node = Mock()
  442. con.node.handle_message.side_effect = KeyError('foo')
  443. con.on_message('foo', 'bar')
  444. con.node.handle_message.assert_called_with('foo', 'bar')
  445. con.node = Mock()
  446. con.node.handle_message.side_effect = ValueError('foo')
  447. con.on_message('foo', 'bar')
  448. con.node.handle_message.assert_called_with('foo', 'bar')
  449. self.assertTrue(con.reset.called)
  450. def test_revoke(self):
  451. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  452. l.steps.pop()
  453. backend = Mock()
  454. id = uuid()
  455. t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
  456. kwargs={}, id=id)
  457. from celery.worker.state import revoked
  458. revoked.add(id)
  459. callback = self._get_on_message(l)
  460. callback(t.decode(), t)
  461. self.assertTrue(self.buffer.empty())
  462. def test_receieve_message_not_registered(self):
  463. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  464. l.steps.pop()
  465. backend = Mock()
  466. m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
  467. l.event_dispatcher = Mock()
  468. callback = self._get_on_message(l)
  469. self.assertFalse(callback(m.decode(), m))
  470. with self.assertRaises(Empty):
  471. self.buffer.get_nowait()
  472. self.assertTrue(self.timer.empty())
  473. @patch('celery.worker.consumer.warn')
  474. @patch('celery.worker.consumer.logger')
  475. def test_receieve_message_ack_raises(self, logger, warn):
  476. l = Consumer(self.buffer.put, timer=self.timer)
  477. backend = Mock()
  478. m = create_message(backend, args=[2, 4, 8], kwargs={})
  479. l.event_dispatcher = Mock()
  480. l.connection_errors = (socket.error, )
  481. m.reject = Mock()
  482. m.reject.side_effect = socket.error('foo')
  483. callback = self._get_on_message(l)
  484. self.assertFalse(callback(m.decode(), m))
  485. self.assertTrue(warn.call_count)
  486. with self.assertRaises(Empty):
  487. self.buffer.get_nowait()
  488. self.assertTrue(self.timer.empty())
  489. m.reject.assert_called_with()
  490. self.assertTrue(logger.critical.call_count)
  491. def test_receive_message_eta(self):
  492. l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
  493. l.steps.pop()
  494. l.event_dispatcher = Mock()
  495. l.event_dispatcher._outbound_buffer = deque()
  496. backend = Mock()
  497. m = create_message(
  498. backend, task=foo_task.name,
  499. args=[2, 4, 8], kwargs={},
  500. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  501. )
  502. l.blueprint.start(l)
  503. p = l.app.conf.BROKER_CONNECTION_RETRY
  504. l.app.conf.BROKER_CONNECTION_RETRY = False
  505. try:
  506. l.blueprint.start(l)
  507. finally:
  508. l.app.conf.BROKER_CONNECTION_RETRY = p
  509. l.blueprint.restart(l)
  510. l.event_dispatcher = Mock()
  511. callback = self._get_on_message(l)
  512. callback(m.decode(), m)
  513. l.timer.stop()
  514. in_hold = l.timer.queue[0]
  515. self.assertEqual(len(in_hold), 3)
  516. eta, priority, entry = in_hold
  517. task = entry.args[0]
  518. self.assertIsInstance(task, Request)
  519. self.assertEqual(task.name, foo_task.name)
  520. self.assertEqual(task.execute(), 2 * 4 * 8)
  521. with self.assertRaises(Empty):
  522. self.buffer.get_nowait()
  523. def test_reset_pidbox_node(self):
  524. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  525. con = find_step(l, consumer.Control).box
  526. con.node = Mock()
  527. chan = con.node.channel = Mock()
  528. l.connection = Mock()
  529. chan.close.side_effect = socket.error('foo')
  530. l.connection_errors = (socket.error, )
  531. con.reset()
  532. chan.close.assert_called_with()
  533. def test_reset_pidbox_node_green(self):
  534. from celery.worker.pidbox import gPidbox
  535. pool = Mock()
  536. pool.is_green = True
  537. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
  538. con = find_step(l, consumer.Control)
  539. self.assertIsInstance(con.box, gPidbox)
  540. con.start(l)
  541. l.pool.spawn_n.assert_called_with(
  542. con.box.loop, l,
  543. )
  544. def test__green_pidbox_node(self):
  545. pool = Mock()
  546. pool.is_green = True
  547. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
  548. l.node = Mock()
  549. controller = find_step(l, consumer.Control)
  550. class BConsumer(Mock):
  551. def __enter__(self):
  552. self.consume()
  553. return self
  554. def __exit__(self, *exc_info):
  555. self.cancel()
  556. controller.box.node.listen = BConsumer()
  557. connections = []
  558. class Connection(object):
  559. calls = 0
  560. def __init__(self, obj):
  561. connections.append(self)
  562. self.obj = obj
  563. self.default_channel = self.channel()
  564. self.closed = False
  565. def __enter__(self):
  566. return self
  567. def __exit__(self, *exc_info):
  568. self.close()
  569. def channel(self):
  570. return Mock()
  571. def as_uri(self):
  572. return 'dummy://'
  573. def drain_events(self, **kwargs):
  574. if not self.calls:
  575. self.calls += 1
  576. raise socket.timeout()
  577. self.obj.connection = None
  578. controller.box._node_shutdown.set()
  579. def close(self):
  580. self.closed = True
  581. l.connection = Mock()
  582. l.connect = lambda: Connection(obj=l)
  583. controller = find_step(l, consumer.Control)
  584. controller.box.loop(l)
  585. self.assertTrue(controller.box.node.listen.called)
  586. self.assertTrue(controller.box.consumer)
  587. controller.box.consumer.consume.assert_called_with()
  588. self.assertIsNone(l.connection)
  589. self.assertTrue(connections[0].closed)
  590. @patch('kombu.connection.Connection._establish_connection')
  591. @patch('kombu.utils.sleep')
  592. def test_connect_errback(self, sleep, connect):
  593. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  594. from kombu.transport.memory import Transport
  595. Transport.connection_errors = (StdChannelError, )
  596. def effect():
  597. if connect.call_count > 1:
  598. return
  599. raise StdChannelError()
  600. connect.side_effect = effect
  601. l.connect()
  602. connect.assert_called_with()
  603. def test_stop_pidbox_node(self):
  604. l = MyKombuConsumer(self.buffer.put, timer=self.timer)
  605. cont = find_step(l, consumer.Control)
  606. cont._node_stopped = Event()
  607. cont._node_shutdown = Event()
  608. cont._node_stopped.set()
  609. cont.stop(l)
  610. def test_start__loop(self):
  611. class _QoS(object):
  612. prev = 3
  613. value = 4
  614. def update(self):
  615. self.prev = self.value
  616. class _Consumer(MyKombuConsumer):
  617. iterations = 0
  618. def reset_connection(self):
  619. if self.iterations >= 1:
  620. raise KeyError('foo')
  621. init_callback = Mock()
  622. l = _Consumer(self.buffer.put, timer=self.timer,
  623. init_callback=init_callback)
  624. l.task_consumer = Mock()
  625. l.broadcast_consumer = Mock()
  626. l.qos = _QoS()
  627. l.connection = Connection()
  628. l.iterations = 0
  629. def raises_KeyError(*args, **kwargs):
  630. l.iterations += 1
  631. if l.qos.prev != l.qos.value:
  632. l.qos.update()
  633. if l.iterations >= 2:
  634. raise KeyError('foo')
  635. l.loop = raises_KeyError
  636. with self.assertRaises(KeyError):
  637. l.start()
  638. self.assertEqual(l.iterations, 2)
  639. self.assertEqual(l.qos.prev, l.qos.value)
  640. init_callback.reset_mock()
  641. l = _Consumer(self.buffer.put, timer=self.timer,
  642. send_events=False, init_callback=init_callback)
  643. l.qos = _QoS()
  644. l.task_consumer = Mock()
  645. l.broadcast_consumer = Mock()
  646. l.connection = Connection()
  647. l.loop = Mock(side_effect=socket.error('foo'))
  648. with self.assertRaises(socket.error):
  649. l.start()
  650. self.assertTrue(l.loop.call_count)
  651. def test_reset_connection_with_no_node(self):
  652. l = Consumer(self.buffer.put, timer=self.timer)
  653. l.steps.pop()
  654. self.assertEqual(None, l.pool)
  655. l.blueprint.start(l)
  656. class test_WorkController(AppCase):
  657. def setup(self):
  658. self.worker = self.create_worker()
  659. from celery import worker
  660. self._logger = worker.logger
  661. self._comp_logger = components.logger
  662. self.logger = worker.logger = Mock()
  663. self.comp_logger = components.logger = Mock()
  664. def teardown(self):
  665. from celery import worker
  666. worker.logger = self._logger
  667. components.logger = self._comp_logger
  668. def create_worker(self, **kw):
  669. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  670. worker.blueprint.shutdown_complete.set()
  671. return worker
  672. @patch('celery.platforms.create_pidlock')
  673. def test_use_pidfile(self, create_pidlock):
  674. create_pidlock.return_value = Mock()
  675. worker = self.create_worker(pidfile='pidfilelockfilepid')
  676. worker.steps = []
  677. worker.start()
  678. self.assertTrue(create_pidlock.called)
  679. worker.stop()
  680. self.assertTrue(worker.pidlock.release.called)
  681. @patch('celery.platforms.signals')
  682. @patch('celery.platforms.set_mp_process_title')
  683. def test_process_initializer(self, set_mp_process_title, _signals):
  684. from celery import Celery
  685. from celery import signals
  686. from celery._state import _tls
  687. from celery.concurrency.processes import process_initializer
  688. from celery.concurrency.processes import (WORKER_SIGRESET,
  689. WORKER_SIGIGNORE)
  690. def on_worker_process_init(**kwargs):
  691. on_worker_process_init.called = True
  692. on_worker_process_init.called = False
  693. signals.worker_process_init.connect(on_worker_process_init)
  694. loader = Mock()
  695. loader.override_backends = {}
  696. app = Celery(loader=loader, set_as_current=False)
  697. app.loader = loader
  698. app.conf = AttributeDict(DEFAULTS)
  699. process_initializer(app, 'awesome.worker.com')
  700. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  701. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  702. self.assertTrue(app.loader.init_worker.call_count)
  703. self.assertTrue(on_worker_process_init.called)
  704. self.assertIs(_tls.current_app, app)
  705. set_mp_process_title.assert_called_with(
  706. 'celeryd', hostname='awesome.worker.com',
  707. )
  708. with patch('celery.task.trace.setup_worker_optimizations') as swo:
  709. os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
  710. try:
  711. process_initializer(app, 'luke.worker.com')
  712. swo.assert_called_with(app)
  713. finally:
  714. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  715. def test_attrs(self):
  716. worker = self.worker
  717. self.assertIsInstance(worker.timer, Timer)
  718. self.assertTrue(worker.timer)
  719. self.assertTrue(worker.pool)
  720. self.assertTrue(worker.consumer)
  721. self.assertTrue(worker.steps)
  722. def test_with_embedded_beat(self):
  723. worker = WorkController(concurrency=1, loglevel=0, beat=True)
  724. self.assertTrue(worker.beat)
  725. self.assertIn(worker.beat, [w.obj for w in worker.steps])
  726. def test_with_autoscaler(self):
  727. worker = self.create_worker(
  728. autoscale=[10, 3], send_events=False,
  729. timer_cls='celery.utils.timer2.Timer',
  730. )
  731. self.assertTrue(worker.autoscaler)
  732. def test_dont_stop_or_terminate(self):
  733. worker = WorkController(concurrency=1, loglevel=0)
  734. worker.stop()
  735. self.assertNotEqual(worker.blueprint.state, CLOSE)
  736. worker.terminate()
  737. self.assertNotEqual(worker.blueprint.state, CLOSE)
  738. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  739. try:
  740. worker.blueprint.state = RUN
  741. worker.stop(in_sighandler=True)
  742. self.assertNotEqual(worker.blueprint.state, CLOSE)
  743. worker.terminate(in_sighandler=True)
  744. self.assertNotEqual(worker.blueprint.state, CLOSE)
  745. finally:
  746. worker.pool.signal_safe = sigsafe
  747. def test_on_timer_error(self):
  748. worker = WorkController(concurrency=1, loglevel=0)
  749. try:
  750. raise KeyError('foo')
  751. except KeyError as exc:
  752. components.Timer(worker).on_timer_error(exc)
  753. msg, args = self.comp_logger.error.call_args[0]
  754. self.assertIn('KeyError', msg % args)
  755. def test_on_timer_tick(self):
  756. worker = WorkController(concurrency=1, loglevel=10)
  757. components.Timer(worker).on_timer_tick(30.0)
  758. xargs = self.comp_logger.debug.call_args[0]
  759. fmt, arg = xargs[0], xargs[1]
  760. self.assertEqual(30.0, arg)
  761. self.assertIn('Next eta %s secs', fmt)
  762. def test_process_task(self):
  763. worker = self.worker
  764. worker.pool = Mock()
  765. backend = Mock()
  766. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  767. kwargs={})
  768. task = Request.from_message(m, m.decode())
  769. worker._process_task(task)
  770. self.assertEqual(worker.pool.apply_async.call_count, 1)
  771. worker.pool.stop()
  772. def test_process_task_raise_base(self):
  773. worker = self.worker
  774. worker.pool = Mock()
  775. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  776. backend = Mock()
  777. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  778. kwargs={})
  779. task = Request.from_message(m, m.decode())
  780. worker.steps = []
  781. worker.blueprint.state = RUN
  782. with self.assertRaises(KeyboardInterrupt):
  783. worker._process_task(task)
  784. self.assertEqual(worker.blueprint.state, TERMINATE)
  785. def test_process_task_raise_SystemTerminate(self):
  786. worker = self.worker
  787. worker.pool = Mock()
  788. worker.pool.apply_async.side_effect = SystemTerminate()
  789. backend = Mock()
  790. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  791. kwargs={})
  792. task = Request.from_message(m, m.decode())
  793. worker.steps = []
  794. worker.blueprint.state = RUN
  795. with self.assertRaises(SystemExit):
  796. worker._process_task(task)
  797. self.assertEqual(worker.blueprint.state, TERMINATE)
  798. def test_process_task_raise_regular(self):
  799. worker = self.worker
  800. worker.pool = Mock()
  801. worker.pool.apply_async.side_effect = KeyError('some exception')
  802. backend = Mock()
  803. m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
  804. kwargs={})
  805. task = Request.from_message(m, m.decode())
  806. worker._process_task(task)
  807. worker.pool.stop()
  808. def test_start_catches_base_exceptions(self):
  809. worker1 = self.create_worker()
  810. stc = MockStep()
  811. stc.start.side_effect = SystemTerminate()
  812. worker1.steps = [stc]
  813. worker1.start()
  814. stc.start.assert_called_with(worker1)
  815. self.assertTrue(stc.terminate.call_count)
  816. worker2 = self.create_worker()
  817. sec = MockStep()
  818. sec.start.side_effect = SystemExit()
  819. sec.terminate = None
  820. worker2.steps = [sec]
  821. worker2.start()
  822. self.assertTrue(sec.stop.call_count)
  823. def test_state_db(self):
  824. from celery.worker import state
  825. Persistent = state.Persistent
  826. state.Persistent = Mock()
  827. try:
  828. worker = self.create_worker(state_db='statefilename')
  829. self.assertTrue(worker._persistence)
  830. finally:
  831. state.Persistent = Persistent
  832. def test_process_task_sem(self):
  833. worker = self.worker
  834. worker._quick_acquire = Mock()
  835. req = Mock()
  836. worker._process_task_sem(req)
  837. worker._quick_acquire.assert_called_with(worker._process_task, req)
  838. def test_signal_consumer_close(self):
  839. worker = self.worker
  840. worker.consumer = Mock()
  841. worker.signal_consumer_close()
  842. worker.consumer.close.assert_called_with()
  843. worker.consumer.close.side_effect = AttributeError()
  844. worker.signal_consumer_close()
  845. def test_start__stop(self):
  846. worker = self.worker
  847. worker.blueprint.shutdown_complete.set()
  848. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  849. worker.blueprint.state = RUN
  850. worker.blueprint.started = 4
  851. for w in worker.steps:
  852. w.start = Mock()
  853. w.close = Mock()
  854. w.stop = Mock()
  855. worker.start()
  856. for w in worker.steps:
  857. self.assertTrue(w.start.call_count)
  858. worker.consumer = Mock()
  859. worker.stop()
  860. for stopstep in worker.steps:
  861. self.assertTrue(stopstep.close.call_count)
  862. self.assertTrue(stopstep.stop.call_count)
  863. # Doesn't close pool if no pool.
  864. worker.start()
  865. worker.pool = None
  866. worker.stop()
  867. # test that stop of None is not attempted
  868. worker.steps[-1] = None
  869. worker.start()
  870. worker.stop()
  871. def test_step_raises(self):
  872. worker = self.worker
  873. step = Mock()
  874. worker.steps = [step]
  875. step.start.side_effect = TypeError()
  876. worker.stop = Mock()
  877. worker.start()
  878. worker.stop.assert_called_with()
  879. def test_state(self):
  880. self.assertTrue(self.worker.state)
  881. def test_start__terminate(self):
  882. worker = self.worker
  883. worker.blueprint.shutdown_complete.set()
  884. worker.blueprint.started = 5
  885. worker.blueprint.state = RUN
  886. worker.steps = [MockStep() for _ in range(5)]
  887. worker.start()
  888. for w in worker.steps[:3]:
  889. self.assertTrue(w.start.call_count)
  890. self.assertTrue(worker.blueprint.started, len(worker.steps))
  891. self.assertEqual(worker.blueprint.state, RUN)
  892. worker.terminate()
  893. for step in worker.steps:
  894. self.assertTrue(step.terminate.call_count)
  895. def test_Queues_pool_no_sem(self):
  896. w = Mock()
  897. w.pool_cls.uses_semaphore = False
  898. components.Queues(w).create(w)
  899. self.assertIs(w.process_task, w._process_task)
  900. def test_Hub_crate(self):
  901. w = Mock()
  902. x = components.Hub(w)
  903. hub = x.create(w)
  904. self.assertTrue(w.timer.max_interval)
  905. self.assertIs(w.hub, hub)
  906. def test_Pool_crate_threaded(self):
  907. w = Mock()
  908. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  909. w.pool_cls = Mock()
  910. w.use_eventloop = False
  911. pool = components.Pool(w)
  912. pool.create(w)
  913. def test_Pool_create(self):
  914. from celery.worker.hub import BoundedSemaphore
  915. w = Mock()
  916. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  917. w.hub = Mock()
  918. w.hub.on_init = []
  919. PoolImp = Mock()
  920. poolimp = PoolImp.return_value = Mock()
  921. poolimp._pool = [Mock(), Mock()]
  922. poolimp._cache = {}
  923. poolimp._fileno_to_inq = {}
  924. poolimp._fileno_to_outq = {}
  925. from celery.concurrency.processes import TaskPool as _TaskPool
  926. class MockTaskPool(_TaskPool):
  927. Pool = PoolImp
  928. @property
  929. def timers(self):
  930. return {Mock(): 30}
  931. w.pool_cls = MockTaskPool
  932. w.use_eventloop = True
  933. w.consumer.restart_count = -1
  934. pool = components.Pool(w)
  935. pool.create(w)
  936. self.assertIsInstance(w.semaphore, BoundedSemaphore)
  937. self.assertTrue(w.hub.on_init)
  938. P = w.pool
  939. P.start()
  940. hub = Mock()
  941. w.hub.on_init[0](hub)
  942. w = Mock()
  943. poolimp.on_process_up(w)
  944. hub.add.assert_has_calls([
  945. call(w.sentinel, P.maintain_pool, READ | ERR),
  946. call(w.outqR_fd, P.handle_result_event, READ | ERR),
  947. ])
  948. poolimp.on_process_down(w)
  949. hub.remove.assert_has_calls([
  950. call(w.sentinel), call(w.outqR_fd),
  951. ])
  952. result = Mock()
  953. tref = result._tref
  954. poolimp.on_timeout_cancel(result)
  955. tref.cancel.assert_called_with()
  956. poolimp.on_timeout_cancel(result) # no more tref
  957. poolimp.on_timeout_set(result, 10, 20)
  958. tsoft, callback = hub.timer.apply_after.call_args[0]
  959. callback()
  960. poolimp.on_timeout_set(result, 10, None)
  961. tsoft, callback = hub.timer.apply_after.call_args[0]
  962. callback()
  963. poolimp.on_timeout_set(result, None, 10)
  964. poolimp.on_timeout_set(result, None, None)
  965. with self.assertRaises(WorkerLostError):
  966. P._pool.did_start_ok = Mock()
  967. P._pool.did_start_ok.return_value = False
  968. w.consumer.restart_count = 0
  969. P.on_poll_init(w, hub)