test_consumer.py 18 KB


  1. import errno
  2. import pytest
  3. import socket
  4. from collections import deque
  5. from case import ContextMock, Mock, call, patch
  6. from billiard.exceptions import RestartFreqExceeded
  7. from celery.worker.consumer.agent import Agent
  8. from celery.worker.consumer.consumer import CLOSE, Consumer
  9. from celery.worker.consumer.gossip import Gossip
  10. from celery.worker.consumer.heart import Heart
  11. from celery.worker.consumer.mingle import Mingle
  12. from celery.worker.consumer.tasks import Tasks
  13. from celery.utils.collections import LimitedSet
  14. class test_Consumer:
  15. def get_consumer(self, no_hub=False, **kwargs):
  16. consumer = Consumer(
  17. on_task_request=Mock(),
  18. init_callback=Mock(),
  19. pool=Mock(),
  20. app=self.app,
  21. timer=Mock(),
  22. controller=Mock(),
  23. hub=None if no_hub else Mock(),
  24. **kwargs
  25. )
  26. consumer.blueprint = Mock(name='blueprint')
  27. consumer._restart_state = Mock(name='_restart_state')
  28. consumer.connection = _amqp_connection()
  29. consumer.connection_errors = (socket.error, OSError,)
  30. consumer.conninfo = consumer.connection
  31. return consumer
  32. def test_repr(self):
  33. assert repr(self.get_consumer())
  34. def test_taskbuckets_defaultdict(self):
  35. c = self.get_consumer()
  36. assert c.task_buckets['fooxasdwx.wewe'] is None
  37. def test_sets_heartbeat(self):
  38. c = self.get_consumer(amqheartbeat=10)
  39. assert c.amqheartbeat == 10
  40. self.app.conf.broker_heartbeat = 20
  41. c = self.get_consumer(amqheartbeat=None)
  42. assert c.amqheartbeat == 20
  43. def test_gevent_bug_disables_connection_timeout(self):
  44. with patch('celery.worker.consumer.consumer._detect_environment') as d:
  45. d.return_value = 'gevent'
  46. self.app.conf.broker_connection_timeout = 33.33
  47. self.get_consumer()
  48. assert self.app.conf.broker_connection_timeout is None
  49. def test_limit_moved_to_pool(self):
  50. with patch('celery.worker.consumer.consumer.task_reserved') as reserv:
  51. c = self.get_consumer()
  52. c.on_task_request = Mock(name='on_task_request')
  53. request = Mock(name='request')
  54. c._limit_move_to_pool(request)
  55. reserv.assert_called_with(request)
  56. c.on_task_request.assert_called_with(request)
  57. def test_update_prefetch_count(self):
  58. c = self.get_consumer()
  59. c._update_qos_eventually = Mock(name='update_qos')
  60. c.initial_prefetch_count = None
  61. c.pool.num_processes = None
  62. c.prefetch_multiplier = 10
  63. assert c._update_prefetch_count(1) is None
  64. c.initial_prefetch_count = 10
  65. c.pool.num_processes = 10
  66. c._update_prefetch_count(8)
  67. c._update_qos_eventually.assert_called_with(8)
  68. assert c.initial_prefetch_count == 10 * 10
  69. def test_flush_events(self):
  70. c = self.get_consumer()
  71. c.event_dispatcher = None
  72. c._flush_events()
  73. c.event_dispatcher = Mock(name='evd')
  74. c._flush_events()
  75. c.event_dispatcher.flush.assert_called_with()
  76. def test_on_send_event_buffered(self):
  77. c = self.get_consumer()
  78. c.hub = None
  79. c.on_send_event_buffered()
  80. c.hub = Mock(name='hub')
  81. c.on_send_event_buffered()
  82. c.hub._ready.add.assert_called_with(c._flush_events)
  83. def test_limit_task(self):
  84. c = self.get_consumer()
  85. c.timer = Mock()
  86. bucket = Mock()
  87. request = Mock()
  88. bucket.can_consume.return_value = True
  89. bucket.contents = deque()
  90. c._limit_task(request, bucket, 3)
  91. bucket.can_consume.assert_called_with(3)
  92. bucket.expected_time.assert_called_with(3)
  93. c.timer.call_after.assert_called_with(
  94. bucket.expected_time(), c._on_bucket_wakeup, (bucket, 3),
  95. priority=c._limit_order,
  96. )
  97. bucket.can_consume.return_value = False
  98. bucket.expected_time.return_value = 3.33
  99. limit_order = c._limit_order
  100. c._limit_task(request, bucket, 4)
  101. assert c._limit_order == limit_order + 1
  102. bucket.can_consume.assert_called_with(4)
  103. c.timer.call_after.assert_called_with(
  104. 3.33, c._on_bucket_wakeup, (bucket, 4),
  105. priority=c._limit_order,
  106. )
  107. bucket.expected_time.assert_called_with(4)
  108. def test_start_blueprint_raises_EMFILE(self):
  109. c = self.get_consumer()
  110. exc = c.blueprint.start.side_effect = OSError()
  111. exc.errno = errno.EMFILE
  112. with pytest.raises(OSError):
  113. c.start()
  114. def test_max_restarts_exceeded(self):
  115. c = self.get_consumer()
  116. def se(*args, **kwargs):
  117. c.blueprint.state = CLOSE
  118. raise RestartFreqExceeded()
  119. c._restart_state.step.side_effect = se
  120. c.blueprint.start.side_effect = socket.error()
  121. with patch('celery.worker.consumer.consumer.sleep') as sleep:
  122. c.start()
  123. sleep.assert_called_with(1)
  124. def test_no_retry_raises_error(self):
  125. self.app.conf.broker_connection_retry = False
  126. c = self.get_consumer()
  127. c.blueprint.start.side_effect = socket.error()
  128. with pytest.raises(socket.error):
  129. c.start()
  130. def _closer(self, c):
  131. def se(*args, **kwargs):
  132. c.blueprint.state = CLOSE
  133. return se
  134. def test_collects_at_restart(self):
  135. c = self.get_consumer()
  136. c.connection.collect.side_effect = MemoryError()
  137. c.blueprint.start.side_effect = socket.error()
  138. c.blueprint.restart.side_effect = self._closer(c)
  139. c.start()
  140. c.connection.collect.assert_called_with()
  141. def test_register_with_event_loop(self):
  142. c = self.get_consumer()
  143. c.register_with_event_loop(Mock(name='loop'))
  144. def test_on_close_clears_semaphore_timer_and_reqs(self):
  145. with patch('celery.worker.consumer.consumer.reserved_requests') as res:
  146. c = self.get_consumer()
  147. c.on_close()
  148. c.controller.semaphore.clear.assert_called_with()
  149. c.timer.clear.assert_called_with()
  150. res.clear.assert_called_with()
  151. c.pool.flush.assert_called_with()
  152. c.controller = None
  153. c.timer = None
  154. c.pool = None
  155. c.on_close()
  156. def test_connect_error_handler(self):
  157. self.app._connection = _amqp_connection()
  158. conn = self.app._connection.return_value
  159. c = self.get_consumer()
  160. assert c.connect()
  161. conn.ensure_connection.assert_called()
  162. errback = conn.ensure_connection.call_args[0][0]
  163. errback(Mock(), 0)
  164. class test_Heart:
  165. def test_start(self):
  166. c = Mock()
  167. c.timer = Mock()
  168. c.event_dispatcher = Mock()
  169. with patch('celery.worker.heartbeat.Heart') as hcls:
  170. h = Heart(c)
  171. assert h.enabled
  172. assert h.heartbeat_interval is None
  173. assert c.heart is None
  174. h.start(c)
  175. assert c.heart
  176. hcls.assert_called_with(c.timer, c.event_dispatcher,
  177. h.heartbeat_interval)
  178. c.heart.start.assert_called_with()
  179. def test_start_heartbeat_interval(self):
  180. c = Mock()
  181. c.timer = Mock()
  182. c.event_dispatcher = Mock()
  183. with patch('celery.worker.heartbeat.Heart') as hcls:
  184. h = Heart(c, False, 20)
  185. assert h.enabled
  186. assert h.heartbeat_interval == 20
  187. assert c.heart is None
  188. h.start(c)
  189. assert c.heart
  190. hcls.assert_called_with(c.timer, c.event_dispatcher,
  191. h.heartbeat_interval)
  192. c.heart.start.assert_called_with()
  193. class test_Tasks:
  194. def test_stop(self):
  195. c = Mock()
  196. tasks = Tasks(c)
  197. assert c.task_consumer is None
  198. assert c.qos is None
  199. c.task_consumer = Mock()
  200. tasks.stop(c)
  201. def test_stop_already_stopped(self):
  202. c = Mock()
  203. tasks = Tasks(c)
  204. tasks.stop(c)
  205. class test_Agent:
  206. def test_start(self):
  207. c = Mock()
  208. agent = Agent(c)
  209. agent.instantiate = Mock()
  210. agent.agent_cls = 'foo:Agent'
  211. assert agent.create(c) is not None
  212. agent.instantiate.assert_called_with(agent.agent_cls, c.connection)
  213. class test_Mingle:
  214. def test_start_no_replies(self):
  215. c = Mock()
  216. c.app.connection_for_read = _amqp_connection()
  217. mingle = Mingle(c)
  218. I = c.app.control.inspect.return_value = Mock()
  219. I.hello.return_value = {}
  220. mingle.start(c)
  221. def test_start(self):
  222. c = Mock()
  223. c.app.connection_for_read = _amqp_connection()
  224. mingle = Mingle(c)
  225. assert mingle.enabled
  226. Aig = LimitedSet()
  227. Big = LimitedSet()
  228. Aig.add('Aig-1')
  229. Aig.add('Aig-2')
  230. Big.add('Big-1')
  231. I = c.app.control.inspect.return_value = Mock()
  232. I.hello.return_value = {
  233. 'A@example.com': {
  234. 'clock': 312,
  235. 'revoked': Aig._data,
  236. },
  237. 'B@example.com': {
  238. 'clock': 29,
  239. 'revoked': Big._data,
  240. },
  241. 'C@example.com': {
  242. 'error': 'unknown method',
  243. },
  244. }
  245. our_revoked = c.controller.state.revoked = LimitedSet()
  246. mingle.start(c)
  247. I.hello.assert_called_with(c.hostname, our_revoked._data)
  248. c.app.clock.adjust.assert_has_calls([
  249. call(312), call(29),
  250. ], any_order=True)
  251. assert 'Aig-1' in our_revoked
  252. assert 'Aig-2' in our_revoked
  253. assert 'Big-1' in our_revoked
  254. def _amqp_connection():
  255. connection = ContextMock(name='Connection')
  256. connection.return_value = ContextMock(name='connection')
  257. connection.return_value.transport.driver_type = 'amqp'
  258. return connection
  259. class test_Gossip:
  260. def test_init(self):
  261. c = self.Consumer()
  262. c.app.connection_for_read = _amqp_connection()
  263. g = Gossip(c)
  264. assert g.enabled
  265. assert c.gossip is g
  266. def test_callbacks(self):
  267. c = self.Consumer()
  268. c.app.connection_for_read = _amqp_connection()
  269. g = Gossip(c)
  270. on_node_join = Mock(name='on_node_join')
  271. on_node_join2 = Mock(name='on_node_join2')
  272. on_node_leave = Mock(name='on_node_leave')
  273. on_node_lost = Mock(name='on.node_lost')
  274. g.on.node_join.add(on_node_join)
  275. g.on.node_join.add(on_node_join2)
  276. g.on.node_leave.add(on_node_leave)
  277. g.on.node_lost.add(on_node_lost)
  278. worker = Mock(name='worker')
  279. g.on_node_join(worker)
  280. on_node_join.assert_called_with(worker)
  281. on_node_join2.assert_called_with(worker)
  282. g.on_node_leave(worker)
  283. on_node_leave.assert_called_with(worker)
  284. g.on_node_lost(worker)
  285. on_node_lost.assert_called_with(worker)
  286. def test_election(self):
  287. c = self.Consumer()
  288. c.app.connection_for_read = _amqp_connection()
  289. g = Gossip(c)
  290. g.start(c)
  291. g.election('id', 'topic', 'action')
  292. assert g.consensus_replies['id'] == []
  293. g.dispatcher.send.assert_called_with(
  294. 'worker-elect', id='id', topic='topic', cver=1, action='action',
  295. )
  296. def test_call_task(self):
  297. c = self.Consumer()
  298. c.app.connection_for_read = _amqp_connection()
  299. g = Gossip(c)
  300. g.start(c)
  301. signature = g.app.signature = Mock(name='app.signature')
  302. task = Mock()
  303. g.call_task(task)
  304. signature.assert_called_with(task)
  305. signature.return_value.apply_async.assert_called_with()
  306. signature.return_value.apply_async.side_effect = MemoryError()
  307. with patch('celery.worker.consumer.gossip.logger') as logger:
  308. g.call_task(task)
  309. logger.exception.assert_called()
  310. def Event(self, id='id', clock=312,
  311. hostname='foo@example.com', pid=4312,
  312. topic='topic', action='action', cver=1):
  313. return {
  314. 'id': id,
  315. 'clock': clock,
  316. 'hostname': hostname,
  317. 'pid': pid,
  318. 'topic': topic,
  319. 'action': action,
  320. 'cver': cver,
  321. }
  322. def test_on_elect(self):
  323. c = self.Consumer()
  324. c.app.connection_for_read = _amqp_connection()
  325. g = Gossip(c)
  326. g.start(c)
  327. event = self.Event('id1')
  328. g.on_elect(event)
  329. in_heap = g.consensus_requests['id1']
  330. assert in_heap
  331. g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')
  332. event.pop('clock')
  333. with patch('celery.worker.consumer.gossip.logger') as logger:
  334. g.on_elect(event)
  335. logger.exception.assert_called()
  336. def Consumer(self, hostname='foo@x.com', pid=4312):
  337. c = Mock()
  338. c.app.connection = _amqp_connection()
  339. c.hostname = hostname
  340. c.pid = pid
  341. return c
  342. def setup_election(self, g, c):
  343. g.start(c)
  344. g.clock = self.app.clock
  345. assert 'idx' not in g.consensus_replies
  346. assert g.on_elect_ack({'id': 'idx'}) is None
  347. g.state.alive_workers.return_value = [
  348. 'foo@x.com', 'bar@x.com', 'baz@x.com',
  349. ]
  350. g.consensus_replies['id1'] = []
  351. g.consensus_requests['id1'] = []
  352. e1 = self.Event('id1', 1, 'foo@x.com')
  353. e2 = self.Event('id1', 2, 'bar@x.com')
  354. e3 = self.Event('id1', 3, 'baz@x.com')
  355. g.on_elect(e1)
  356. g.on_elect(e2)
  357. g.on_elect(e3)
  358. assert len(g.consensus_requests['id1']) == 3
  359. with patch('celery.worker.consumer.gossip.info'):
  360. g.on_elect_ack(e1)
  361. assert len(g.consensus_replies['id1']) == 1
  362. g.on_elect_ack(e2)
  363. assert len(g.consensus_replies['id1']) == 2
  364. g.on_elect_ack(e3)
  365. with pytest.raises(KeyError):
  366. g.consensus_replies['id1']
  367. def test_on_elect_ack_win(self):
  368. c = self.Consumer(hostname='foo@x.com') # I will win
  369. c.app.connection_for_read = _amqp_connection()
  370. g = Gossip(c)
  371. handler = g.election_handlers['topic'] = Mock()
  372. self.setup_election(g, c)
  373. handler.assert_called_with('action')
  374. def test_on_elect_ack_lose(self):
  375. c = self.Consumer(hostname='bar@x.com') # I will lose
  376. c.app.connection_for_read = _amqp_connection()
  377. g = Gossip(c)
  378. handler = g.election_handlers['topic'] = Mock()
  379. self.setup_election(g, c)
  380. handler.assert_not_called()
  381. def test_on_elect_ack_win_but_no_action(self):
  382. c = self.Consumer(hostname='foo@x.com') # I will win
  383. c.app.connection_for_read = _amqp_connection()
  384. g = Gossip(c)
  385. g.election_handlers = {}
  386. with patch('celery.worker.consumer.gossip.logger') as logger:
  387. self.setup_election(g, c)
  388. logger.exception.assert_called()
  389. def test_on_node_join(self):
  390. c = self.Consumer()
  391. c.app.connection_for_read = _amqp_connection()
  392. g = Gossip(c)
  393. with patch('celery.worker.consumer.gossip.debug') as debug:
  394. g.on_node_join(c)
  395. debug.assert_called_with('%s joined the party', 'foo@x.com')
  396. def test_on_node_leave(self):
  397. c = self.Consumer()
  398. c.app.connection_for_read = _amqp_connection()
  399. g = Gossip(c)
  400. with patch('celery.worker.consumer.gossip.debug') as debug:
  401. g.on_node_leave(c)
  402. debug.assert_called_with('%s left', 'foo@x.com')
  403. def test_on_node_lost(self):
  404. c = self.Consumer()
  405. c.app.connection_for_read = _amqp_connection()
  406. g = Gossip(c)
  407. with patch('celery.worker.consumer.gossip.info') as info:
  408. g.on_node_lost(c)
  409. info.assert_called_with('missed heartbeat from %s', 'foo@x.com')
  410. def test_register_timer(self):
  411. c = self.Consumer()
  412. c.app.connection_for_read = _amqp_connection()
  413. g = Gossip(c)
  414. g.register_timer()
  415. c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
  416. tref = g._tref
  417. g.register_timer()
  418. tref.cancel.assert_called_with()
  419. def test_periodic(self):
  420. c = self.Consumer()
  421. c.app.connection_for_read = _amqp_connection()
  422. g = Gossip(c)
  423. g.on_node_lost = Mock()
  424. state = g.state = Mock()
  425. worker = Mock()
  426. state.workers = {'foo': worker}
  427. worker.alive = True
  428. worker.hostname = 'foo'
  429. g.periodic()
  430. worker.alive = False
  431. g.periodic()
  432. g.on_node_lost.assert_called_with(worker)
  433. with pytest.raises(KeyError):
  434. state.workers['foo']
  435. def test_on_message__task(self):
  436. c = self.Consumer()
  437. c.app.connection_for_read = _amqp_connection()
  438. g = Gossip(c)
  439. assert g.enabled
  440. message = Mock(name='message')
  441. message.delivery_info = {'routing_key': 'task.failed'}
  442. g.on_message(Mock(name='prepare'), message)
  443. def test_on_message(self):
  444. c = self.Consumer()
  445. c.app.connection_for_read = _amqp_connection()
  446. g = Gossip(c)
  447. assert g.enabled
  448. prepare = Mock()
  449. prepare.return_value = 'worker-online', {}
  450. c.app.events.State.assert_called_with(
  451. on_node_join=g.on_node_join,
  452. on_node_leave=g.on_node_leave,
  453. max_tasks_in_memory=1,
  454. )
  455. g.update_state = Mock()
  456. worker = Mock()
  457. g.on_node_join = Mock()
  458. g.on_node_leave = Mock()
  459. g.update_state.return_value = worker, 1
  460. message = Mock()
  461. message.delivery_info = {'routing_key': 'worker-online'}
  462. message.headers = {'hostname': 'other'}
  463. handler = g.event_handlers['worker-online'] = Mock()
  464. g.on_message(prepare, message)
  465. handler.assert_called_with(message.payload)
  466. g.event_handlers = {}
  467. g.on_message(prepare, message)
  468. message.delivery_info = {'routing_key': 'worker-offline'}
  469. prepare.return_value = 'worker-offline', {}
  470. g.on_message(prepare, message)
  471. message.delivery_info = {'routing_key': 'worker-baz'}
  472. prepare.return_value = 'worker-baz', {}
  473. g.update_state.return_value = worker, 0
  474. g.on_message(prepare, message)
  475. message.headers = {'hostname': g.hostname}
  476. g.on_message(prepare, message)
  477. g.clock.forward.assert_called_with()