test_worker.py 39 KB


  1. from __future__ import absolute_import, print_function
  2. import os
  3. import socket
  4. import sys
  5. from collections import deque
  6. from datetime import datetime, timedelta
  7. from threading import Event
  8. from amqp import ChannelError
  9. from kombu import Connection
  10. from kombu.common import QoS, ignore_errors
  11. from kombu.transport.base import Message
  12. from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
  13. from celery.concurrency.base import BasePool
  14. from celery.exceptions import (
  15. WorkerShutdown, WorkerTerminate, TaskRevokedError,
  16. InvalidTaskError, ImproperlyConfigured,
  17. )
  18. from celery.five import Empty, range, Queue as FastQueue
  19. from celery.platforms import EX_FAILURE
  20. from celery.utils import uuid
  21. from celery.worker import components
  22. from celery.worker import consumer
  23. from celery.worker.consumer import Consumer as __Consumer
  24. from celery.worker.request import Request
  25. from celery.utils import worker_direct
  26. from celery.utils.serialization import pickle
  27. from celery.utils.timer2 import Timer
  28. from celery.tests.case import AppCase, Mock, SkipTest, TaskMessage, patch
  29. def MockStep(step=None):
  30. step = Mock() if step is None else step
  31. step.blueprint = Mock()
  32. step.blueprint.name = 'MockNS'
  33. step.name = 'MockStep(%s)' % (id(step),)
  34. return step
  35. def mock_event_dispatcher():
  36. evd = Mock(name='event_dispatcher')
  37. evd.groups = ['worker']
  38. evd._outbound_buffer = deque()
  39. return evd
  40. class PlaceHolder(object):
  41. pass
  42. def find_step(obj, typ):
  43. return obj.blueprint.steps[typ.name]
  44. class Consumer(__Consumer):
  45. def __init__(self, *args, **kwargs):
  46. kwargs.setdefault('without_mingle', True) # disable Mingle step
  47. kwargs.setdefault('without_gossip', True) # disable Gossip step
  48. kwargs.setdefault('without_heartbeat', True) # disable Heart step
  49. super(Consumer, self).__init__(*args, **kwargs)
  50. class _MyKombuConsumer(Consumer):
  51. broadcast_consumer = Mock()
  52. task_consumer = Mock()
  53. def __init__(self, *args, **kwargs):
  54. kwargs.setdefault('pool', BasePool(2))
  55. super(_MyKombuConsumer, self).__init__(*args, **kwargs)
  56. def restart_heartbeat(self):
  57. self.heart = None
  58. class MyKombuConsumer(Consumer):
  59. def loop(self, *args, **kwargs):
  60. pass
  61. class MockNode(object):
  62. commands = []
  63. def handle_message(self, body, message):
  64. self.commands.append(body.pop('command', None))
  65. class MockEventDispatcher(object):
  66. sent = []
  67. closed = False
  68. flushed = False
  69. _outbound_buffer = []
  70. def send(self, event, *args, **kwargs):
  71. self.sent.append(event)
  72. def close(self):
  73. self.closed = True
  74. def flush(self):
  75. self.flushed = True
  76. class MockHeart(object):
  77. closed = False
  78. def stop(self):
  79. self.closed = True
  80. def create_message(channel, **data):
  81. data.setdefault('id', uuid())
  82. channel.no_ack_consumers = set()
  83. m = Message(channel, body=pickle.dumps(dict(**data)),
  84. content_type='application/x-python-serialize',
  85. content_encoding='binary',
  86. delivery_info={'consumer_tag': 'mock'})
  87. m.accept = ['application/x-python-serialize']
  88. return m
  89. def create_task_message(channel, *args, **kwargs):
  90. m = TaskMessage(*args, **kwargs)
  91. m.channel = channel
  92. m.delivery_info = {'consumer_tag': 'mock'}
  93. return m
  94. class test_Consumer(AppCase):
  95. def setup(self):
  96. self.buffer = FastQueue()
  97. self.timer = Timer()
  98. @self.app.task(shared=False)
  99. def foo_task(x, y, z):
  100. return x * y * z
  101. self.foo_task = foo_task
  102. def teardown(self):
  103. self.timer.stop()
  104. def test_info(self):
  105. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  106. l.task_consumer = Mock()
  107. l.qos = QoS(l.task_consumer.qos, 10)
  108. l.connection = Mock()
  109. l.connection.info.return_value = {'foo': 'bar'}
  110. l.controller = l.app.WorkController()
  111. l.pool = l.controller.pool = Mock()
  112. l.controller.pool.info.return_value = [Mock(), Mock()]
  113. l.controller.consumer = l
  114. info = l.controller.stats()
  115. self.assertEqual(info['prefetch_count'], 10)
  116. self.assertTrue(info['broker'])
  117. def test_start_when_closed(self):
  118. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  119. l.blueprint.state = CLOSE
  120. l.start()
  121. def test_connection(self):
  122. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  123. l.controller = l.app.WorkController()
  124. l.pool = l.controller.pool = Mock()
  125. l.blueprint.start(l)
  126. self.assertIsInstance(l.connection, Connection)
  127. l.blueprint.state = RUN
  128. l.event_dispatcher = None
  129. l.blueprint.restart(l)
  130. self.assertTrue(l.connection)
  131. l.blueprint.state = RUN
  132. l.shutdown()
  133. self.assertIsNone(l.connection)
  134. self.assertIsNone(l.task_consumer)
  135. l.blueprint.start(l)
  136. self.assertIsInstance(l.connection, Connection)
  137. l.blueprint.restart(l)
  138. l.stop()
  139. l.shutdown()
  140. self.assertIsNone(l.connection)
  141. self.assertIsNone(l.task_consumer)
  142. def test_close_connection(self):
  143. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  144. l.blueprint.state = RUN
  145. step = find_step(l, consumer.Connection)
  146. conn = l.connection = Mock()
  147. step.shutdown(l)
  148. self.assertTrue(conn.close.called)
  149. self.assertIsNone(l.connection)
  150. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  151. eventer = l.event_dispatcher = mock_event_dispatcher()
  152. eventer.enabled = True
  153. heart = l.heart = MockHeart()
  154. l.blueprint.state = RUN
  155. Events = find_step(l, consumer.Events)
  156. Events.shutdown(l)
  157. Heart = find_step(l, consumer.Heart)
  158. Heart.shutdown(l)
  159. self.assertTrue(eventer.close.call_count)
  160. self.assertTrue(heart.closed)
  161. @patch('celery.worker.consumer.warn')
  162. def test_receive_message_unknown(self, warn):
  163. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  164. l.blueprint.state = RUN
  165. l.steps.pop()
  166. channel = Mock()
  167. m = create_message(channel, unknown={'baz': '!!!'})
  168. l.event_dispatcher = mock_event_dispatcher()
  169. l.node = MockNode()
  170. callback = self._get_on_message(l)
  171. callback(m)
  172. self.assertTrue(warn.call_count)
  173. @patch('celery.worker.strategy.to_timestamp')
  174. def test_receive_message_eta_OverflowError(self, to_timestamp):
  175. to_timestamp.side_effect = OverflowError()
  176. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  177. l.controller = l.app.WorkController()
  178. l.pool = l.controller.pool = Mock()
  179. l.blueprint.state = RUN
  180. l.steps.pop()
  181. m = create_task_message(
  182. Mock(), self.foo_task.name,
  183. args=('2, 2'), kwargs={},
  184. eta=datetime.now().isoformat(),
  185. )
  186. l.event_dispatcher = mock_event_dispatcher()
  187. l.node = MockNode()
  188. l.update_strategies()
  189. l.qos = Mock()
  190. callback = self._get_on_message(l)
  191. callback(m)
  192. self.assertTrue(m.acknowledged)
  193. @patch('celery.worker.consumer.error')
  194. def test_receive_message_InvalidTaskError(self, error):
  195. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  196. l.blueprint.state = RUN
  197. l.event_dispatcher = mock_event_dispatcher()
  198. l.steps.pop()
  199. l.controller = l.app.WorkController()
  200. l.pool = l.controller.pool = Mock()
  201. m = create_task_message(
  202. Mock(), self.foo_task.name,
  203. args=(1, 2), kwargs='foobarbaz', id=1)
  204. l.update_strategies()
  205. l.event_dispatcher = mock_event_dispatcher()
  206. strat = l.strategies[self.foo_task.name] = Mock(name='strategy')
  207. strat.side_effect = InvalidTaskError()
  208. callback = self._get_on_message(l)
  209. callback(m)
  210. self.assertTrue(error.called)
  211. self.assertIn('Received invalid task message', error.call_args[0][0])
  212. @patch('celery.worker.consumer.crit')
  213. def test_on_decode_error(self, crit):
  214. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  215. class MockMessage(Mock):
  216. content_type = 'application/x-msgpack'
  217. content_encoding = 'binary'
  218. body = 'foobarbaz'
  219. message = MockMessage()
  220. l.on_decode_error(message, KeyError('foo'))
  221. self.assertTrue(message.ack.call_count)
  222. self.assertIn("Can't decode message body", crit.call_args[0][0])
  223. def _get_on_message(self, l):
  224. if l.qos is None:
  225. l.qos = Mock()
  226. l.event_dispatcher = mock_event_dispatcher()
  227. l.task_consumer = Mock()
  228. l.connection = Mock()
  229. l.connection.drain_events.side_effect = WorkerShutdown()
  230. with self.assertRaises(WorkerShutdown):
  231. l.loop(*l.loop_args())
  232. self.assertTrue(l.task_consumer.on_message)
  233. return l.task_consumer.on_message
  234. def test_receieve_message(self):
  235. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  236. l.controller = l.app.WorkController()
  237. l.pool = l.controller.pool = Mock()
  238. l.blueprint.state = RUN
  239. l.event_dispatcher = mock_event_dispatcher()
  240. m = create_task_message(
  241. Mock(), self.foo_task.name,
  242. args=[2, 4, 8], kwargs={},
  243. )
  244. l.update_strategies()
  245. callback = self._get_on_message(l)
  246. callback(m)
  247. in_bucket = self.buffer.get_nowait()
  248. self.assertIsInstance(in_bucket, Request)
  249. self.assertEqual(in_bucket.name, self.foo_task.name)
  250. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  251. self.assertTrue(self.timer.empty())
  252. def test_start_channel_error(self):
  253. class MockConsumer(Consumer):
  254. iterations = 0
  255. def loop(self, *args, **kwargs):
  256. if not self.iterations:
  257. self.iterations = 1
  258. raise KeyError('foo')
  259. raise SyntaxError('bar')
  260. l = MockConsumer(self.buffer.put, timer=self.timer,
  261. send_events=False, pool=BasePool(), app=self.app)
  262. l.controller = l.app.WorkController()
  263. l.pool = l.controller.pool = Mock()
  264. l.channel_errors = (KeyError,)
  265. with self.assertRaises(KeyError):
  266. l.start()
  267. l.timer.stop()
  268. def test_start_connection_error(self):
  269. class MockConsumer(Consumer):
  270. iterations = 0
  271. def loop(self, *args, **kwargs):
  272. if not self.iterations:
  273. self.iterations = 1
  274. raise KeyError('foo')
  275. raise SyntaxError('bar')
  276. l = MockConsumer(self.buffer.put, timer=self.timer,
  277. send_events=False, pool=BasePool(), app=self.app)
  278. l.controller = l.app.WorkController()
  279. l.pool = l.controller.pool = Mock()
  280. l.connection_errors = (KeyError,)
  281. self.assertRaises(SyntaxError, l.start)
  282. l.timer.stop()
  283. def test_loop_ignores_socket_timeout(self):
  284. class Connection(self.app.connection().__class__):
  285. obj = None
  286. def drain_events(self, **kwargs):
  287. self.obj.connection = None
  288. raise socket.timeout(10)
  289. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  290. l.connection = Connection()
  291. l.task_consumer = Mock()
  292. l.connection.obj = l
  293. l.qos = QoS(l.task_consumer.qos, 10)
  294. l.loop(*l.loop_args())
  295. def test_loop_when_socket_error(self):
  296. class Connection(self.app.connection().__class__):
  297. obj = None
  298. def drain_events(self, **kwargs):
  299. self.obj.connection = None
  300. raise socket.error('foo')
  301. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  302. l.blueprint.state = RUN
  303. c = l.connection = Connection()
  304. l.connection.obj = l
  305. l.task_consumer = Mock()
  306. l.qos = QoS(l.task_consumer.qos, 10)
  307. with self.assertRaises(socket.error):
  308. l.loop(*l.loop_args())
  309. l.blueprint.state = CLOSE
  310. l.connection = c
  311. l.loop(*l.loop_args())
  312. def test_loop(self):
  313. class Connection(self.app.connection().__class__):
  314. obj = None
  315. def drain_events(self, **kwargs):
  316. self.obj.connection = None
  317. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  318. l.blueprint.state = RUN
  319. l.connection = Connection()
  320. l.connection.obj = l
  321. l.task_consumer = Mock()
  322. l.qos = QoS(l.task_consumer.qos, 10)
  323. l.loop(*l.loop_args())
  324. l.loop(*l.loop_args())
  325. self.assertTrue(l.task_consumer.consume.call_count)
  326. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  327. self.assertEqual(l.qos.value, 10)
  328. l.qos.decrement_eventually()
  329. self.assertEqual(l.qos.value, 9)
  330. l.qos.update()
  331. self.assertEqual(l.qos.value, 9)
  332. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  333. def test_ignore_errors(self):
  334. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  335. l.connection_errors = (AttributeError, KeyError,)
  336. l.channel_errors = (SyntaxError,)
  337. ignore_errors(l, Mock(side_effect=AttributeError('foo')))
  338. ignore_errors(l, Mock(side_effect=KeyError('foo')))
  339. ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
  340. with self.assertRaises(IndexError):
  341. ignore_errors(l, Mock(side_effect=IndexError('foo')))
  342. def test_apply_eta_task(self):
  343. from celery.worker import state
  344. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  345. l.controller = l.app.WorkController()
  346. l.pool = l.controller.pool = Mock()
  347. l.qos = QoS(None, 10)
  348. task = object()
  349. qos = l.qos.value
  350. l.apply_eta_task(task)
  351. self.assertIn(task, state.reserved_requests)
  352. self.assertEqual(l.qos.value, qos - 1)
  353. self.assertIs(self.buffer.get_nowait(), task)
  354. def test_receieve_message_eta_isoformat(self):
  355. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  356. l.controller = l.app.WorkController()
  357. l.pool = l.controller.pool = Mock()
  358. l.blueprint.state = RUN
  359. l.steps.pop()
  360. m = create_task_message(
  361. Mock(), self.foo_task.name,
  362. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  363. args=[2, 4, 8], kwargs={},
  364. )
  365. l.task_consumer = Mock()
  366. l.qos = QoS(l.task_consumer.qos, 1)
  367. current_pcount = l.qos.value
  368. l.event_dispatcher = mock_event_dispatcher()
  369. l.enabled = False
  370. l.update_strategies()
  371. callback = self._get_on_message(l)
  372. callback(m)
  373. l.timer.stop()
  374. l.timer.join(1)
  375. items = [entry[2] for entry in self.timer.queue]
  376. found = 0
  377. for item in items:
  378. if item.args[0].name == self.foo_task.name:
  379. found = True
  380. self.assertTrue(found)
  381. self.assertGreater(l.qos.value, current_pcount)
  382. l.timer.stop()
  383. def test_pidbox_callback(self):
  384. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  385. con = find_step(l, consumer.Control).box
  386. con.node = Mock()
  387. con.reset = Mock()
  388. con.on_message('foo', 'bar')
  389. con.node.handle_message.assert_called_with('foo', 'bar')
  390. con.node = Mock()
  391. con.node.handle_message.side_effect = KeyError('foo')
  392. con.on_message('foo', 'bar')
  393. con.node.handle_message.assert_called_with('foo', 'bar')
  394. con.node = Mock()
  395. con.node.handle_message.side_effect = ValueError('foo')
  396. con.on_message('foo', 'bar')
  397. con.node.handle_message.assert_called_with('foo', 'bar')
  398. self.assertTrue(con.reset.called)
  399. def test_revoke(self):
  400. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  401. l.blueprint.state = RUN
  402. l.steps.pop()
  403. channel = Mock()
  404. id = uuid()
  405. t = create_task_message(
  406. channel, self.foo_task.name,
  407. args=[2, 4, 8], kwargs={}, id=id,
  408. )
  409. from celery.worker.state import revoked
  410. revoked.add(id)
  411. callback = self._get_on_message(l)
  412. callback(t)
  413. self.assertTrue(self.buffer.empty())
  414. def test_receieve_message_not_registered(self):
  415. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  416. l.blueprint.state = RUN
  417. l.steps.pop()
  418. channel = Mock(name='channel')
  419. m = create_task_message(
  420. channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
  421. )
  422. l.event_dispatcher = mock_event_dispatcher()
  423. callback = self._get_on_message(l)
  424. self.assertFalse(callback(m))
  425. with self.assertRaises(Empty):
  426. self.buffer.get_nowait()
  427. self.assertTrue(self.timer.empty())
  428. @patch('celery.worker.consumer.warn')
  429. @patch('celery.worker.consumer.logger')
  430. def test_receieve_message_ack_raises(self, logger, warn):
  431. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  432. l.controller = l.app.WorkController()
  433. l.pool = l.controller.pool = Mock()
  434. l.blueprint.state = RUN
  435. channel = Mock()
  436. m = create_task_message(
  437. channel, self.foo_task.name,
  438. args=[2, 4, 8], kwargs={},
  439. )
  440. m.headers = None
  441. l.event_dispatcher = mock_event_dispatcher()
  442. l.update_strategies()
  443. l.connection_errors = (socket.error,)
  444. m.reject = Mock()
  445. m.reject.side_effect = socket.error('foo')
  446. callback = self._get_on_message(l)
  447. self.assertFalse(callback(m))
  448. self.assertTrue(warn.call_count)
  449. with self.assertRaises(Empty):
  450. self.buffer.get_nowait()
  451. self.assertTrue(self.timer.empty())
  452. m.reject_log_error.assert_called_with(logger, l.connection_errors)
  453. def test_receive_message_eta(self):
  454. import sys
  455. from functools import partial
  456. if os.environ.get('C_DEBUG_TEST'):
  457. pp = partial(print, file=sys.__stderr__)
  458. else:
  459. def pp(*args, **kwargs):
  460. pass
  461. pp('TEST RECEIVE MESSAGE ETA')
  462. pp('+CREATE MYKOMBUCONSUMER')
  463. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  464. l.controller = l.app.WorkController()
  465. l.pool = l.controller.pool = Mock()
  466. pp('-CREATE MYKOMBUCONSUMER')
  467. l.steps.pop()
  468. l.event_dispatcher = mock_event_dispatcher()
  469. channel = Mock(name='channel')
  470. pp('+ CREATE MESSAGE')
  471. m = create_task_message(
  472. channel, self.foo_task.name,
  473. args=[2, 4, 8], kwargs={},
  474. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  475. )
  476. pp('- CREATE MESSAGE')
  477. try:
  478. pp('+ BLUEPRINT START 1')
  479. l.blueprint.start(l)
  480. pp('- BLUEPRINT START 1')
  481. p = l.app.conf.broker_connection_retry
  482. l.app.conf.broker_connection_retry = False
  483. pp('+ BLUEPRINT START 2')
  484. l.blueprint.start(l)
  485. pp('- BLUEPRINT START 2')
  486. l.app.conf.broker_connection_retry = p
  487. pp('+ BLUEPRINT RESTART')
  488. l.blueprint.restart(l)
  489. pp('- BLUEPRINT RESTART')
  490. l.event_dispatcher = mock_event_dispatcher()
  491. pp('+ GET ON MESSAGE')
  492. callback = self._get_on_message(l)
  493. pp('- GET ON MESSAGE')
  494. pp('+ CALLBACK')
  495. callback(m)
  496. pp('- CALLBACK')
  497. finally:
  498. pp('+ STOP TIMER')
  499. l.timer.stop()
  500. pp('- STOP TIMER')
  501. try:
  502. pp('+ JOIN TIMER')
  503. l.timer.join()
  504. pp('- JOIN TIMER')
  505. except RuntimeError:
  506. pass
  507. in_hold = l.timer.queue[0]
  508. self.assertEqual(len(in_hold), 3)
  509. eta, priority, entry = in_hold
  510. task = entry.args[0]
  511. self.assertIsInstance(task, Request)
  512. self.assertEqual(task.name, self.foo_task.name)
  513. self.assertEqual(task.execute(), 2 * 4 * 8)
  514. with self.assertRaises(Empty):
  515. self.buffer.get_nowait()
  516. def test_reset_pidbox_node(self):
  517. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  518. con = find_step(l, consumer.Control).box
  519. con.node = Mock()
  520. chan = con.node.channel = Mock()
  521. l.connection = Mock()
  522. chan.close.side_effect = socket.error('foo')
  523. l.connection_errors = (socket.error,)
  524. con.reset()
  525. chan.close.assert_called_with()
  526. def test_reset_pidbox_node_green(self):
  527. from celery.worker.pidbox import gPidbox
  528. pool = Mock()
  529. pool.is_green = True
  530. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  531. app=self.app)
  532. con = find_step(l, consumer.Control)
  533. self.assertIsInstance(con.box, gPidbox)
  534. con.start(l)
  535. l.pool.spawn_n.assert_called_with(
  536. con.box.loop, l,
  537. )
  538. def test__green_pidbox_node(self):
  539. pool = Mock()
  540. pool.is_green = True
  541. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  542. app=self.app)
  543. l.node = Mock()
  544. controller = find_step(l, consumer.Control)
  545. class BConsumer(Mock):
  546. def __enter__(self):
  547. self.consume()
  548. return self
  549. def __exit__(self, *exc_info):
  550. self.cancel()
  551. controller.box.node.listen = BConsumer()
  552. connections = []
  553. class Connection(object):
  554. calls = 0
  555. def __init__(self, obj):
  556. connections.append(self)
  557. self.obj = obj
  558. self.default_channel = self.channel()
  559. self.closed = False
  560. def __enter__(self):
  561. return self
  562. def __exit__(self, *exc_info):
  563. self.close()
  564. def channel(self):
  565. return Mock()
  566. def as_uri(self):
  567. return 'dummy://'
  568. def drain_events(self, **kwargs):
  569. if not self.calls:
  570. self.calls += 1
  571. raise socket.timeout()
  572. self.obj.connection = None
  573. controller.box._node_shutdown.set()
  574. def close(self):
  575. self.closed = True
  576. l.connection = Mock()
  577. l.connect = lambda: Connection(obj=l)
  578. controller = find_step(l, consumer.Control)
  579. controller.box.loop(l)
  580. self.assertTrue(controller.box.node.listen.called)
  581. self.assertTrue(controller.box.consumer)
  582. controller.box.consumer.consume.assert_called_with()
  583. self.assertIsNone(l.connection)
  584. self.assertTrue(connections[0].closed)
  585. @patch('kombu.connection.Connection._establish_connection')
  586. @patch('kombu.utils.sleep')
  587. def test_connect_errback(self, sleep, connect):
  588. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  589. from kombu.transport.memory import Transport
  590. Transport.connection_errors = (ChannelError,)
  591. def effect():
  592. if connect.call_count > 1:
  593. return
  594. raise ChannelError('error')
  595. connect.side_effect = effect
  596. l.connect()
  597. connect.assert_called_with()
  598. def test_stop_pidbox_node(self):
  599. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  600. cont = find_step(l, consumer.Control)
  601. cont._node_stopped = Event()
  602. cont._node_shutdown = Event()
  603. cont._node_stopped.set()
  604. cont.stop(l)
  605. def test_start__loop(self):
  606. class _QoS(object):
  607. prev = 3
  608. value = 4
  609. def update(self):
  610. self.prev = self.value
  611. class _Consumer(MyKombuConsumer):
  612. iterations = 0
  613. def reset_connection(self):
  614. if self.iterations >= 1:
  615. raise KeyError('foo')
  616. init_callback = Mock()
  617. l = _Consumer(self.buffer.put, timer=self.timer,
  618. init_callback=init_callback, app=self.app)
  619. l.controller = l.app.WorkController()
  620. l.pool = l.controller.pool = Mock()
  621. l.task_consumer = Mock()
  622. l.broadcast_consumer = Mock()
  623. l.qos = _QoS()
  624. l.connection = Connection()
  625. l.iterations = 0
  626. def raises_KeyError(*args, **kwargs):
  627. l.iterations += 1
  628. if l.qos.prev != l.qos.value:
  629. l.qos.update()
  630. if l.iterations >= 2:
  631. raise KeyError('foo')
  632. l.loop = raises_KeyError
  633. with self.assertRaises(KeyError):
  634. l.start()
  635. self.assertEqual(l.iterations, 2)
  636. self.assertEqual(l.qos.prev, l.qos.value)
  637. init_callback.reset_mock()
  638. l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
  639. send_events=False, init_callback=init_callback)
  640. l.controller = l.app.WorkController()
  641. l.pool = l.controller.pool = Mock()
  642. l.qos = _QoS()
  643. l.task_consumer = Mock()
  644. l.broadcast_consumer = Mock()
  645. l.connection = Connection()
  646. l.loop = Mock(side_effect=socket.error('foo'))
  647. with self.assertRaises(socket.error):
  648. l.start()
  649. self.assertTrue(l.loop.call_count)
  650. def test_reset_connection_with_no_node(self):
  651. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  652. l.controller = l.app.WorkController()
  653. l.pool = l.controller.pool = Mock()
  654. l.steps.pop()
  655. l.blueprint.start(l)
  656. class test_WorkController(AppCase):
  657. def setup(self):
  658. self.worker = self.create_worker()
  659. from celery import worker
  660. self._logger = worker.logger
  661. self._comp_logger = components.logger
  662. self.logger = worker.logger = Mock()
  663. self.comp_logger = components.logger = Mock()
  664. @self.app.task(shared=False)
  665. def foo_task(x, y, z):
  666. return x * y * z
  667. self.foo_task = foo_task
  668. def teardown(self):
  669. from celery import worker
  670. worker.logger = self._logger
  671. components.logger = self._comp_logger
  672. def create_worker(self, **kw):
  673. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  674. worker.blueprint.shutdown_complete.set()
  675. return worker
  676. def test_on_consumer_ready(self):
  677. self.worker.on_consumer_ready(Mock())
  678. def test_setup_queues_worker_direct(self):
  679. self.app.conf.worker_direct = True
  680. self.app.amqp.__dict__['queues'] = Mock()
  681. self.worker.setup_queues({})
  682. self.app.amqp.queues.select_add.assert_called_with(
  683. worker_direct(self.worker.hostname),
  684. )
  685. def test_setup_queues__missing_queue(self):
  686. self.app.amqp.queues.select = Mock(name='select')
  687. self.app.amqp.queues.deselect = Mock(name='deselect')
  688. self.app.amqp.queues.select.side_effect = KeyError()
  689. self.app.amqp.queues.deselect.side_effect = KeyError()
  690. with self.assertRaises(ImproperlyConfigured):
  691. self.worker.setup_queues("x,y", exclude="foo,bar")
  692. self.app.amqp.queues.select = Mock(name='select')
  693. with self.assertRaises(ImproperlyConfigured):
  694. self.worker.setup_queues("x,y", exclude="foo,bar")
  695. def test_send_worker_shutdown(self):
  696. with patch('celery.signals.worker_shutdown') as ws:
  697. self.worker._send_worker_shutdown()
  698. ws.send.assert_called_with(sender=self.worker)
  699. def test_process_shutdown_on_worker_shutdown(self):
  700. raise SkipTest('unstable test')
  701. from celery.concurrency.prefork import process_destructor
  702. from celery.concurrency.asynpool import Worker
  703. with patch('celery.signals.worker_process_shutdown') as ws:
  704. Worker._make_shortcuts = Mock()
  705. with patch('os._exit') as _exit:
  706. worker = Worker(None, None, on_exit=process_destructor)
  707. worker._do_exit(22, 3.1415926)
  708. ws.send.assert_called_with(
  709. sender=None, pid=22, exitcode=3.1415926,
  710. )
  711. _exit.assert_called_with(3.1415926)
  712. def test_process_task_revoked_release_semaphore(self):
  713. self.worker._quick_release = Mock()
  714. req = Mock()
  715. req.execute_using_pool.side_effect = TaskRevokedError
  716. self.worker._process_task(req)
  717. self.worker._quick_release.assert_called_with()
  718. delattr(self.worker, '_quick_release')
  719. self.worker._process_task(req)
  720. def test_shutdown_no_blueprint(self):
  721. self.worker.blueprint = None
  722. self.worker._shutdown()
  723. @patch('celery.worker.create_pidlock')
  724. def test_use_pidfile(self, create_pidlock):
  725. create_pidlock.return_value = Mock()
  726. worker = self.create_worker(pidfile='pidfilelockfilepid')
  727. worker.steps = []
  728. worker.start()
  729. self.assertTrue(create_pidlock.called)
  730. worker.stop()
  731. self.assertTrue(worker.pidlock.release.called)
  732. def test_attrs(self):
  733. worker = self.worker
  734. self.assertIsNotNone(worker.timer)
  735. self.assertIsInstance(worker.timer, Timer)
  736. self.assertIsNotNone(worker.pool)
  737. self.assertIsNotNone(worker.consumer)
  738. self.assertTrue(worker.steps)
  739. def test_with_embedded_beat(self):
  740. worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
  741. self.assertTrue(worker.beat)
  742. self.assertIn(worker.beat, [w.obj for w in worker.steps])
  743. def test_with_autoscaler(self):
  744. worker = self.create_worker(
  745. autoscale=[10, 3], send_events=False,
  746. timer_cls='celery.utils.timer2.Timer',
  747. )
  748. self.assertTrue(worker.autoscaler)
  749. def test_dont_stop_or_terminate(self):
  750. worker = self.app.WorkController(concurrency=1, loglevel=0)
  751. worker.stop()
  752. self.assertNotEqual(worker.blueprint.state, CLOSE)
  753. worker.terminate()
  754. self.assertNotEqual(worker.blueprint.state, CLOSE)
  755. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  756. try:
  757. worker.blueprint.state = RUN
  758. worker.stop(in_sighandler=True)
  759. self.assertNotEqual(worker.blueprint.state, CLOSE)
  760. worker.terminate(in_sighandler=True)
  761. self.assertNotEqual(worker.blueprint.state, CLOSE)
  762. finally:
  763. worker.pool.signal_safe = sigsafe
  764. def test_on_timer_error(self):
  765. worker = self.app.WorkController(concurrency=1, loglevel=0)
  766. try:
  767. raise KeyError('foo')
  768. except KeyError as exc:
  769. components.Timer(worker).on_timer_error(exc)
  770. msg, args = self.comp_logger.error.call_args[0]
  771. self.assertIn('KeyError', msg % args)
  772. def test_on_timer_tick(self):
  773. worker = self.app.WorkController(concurrency=1, loglevel=10)
  774. components.Timer(worker).on_timer_tick(30.0)
  775. xargs = self.comp_logger.debug.call_args[0]
  776. fmt, arg = xargs[0], xargs[1]
  777. self.assertEqual(30.0, arg)
  778. self.assertIn('Next eta %s secs', fmt)
  779. def test_process_task(self):
  780. worker = self.worker
  781. worker.pool = Mock()
  782. channel = Mock()
  783. m = create_task_message(
  784. channel, self.foo_task.name,
  785. args=[4, 8, 10], kwargs={},
  786. )
  787. task = Request(m, app=self.app)
  788. worker._process_task(task)
  789. self.assertEqual(worker.pool.apply_async.call_count, 1)
  790. worker.pool.stop()
  791. def test_process_task_raise_base(self):
  792. worker = self.worker
  793. worker.pool = Mock()
  794. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  795. channel = Mock()
  796. m = create_task_message(
  797. channel, self.foo_task.name,
  798. args=[4, 8, 10], kwargs={},
  799. )
  800. task = Request(m, app=self.app)
  801. worker.steps = []
  802. worker.blueprint.state = RUN
  803. with self.assertRaises(KeyboardInterrupt):
  804. worker._process_task(task)
  805. def test_process_task_raise_WorkerTerminate(self):
  806. worker = self.worker
  807. worker.pool = Mock()
  808. worker.pool.apply_async.side_effect = WorkerTerminate()
  809. channel = Mock()
  810. m = create_task_message(
  811. channel, self.foo_task.name,
  812. args=[4, 8, 10], kwargs={},
  813. )
  814. task = Request(m, app=self.app)
  815. worker.steps = []
  816. worker.blueprint.state = RUN
  817. with self.assertRaises(SystemExit):
  818. worker._process_task(task)
  819. def test_process_task_raise_regular(self):
  820. worker = self.worker
  821. worker.pool = Mock()
  822. worker.pool.apply_async.side_effect = KeyError('some exception')
  823. channel = Mock()
  824. m = create_task_message(
  825. channel, self.foo_task.name,
  826. args=[4, 8, 10], kwargs={},
  827. )
  828. task = Request(m, app=self.app)
  829. worker._process_task(task)
  830. worker.pool.stop()
  831. def test_start_catches_base_exceptions(self):
  832. worker1 = self.create_worker()
  833. worker1.blueprint.state = RUN
  834. stc = MockStep()
  835. stc.start.side_effect = WorkerTerminate()
  836. worker1.steps = [stc]
  837. worker1.start()
  838. stc.start.assert_called_with(worker1)
  839. self.assertTrue(stc.terminate.call_count)
  840. worker2 = self.create_worker()
  841. worker2.blueprint.state = RUN
  842. sec = MockStep()
  843. sec.start.side_effect = WorkerShutdown()
  844. sec.terminate = None
  845. worker2.steps = [sec]
  846. worker2.start()
  847. self.assertTrue(sec.stop.call_count)
  848. def test_state_db(self):
  849. from celery.worker import state
  850. Persistent = state.Persistent
  851. state.Persistent = Mock()
  852. try:
  853. worker = self.create_worker(state_db='statefilename')
  854. self.assertTrue(worker._persistence)
  855. finally:
  856. state.Persistent = Persistent
  857. def test_process_task_sem(self):
  858. worker = self.worker
  859. worker._quick_acquire = Mock()
  860. req = Mock()
  861. worker._process_task_sem(req)
  862. worker._quick_acquire.assert_called_with(worker._process_task, req)
  863. def test_signal_consumer_close(self):
  864. worker = self.worker
  865. worker.consumer = Mock()
  866. worker.signal_consumer_close()
  867. worker.consumer.close.assert_called_with()
  868. worker.consumer.close.side_effect = AttributeError()
  869. worker.signal_consumer_close()
  870. def test_rusage__no_resource(self):
  871. from celery import worker
  872. prev, worker.resource = worker.resource, None
  873. try:
  874. self.worker.pool = Mock(name='pool')
  875. with self.assertRaises(NotImplementedError):
  876. self.worker.rusage()
  877. self.worker.stats()
  878. finally:
  879. worker.resource = prev
  880. def test_repr(self):
  881. self.assertTrue(repr(self.worker))
  882. def test_str(self):
  883. self.assertEqual(str(self.worker), self.worker.hostname)
  884. def test_start__stop(self):
  885. worker = self.worker
  886. worker.blueprint.shutdown_complete.set()
  887. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  888. worker.blueprint.state = RUN
  889. worker.blueprint.started = 4
  890. for w in worker.steps:
  891. w.start = Mock()
  892. w.close = Mock()
  893. w.stop = Mock()
  894. worker.start()
  895. for w in worker.steps:
  896. self.assertTrue(w.start.call_count)
  897. worker.consumer = Mock()
  898. worker.stop(exitcode=3)
  899. for stopstep in worker.steps:
  900. self.assertTrue(stopstep.close.call_count)
  901. self.assertTrue(stopstep.stop.call_count)
  902. # Doesn't close pool if no pool.
  903. worker.start()
  904. worker.pool = None
  905. worker.stop()
  906. # test that stop of None is not attempted
  907. worker.steps[-1] = None
  908. worker.start()
  909. worker.stop()
  910. def test_start__KeyboardInterrupt(self):
  911. worker = self.worker
  912. worker.blueprint = Mock(name='blueprint')
  913. worker.blueprint.start.side_effect = KeyboardInterrupt()
  914. worker.stop = Mock(name='stop')
  915. worker.start()
  916. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  917. def test_register_with_event_loop(self):
  918. worker = self.worker
  919. hub = Mock(name='hub')
  920. worker.blueprint = Mock(name='blueprint')
  921. worker.register_with_event_loop(hub)
  922. worker.blueprint.send_all.assert_called_with(
  923. worker, 'register_with_event_loop', args=(hub,),
  924. description='hub.register',
  925. )
  926. def test_step_raises(self):
  927. worker = self.worker
  928. step = Mock()
  929. worker.steps = [step]
  930. step.start.side_effect = TypeError()
  931. worker.stop = Mock()
  932. worker.start()
  933. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  934. def test_state(self):
  935. self.assertTrue(self.worker.state)
  936. def test_start__terminate(self):
  937. worker = self.worker
  938. worker.blueprint.shutdown_complete.set()
  939. worker.blueprint.started = 5
  940. worker.blueprint.state = RUN
  941. worker.steps = [MockStep() for _ in range(5)]
  942. worker.start()
  943. for w in worker.steps[:3]:
  944. self.assertTrue(w.start.call_count)
  945. self.assertTrue(worker.blueprint.started, len(worker.steps))
  946. self.assertEqual(worker.blueprint.state, RUN)
  947. worker.terminate()
  948. for step in worker.steps:
  949. self.assertTrue(step.terminate.call_count)
  950. worker.blueprint.state = TERMINATE
  951. worker.terminate()
  952. def test_Hub_crate(self):
  953. w = Mock()
  954. x = components.Hub(w)
  955. x.create(w)
  956. self.assertTrue(w.timer.max_interval)
  957. def test_Pool_crate_threaded(self):
  958. w = Mock()
  959. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  960. w.pool_cls = Mock()
  961. w.use_eventloop = False
  962. pool = components.Pool(w)
  963. pool.create(w)
  964. def test_Pool_pool_no_sem(self):
  965. w = Mock()
  966. w.pool_cls.uses_semaphore = False
  967. components.Pool(w).create(w)
  968. self.assertIs(w.process_task, w._process_task)
  969. def test_Pool_create(self):
  970. from kombu.async.semaphore import LaxBoundedSemaphore
  971. w = Mock()
  972. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  973. w.hub = Mock()
  974. PoolImp = Mock()
  975. poolimp = PoolImp.return_value = Mock()
  976. poolimp._pool = [Mock(), Mock()]
  977. poolimp._cache = {}
  978. poolimp._fileno_to_inq = {}
  979. poolimp._fileno_to_outq = {}
  980. from celery.concurrency.prefork import TaskPool as _TaskPool
  981. class MockTaskPool(_TaskPool):
  982. Pool = PoolImp
  983. @property
  984. def timers(self):
  985. return {Mock(): 30}
  986. w.pool_cls = MockTaskPool
  987. w.use_eventloop = True
  988. w.consumer.restart_count = -1
  989. pool = components.Pool(w)
  990. pool.create(w)
  991. pool.register_with_event_loop(w, w.hub)
  992. if sys.platform != 'win32':
  993. self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
  994. P = w.pool
  995. P.start()