test_consumer.py 20 KB


  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import socket
  4. from collections import deque
  5. import pytest
  6. from billiard.exceptions import RestartFreqExceeded
  7. from case import ContextMock, Mock, call, patch, skip
  8. from celery.utils.collections import LimitedSet
  9. from celery.worker.consumer.agent import Agent
  10. from celery.worker.consumer.consumer import (CLOSE, TERMINATE, Consumer,
  11. dump_body)
  12. from celery.worker.consumer.gossip import Gossip
  13. from celery.worker.consumer.heart import Heart
  14. from celery.worker.consumer.mingle import Mingle
  15. from celery.worker.consumer.tasks import Tasks
  16. class test_Consumer:
  17. def get_consumer(self, no_hub=False, **kwargs):
  18. consumer = Consumer(
  19. on_task_request=Mock(),
  20. init_callback=Mock(),
  21. pool=Mock(),
  22. app=self.app,
  23. timer=Mock(),
  24. controller=Mock(),
  25. hub=None if no_hub else Mock(),
  26. **kwargs
  27. )
  28. consumer.blueprint = Mock(name='blueprint')
  29. consumer._restart_state = Mock(name='_restart_state')
  30. consumer.connection = _amqp_connection()
  31. consumer.connection_errors = (socket.error, OSError,)
  32. consumer.conninfo = consumer.connection
  33. return consumer
  34. def test_repr(self):
  35. assert repr(self.get_consumer())
  36. def test_taskbuckets_defaultdict(self):
  37. c = self.get_consumer()
  38. assert c.task_buckets['fooxasdwx.wewe'] is None
  39. @skip.if_python3(reason='buffer type not available')
  40. def test_dump_body_buffer(self):
  41. msg = Mock()
  42. msg.body = 'str'
  43. assert dump_body(msg, buffer(msg.body)) # noqa: F821
  44. def test_sets_heartbeat(self):
  45. c = self.get_consumer(amqheartbeat=10)
  46. assert c.amqheartbeat == 10
  47. self.app.conf.broker_heartbeat = 20
  48. c = self.get_consumer(amqheartbeat=None)
  49. assert c.amqheartbeat == 20
  50. def test_gevent_bug_disables_connection_timeout(self):
  51. with patch('celery.worker.consumer.consumer._detect_environment') as d:
  52. d.return_value = 'gevent'
  53. self.app.conf.broker_connection_timeout = 33.33
  54. self.get_consumer()
  55. assert self.app.conf.broker_connection_timeout is None
  56. def test_limit_moved_to_pool(self):
  57. with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
  58. c = self.get_consumer()
  59. c.on_task_request = Mock(name='on_task_request')
  60. request = Mock(name='request')
  61. c._limit_move_to_pool(request)
  62. reserv.assert_called_with(request)
  63. c.on_task_request.assert_called_with(request)
  64. def test_update_prefetch_count(self):
  65. c = self.get_consumer()
  66. c._update_qos_eventually = Mock(name='update_qos')
  67. c.initial_prefetch_count = None
  68. c.pool.num_processes = None
  69. c.prefetch_multiplier = 10
  70. assert c._update_prefetch_count(1) is None
  71. c.initial_prefetch_count = 10
  72. c.pool.num_processes = 10
  73. c._update_prefetch_count(8)
  74. c._update_qos_eventually.assert_called_with(8)
  75. assert c.initial_prefetch_count == 10 * 10
  76. def test_flush_events(self):
  77. c = self.get_consumer()
  78. c.event_dispatcher = None
  79. c._flush_events()
  80. c.event_dispatcher = Mock(name='evd')
  81. c._flush_events()
  82. c.event_dispatcher.flush.assert_called_with()
  83. def test_on_send_event_buffered(self):
  84. c = self.get_consumer()
  85. c.hub = None
  86. c.on_send_event_buffered()
  87. c.hub = Mock(name='hub')
  88. c.on_send_event_buffered()
  89. c.hub._ready.add.assert_called_with(c._flush_events)
  90. def test_schedule_bucket_request(self):
  91. c = self.get_consumer()
  92. c.timer = Mock()
  93. bucket = Mock()
  94. request = Mock()
  95. bucket.pop = lambda: bucket.contents.popleft()
  96. bucket.can_consume.return_value = True
  97. bucket.contents = deque()
  98. with patch(
  99. 'celery.worker.consumer.consumer.Consumer._limit_move_to_pool'
  100. ) as reserv:
  101. bucket.contents.append((request, 3))
  102. c._schedule_bucket_request(bucket)
  103. bucket.can_consume.assert_called_with(3)
  104. reserv.assert_called_with(request)
  105. bucket.can_consume.return_value = False
  106. bucket.contents = deque()
  107. bucket.expected_time.return_value = 3.33
  108. bucket.contents.append((request, 4))
  109. limit_order = c._limit_order
  110. c._schedule_bucket_request(bucket)
  111. assert c._limit_order == limit_order + 1
  112. bucket.can_consume.assert_called_with(4)
  113. c.timer.call_after.assert_called_with(
  114. 3.33, c._schedule_bucket_request, (bucket,),
  115. priority=c._limit_order,
  116. )
  117. bucket.expected_time.assert_called_with(4)
  118. assert bucket.pop() == (request, 4)
  119. bucket.contents = deque()
  120. bucket.can_consume.reset_mock()
  121. c._schedule_bucket_request(bucket)
  122. bucket.can_consume.assert_not_called()
  123. def test_limit_task(self):
  124. c = self.get_consumer()
  125. bucket = Mock()
  126. request = Mock()
  127. with patch(
  128. 'celery.worker.consumer.consumer.Consumer._schedule_bucket_request'
  129. ) as reserv:
  130. c._limit_task(request, bucket, 1)
  131. bucket.add.assert_called_with((request, 1))
  132. reserv.assert_called_with(bucket)
  133. def test_post_eta(self):
  134. c = self.get_consumer()
  135. c.qos = Mock()
  136. bucket = Mock()
  137. request = Mock()
  138. with patch(
  139. 'celery.worker.consumer.consumer.Consumer._schedule_bucket_request'
  140. ) as reserv:
  141. c._limit_post_eta(request, bucket, 1)
  142. c.qos.decrement_eventually.assert_called_with()
  143. bucket.add.assert_called_with((request, 1))
  144. reserv.assert_called_with(bucket)
  145. def test_start_blueprint_raises_EMFILE(self):
  146. c = self.get_consumer()
  147. exc = c.blueprint.start.side_effect = OSError()
  148. exc.errno = errno.EMFILE
  149. with pytest.raises(OSError):
  150. c.start()
  151. def test_max_restarts_exceeded(self):
  152. c = self.get_consumer()
  153. def se(*args, **kwargs):
  154. c.blueprint.state = CLOSE
  155. raise RestartFreqExceeded()
  156. c._restart_state.step.side_effect = se
  157. c.blueprint.start.side_effect = socket.error()
  158. with patch('celery.worker.consumer.consumer.sleep') as sleep:
  159. c.start()
  160. sleep.assert_called_with(1)
  161. def test_do_not_restart_when_closed(self):
  162. c = self.get_consumer()
  163. c.blueprint.state = None
  164. def bp_start(*args, **kwargs):
  165. c.blueprint.state = CLOSE
  166. c.blueprint.start.side_effect = bp_start
  167. with patch('celery.worker.consumer.consumer.sleep'):
  168. c.start()
  169. c.blueprint.start.assert_called_once_with(c)
  170. def test_do_not_restart_when_terminated(self):
  171. c = self.get_consumer()
  172. c.blueprint.state = None
  173. def bp_start(*args, **kwargs):
  174. c.blueprint.state = TERMINATE
  175. c.blueprint.start.side_effect = bp_start
  176. with patch('celery.worker.consumer.consumer.sleep'):
  177. c.start()
  178. c.blueprint.start.assert_called_once_with(c)
  179. def test_no_retry_raises_error(self):
  180. self.app.conf.broker_connection_retry = False
  181. c = self.get_consumer()
  182. c.blueprint.start.side_effect = socket.error()
  183. with pytest.raises(socket.error):
  184. c.start()
  185. def _closer(self, c):
  186. def se(*args, **kwargs):
  187. c.blueprint.state = CLOSE
  188. return se
  189. def test_collects_at_restart(self):
  190. c = self.get_consumer()
  191. c.connection.collect.side_effect = MemoryError()
  192. c.blueprint.start.side_effect = socket.error()
  193. c.blueprint.restart.side_effect = self._closer(c)
  194. c.start()
  195. c.connection.collect.assert_called_with()
  196. def test_register_with_event_loop(self):
  197. c = self.get_consumer()
  198. c.register_with_event_loop(Mock(name='loop'))
  199. def test_on_close_clears_semaphore_timer_and_reqs(self):
  200. with patch('celery.worker.consumer.consumer.reserved_requests') as res:
  201. c = self.get_consumer()
  202. c.on_close()
  203. c.controller.semaphore.clear.assert_called_with()
  204. c.timer.clear.assert_called_with()
  205. res.clear.assert_called_with()
  206. c.pool.flush.assert_called_with()
  207. c.controller = None
  208. c.timer = None
  209. c.pool = None
  210. c.on_close()
  211. def test_connect_error_handler(self):
  212. self.app._connection = _amqp_connection()
  213. conn = self.app._connection.return_value
  214. c = self.get_consumer()
  215. assert c.connect()
  216. conn.ensure_connection.assert_called()
  217. errback = conn.ensure_connection.call_args[0][0]
  218. errback(Mock(), 0)
  219. class test_Heart:
  220. def test_start(self):
  221. c = Mock()
  222. c.timer = Mock()
  223. c.event_dispatcher = Mock()
  224. with patch('celery.worker.heartbeat.Heart') as hcls:
  225. h = Heart(c)
  226. assert h.enabled
  227. assert h.heartbeat_interval is None
  228. assert c.heart is None
  229. h.start(c)
  230. assert c.heart
  231. hcls.assert_called_with(c.timer, c.event_dispatcher,
  232. h.heartbeat_interval)
  233. c.heart.start.assert_called_with()
  234. def test_start_heartbeat_interval(self):
  235. c = Mock()
  236. c.timer = Mock()
  237. c.event_dispatcher = Mock()
  238. with patch('celery.worker.heartbeat.Heart') as hcls:
  239. h = Heart(c, False, 20)
  240. assert h.enabled
  241. assert h.heartbeat_interval == 20
  242. assert c.heart is None
  243. h.start(c)
  244. assert c.heart
  245. hcls.assert_called_with(c.timer, c.event_dispatcher,
  246. h.heartbeat_interval)
  247. c.heart.start.assert_called_with()
  248. class test_Tasks:
  249. def test_stop(self):
  250. c = Mock()
  251. tasks = Tasks(c)
  252. assert c.task_consumer is None
  253. assert c.qos is None
  254. c.task_consumer = Mock()
  255. tasks.stop(c)
  256. def test_stop_already_stopped(self):
  257. c = Mock()
  258. tasks = Tasks(c)
  259. tasks.stop(c)
  260. class test_Agent:
  261. def test_start(self):
  262. c = Mock()
  263. agent = Agent(c)
  264. agent.instantiate = Mock()
  265. agent.agent_cls = 'foo:Agent'
  266. assert agent.create(c) is not None
  267. agent.instantiate.assert_called_with(agent.agent_cls, c.connection)
  268. class test_Mingle:
  269. def test_start_no_replies(self):
  270. c = Mock()
  271. c.app.connection_for_read = _amqp_connection()
  272. mingle = Mingle(c)
  273. I = c.app.control.inspect.return_value = Mock()
  274. I.hello.return_value = {}
  275. mingle.start(c)
  276. def test_start(self):
  277. c = Mock()
  278. c.app.connection_for_read = _amqp_connection()
  279. mingle = Mingle(c)
  280. assert mingle.enabled
  281. Aig = LimitedSet()
  282. Big = LimitedSet()
  283. Aig.add('Aig-1')
  284. Aig.add('Aig-2')
  285. Big.add('Big-1')
  286. I = c.app.control.inspect.return_value = Mock()
  287. I.hello.return_value = {
  288. 'A@example.com': {
  289. 'clock': 312,
  290. 'revoked': Aig._data,
  291. },
  292. 'B@example.com': {
  293. 'clock': 29,
  294. 'revoked': Big._data,
  295. },
  296. 'C@example.com': {
  297. 'error': 'unknown method',
  298. },
  299. }
  300. our_revoked = c.controller.state.revoked = LimitedSet()
  301. mingle.start(c)
  302. I.hello.assert_called_with(c.hostname, our_revoked._data)
  303. c.app.clock.adjust.assert_has_calls([
  304. call(312), call(29),
  305. ], any_order=True)
  306. assert 'Aig-1' in our_revoked
  307. assert 'Aig-2' in our_revoked
  308. assert 'Big-1' in our_revoked
  309. def _amqp_connection():
  310. connection = ContextMock(name='Connection')
  311. connection.return_value = ContextMock(name='connection')
  312. connection.return_value.transport.driver_type = 'amqp'
  313. return connection
  314. class test_Gossip:
  315. def test_init(self):
  316. c = self.Consumer()
  317. c.app.connection_for_read = _amqp_connection()
  318. g = Gossip(c)
  319. assert g.enabled
  320. assert c.gossip is g
  321. def test_callbacks(self):
  322. c = self.Consumer()
  323. c.app.connection_for_read = _amqp_connection()
  324. g = Gossip(c)
  325. on_node_join = Mock(name='on_node_join')
  326. on_node_join2 = Mock(name='on_node_join2')
  327. on_node_leave = Mock(name='on_node_leave')
  328. on_node_lost = Mock(name='on.node_lost')
  329. g.on.node_join.add(on_node_join)
  330. g.on.node_join.add(on_node_join2)
  331. g.on.node_leave.add(on_node_leave)
  332. g.on.node_lost.add(on_node_lost)
  333. worker = Mock(name='worker')
  334. g.on_node_join(worker)
  335. on_node_join.assert_called_with(worker)
  336. on_node_join2.assert_called_with(worker)
  337. g.on_node_leave(worker)
  338. on_node_leave.assert_called_with(worker)
  339. g.on_node_lost(worker)
  340. on_node_lost.assert_called_with(worker)
  341. def test_election(self):
  342. c = self.Consumer()
  343. c.app.connection_for_read = _amqp_connection()
  344. g = Gossip(c)
  345. g.start(c)
  346. g.election('id', 'topic', 'action')
  347. assert g.consensus_replies['id'] == []
  348. g.dispatcher.send.assert_called_with(
  349. 'worker-elect', id='id', topic='topic', cver=1, action='action',
  350. )
  351. def test_call_task(self):
  352. c = self.Consumer()
  353. c.app.connection_for_read = _amqp_connection()
  354. g = Gossip(c)
  355. g.start(c)
  356. signature = g.app.signature = Mock(name='app.signature')
  357. task = Mock()
  358. g.call_task(task)
  359. signature.assert_called_with(task)
  360. signature.return_value.apply_async.assert_called_with()
  361. signature.return_value.apply_async.side_effect = MemoryError()
  362. with patch('celery.worker.consumer.gossip.logger') as logger:
  363. g.call_task(task)
  364. logger.exception.assert_called()
  365. def Event(self, id='id', clock=312,
  366. hostname='foo@example.com', pid=4312,
  367. topic='topic', action='action', cver=1):
  368. return {
  369. 'id': id,
  370. 'clock': clock,
  371. 'hostname': hostname,
  372. 'pid': pid,
  373. 'topic': topic,
  374. 'action': action,
  375. 'cver': cver,
  376. }
  377. def test_on_elect(self):
  378. c = self.Consumer()
  379. c.app.connection_for_read = _amqp_connection()
  380. g = Gossip(c)
  381. g.start(c)
  382. event = self.Event('id1')
  383. g.on_elect(event)
  384. in_heap = g.consensus_requests['id1']
  385. assert in_heap
  386. g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
  387. event.pop('clock')
  388. with patch('celery.worker.consumer.gossip.logger') as logger:
  389. g.on_elect(event)
  390. logger.exception.assert_called()
  391. def Consumer(self, hostname='foo@x.com', pid=4312):
  392. c = Mock()
  393. c.app.connection = _amqp_connection()
  394. c.hostname = hostname
  395. c.pid = pid
  396. return c
  397. def setup_election(self, g, c):
  398. g.start(c)
  399. g.clock = self.app.clock
  400. assert 'idx' not in g.consensus_replies
  401. assert g.on_elect_ack({'id': 'idx'}) is None
  402. g.state.alive_workers.return_value = [
  403. 'foo@x.com', 'bar@x.com', 'baz@x.com',
  404. ]
  405. g.consensus_replies['id1'] = []
  406. g.consensus_requests['id1'] = []
  407. e1 = self.Event('id1', 1, 'foo@x.com')
  408. e2 = self.Event('id1', 2, 'bar@x.com')
  409. e3 = self.Event('id1', 3, 'baz@x.com')
  410. g.on_elect(e1)
  411. g.on_elect(e2)
  412. g.on_elect(e3)
  413. assert len(g.consensus_requests['id1']) == 3
  414. with patch('celery.worker.consumer.gossip.info'):
  415. g.on_elect_ack(e1)
  416. assert len(g.consensus_replies['id1']) == 1
  417. g.on_elect_ack(e2)
  418. assert len(g.consensus_replies['id1']) == 2
  419. g.on_elect_ack(e3)
  420. with pytest.raises(KeyError):
  421. g.consensus_replies['id1']
  422. def test_on_elect_ack_win(self):
  423. c = self.Consumer(hostname='foo@x.com') # I will win
  424. c.app.connection_for_read = _amqp_connection()
  425. g = Gossip(c)
  426. handler = g.election_handlers['topic'] = Mock()
  427. self.setup_election(g, c)
  428. handler.assert_called_with('action')
  429. def test_on_elect_ack_lose(self):
  430. c = self.Consumer(hostname='bar@x.com') # I will lose
  431. c.app.connection_for_read = _amqp_connection()
  432. g = Gossip(c)
  433. handler = g.election_handlers['topic'] = Mock()
  434. self.setup_election(g, c)
  435. handler.assert_not_called()
  436. def test_on_elect_ack_win_but_no_action(self):
  437. c = self.Consumer(hostname='foo@x.com') # I will win
  438. c.app.connection_for_read = _amqp_connection()
  439. g = Gossip(c)
  440. g.election_handlers = {}
  441. with patch('celery.worker.consumer.gossip.logger') as logger:
  442. self.setup_election(g, c)
  443. logger.exception.assert_called()
  444. def test_on_node_join(self):
  445. c = self.Consumer()
  446. c.app.connection_for_read = _amqp_connection()
  447. g = Gossip(c)
  448. with patch('celery.worker.consumer.gossip.debug') as debug:
  449. g.on_node_join(c)
  450. debug.assert_called_with('%s joined the party', 'foo@x.com')
  451. def test_on_node_leave(self):
  452. c = self.Consumer()
  453. c.app.connection_for_read = _amqp_connection()
  454. g = Gossip(c)
  455. with patch('celery.worker.consumer.gossip.debug') as debug:
  456. g.on_node_leave(c)
  457. debug.assert_called_with('%s left', 'foo@x.com')
  458. def test_on_node_lost(self):
  459. c = self.Consumer()
  460. c.app.connection_for_read = _amqp_connection()
  461. g = Gossip(c)
  462. with patch('celery.worker.consumer.gossip.info') as info:
  463. g.on_node_lost(c)
  464. info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
  465. def test_register_timer(self):
  466. c = self.Consumer()
  467. c.app.connection_for_read = _amqp_connection()
  468. g = Gossip(c)
  469. g.register_timer()
  470. c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
  471. tref = g._tref
  472. g.register_timer()
  473. tref.cancel.assert_called_with()
  474. def test_periodic(self):
  475. c = self.Consumer()
  476. c.app.connection_for_read = _amqp_connection()
  477. g = Gossip(c)
  478. g.on_node_lost = Mock()
  479. state = g.state = Mock()
  480. worker = Mock()
  481. state.workers = {'foo': worker}
  482. worker.alive = True
  483. worker.hostname = 'foo'
  484. g.periodic()
  485. worker.alive = False
  486. g.periodic()
  487. g.on_node_lost.assert_called_with(worker)
  488. with pytest.raises(KeyError):
  489. state.workers['foo']
  490. def test_on_message__task(self):
  491. c = self.Consumer()
  492. c.app.connection_for_read = _amqp_connection()
  493. g = Gossip(c)
  494. assert g.enabled
  495. message = Mock(name='message')
  496. message.delivery_info = {'routing_key': 'task.failed'}
  497. g.on_message(Mock(name='prepare'), message)
  498. def test_on_message(self):
  499. c = self.Consumer()
  500. c.app.connection_for_read = _amqp_connection()
  501. g = Gossip(c)
  502. assert g.enabled
  503. prepare = Mock()
  504. prepare.return_value = 'worker-online', {}
  505. c.app.events.State.assert_called_with(
  506. on_node_join=g.on_node_join,
  507. on_node_leave=g.on_node_leave,
  508. max_tasks_in_memory=1,
  509. )
  510. g.update_state = Mock()
  511. worker = Mock()
  512. g.on_node_join = Mock()
  513. g.on_node_leave = Mock()
  514. g.update_state.return_value = worker, 1
  515. message = Mock()
  516. message.delivery_info = {'routing_key': 'worker-online'}
  517. message.headers = {'hostname': 'other'}
  518. handler = g.event_handlers['worker-online'] = Mock()
  519. g.on_message(prepare, message)
  520. handler.assert_called_with(message.payload)
  521. g.event_handlers = {}
  522. g.on_message(prepare, message)
  523. message.delivery_info = {'routing_key': 'worker-offline'}
  524. prepare.return_value = 'worker-offline', {}
  525. g.on_message(prepare, message)
  526. message.delivery_info = {'routing_key': 'worker-baz'}
  527. prepare.return_value = 'worker-baz', {}
  528. g.update_state.return_value = worker, 0
  529. g.on_message(prepare, message)
  530. message.headers = {'hostname': g.hostname}
  531. g.on_message(prepare, message)
  532. g.clock.forward.assert_called_with()