test_consumer.py 18 KB


  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import pytest
  4. import socket
  5. from collections import deque
  6. from case import ContextMock, Mock, call, patch, skip
  7. from billiard.exceptions import RestartFreqExceeded
  8. from celery.worker.consumer.agent import Agent
  9. from celery.worker.consumer.consumer import CLOSE, Consumer, dump_body
  10. from celery.worker.consumer.gossip import Gossip
  11. from celery.worker.consumer.heart import Heart
  12. from celery.worker.consumer.mingle import Mingle
  13. from celery.worker.consumer.tasks import Tasks
  14. from celery.utils.collections import LimitedSet
  15. class test_Consumer:
  16. def get_consumer(self, no_hub=False, **kwargs):
  17. consumer = Consumer(
  18. on_task_request=Mock(),
  19. init_callback=Mock(),
  20. pool=Mock(),
  21. app=self.app,
  22. timer=Mock(),
  23. controller=Mock(),
  24. hub=None if no_hub else Mock(),
  25. **kwargs
  26. )
  27. consumer.blueprint = Mock(name='blueprint')
  28. consumer._restart_state = Mock(name='_restart_state')
  29. consumer.connection = _amqp_connection()
  30. consumer.connection_errors = (socket.error, OSError,)
  31. consumer.conninfo = consumer.connection
  32. return consumer
  33. def test_repr(self):
  34. assert repr(self.get_consumer())
  35. def test_taskbuckets_defaultdict(self):
  36. c = self.get_consumer()
  37. assert c.task_buckets['fooxasdwx.wewe'] is None
  38. @skip.if_python3(reason='buffer type not available')
  39. def test_dump_body_buffer(self):
  40. msg = Mock()
  41. msg.body = 'str'
  42. assert dump_body(msg, buffer(msg.body))
  43. def test_sets_heartbeat(self):
  44. c = self.get_consumer(amqheartbeat=10)
  45. assert c.amqheartbeat == 10
  46. self.app.conf.broker_heartbeat = 20
  47. c = self.get_consumer(amqheartbeat=None)
  48. assert c.amqheartbeat == 20
  49. def test_gevent_bug_disables_connection_timeout(self):
  50. with patch('celery.worker.consumer.consumer._detect_environment') as d:
  51. d.return_value = 'gevent'
  52. self.app.conf.broker_connection_timeout = 33.33
  53. self.get_consumer()
  54. assert self.app.conf.broker_connection_timeout is None
  55. def test_limit_moved_to_pool(self):
  56. with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
  57. c = self.get_consumer()
  58. c.on_task_request = Mock(name='on_task_request')
  59. request = Mock(name='request')
  60. c._limit_move_to_pool(request)
  61. reserv.assert_called_with(request)
  62. c.on_task_request.assert_called_with(request)
  63. def test_update_prefetch_count(self):
  64. c = self.get_consumer()
  65. c._update_qos_eventually = Mock(name='update_qos')
  66. c.initial_prefetch_count = None
  67. c.pool.num_processes = None
  68. c.prefetch_multiplier = 10
  69. assert c._update_prefetch_count(1) is None
  70. c.initial_prefetch_count = 10
  71. c.pool.num_processes = 10
  72. c._update_prefetch_count(8)
  73. c._update_qos_eventually.assert_called_with(8)
  74. assert c.initial_prefetch_count == 10 * 10
  75. def test_flush_events(self):
  76. c = self.get_consumer()
  77. c.event_dispatcher = None
  78. c._flush_events()
  79. c.event_dispatcher = Mock(name='evd')
  80. c._flush_events()
  81. c.event_dispatcher.flush.assert_called_with()
  82. def test_on_send_event_buffered(self):
  83. c = self.get_consumer()
  84. c.hub = None
  85. c.on_send_event_buffered()
  86. c.hub = Mock(name='hub')
  87. c.on_send_event_buffered()
  88. c.hub._ready.add.assert_called_with(c._flush_events)
  89. def test_limit_task(self):
  90. c = self.get_consumer()
  91. c.timer = Mock()
  92. bucket = Mock()
  93. request = Mock()
  94. bucket.can_consume.return_value = True
  95. bucket.contents = deque()
  96. c._limit_task(request, bucket, 3)
  97. bucket.can_consume.assert_called_with(3)
  98. bucket.expected_time.assert_called_with(3)
  99. c.timer.call_after.assert_called_with(
  100. bucket.expected_time(), c._on_bucket_wakeup, (bucket, 3),
  101. priority=c._limit_order,
  102. )
  103. bucket.can_consume.return_value = False
  104. bucket.expected_time.return_value = 3.33
  105. limit_order = c._limit_order
  106. c._limit_task(request, bucket, 4)
  107. assert c._limit_order == limit_order + 1
  108. bucket.can_consume.assert_called_with(4)
  109. c.timer.call_after.assert_called_with(
  110. 3.33, c._on_bucket_wakeup, (bucket, 4),
  111. priority=c._limit_order,
  112. )
  113. bucket.expected_time.assert_called_with(4)
  114. def test_start_blueprint_raises_EMFILE(self):
  115. c = self.get_consumer()
  116. exc = c.blueprint.start.side_effect = OSError()
  117. exc.errno = errno.EMFILE
  118. with pytest.raises(OSError):
  119. c.start()
  120. def test_max_restarts_exceeded(self):
  121. c = self.get_consumer()
  122. def se(*args, **kwargs):
  123. c.blueprint.state = CLOSE
  124. raise RestartFreqExceeded()
  125. c._restart_state.step.side_effect = se
  126. c.blueprint.start.side_effect = socket.error()
  127. with patch('celery.worker.consumer.consumer.sleep') as sleep:
  128. c.start()
  129. sleep.assert_called_with(1)
  130. def test_no_retry_raises_error(self):
  131. self.app.conf.broker_connection_retry = False
  132. c = self.get_consumer()
  133. c.blueprint.start.side_effect = socket.error()
  134. with pytest.raises(socket.error):
  135. c.start()
  136. def _closer(self, c):
  137. def se(*args, **kwargs):
  138. c.blueprint.state = CLOSE
  139. return se
  140. def test_collects_at_restart(self):
  141. c = self.get_consumer()
  142. c.connection.collect.side_effect = MemoryError()
  143. c.blueprint.start.side_effect = socket.error()
  144. c.blueprint.restart.side_effect = self._closer(c)
  145. c.start()
  146. c.connection.collect.assert_called_with()
  147. def test_register_with_event_loop(self):
  148. c = self.get_consumer()
  149. c.register_with_event_loop(Mock(name='loop'))
  150. def test_on_close_clears_semaphore_timer_and_reqs(self):
  151. with patch('celery.worker.consumer.consumer.reserved_requests') as res:
  152. c = self.get_consumer()
  153. c.on_close()
  154. c.controller.semaphore.clear.assert_called_with()
  155. c.timer.clear.assert_called_with()
  156. res.clear.assert_called_with()
  157. c.pool.flush.assert_called_with()
  158. c.controller = None
  159. c.timer = None
  160. c.pool = None
  161. c.on_close()
  162. def test_connect_error_handler(self):
  163. self.app._connection = _amqp_connection()
  164. conn = self.app._connection.return_value
  165. c = self.get_consumer()
  166. assert c.connect()
  167. conn.ensure_connection.assert_called()
  168. errback = conn.ensure_connection.call_args[0][0]
  169. errback(Mock(), 0)
  170. class test_Heart:
  171. def test_start(self):
  172. c = Mock()
  173. c.timer = Mock()
  174. c.event_dispatcher = Mock()
  175. with patch('celery.worker.heartbeat.Heart') as hcls:
  176. h = Heart(c)
  177. assert h.enabled
  178. assert h.heartbeat_interval is None
  179. assert c.heart is None
  180. h.start(c)
  181. assert c.heart
  182. hcls.assert_called_with(c.timer, c.event_dispatcher,
  183. h.heartbeat_interval)
  184. c.heart.start.assert_called_with()
  185. def test_start_heartbeat_interval(self):
  186. c = Mock()
  187. c.timer = Mock()
  188. c.event_dispatcher = Mock()
  189. with patch('celery.worker.heartbeat.Heart') as hcls:
  190. h = Heart(c, False, 20)
  191. assert h.enabled
  192. assert h.heartbeat_interval == 20
  193. assert c.heart is None
  194. h.start(c)
  195. assert c.heart
  196. hcls.assert_called_with(c.timer, c.event_dispatcher,
  197. h.heartbeat_interval)
  198. c.heart.start.assert_called_with()
  199. class test_Tasks:
  200. def test_stop(self):
  201. c = Mock()
  202. tasks = Tasks(c)
  203. assert c.task_consumer is None
  204. assert c.qos is None
  205. c.task_consumer = Mock()
  206. tasks.stop(c)
  207. def test_stop_already_stopped(self):
  208. c = Mock()
  209. tasks = Tasks(c)
  210. tasks.stop(c)
  211. class test_Agent:
  212. def test_start(self):
  213. c = Mock()
  214. agent = Agent(c)
  215. agent.instantiate = Mock()
  216. agent.agent_cls = 'foo:Agent'
  217. assert agent.create(c) is not None
  218. agent.instantiate.assert_called_with(agent.agent_cls, c.connection)
  219. class test_Mingle:
  220. def test_start_no_replies(self):
  221. c = Mock()
  222. c.app.connection_for_read = _amqp_connection()
  223. mingle = Mingle(c)
  224. I = c.app.control.inspect.return_value = Mock()
  225. I.hello.return_value = {}
  226. mingle.start(c)
  227. def test_start(self):
  228. c = Mock()
  229. c.app.connection_for_read = _amqp_connection()
  230. mingle = Mingle(c)
  231. assert mingle.enabled
  232. Aig = LimitedSet()
  233. Big = LimitedSet()
  234. Aig.add('Aig-1')
  235. Aig.add('Aig-2')
  236. Big.add('Big-1')
  237. I = c.app.control.inspect.return_value = Mock()
  238. I.hello.return_value = {
  239. 'A@example.com': {
  240. 'clock': 312,
  241. 'revoked': Aig._data,
  242. },
  243. 'B@example.com': {
  244. 'clock': 29,
  245. 'revoked': Big._data,
  246. },
  247. 'C@example.com': {
  248. 'error': 'unknown method',
  249. },
  250. }
  251. our_revoked = c.controller.state.revoked = LimitedSet()
  252. mingle.start(c)
  253. I.hello.assert_called_with(c.hostname, our_revoked._data)
  254. c.app.clock.adjust.assert_has_calls([
  255. call(312), call(29),
  256. ], any_order=True)
  257. assert 'Aig-1' in our_revoked
  258. assert 'Aig-2' in our_revoked
  259. assert 'Big-1' in our_revoked
  260. def _amqp_connection():
  261. connection = ContextMock(name='Connection')
  262. connection.return_value = ContextMock(name='connection')
  263. connection.return_value.transport.driver_type = 'amqp'
  264. return connection
  265. class test_Gossip:
  266. def test_init(self):
  267. c = self.Consumer()
  268. c.app.connection_for_read = _amqp_connection()
  269. g = Gossip(c)
  270. assert g.enabled
  271. assert c.gossip is g
  272. def test_callbacks(self):
  273. c = self.Consumer()
  274. c.app.connection_for_read = _amqp_connection()
  275. g = Gossip(c)
  276. on_node_join = Mock(name='on_node_join')
  277. on_node_join2 = Mock(name='on_node_join2')
  278. on_node_leave = Mock(name='on_node_leave')
  279. on_node_lost = Mock(name='on.node_lost')
  280. g.on.node_join.add(on_node_join)
  281. g.on.node_join.add(on_node_join2)
  282. g.on.node_leave.add(on_node_leave)
  283. g.on.node_lost.add(on_node_lost)
  284. worker = Mock(name='worker')
  285. g.on_node_join(worker)
  286. on_node_join.assert_called_with(worker)
  287. on_node_join2.assert_called_with(worker)
  288. g.on_node_leave(worker)
  289. on_node_leave.assert_called_with(worker)
  290. g.on_node_lost(worker)
  291. on_node_lost.assert_called_with(worker)
  292. def test_election(self):
  293. c = self.Consumer()
  294. c.app.connection_for_read = _amqp_connection()
  295. g = Gossip(c)
  296. g.start(c)
  297. g.election('id', 'topic', 'action')
  298. assert g.consensus_replies['id'] == []
  299. g.dispatcher.send.assert_called_with(
  300. 'worker-elect', id='id', topic='topic', cver=1, action='action',
  301. )
  302. def test_call_task(self):
  303. c = self.Consumer()
  304. c.app.connection_for_read = _amqp_connection()
  305. g = Gossip(c)
  306. g.start(c)
  307. signature = g.app.signature = Mock(name='app.signature')
  308. task = Mock()
  309. g.call_task(task)
  310. signature.assert_called_with(task)
  311. signature.return_value.apply_async.assert_called_with()
  312. signature.return_value.apply_async.side_effect = MemoryError()
  313. with patch('celery.worker.consumer.gossip.logger') as logger:
  314. g.call_task(task)
  315. logger.exception.assert_called()
  316. def Event(self, id='id', clock=312,
  317. hostname='foo@example.com', pid=4312,
  318. topic='topic', action='action', cver=1):
  319. return {
  320. 'id': id,
  321. 'clock': clock,
  322. 'hostname': hostname,
  323. 'pid': pid,
  324. 'topic': topic,
  325. 'action': action,
  326. 'cver': cver,
  327. }
  328. def test_on_elect(self):
  329. c = self.Consumer()
  330. c.app.connection_for_read = _amqp_connection()
  331. g = Gossip(c)
  332. g.start(c)
  333. event = self.Event('id1')
  334. g.on_elect(event)
  335. in_heap = g.consensus_requests['id1']
  336. assert in_heap
  337. g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
  338. event.pop('clock')
  339. with patch('celery.worker.consumer.gossip.logger') as logger:
  340. g.on_elect(event)
  341. logger.exception.assert_called()
  342. def Consumer(self, hostname='foo@x.com', pid=4312):
  343. c = Mock()
  344. c.app.connection = _amqp_connection()
  345. c.hostname = hostname
  346. c.pid = pid
  347. return c
  348. def setup_election(self, g, c):
  349. g.start(c)
  350. g.clock = self.app.clock
  351. assert 'idx' not in g.consensus_replies
  352. assert g.on_elect_ack({'id': 'idx'}) is None
  353. g.state.alive_workers.return_value = [
  354. 'foo@x.com', 'bar@x.com', 'baz@x.com',
  355. ]
  356. g.consensus_replies['id1'] = []
  357. g.consensus_requests['id1'] = []
  358. e1 = self.Event('id1', 1, 'foo@x.com')
  359. e2 = self.Event('id1', 2, 'bar@x.com')
  360. e3 = self.Event('id1', 3, 'baz@x.com')
  361. g.on_elect(e1)
  362. g.on_elect(e2)
  363. g.on_elect(e3)
  364. assert len(g.consensus_requests['id1']) == 3
  365. with patch('celery.worker.consumer.gossip.info'):
  366. g.on_elect_ack(e1)
  367. assert len(g.consensus_replies['id1']) == 1
  368. g.on_elect_ack(e2)
  369. assert len(g.consensus_replies['id1']) == 2
  370. g.on_elect_ack(e3)
  371. with pytest.raises(KeyError):
  372. g.consensus_replies['id1']
  373. def test_on_elect_ack_win(self):
  374. c = self.Consumer(hostname='foo@x.com') # I will win
  375. c.app.connection_for_read = _amqp_connection()
  376. g = Gossip(c)
  377. handler = g.election_handlers['topic'] = Mock()
  378. self.setup_election(g, c)
  379. handler.assert_called_with('action')
  380. def test_on_elect_ack_lose(self):
  381. c = self.Consumer(hostname='bar@x.com') # I will lose
  382. c.app.connection_for_read = _amqp_connection()
  383. g = Gossip(c)
  384. handler = g.election_handlers['topic'] = Mock()
  385. self.setup_election(g, c)
  386. handler.assert_not_called()
  387. def test_on_elect_ack_win_but_no_action(self):
  388. c = self.Consumer(hostname='foo@x.com') # I will win
  389. c.app.connection_for_read = _amqp_connection()
  390. g = Gossip(c)
  391. g.election_handlers = {}
  392. with patch('celery.worker.consumer.gossip.logger') as logger:
  393. self.setup_election(g, c)
  394. logger.exception.assert_called()
  395. def test_on_node_join(self):
  396. c = self.Consumer()
  397. c.app.connection_for_read = _amqp_connection()
  398. g = Gossip(c)
  399. with patch('celery.worker.consumer.gossip.debug') as debug:
  400. g.on_node_join(c)
  401. debug.assert_called_with('%s joined the party', 'foo@x.com')
  402. def test_on_node_leave(self):
  403. c = self.Consumer()
  404. c.app.connection_for_read = _amqp_connection()
  405. g = Gossip(c)
  406. with patch('celery.worker.consumer.gossip.debug') as debug:
  407. g.on_node_leave(c)
  408. debug.assert_called_with('%s left', 'foo@x.com')
  409. def test_on_node_lost(self):
  410. c = self.Consumer()
  411. c.app.connection_for_read = _amqp_connection()
  412. g = Gossip(c)
  413. with patch('celery.worker.consumer.gossip.info') as info:
  414. g.on_node_lost(c)
  415. info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
  416. def test_register_timer(self):
  417. c = self.Consumer()
  418. c.app.connection_for_read = _amqp_connection()
  419. g = Gossip(c)
  420. g.register_timer()
  421. c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
  422. tref = g._tref
  423. g.register_timer()
  424. tref.cancel.assert_called_with()
  425. def test_periodic(self):
  426. c = self.Consumer()
  427. c.app.connection_for_read = _amqp_connection()
  428. g = Gossip(c)
  429. g.on_node_lost = Mock()
  430. state = g.state = Mock()
  431. worker = Mock()
  432. state.workers = {'foo': worker}
  433. worker.alive = True
  434. worker.hostname = 'foo'
  435. g.periodic()
  436. worker.alive = False
  437. g.periodic()
  438. g.on_node_lost.assert_called_with(worker)
  439. with pytest.raises(KeyError):
  440. state.workers['foo']
  441. def test_on_message__task(self):
  442. c = self.Consumer()
  443. c.app.connection_for_read = _amqp_connection()
  444. g = Gossip(c)
  445. assert g.enabled
  446. message = Mock(name='message')
  447. message.delivery_info = {'routing_key': 'task.failed'}
  448. g.on_message(Mock(name='prepare'), message)
  449. def test_on_message(self):
  450. c = self.Consumer()
  451. c.app.connection_for_read = _amqp_connection()
  452. g = Gossip(c)
  453. assert g.enabled
  454. prepare = Mock()
  455. prepare.return_value = 'worker-online', {}
  456. c.app.events.State.assert_called_with(
  457. on_node_join=g.on_node_join,
  458. on_node_leave=g.on_node_leave,
  459. max_tasks_in_memory=1,
  460. )
  461. g.update_state = Mock()
  462. worker = Mock()
  463. g.on_node_join = Mock()
  464. g.on_node_leave = Mock()
  465. g.update_state.return_value = worker, 1
  466. message = Mock()
  467. message.delivery_info = {'routing_key': 'worker-online'}
  468. message.headers = {'hostname': 'other'}
  469. handler = g.event_handlers['worker-online'] = Mock()
  470. g.on_message(prepare, message)
  471. handler.assert_called_with(message.payload)
  472. g.event_handlers = {}
  473. g.on_message(prepare, message)
  474. message.delivery_info = {'routing_key': 'worker-offline'}
  475. prepare.return_value = 'worker-offline', {}
  476. g.on_message(prepare, message)
  477. message.delivery_info = {'routing_key': 'worker-baz'}
  478. prepare.return_value = 'worker-baz', {}
  479. g.update_state.return_value = worker, 0
  480. g.on_message(prepare, message)
  481. message.headers = {'hostname': g.hostname}
  482. g.on_message(prepare, message)
  483. g.clock.forward.assert_called_with()