test_worker.py 39 KB


  1. from __future__ import absolute_import, print_function, unicode_literals
  2. import os
  3. import socket
  4. import sys
  5. from collections import deque
  6. from datetime import datetime, timedelta
  7. from threading import Event
  8. from amqp import ChannelError
  9. from kombu import Connection
  10. from kombu.common import QoS, ignore_errors
  11. from kombu.transport.base import Message
  12. from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
  13. from celery.concurrency.base import BasePool
  14. from celery.exceptions import (
  15. WorkerShutdown, WorkerTerminate, TaskRevokedError,
  16. InvalidTaskError, ImproperlyConfigured,
  17. )
  18. from celery.five import Empty, range, Queue as FastQueue
  19. from celery.platforms import EX_FAILURE
  20. from celery.utils import uuid
  21. from celery.worker import components
  22. from celery.worker import consumer
  23. from celery.worker.consumer import Consumer as __Consumer
  24. from celery.worker.request import Request
  25. from celery.utils import worker_direct
  26. from celery.utils.serialization import pickle
  27. from celery.utils.timer2 import Timer
  28. from celery.tests.case import AppCase, Mock, TaskMessage, patch, todo
  29. def MockStep(step=None):
  30. step = Mock() if step is None else step
  31. step.blueprint = Mock()
  32. step.blueprint.name = 'MockNS'
  33. step.name = 'MockStep(%s)' % (id(step),)
  34. return step
  35. def mock_event_dispatcher():
  36. evd = Mock(name='event_dispatcher')
  37. evd.groups = ['worker']
  38. evd._outbound_buffer = deque()
  39. return evd
  40. class PlaceHolder(object):
  41. pass
  42. def find_step(obj, typ):
  43. return obj.blueprint.steps[typ.name]
  44. class Consumer(__Consumer):
  45. def __init__(self, *args, **kwargs):
  46. kwargs.setdefault('without_mingle', True) # disable Mingle step
  47. kwargs.setdefault('without_gossip', True) # disable Gossip step
  48. kwargs.setdefault('without_heartbeat', True) # disable Heart step
  49. kwargs.setdefault('controller', Mock())
  50. super(Consumer, self).__init__(*args, **kwargs)
  51. class _MyKombuConsumer(Consumer):
  52. broadcast_consumer = Mock()
  53. task_consumer = Mock()
  54. def __init__(self, *args, **kwargs):
  55. kwargs.setdefault('pool', BasePool(2))
  56. kwargs.setdefault('controller', Mock())
  57. super(_MyKombuConsumer, self).__init__(*args, **kwargs)
  58. def restart_heartbeat(self):
  59. self.heart = None
  60. class MyKombuConsumer(Consumer):
  61. def loop(self, *args, **kwargs):
  62. pass
  63. class MockNode(object):
  64. commands = []
  65. def handle_message(self, body, message):
  66. self.commands.append(body.pop('command', None))
  67. class MockEventDispatcher(object):
  68. sent = []
  69. closed = False
  70. flushed = False
  71. _outbound_buffer = []
  72. def send(self, event, *args, **kwargs):
  73. self.sent.append(event)
  74. def close(self):
  75. self.closed = True
  76. def flush(self):
  77. self.flushed = True
  78. class MockHeart(object):
  79. closed = False
  80. def stop(self):
  81. self.closed = True
  82. def create_message(channel, **data):
  83. data.setdefault('id', uuid())
  84. channel.no_ack_consumers = set()
  85. m = Message(channel, body=pickle.dumps(dict(**data)),
  86. content_type='application/x-python-serialize',
  87. content_encoding='binary',
  88. delivery_info={'consumer_tag': 'mock'})
  89. m.accept = ['application/x-python-serialize']
  90. return m
  91. def create_task_message(channel, *args, **kwargs):
  92. m = TaskMessage(*args, **kwargs)
  93. m.channel = channel
  94. m.delivery_info = {'consumer_tag': 'mock'}
  95. return m
  96. class test_Consumer(AppCase):
  97. def setup(self):
  98. self.buffer = FastQueue()
  99. self.timer = Timer()
  100. @self.app.task(shared=False)
  101. def foo_task(x, y, z):
  102. return x * y * z
  103. self.foo_task = foo_task
  104. def teardown(self):
  105. self.timer.stop()
  106. def test_info(self):
  107. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  108. l.task_consumer = Mock()
  109. l.qos = QoS(l.task_consumer.qos, 10)
  110. l.connection = Mock()
  111. l.connection.info.return_value = {'foo': 'bar'}
  112. l.controller = l.app.WorkController()
  113. l.pool = l.controller.pool = Mock()
  114. l.controller.pool.info.return_value = [Mock(), Mock()]
  115. l.controller.consumer = l
  116. info = l.controller.stats()
  117. self.assertEqual(info['prefetch_count'], 10)
  118. self.assertTrue(info['broker'])
  119. def test_start_when_closed(self):
  120. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  121. l.blueprint.state = CLOSE
  122. l.start()
  123. def test_connection(self):
  124. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  125. l.controller = l.app.WorkController()
  126. l.pool = l.controller.pool = Mock()
  127. l.blueprint.start(l)
  128. self.assertIsInstance(l.connection, Connection)
  129. l.blueprint.state = RUN
  130. l.event_dispatcher = None
  131. l.blueprint.restart(l)
  132. self.assertTrue(l.connection)
  133. l.blueprint.state = RUN
  134. l.shutdown()
  135. self.assertIsNone(l.connection)
  136. self.assertIsNone(l.task_consumer)
  137. l.blueprint.start(l)
  138. self.assertIsInstance(l.connection, Connection)
  139. l.blueprint.restart(l)
  140. l.stop()
  141. l.shutdown()
  142. self.assertIsNone(l.connection)
  143. self.assertIsNone(l.task_consumer)
  144. def test_close_connection(self):
  145. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  146. l.blueprint.state = RUN
  147. step = find_step(l, consumer.Connection)
  148. conn = l.connection = Mock()
  149. step.shutdown(l)
  150. self.assertTrue(conn.close.called)
  151. self.assertIsNone(l.connection)
  152. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  153. eventer = l.event_dispatcher = mock_event_dispatcher()
  154. eventer.enabled = True
  155. heart = l.heart = MockHeart()
  156. l.blueprint.state = RUN
  157. Events = find_step(l, consumer.Events)
  158. Events.shutdown(l)
  159. Heart = find_step(l, consumer.Heart)
  160. Heart.shutdown(l)
  161. self.assertTrue(eventer.close.call_count)
  162. self.assertTrue(heart.closed)
  163. @patch('celery.worker.consumer.consumer.warn')
  164. def test_receive_message_unknown(self, warn):
  165. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  166. l.blueprint.state = RUN
  167. l.steps.pop()
  168. channel = Mock()
  169. m = create_message(channel, unknown={'baz': '!!!'})
  170. l.event_dispatcher = mock_event_dispatcher()
  171. l.node = MockNode()
  172. callback = self._get_on_message(l)
  173. callback(m)
  174. self.assertTrue(warn.call_count)
  175. @patch('celery.worker.strategy.to_timestamp')
  176. def test_receive_message_eta_OverflowError(self, to_timestamp):
  177. to_timestamp.side_effect = OverflowError()
  178. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  179. l.controller = l.app.WorkController()
  180. l.pool = l.controller.pool = Mock()
  181. l.blueprint.state = RUN
  182. l.steps.pop()
  183. m = create_task_message(
  184. Mock(), self.foo_task.name,
  185. args=('2, 2'), kwargs={},
  186. eta=datetime.now().isoformat(),
  187. )
  188. l.event_dispatcher = mock_event_dispatcher()
  189. l.node = MockNode()
  190. l.update_strategies()
  191. l.qos = Mock()
  192. callback = self._get_on_message(l)
  193. callback(m)
  194. self.assertTrue(m.acknowledged)
  195. @patch('celery.worker.consumer.consumer.error')
  196. def test_receive_message_InvalidTaskError(self, error):
  197. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  198. l.blueprint.state = RUN
  199. l.event_dispatcher = mock_event_dispatcher()
  200. l.steps.pop()
  201. l.controller = l.app.WorkController()
  202. l.pool = l.controller.pool = Mock()
  203. m = create_task_message(
  204. Mock(), self.foo_task.name,
  205. args=(1, 2), kwargs='foobarbaz', id=1)
  206. l.update_strategies()
  207. l.event_dispatcher = mock_event_dispatcher()
  208. strat = l.strategies[self.foo_task.name] = Mock(name='strategy')
  209. strat.side_effect = InvalidTaskError()
  210. callback = self._get_on_message(l)
  211. callback(m)
  212. self.assertTrue(error.called)
  213. self.assertIn('Received invalid task message', error.call_args[0][0])
  214. @patch('celery.worker.consumer.consumer.crit')
  215. def test_on_decode_error(self, crit):
  216. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  217. class MockMessage(Mock):
  218. content_type = 'application/x-msgpack'
  219. content_encoding = 'binary'
  220. body = 'foobarbaz'
  221. message = MockMessage()
  222. l.on_decode_error(message, KeyError('foo'))
  223. self.assertTrue(message.ack.call_count)
  224. self.assertIn("Can't decode message body", crit.call_args[0][0])
  225. def _get_on_message(self, l):
  226. if l.qos is None:
  227. l.qos = Mock()
  228. l.event_dispatcher = mock_event_dispatcher()
  229. l.task_consumer = Mock()
  230. l.connection = Mock()
  231. l.connection.drain_events.side_effect = WorkerShutdown()
  232. with self.assertRaises(WorkerShutdown):
  233. l.loop(*l.loop_args())
  234. self.assertTrue(l.task_consumer.on_message)
  235. return l.task_consumer.on_message
  236. def test_receieve_message(self):
  237. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  238. l.controller = l.app.WorkController()
  239. l.pool = l.controller.pool = Mock()
  240. l.blueprint.state = RUN
  241. l.event_dispatcher = mock_event_dispatcher()
  242. m = create_task_message(
  243. Mock(), self.foo_task.name,
  244. args=[2, 4, 8], kwargs={},
  245. )
  246. l.update_strategies()
  247. callback = self._get_on_message(l)
  248. callback(m)
  249. in_bucket = self.buffer.get_nowait()
  250. self.assertIsInstance(in_bucket, Request)
  251. self.assertEqual(in_bucket.name, self.foo_task.name)
  252. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  253. self.assertTrue(self.timer.empty())
  254. def test_start_channel_error(self):
  255. class MockConsumer(Consumer):
  256. iterations = 0
  257. def loop(self, *args, **kwargs):
  258. if not self.iterations:
  259. self.iterations = 1
  260. raise KeyError('foo')
  261. raise SyntaxError('bar')
  262. l = MockConsumer(self.buffer.put, timer=self.timer,
  263. send_events=False, pool=BasePool(), app=self.app)
  264. l.controller = l.app.WorkController()
  265. l.pool = l.controller.pool = Mock()
  266. l.channel_errors = (KeyError,)
  267. with self.assertRaises(KeyError):
  268. l.start()
  269. l.timer.stop()
  270. def test_start_connection_error(self):
  271. class MockConsumer(Consumer):
  272. iterations = 0
  273. def loop(self, *args, **kwargs):
  274. if not self.iterations:
  275. self.iterations = 1
  276. raise KeyError('foo')
  277. raise SyntaxError('bar')
  278. l = MockConsumer(self.buffer.put, timer=self.timer,
  279. send_events=False, pool=BasePool(), app=self.app)
  280. l.controller = l.app.WorkController()
  281. l.pool = l.controller.pool = Mock()
  282. l.connection_errors = (KeyError,)
  283. with self.assertRaises(SyntaxError):
  284. l.start()
  285. l.timer.stop()
  286. def test_loop_ignores_socket_timeout(self):
  287. class Connection(self.app.connection_for_read().__class__):
  288. obj = None
  289. def drain_events(self, **kwargs):
  290. self.obj.connection = None
  291. raise socket.timeout(10)
  292. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  293. l.connection = Connection()
  294. l.task_consumer = Mock()
  295. l.connection.obj = l
  296. l.qos = QoS(l.task_consumer.qos, 10)
  297. l.loop(*l.loop_args())
  298. def test_loop_when_socket_error(self):
  299. class Connection(self.app.connection_for_read().__class__):
  300. obj = None
  301. def drain_events(self, **kwargs):
  302. self.obj.connection = None
  303. raise socket.error('foo')
  304. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  305. l.blueprint.state = RUN
  306. c = l.connection = Connection()
  307. l.connection.obj = l
  308. l.task_consumer = Mock()
  309. l.qos = QoS(l.task_consumer.qos, 10)
  310. with self.assertRaises(socket.error):
  311. l.loop(*l.loop_args())
  312. l.blueprint.state = CLOSE
  313. l.connection = c
  314. l.loop(*l.loop_args())
  315. def test_loop(self):
  316. class Connection(self.app.connection_for_read().__class__):
  317. obj = None
  318. def drain_events(self, **kwargs):
  319. self.obj.connection = None
  320. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  321. l.blueprint.state = RUN
  322. l.connection = Connection()
  323. l.connection.obj = l
  324. l.task_consumer = Mock()
  325. l.qos = QoS(l.task_consumer.qos, 10)
  326. l.loop(*l.loop_args())
  327. l.loop(*l.loop_args())
  328. self.assertTrue(l.task_consumer.consume.call_count)
  329. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  330. self.assertEqual(l.qos.value, 10)
  331. l.qos.decrement_eventually()
  332. self.assertEqual(l.qos.value, 9)
  333. l.qos.update()
  334. self.assertEqual(l.qos.value, 9)
  335. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  336. def test_ignore_errors(self):
  337. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  338. l.connection_errors = (AttributeError, KeyError,)
  339. l.channel_errors = (SyntaxError,)
  340. ignore_errors(l, Mock(side_effect=AttributeError('foo')))
  341. ignore_errors(l, Mock(side_effect=KeyError('foo')))
  342. ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
  343. with self.assertRaises(IndexError):
  344. ignore_errors(l, Mock(side_effect=IndexError('foo')))
  345. def test_apply_eta_task(self):
  346. from celery.worker import state
  347. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  348. l.controller = l.app.WorkController()
  349. l.pool = l.controller.pool = Mock()
  350. l.qos = QoS(None, 10)
  351. task = object()
  352. qos = l.qos.value
  353. l.apply_eta_task(task)
  354. self.assertIn(task, state.reserved_requests)
  355. self.assertEqual(l.qos.value, qos - 1)
  356. self.assertIs(self.buffer.get_nowait(), task)
  357. def test_receieve_message_eta_isoformat(self):
  358. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  359. l.controller = l.app.WorkController()
  360. l.pool = l.controller.pool = Mock()
  361. l.blueprint.state = RUN
  362. l.steps.pop()
  363. m = create_task_message(
  364. Mock(), self.foo_task.name,
  365. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  366. args=[2, 4, 8], kwargs={},
  367. )
  368. l.task_consumer = Mock()
  369. l.qos = QoS(l.task_consumer.qos, 1)
  370. current_pcount = l.qos.value
  371. l.event_dispatcher = mock_event_dispatcher()
  372. l.enabled = False
  373. l.update_strategies()
  374. callback = self._get_on_message(l)
  375. callback(m)
  376. l.timer.stop()
  377. l.timer.join(1)
  378. items = [entry[2] for entry in self.timer.queue]
  379. found = 0
  380. for item in items:
  381. if item.args[0].name == self.foo_task.name:
  382. found = True
  383. self.assertTrue(found)
  384. self.assertGreater(l.qos.value, current_pcount)
  385. l.timer.stop()
  386. def test_pidbox_callback(self):
  387. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  388. con = find_step(l, consumer.Control).box
  389. con.node = Mock()
  390. con.reset = Mock()
  391. con.on_message('foo', 'bar')
  392. con.node.handle_message.assert_called_with('foo', 'bar')
  393. con.node = Mock()
  394. con.node.handle_message.side_effect = KeyError('foo')
  395. con.on_message('foo', 'bar')
  396. con.node.handle_message.assert_called_with('foo', 'bar')
  397. con.node = Mock()
  398. con.node.handle_message.side_effect = ValueError('foo')
  399. con.on_message('foo', 'bar')
  400. con.node.handle_message.assert_called_with('foo', 'bar')
  401. self.assertTrue(con.reset.called)
  402. def test_revoke(self):
  403. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  404. l.blueprint.state = RUN
  405. l.steps.pop()
  406. channel = Mock()
  407. id = uuid()
  408. t = create_task_message(
  409. channel, self.foo_task.name,
  410. args=[2, 4, 8], kwargs={}, id=id,
  411. )
  412. from celery.worker.state import revoked
  413. revoked.add(id)
  414. callback = self._get_on_message(l)
  415. callback(t)
  416. self.assertTrue(self.buffer.empty())
  417. def test_receieve_message_not_registered(self):
  418. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  419. l.blueprint.state = RUN
  420. l.steps.pop()
  421. channel = Mock(name='channel')
  422. m = create_task_message(
  423. channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
  424. )
  425. l.event_dispatcher = mock_event_dispatcher()
  426. callback = self._get_on_message(l)
  427. self.assertFalse(callback(m))
  428. with self.assertRaises(Empty):
  429. self.buffer.get_nowait()
  430. self.assertTrue(self.timer.empty())
  431. @patch('celery.worker.consumer.consumer.warn')
  432. @patch('celery.worker.consumer.consumer.logger')
  433. def test_receieve_message_ack_raises(self, logger, warn):
  434. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  435. l.controller = l.app.WorkController()
  436. l.pool = l.controller.pool = Mock()
  437. l.blueprint.state = RUN
  438. channel = Mock()
  439. m = create_task_message(
  440. channel, self.foo_task.name,
  441. args=[2, 4, 8], kwargs={},
  442. )
  443. m.headers = None
  444. l.event_dispatcher = mock_event_dispatcher()
  445. l.update_strategies()
  446. l.connection_errors = (socket.error,)
  447. m.reject = Mock()
  448. m.reject.side_effect = socket.error('foo')
  449. callback = self._get_on_message(l)
  450. self.assertFalse(callback(m))
  451. self.assertTrue(warn.call_count)
  452. with self.assertRaises(Empty):
  453. self.buffer.get_nowait()
  454. self.assertTrue(self.timer.empty())
  455. m.reject_log_error.assert_called_with(logger, l.connection_errors)
  456. def test_receive_message_eta(self):
  457. import sys
  458. from functools import partial
  459. if os.environ.get('C_DEBUG_TEST'):
  460. pp = partial(print, file=sys.__stderr__)
  461. else:
  462. def pp(*args, **kwargs):
  463. pass
  464. pp('TEST RECEIVE MESSAGE ETA')
  465. pp('+CREATE MYKOMBUCONSUMER')
  466. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  467. l.controller = l.app.WorkController()
  468. l.pool = l.controller.pool = Mock()
  469. pp('-CREATE MYKOMBUCONSUMER')
  470. l.steps.pop()
  471. l.event_dispatcher = mock_event_dispatcher()
  472. channel = Mock(name='channel')
  473. pp('+ CREATE MESSAGE')
  474. m = create_task_message(
  475. channel, self.foo_task.name,
  476. args=[2, 4, 8], kwargs={},
  477. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  478. )
  479. pp('- CREATE MESSAGE')
  480. try:
  481. pp('+ BLUEPRINT START 1')
  482. l.blueprint.start(l)
  483. pp('- BLUEPRINT START 1')
  484. p = l.app.conf.broker_connection_retry
  485. l.app.conf.broker_connection_retry = False
  486. pp('+ BLUEPRINT START 2')
  487. l.blueprint.start(l)
  488. pp('- BLUEPRINT START 2')
  489. l.app.conf.broker_connection_retry = p
  490. pp('+ BLUEPRINT RESTART')
  491. l.blueprint.restart(l)
  492. pp('- BLUEPRINT RESTART')
  493. l.event_dispatcher = mock_event_dispatcher()
  494. pp('+ GET ON MESSAGE')
  495. callback = self._get_on_message(l)
  496. pp('- GET ON MESSAGE')
  497. pp('+ CALLBACK')
  498. callback(m)
  499. pp('- CALLBACK')
  500. finally:
  501. pp('+ STOP TIMER')
  502. l.timer.stop()
  503. pp('- STOP TIMER')
  504. try:
  505. pp('+ JOIN TIMER')
  506. l.timer.join()
  507. pp('- JOIN TIMER')
  508. except RuntimeError:
  509. pass
  510. in_hold = l.timer.queue[0]
  511. self.assertEqual(len(in_hold), 3)
  512. eta, priority, entry = in_hold
  513. task = entry.args[0]
  514. self.assertIsInstance(task, Request)
  515. self.assertEqual(task.name, self.foo_task.name)
  516. self.assertEqual(task.execute(), 2 * 4 * 8)
  517. with self.assertRaises(Empty):
  518. self.buffer.get_nowait()
  519. def test_reset_pidbox_node(self):
  520. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  521. con = find_step(l, consumer.Control).box
  522. con.node = Mock()
  523. chan = con.node.channel = Mock()
  524. l.connection = Mock()
  525. chan.close.side_effect = socket.error('foo')
  526. l.connection_errors = (socket.error,)
  527. con.reset()
  528. chan.close.assert_called_with()
  529. def test_reset_pidbox_node_green(self):
  530. from celery.worker.pidbox import gPidbox
  531. pool = Mock()
  532. pool.is_green = True
  533. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  534. app=self.app)
  535. con = find_step(l, consumer.Control)
  536. self.assertIsInstance(con.box, gPidbox)
  537. con.start(l)
  538. l.pool.spawn_n.assert_called_with(
  539. con.box.loop, l,
  540. )
  541. def test__green_pidbox_node(self):
  542. pool = Mock()
  543. pool.is_green = True
  544. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  545. app=self.app)
  546. l.node = Mock()
  547. controller = find_step(l, consumer.Control)
  548. class BConsumer(Mock):
  549. def __enter__(self):
  550. self.consume()
  551. return self
  552. def __exit__(self, *exc_info):
  553. self.cancel()
  554. controller.box.node.listen = BConsumer()
  555. connections = []
  556. class Connection(object):
  557. calls = 0
  558. def __init__(self, obj):
  559. connections.append(self)
  560. self.obj = obj
  561. self.default_channel = self.channel()
  562. self.closed = False
  563. def __enter__(self):
  564. return self
  565. def __exit__(self, *exc_info):
  566. self.close()
  567. def channel(self):
  568. return Mock()
  569. def as_uri(self):
  570. return 'dummy://'
  571. def drain_events(self, **kwargs):
  572. if not self.calls:
  573. self.calls += 1
  574. raise socket.timeout()
  575. self.obj.connection = None
  576. controller.box._node_shutdown.set()
  577. def close(self):
  578. self.closed = True
  579. l.connection = Mock()
  580. l.connect = lambda: Connection(obj=l)
  581. controller = find_step(l, consumer.Control)
  582. controller.box.loop(l)
  583. self.assertTrue(controller.box.node.listen.called)
  584. self.assertTrue(controller.box.consumer)
  585. controller.box.consumer.consume.assert_called_with()
  586. self.assertIsNone(l.connection)
  587. self.assertTrue(connections[0].closed)
  588. @patch('kombu.connection.Connection._establish_connection')
  589. @patch('kombu.utils.sleep')
  590. def test_connect_errback(self, sleep, connect):
  591. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  592. from kombu.transport.memory import Transport
  593. Transport.connection_errors = (ChannelError,)
  594. def effect():
  595. if connect.call_count > 1:
  596. return
  597. raise ChannelError('error')
  598. connect.side_effect = effect
  599. l.connect()
  600. connect.assert_called_with()
  601. def test_stop_pidbox_node(self):
  602. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  603. cont = find_step(l, consumer.Control)
  604. cont._node_stopped = Event()
  605. cont._node_shutdown = Event()
  606. cont._node_stopped.set()
  607. cont.stop(l)
  608. def test_start__loop(self):
  609. class _QoS(object):
  610. prev = 3
  611. value = 4
  612. def update(self):
  613. self.prev = self.value
  614. class _Consumer(MyKombuConsumer):
  615. iterations = 0
  616. def reset_connection(self):
  617. if self.iterations >= 1:
  618. raise KeyError('foo')
  619. init_callback = Mock()
  620. l = _Consumer(self.buffer.put, timer=self.timer,
  621. init_callback=init_callback, app=self.app)
  622. l.controller = l.app.WorkController()
  623. l.pool = l.controller.pool = Mock()
  624. l.task_consumer = Mock()
  625. l.broadcast_consumer = Mock()
  626. l.qos = _QoS()
  627. l.connection = Connection()
  628. l.iterations = 0
  629. def raises_KeyError(*args, **kwargs):
  630. l.iterations += 1
  631. if l.qos.prev != l.qos.value:
  632. l.qos.update()
  633. if l.iterations >= 2:
  634. raise KeyError('foo')
  635. l.loop = raises_KeyError
  636. with self.assertRaises(KeyError):
  637. l.start()
  638. self.assertEqual(l.iterations, 2)
  639. self.assertEqual(l.qos.prev, l.qos.value)
  640. init_callback.reset_mock()
  641. l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
  642. send_events=False, init_callback=init_callback)
  643. l.controller = l.app.WorkController()
  644. l.pool = l.controller.pool = Mock()
  645. l.qos = _QoS()
  646. l.task_consumer = Mock()
  647. l.broadcast_consumer = Mock()
  648. l.connection = Connection()
  649. l.loop = Mock(side_effect=socket.error('foo'))
  650. with self.assertRaises(socket.error):
  651. l.start()
  652. self.assertTrue(l.loop.call_count)
  653. def test_reset_connection_with_no_node(self):
  654. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  655. l.controller = l.app.WorkController()
  656. l.pool = l.controller.pool = Mock()
  657. l.steps.pop()
  658. l.blueprint.start(l)
  659. class test_WorkController(AppCase):
  660. def setup(self):
  661. self.worker = self.create_worker()
  662. from celery import worker
  663. self._logger = worker.logger
  664. self._comp_logger = components.logger
  665. self.logger = worker.logger = Mock()
  666. self.comp_logger = components.logger = Mock()
  667. @self.app.task(shared=False)
  668. def foo_task(x, y, z):
  669. return x * y * z
  670. self.foo_task = foo_task
  671. def teardown(self):
  672. from celery import worker
  673. worker.logger = self._logger
  674. components.logger = self._comp_logger
  675. def create_worker(self, **kw):
  676. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  677. worker.blueprint.shutdown_complete.set()
  678. return worker
  679. def test_on_consumer_ready(self):
  680. self.worker.on_consumer_ready(Mock())
  681. def test_setup_queues_worker_direct(self):
  682. self.app.conf.worker_direct = True
  683. self.app.amqp.__dict__['queues'] = Mock()
  684. self.worker.setup_queues({})
  685. self.app.amqp.queues.select_add.assert_called_with(
  686. worker_direct(self.worker.hostname),
  687. )
  688. def test_setup_queues__missing_queue(self):
  689. self.app.amqp.queues.select = Mock(name='select')
  690. self.app.amqp.queues.deselect = Mock(name='deselect')
  691. self.app.amqp.queues.select.side_effect = KeyError()
  692. self.app.amqp.queues.deselect.side_effect = KeyError()
  693. with self.assertRaises(ImproperlyConfigured):
  694. self.worker.setup_queues('x,y', exclude='foo,bar')
  695. self.app.amqp.queues.select = Mock(name='select')
  696. with self.assertRaises(ImproperlyConfigured):
  697. self.worker.setup_queues('x,y', exclude='foo,bar')
  698. def test_send_worker_shutdown(self):
  699. with patch('celery.signals.worker_shutdown') as ws:
  700. self.worker._send_worker_shutdown()
  701. ws.send.assert_called_with(sender=self.worker)
  702. @todo('unstable test')
  703. def test_process_shutdown_on_worker_shutdown(self):
  704. from celery.concurrency.prefork import process_destructor
  705. from celery.concurrency.asynpool import Worker
  706. with patch('celery.signals.worker_process_shutdown') as ws:
  707. Worker._make_shortcuts = Mock()
  708. with patch('os._exit') as _exit:
  709. worker = Worker(None, None, on_exit=process_destructor)
  710. worker._do_exit(22, 3.1415926)
  711. ws.send.assert_called_with(
  712. sender=None, pid=22, exitcode=3.1415926,
  713. )
  714. _exit.assert_called_with(3.1415926)
  715. def test_process_task_revoked_release_semaphore(self):
  716. self.worker._quick_release = Mock()
  717. req = Mock()
  718. req.execute_using_pool.side_effect = TaskRevokedError
  719. self.worker._process_task(req)
  720. self.worker._quick_release.assert_called_with()
  721. delattr(self.worker, '_quick_release')
  722. self.worker._process_task(req)
  723. def test_shutdown_no_blueprint(self):
  724. self.worker.blueprint = None
  725. self.worker._shutdown()
  726. @patch('celery.worker.create_pidlock')
  727. def test_use_pidfile(self, create_pidlock):
  728. create_pidlock.return_value = Mock()
  729. worker = self.create_worker(pidfile='pidfilelockfilepid')
  730. worker.steps = []
  731. worker.start()
  732. self.assertTrue(create_pidlock.called)
  733. worker.stop()
  734. self.assertTrue(worker.pidlock.release.called)
  735. def test_attrs(self):
  736. worker = self.worker
  737. self.assertIsNotNone(worker.timer)
  738. self.assertIsInstance(worker.timer, Timer)
  739. self.assertIsNotNone(worker.pool)
  740. self.assertIsNotNone(worker.consumer)
  741. self.assertTrue(worker.steps)
  742. def test_with_embedded_beat(self):
  743. worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
  744. self.assertTrue(worker.beat)
  745. self.assertIn(worker.beat, [w.obj for w in worker.steps])
  746. def test_with_autoscaler(self):
  747. worker = self.create_worker(
  748. autoscale=[10, 3], send_events=False,
  749. timer_cls='celery.utils.timer2.Timer',
  750. )
  751. self.assertTrue(worker.autoscaler)
  752. def test_dont_stop_or_terminate(self):
  753. worker = self.app.WorkController(concurrency=1, loglevel=0)
  754. worker.stop()
  755. self.assertNotEqual(worker.blueprint.state, CLOSE)
  756. worker.terminate()
  757. self.assertNotEqual(worker.blueprint.state, CLOSE)
  758. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  759. try:
  760. worker.blueprint.state = RUN
  761. worker.stop(in_sighandler=True)
  762. self.assertNotEqual(worker.blueprint.state, CLOSE)
  763. worker.terminate(in_sighandler=True)
  764. self.assertNotEqual(worker.blueprint.state, CLOSE)
  765. finally:
  766. worker.pool.signal_safe = sigsafe
  767. def test_on_timer_error(self):
  768. worker = self.app.WorkController(concurrency=1, loglevel=0)
  769. try:
  770. raise KeyError('foo')
  771. except KeyError as exc:
  772. components.Timer(worker).on_timer_error(exc)
  773. msg, args = self.comp_logger.error.call_args[0]
  774. self.assertIn('KeyError', msg % args)
  775. def test_on_timer_tick(self):
  776. worker = self.app.WorkController(concurrency=1, loglevel=10)
  777. components.Timer(worker).on_timer_tick(30.0)
  778. xargs = self.comp_logger.debug.call_args[0]
  779. fmt, arg = xargs[0], xargs[1]
  780. self.assertEqual(30.0, arg)
  781. self.assertIn('Next eta %s secs', fmt)
  782. def test_process_task(self):
  783. worker = self.worker
  784. worker.pool = Mock()
  785. channel = Mock()
  786. m = create_task_message(
  787. channel, self.foo_task.name,
  788. args=[4, 8, 10], kwargs={},
  789. )
  790. task = Request(m, app=self.app)
  791. worker._process_task(task)
  792. self.assertEqual(worker.pool.apply_async.call_count, 1)
  793. worker.pool.stop()
  794. def test_process_task_raise_base(self):
  795. worker = self.worker
  796. worker.pool = Mock()
  797. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  798. channel = Mock()
  799. m = create_task_message(
  800. channel, self.foo_task.name,
  801. args=[4, 8, 10], kwargs={},
  802. )
  803. task = Request(m, app=self.app)
  804. worker.steps = []
  805. worker.blueprint.state = RUN
  806. with self.assertRaises(KeyboardInterrupt):
  807. worker._process_task(task)
  808. def test_process_task_raise_WorkerTerminate(self):
  809. worker = self.worker
  810. worker.pool = Mock()
  811. worker.pool.apply_async.side_effect = WorkerTerminate()
  812. channel = Mock()
  813. m = create_task_message(
  814. channel, self.foo_task.name,
  815. args=[4, 8, 10], kwargs={},
  816. )
  817. task = Request(m, app=self.app)
  818. worker.steps = []
  819. worker.blueprint.state = RUN
  820. with self.assertRaises(SystemExit):
  821. worker._process_task(task)
  822. def test_process_task_raise_regular(self):
  823. worker = self.worker
  824. worker.pool = Mock()
  825. worker.pool.apply_async.side_effect = KeyError('some exception')
  826. channel = Mock()
  827. m = create_task_message(
  828. channel, self.foo_task.name,
  829. args=[4, 8, 10], kwargs={},
  830. )
  831. task = Request(m, app=self.app)
  832. worker._process_task(task)
  833. worker.pool.stop()
  834. def test_start_catches_base_exceptions(self):
  835. worker1 = self.create_worker()
  836. worker1.blueprint.state = RUN
  837. stc = MockStep()
  838. stc.start.side_effect = WorkerTerminate()
  839. worker1.steps = [stc]
  840. worker1.start()
  841. stc.start.assert_called_with(worker1)
  842. self.assertTrue(stc.terminate.call_count)
  843. worker2 = self.create_worker()
  844. worker2.blueprint.state = RUN
  845. sec = MockStep()
  846. sec.start.side_effect = WorkerShutdown()
  847. sec.terminate = None
  848. worker2.steps = [sec]
  849. worker2.start()
  850. self.assertTrue(sec.stop.call_count)
  851. def test_state_db(self):
  852. from celery.worker import state
  853. Persistent = state.Persistent
  854. state.Persistent = Mock()
  855. try:
  856. worker = self.create_worker(state_db='statefilename')
  857. self.assertTrue(worker._persistence)
  858. finally:
  859. state.Persistent = Persistent
  860. def test_process_task_sem(self):
  861. worker = self.worker
  862. worker._quick_acquire = Mock()
  863. req = Mock()
  864. worker._process_task_sem(req)
  865. worker._quick_acquire.assert_called_with(worker._process_task, req)
  866. def test_signal_consumer_close(self):
  867. worker = self.worker
  868. worker.consumer = Mock()
  869. worker.signal_consumer_close()
  870. worker.consumer.close.assert_called_with()
  871. worker.consumer.close.side_effect = AttributeError()
  872. worker.signal_consumer_close()
  873. def test_rusage__no_resource(self):
  874. from celery import worker
  875. prev, worker.resource = worker.resource, None
  876. try:
  877. self.worker.pool = Mock(name='pool')
  878. with self.assertRaises(NotImplementedError):
  879. self.worker.rusage()
  880. self.worker.stats()
  881. finally:
  882. worker.resource = prev
  883. def test_repr(self):
  884. self.assertTrue(repr(self.worker))
  885. def test_str(self):
  886. self.assertEqual(str(self.worker), self.worker.hostname)
  887. def test_start__stop(self):
  888. worker = self.worker
  889. worker.blueprint.shutdown_complete.set()
  890. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  891. worker.blueprint.state = RUN
  892. worker.blueprint.started = 4
  893. for w in worker.steps:
  894. w.start = Mock()
  895. w.close = Mock()
  896. w.stop = Mock()
  897. worker.start()
  898. for w in worker.steps:
  899. self.assertTrue(w.start.call_count)
  900. worker.consumer = Mock()
  901. worker.stop(exitcode=3)
  902. for stopstep in worker.steps:
  903. self.assertTrue(stopstep.close.call_count)
  904. self.assertTrue(stopstep.stop.call_count)
  905. # Doesn't close pool if no pool.
  906. worker.start()
  907. worker.pool = None
  908. worker.stop()
  909. # test that stop of None is not attempted
  910. worker.steps[-1] = None
  911. worker.start()
  912. worker.stop()
  913. def test_start__KeyboardInterrupt(self):
  914. worker = self.worker
  915. worker.blueprint = Mock(name='blueprint')
  916. worker.blueprint.start.side_effect = KeyboardInterrupt()
  917. worker.stop = Mock(name='stop')
  918. worker.start()
  919. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  920. def test_register_with_event_loop(self):
  921. worker = self.worker
  922. hub = Mock(name='hub')
  923. worker.blueprint = Mock(name='blueprint')
  924. worker.register_with_event_loop(hub)
  925. worker.blueprint.send_all.assert_called_with(
  926. worker, 'register_with_event_loop', args=(hub,),
  927. description='hub.register',
  928. )
  929. def test_step_raises(self):
  930. worker = self.worker
  931. step = Mock()
  932. worker.steps = [step]
  933. step.start.side_effect = TypeError()
  934. worker.stop = Mock()
  935. worker.start()
  936. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  937. def test_state(self):
  938. self.assertTrue(self.worker.state)
  939. def test_start__terminate(self):
  940. worker = self.worker
  941. worker.blueprint.shutdown_complete.set()
  942. worker.blueprint.started = 5
  943. worker.blueprint.state = RUN
  944. worker.steps = [MockStep() for _ in range(5)]
  945. worker.start()
  946. for w in worker.steps[:3]:
  947. self.assertTrue(w.start.call_count)
  948. self.assertTrue(worker.blueprint.started, len(worker.steps))
  949. self.assertEqual(worker.blueprint.state, RUN)
  950. worker.terminate()
  951. for step in worker.steps:
  952. self.assertTrue(step.terminate.call_count)
  953. worker.blueprint.state = TERMINATE
  954. worker.terminate()
  955. def test_Hub_crate(self):
  956. w = Mock()
  957. x = components.Hub(w)
  958. x.create(w)
  959. self.assertTrue(w.timer.max_interval)
  960. def test_Pool_crate_threaded(self):
  961. w = Mock()
  962. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  963. w.pool_cls = Mock()
  964. w.use_eventloop = False
  965. pool = components.Pool(w)
  966. pool.create(w)
  967. def test_Pool_pool_no_sem(self):
  968. w = Mock()
  969. w.pool_cls.uses_semaphore = False
  970. components.Pool(w).create(w)
  971. self.assertIs(w.process_task, w._process_task)
  972. def test_Pool_create(self):
  973. from kombu.async.semaphore import LaxBoundedSemaphore
  974. w = Mock()
  975. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  976. w.hub = Mock()
  977. PoolImp = Mock()
  978. poolimp = PoolImp.return_value = Mock()
  979. poolimp._pool = [Mock(), Mock()]
  980. poolimp._cache = {}
  981. poolimp._fileno_to_inq = {}
  982. poolimp._fileno_to_outq = {}
  983. from celery.concurrency.prefork import TaskPool as _TaskPool
  984. class MockTaskPool(_TaskPool):
  985. Pool = PoolImp
  986. @property
  987. def timers(self):
  988. return {Mock(): 30}
  989. w.pool_cls = MockTaskPool
  990. w.use_eventloop = True
  991. w.consumer.restart_count = -1
  992. pool = components.Pool(w)
  993. pool.create(w)
  994. pool.register_with_event_loop(w, w.hub)
  995. if sys.platform != 'win32':
  996. self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
  997. P = w.pool
  998. P.start()