test_consumer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. from __future__ import absolute_import
  2. import errno
  3. import socket
  4. from billiard.exceptions import RestartFreqExceeded
  5. from celery.datastructures import LimitedSet
  6. from celery.worker import state as worker_state
  7. from celery.worker.consumer import (
  8. Consumer,
  9. Heart,
  10. Tasks,
  11. Agent,
  12. Mingle,
  13. Gossip,
  14. dump_body,
  15. CLOSE,
  16. )
  17. from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
  18. class test_Consumer(AppCase):
  19. def get_consumer(self, no_hub=False, **kwargs):
  20. consumer = Consumer(
  21. on_task_request=Mock(),
  22. init_callback=Mock(),
  23. pool=Mock(),
  24. app=self.app,
  25. timer=Mock(),
  26. controller=Mock(),
  27. hub=None if no_hub else Mock(),
  28. **kwargs
  29. )
  30. consumer.blueprint = Mock(name='blueprint')
  31. consumer._restart_state = Mock(name='_restart_state')
  32. consumer.connection = _amqp_connection()
  33. consumer.connection_errors = (socket.error, OSError,)
  34. consumer.conninfo = consumer.connection
  35. return consumer
  36. def test_repr(self):
  37. self.assertTrue(repr(self.get_consumer()))
  38. def test_taskbuckets_defaultdict(self):
  39. c = self.get_consumer()
  40. self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])
  41. def test_dump_body_buffer(self):
  42. msg = Mock()
  43. msg.body = 'str'
  44. try:
  45. buf = buffer(msg.body)
  46. except NameError:
  47. raise SkipTest('buffer type not available')
  48. self.assertTrue(dump_body(msg, buf))
  49. def test_sets_heartbeat(self):
  50. c = self.get_consumer(amqheartbeat=10)
  51. self.assertEqual(c.amqheartbeat, 10)
  52. self.app.conf.broker_heartbeat = 20
  53. c = self.get_consumer(amqheartbeat=None)
  54. self.assertEqual(c.amqheartbeat, 20)
  55. def test_gevent_bug_disables_connection_timeout(self):
  56. with patch('celery.worker.consumer._detect_environment') as de:
  57. de.return_value = 'gevent'
  58. self.app.conf.broker_connection_timeout = 33.33
  59. self.get_consumer()
  60. self.assertIsNone(self.app.conf.broker_connection_timeout)
  61. def test_limit_moved_to_pool(self):
  62. with patch('celery.worker.consumer.task_reserved') as reserved:
  63. c = self.get_consumer()
  64. c.on_task_request = Mock(name='on_task_request')
  65. request = Mock(name='request')
  66. c._limit_move_to_pool(request)
  67. reserved.assert_called_with(request)
  68. c.on_task_request.assert_called_with(request)
  69. def test_update_prefetch_count(self):
  70. c = self.get_consumer()
  71. c._update_qos_eventually = Mock(name='update_qos')
  72. c.initial_prefetch_count = None
  73. c.pool.num_processes = None
  74. c.prefetch_multiplier = 10
  75. self.assertIsNone(c._update_prefetch_count(1))
  76. c.initial_prefetch_count = 10
  77. c.pool.num_processes = 10
  78. c._update_prefetch_count(8)
  79. c._update_qos_eventually.assert_called_with(8)
  80. self.assertEqual(c.initial_prefetch_count, 10 * 10)
  81. def test_flush_events(self):
  82. c = self.get_consumer()
  83. c.event_dispatcher = None
  84. c._flush_events()
  85. c.event_dispatcher = Mock(name='evd')
  86. c._flush_events()
  87. c.event_dispatcher.flush.assert_called_with()
  88. def test_on_send_event_buffered(self):
  89. c = self.get_consumer()
  90. c.hub = None
  91. c.on_send_event_buffered()
  92. c.hub = Mock(name='hub')
  93. c.on_send_event_buffered()
  94. c.hub._ready.add.assert_called_with(c._flush_events)
  95. def test_limit_task(self):
  96. c = self.get_consumer()
  97. with patch('celery.worker.consumer.task_reserved') as reserved:
  98. bucket = Mock()
  99. request = Mock()
  100. bucket.can_consume.return_value = True
  101. c._limit_task(request, bucket, 3)
  102. bucket.can_consume.assert_called_with(3)
  103. reserved.assert_called_with(request)
  104. c.on_task_request.assert_called_with(request)
  105. with patch('celery.worker.consumer.task_reserved') as reserved:
  106. bucket.can_consume.return_value = False
  107. bucket.expected_time.return_value = 3.33
  108. limit_order = c._limit_order
  109. c._limit_task(request, bucket, 4)
  110. self.assertEqual(c._limit_order, limit_order + 1)
  111. bucket.can_consume.assert_called_with(4)
  112. c.timer.call_after.assert_called_with(
  113. 3.33, c._limit_move_to_pool, (request,),
  114. priority=c._limit_order,
  115. )
  116. bucket.expected_time.assert_called_with(4)
  117. self.assertFalse(reserved.called)
  118. def test_start_blueprint_raises_EMFILE(self):
  119. c = self.get_consumer()
  120. exc = c.blueprint.start.side_effect = OSError()
  121. exc.errno = errno.EMFILE
  122. with self.assertRaises(OSError):
  123. c.start()
  124. def test_max_restarts_exceeded(self):
  125. c = self.get_consumer()
  126. def se(*args, **kwargs):
  127. c.blueprint.state = CLOSE
  128. raise RestartFreqExceeded()
  129. c._restart_state.step.side_effect = se
  130. c.blueprint.start.side_effect = socket.error()
  131. with patch('celery.worker.consumer.sleep') as sleep:
  132. c.start()
  133. sleep.assert_called_with(1)
  134. def test_no_retry_raises_error(self):
  135. self.app.conf.broker_connection_retry = False
  136. c = self.get_consumer()
  137. c.blueprint.start.side_effect = socket.error()
  138. with self.assertRaises(socket.error):
  139. c.start()
  140. def _closer(self, c):
  141. def se(*args, **kwargs):
  142. c.blueprint.state = CLOSE
  143. return se
  144. def test_collects_at_restart(self):
  145. c = self.get_consumer()
  146. c.connection.collect.side_effect = MemoryError()
  147. c.blueprint.start.side_effect = socket.error()
  148. c.blueprint.restart.side_effect = self._closer(c)
  149. c.start()
  150. c.connection.collect.assert_called_with()
  151. def test_register_with_event_loop(self):
  152. c = self.get_consumer()
  153. c.register_with_event_loop(Mock(name='loop'))
  154. def test_on_close_clears_semaphore_timer_and_reqs(self):
  155. with patch('celery.worker.consumer.reserved_requests') as reserved:
  156. c = self.get_consumer()
  157. c.on_close()
  158. c.controller.semaphore.clear.assert_called_with()
  159. c.timer.clear.assert_called_with()
  160. reserved.clear.assert_called_with()
  161. c.pool.flush.assert_called_with()
  162. c.controller = None
  163. c.timer = None
  164. c.pool = None
  165. c.on_close()
  166. def test_connect_error_handler(self):
  167. self.app._connection = _amqp_connection()
  168. conn = self.app._connection.return_value
  169. c = self.get_consumer()
  170. self.assertTrue(c.connect())
  171. self.assertTrue(conn.ensure_connection.called)
  172. errback = conn.ensure_connection.call_args[0][0]
  173. conn.alt = [(1, 2, 3)]
  174. errback(Mock(), 0)
  175. class test_Heart(AppCase):
  176. def test_start(self):
  177. c = Mock()
  178. c.timer = Mock()
  179. c.event_dispatcher = Mock()
  180. with patch('celery.worker.heartbeat.Heart') as hcls:
  181. h = Heart(c)
  182. self.assertTrue(h.enabled)
  183. self.assertEqual(h.heartbeat_interval, None)
  184. self.assertIsNone(c.heart)
  185. h.start(c)
  186. self.assertTrue(c.heart)
  187. hcls.assert_called_with(c.timer, c.event_dispatcher,
  188. h.heartbeat_interval)
  189. c.heart.start.assert_called_with()
  190. def test_start_heartbeat_interval(self):
  191. c = Mock()
  192. c.timer = Mock()
  193. c.event_dispatcher = Mock()
  194. with patch('celery.worker.heartbeat.Heart') as hcls:
  195. h = Heart(c, False, 20)
  196. self.assertTrue(h.enabled)
  197. self.assertEqual(h.heartbeat_interval, 20)
  198. self.assertIsNone(c.heart)
  199. h.start(c)
  200. self.assertTrue(c.heart)
  201. hcls.assert_called_with(c.timer, c.event_dispatcher,
  202. h.heartbeat_interval)
  203. c.heart.start.assert_called_with()
  204. class test_Tasks(AppCase):
  205. def test_stop(self):
  206. c = Mock()
  207. tasks = Tasks(c)
  208. self.assertIsNone(c.task_consumer)
  209. self.assertIsNone(c.qos)
  210. c.task_consumer = Mock()
  211. tasks.stop(c)
  212. def test_stop_already_stopped(self):
  213. c = Mock()
  214. tasks = Tasks(c)
  215. tasks.stop(c)
  216. class test_Agent(AppCase):
  217. def test_start(self):
  218. c = Mock()
  219. agent = Agent(c)
  220. agent.instantiate = Mock()
  221. agent.agent_cls = 'foo:Agent'
  222. self.assertIsNotNone(agent.create(c))
  223. agent.instantiate.assert_called_with(agent.agent_cls, c.connection)
  224. class test_Mingle(AppCase):
  225. def test_start_no_replies(self):
  226. c = Mock()
  227. c.app.connection_for_read = _amqp_connection()
  228. mingle = Mingle(c)
  229. I = c.app.control.inspect.return_value = Mock()
  230. I.hello.return_value = {}
  231. mingle.start(c)
  232. def test_start(self):
  233. try:
  234. c = Mock()
  235. c.app.connection_for_read = _amqp_connection()
  236. mingle = Mingle(c)
  237. self.assertTrue(mingle.enabled)
  238. Aig = LimitedSet()
  239. Big = LimitedSet()
  240. Aig.add('Aig-1')
  241. Aig.add('Aig-2')
  242. Big.add('Big-1')
  243. I = c.app.control.inspect.return_value = Mock()
  244. I.hello.return_value = {
  245. 'A@example.com': {
  246. 'clock': 312,
  247. 'revoked': Aig._data,
  248. },
  249. 'B@example.com': {
  250. 'clock': 29,
  251. 'revoked': Big._data,
  252. },
  253. 'C@example.com': {
  254. 'error': 'unknown method',
  255. },
  256. }
  257. mingle.start(c)
  258. I.hello.assert_called_with(c.hostname, worker_state.revoked._data)
  259. c.app.clock.adjust.assert_has_calls([
  260. call(312), call(29),
  261. ], any_order=True)
  262. self.assertIn('Aig-1', worker_state.revoked)
  263. self.assertIn('Aig-2', worker_state.revoked)
  264. self.assertIn('Big-1', worker_state.revoked)
  265. finally:
  266. worker_state.revoked.clear()
  267. def _amqp_connection():
  268. connection = ContextMock(name='Connection')
  269. connection.return_value = ContextMock(name='connection')
  270. connection.return_value.transport.driver_type = 'amqp'
  271. return connection
  272. class test_Gossip(AppCase):
  273. def test_init(self):
  274. c = self.Consumer()
  275. c.app.connection_for_read = _amqp_connection()
  276. g = Gossip(c)
  277. self.assertTrue(g.enabled)
  278. self.assertIs(c.gossip, g)
  279. def test_callbacks(self):
  280. c = self.Consumer()
  281. c.app.connection_for_read = _amqp_connection()
  282. g = Gossip(c)
  283. on_node_join = Mock(name='on_node_join')
  284. on_node_join2 = Mock(name='on_node_join2')
  285. on_node_leave = Mock(name='on_node_leave')
  286. on_node_lost = Mock(name='on.node_lost')
  287. g.on.node_join.add(on_node_join)
  288. g.on.node_join.add(on_node_join2)
  289. g.on.node_leave.add(on_node_leave)
  290. g.on.node_lost.add(on_node_lost)
  291. worker = Mock(name='worker')
  292. g.on_node_join(worker)
  293. on_node_join.assert_called_with(worker)
  294. on_node_join2.assert_called_with(worker)
  295. g.on_node_leave(worker)
  296. on_node_leave.assert_called_with(worker)
  297. g.on_node_lost(worker)
  298. on_node_lost.assert_called_with(worker)
  299. def test_election(self):
  300. c = self.Consumer()
  301. c.app.connection_for_read = _amqp_connection()
  302. g = Gossip(c)
  303. g.start(c)
  304. g.election('id', 'topic', 'action')
  305. self.assertListEqual(g.consensus_replies['id'], [])
  306. g.dispatcher.send.assert_called_with(
  307. 'worker-elect', id='id', topic='topic', cver=1, action='action',
  308. )
  309. def test_call_task(self):
  310. c = self.Consumer()
  311. c.app.connection_for_read = _amqp_connection()
  312. g = Gossip(c)
  313. g.start(c)
  314. with patch('celery.worker.consumer.signature') as signature:
  315. sig = signature.return_value = Mock()
  316. task = Mock()
  317. g.call_task(task)
  318. signature.assert_called_with(task, app=c.app)
  319. sig.apply_async.assert_called_with()
  320. sig.apply_async.side_effect = MemoryError()
  321. with patch('celery.worker.consumer.error') as error:
  322. g.call_task(task)
  323. self.assertTrue(error.called)
  324. def Event(self, id='id', clock=312,
  325. hostname='foo@example.com', pid=4312,
  326. topic='topic', action='action', cver=1):
  327. return {
  328. 'id': id,
  329. 'clock': clock,
  330. 'hostname': hostname,
  331. 'pid': pid,
  332. 'topic': topic,
  333. 'action': action,
  334. 'cver': cver,
  335. }
  336. def test_on_elect(self):
  337. c = self.Consumer()
  338. c.app.connection_for_read = _amqp_connection()
  339. g = Gossip(c)
  340. g.start(c)
  341. event = self.Event('id1')
  342. g.on_elect(event)
  343. in_heap = g.consensus_requests['id1']
  344. self.assertTrue(in_heap)
  345. g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
  346. event.pop('clock')
  347. with patch('celery.worker.consumer.error') as error:
  348. g.on_elect(event)
  349. self.assertTrue(error.called)
  350. def Consumer(self, hostname='foo@x.com', pid=4312):
  351. c = Mock()
  352. c.app.connection = _amqp_connection()
  353. c.hostname = hostname
  354. c.pid = pid
  355. return c
  356. def setup_election(self, g, c):
  357. g.start(c)
  358. g.clock = self.app.clock
  359. self.assertNotIn('idx', g.consensus_replies)
  360. self.assertIsNone(g.on_elect_ack({'id': 'idx'}))
  361. g.state.alive_workers.return_value = [
  362. 'foo@x.com', 'bar@x.com', 'baz@x.com',
  363. ]
  364. g.consensus_replies['id1'] = []
  365. g.consensus_requests['id1'] = []
  366. e1 = self.Event('id1', 1, 'foo@x.com')
  367. e2 = self.Event('id1', 2, 'bar@x.com')
  368. e3 = self.Event('id1', 3, 'baz@x.com')
  369. g.on_elect(e1)
  370. g.on_elect(e2)
  371. g.on_elect(e3)
  372. self.assertEqual(len(g.consensus_requests['id1']), 3)
  373. with patch('celery.worker.consumer.info'):
  374. g.on_elect_ack(e1)
  375. self.assertEqual(len(g.consensus_replies['id1']), 1)
  376. g.on_elect_ack(e2)
  377. self.assertEqual(len(g.consensus_replies['id1']), 2)
  378. g.on_elect_ack(e3)
  379. with self.assertRaises(KeyError):
  380. g.consensus_replies['id1']
  381. def test_on_elect_ack_win(self):
  382. c = self.Consumer(hostname='foo@x.com') # I will win
  383. c.app.connection_for_read = _amqp_connection()
  384. g = Gossip(c)
  385. handler = g.election_handlers['topic'] = Mock()
  386. self.setup_election(g, c)
  387. handler.assert_called_with('action')
  388. def test_on_elect_ack_lose(self):
  389. c = self.Consumer(hostname='bar@x.com') # I will lose
  390. c.app.connection_for_read = _amqp_connection()
  391. g = Gossip(c)
  392. handler = g.election_handlers['topic'] = Mock()
  393. self.setup_election(g, c)
  394. self.assertFalse(handler.called)
  395. def test_on_elect_ack_win_but_no_action(self):
  396. c = self.Consumer(hostname='foo@x.com') # I will win
  397. c.app.connection_for_read = _amqp_connection()
  398. g = Gossip(c)
  399. g.election_handlers = {}
  400. with patch('celery.worker.consumer.error') as error:
  401. self.setup_election(g, c)
  402. self.assertTrue(error.called)
  403. def test_on_node_join(self):
  404. c = self.Consumer()
  405. c.app.connection_for_read = _amqp_connection()
  406. g = Gossip(c)
  407. with patch('celery.worker.consumer.debug') as debug:
  408. g.on_node_join(c)
  409. debug.assert_called_with('%s joined the party', 'foo@x.com')
  410. def test_on_node_leave(self):
  411. c = self.Consumer()
  412. c.app.connection_for_read = _amqp_connection()
  413. g = Gossip(c)
  414. with patch('celery.worker.consumer.debug') as debug:
  415. g.on_node_leave(c)
  416. debug.assert_called_with('%s left', 'foo@x.com')
  417. def test_on_node_lost(self):
  418. c = self.Consumer()
  419. c.app.connection_for_read = _amqp_connection()
  420. g = Gossip(c)
  421. with patch('celery.worker.consumer.info') as info:
  422. g.on_node_lost(c)
  423. info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
  424. def test_register_timer(self):
  425. c = self.Consumer()
  426. c.app.connection_for_read = _amqp_connection()
  427. g = Gossip(c)
  428. g.register_timer()
  429. c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
  430. tref = g._tref
  431. g.register_timer()
  432. tref.cancel.assert_called_with()
  433. def test_periodic(self):
  434. c = self.Consumer()
  435. c.app.connection_for_read = _amqp_connection()
  436. g = Gossip(c)
  437. g.on_node_lost = Mock()
  438. state = g.state = Mock()
  439. worker = Mock()
  440. state.workers = {'foo': worker}
  441. worker.alive = True
  442. worker.hostname = 'foo'
  443. g.periodic()
  444. worker.alive = False
  445. g.periodic()
  446. g.on_node_lost.assert_called_with(worker)
  447. with self.assertRaises(KeyError):
  448. state.workers['foo']
  449. def test_on_message__task(self):
  450. c = self.Consumer()
  451. c.app.connection_for_read = _amqp_connection()
  452. g = Gossip(c)
  453. self.assertTrue(g.enabled)
  454. message = Mock(name='message')
  455. message.delivery_info = {'routing_key': 'task.failed'}
  456. g.on_message(Mock(name='prepare'), message)
  457. def test_on_message(self):
  458. c = self.Consumer()
  459. c.app.connection_for_read = _amqp_connection()
  460. g = Gossip(c)
  461. self.assertTrue(g.enabled)
  462. prepare = Mock()
  463. prepare.return_value = 'worker-online', {}
  464. c.app.events.State.assert_called_with(
  465. on_node_join=g.on_node_join,
  466. on_node_leave=g.on_node_leave,
  467. max_tasks_in_memory=1,
  468. )
  469. g.update_state = Mock()
  470. worker = Mock()
  471. g.on_node_join = Mock()
  472. g.on_node_leave = Mock()
  473. g.update_state.return_value = worker, 1
  474. message = Mock()
  475. message.delivery_info = {'routing_key': 'worker-online'}
  476. message.headers = {'hostname': 'other'}
  477. handler = g.event_handlers['worker-online'] = Mock()
  478. g.on_message(prepare, message)
  479. handler.assert_called_with(message.payload)
  480. g.event_handlers = {}
  481. g.on_message(prepare, message)
  482. message.delivery_info = {'routing_key': 'worker-offline'}
  483. prepare.return_value = 'worker-offline', {}
  484. g.on_message(prepare, message)
  485. message.delivery_info = {'routing_key': 'worker-baz'}
  486. prepare.return_value = 'worker-baz', {}
  487. g.update_state.return_value = worker, 0
  488. g.on_message(prepare, message)
  489. message.headers = {'hostname': g.hostname}
  490. g.on_message(prepare, message)
  491. g.clock.forward.assert_called_with()