test_worker.py 37 KB


  1. from __future__ import absolute_import
  2. import os
  3. import socket
  4. from collections import deque
  5. from datetime import datetime, timedelta
  6. from threading import Event
  7. from amqp import ChannelError
  8. from kombu import Connection
  9. from kombu.common import QoS, ignore_errors
  10. from kombu.transport.base import Message
  11. from celery.app.defaults import DEFAULTS
  12. from celery.bootsteps import RUN, CLOSE, StartStopStep
  13. from celery.concurrency.base import BasePool
  14. from celery.datastructures import AttributeDict
  15. from celery.exceptions import (
  16. WorkerShutdown, WorkerTerminate, TaskRevokedError,
  17. )
  18. from celery.five import Empty, range, Queue as FastQueue
  19. from celery.utils import uuid
  20. from celery.worker import components
  21. from celery.worker import consumer
  22. from celery.worker.consumer import Consumer as __Consumer
  23. from celery.worker.job import Request
  24. from celery.utils import worker_direct
  25. from celery.utils.serialization import pickle
  26. from celery.utils.timer2 import Timer
  27. from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
  28. def MockStep(step=None):
  29. step = Mock() if step is None else step
  30. step.blueprint = Mock()
  31. step.blueprint.name = 'MockNS'
  32. step.name = 'MockStep(%s)' % (id(step), )
  33. return step
  34. def mock_event_dispatcher():
  35. evd = Mock(name='event_dispatcher')
  36. evd.groups = ['worker']
  37. evd._outbound_buffer = deque()
  38. return evd
  39. class PlaceHolder(object):
  40. pass
  41. def find_step(obj, typ):
  42. return obj.blueprint.steps[typ.name]
  43. class Consumer(__Consumer):
  44. def __init__(self, *args, **kwargs):
  45. kwargs.setdefault('without_mingle', True) # disable Mingle step
  46. kwargs.setdefault('without_gossip', True) # disable Gossip step
  47. kwargs.setdefault('without_heartbeat', True) # disable Heart step
  48. super(Consumer, self).__init__(*args, **kwargs)
  49. class _MyKombuConsumer(Consumer):
  50. broadcast_consumer = Mock()
  51. task_consumer = Mock()
  52. def __init__(self, *args, **kwargs):
  53. kwargs.setdefault('pool', BasePool(2))
  54. super(_MyKombuConsumer, self).__init__(*args, **kwargs)
  55. def restart_heartbeat(self):
  56. self.heart = None
  57. class MyKombuConsumer(Consumer):
  58. def loop(self, *args, **kwargs):
  59. pass
  60. class MockNode(object):
  61. commands = []
  62. def handle_message(self, body, message):
  63. self.commands.append(body.pop('command', None))
  64. class MockEventDispatcher(object):
  65. sent = []
  66. closed = False
  67. flushed = False
  68. _outbound_buffer = []
  69. def send(self, event, *args, **kwargs):
  70. self.sent.append(event)
  71. def close(self):
  72. self.closed = True
  73. def flush(self):
  74. self.flushed = True
  75. class MockHeart(object):
  76. closed = False
  77. def stop(self):
  78. self.closed = True
  79. def create_message(channel, **data):
  80. data.setdefault('id', uuid())
  81. channel.no_ack_consumers = set()
  82. m = Message(channel, body=pickle.dumps(dict(**data)),
  83. content_type='application/x-python-serialize',
  84. content_encoding='binary',
  85. delivery_info={'consumer_tag': 'mock'})
  86. m.accept = ['application/x-python-serialize']
  87. return m
  88. class test_Consumer(AppCase):
  89. def setup(self):
  90. self.buffer = FastQueue()
  91. self.timer = Timer()
  92. @self.app.task(shared=False)
  93. def foo_task(x, y, z):
  94. return x * y * z
  95. self.foo_task = foo_task
  96. def teardown(self):
  97. self.timer.stop()
  98. def test_info(self):
  99. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  100. l.task_consumer = Mock()
  101. l.qos = QoS(l.task_consumer.qos, 10)
  102. l.connection = Mock()
  103. l.connection.info.return_value = {'foo': 'bar'}
  104. l.controller = l.app.WorkController()
  105. l.controller.pool = Mock()
  106. l.controller.pool.info.return_value = [Mock(), Mock()]
  107. l.controller.consumer = l
  108. info = l.controller.stats()
  109. self.assertEqual(info['prefetch_count'], 10)
  110. self.assertTrue(info['broker'])
  111. def test_start_when_closed(self):
  112. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  113. l.blueprint.state = CLOSE
  114. l.start()
  115. def test_connection(self):
  116. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  117. l.blueprint.start(l)
  118. self.assertIsInstance(l.connection, Connection)
  119. l.blueprint.state = RUN
  120. l.event_dispatcher = None
  121. l.blueprint.restart(l)
  122. self.assertTrue(l.connection)
  123. l.blueprint.state = RUN
  124. l.shutdown()
  125. self.assertIsNone(l.connection)
  126. self.assertIsNone(l.task_consumer)
  127. l.blueprint.start(l)
  128. self.assertIsInstance(l.connection, Connection)
  129. l.blueprint.restart(l)
  130. l.stop()
  131. l.shutdown()
  132. self.assertIsNone(l.connection)
  133. self.assertIsNone(l.task_consumer)
  134. def test_close_connection(self):
  135. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  136. l.blueprint.state = RUN
  137. step = find_step(l, consumer.Connection)
  138. conn = l.connection = Mock()
  139. step.shutdown(l)
  140. self.assertTrue(conn.close.called)
  141. self.assertIsNone(l.connection)
  142. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  143. eventer = l.event_dispatcher = mock_event_dispatcher()
  144. eventer.enabled = True
  145. heart = l.heart = MockHeart()
  146. l.blueprint.state = RUN
  147. Events = find_step(l, consumer.Events)
  148. Events.shutdown(l)
  149. Heart = find_step(l, consumer.Heart)
  150. Heart.shutdown(l)
  151. self.assertTrue(eventer.close.call_count)
  152. self.assertTrue(heart.closed)
  153. @patch('celery.worker.consumer.warn')
  154. def test_receive_message_unknown(self, warn):
  155. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  156. l.blueprint.state = RUN
  157. l.steps.pop()
  158. backend = Mock()
  159. m = create_message(backend, unknown={'baz': '!!!'})
  160. l.event_dispatcher = mock_event_dispatcher()
  161. l.node = MockNode()
  162. callback = self._get_on_message(l)
  163. callback(m.decode(), m)
  164. self.assertTrue(warn.call_count)
  165. @patch('celery.worker.strategy.to_timestamp')
  166. def test_receive_message_eta_OverflowError(self, to_timestamp):
  167. to_timestamp.side_effect = OverflowError()
  168. print('+ CREATE _MyKombuConsumer')
  169. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  170. print('- CREATE _myKombuConsumer')
  171. l.blueprint.state = RUN
  172. l.steps.pop()
  173. print('+ CREATE MESSAGE')
  174. m = create_message(Mock(), task=self.foo_task.name,
  175. args=('2, 2'),
  176. kwargs={},
  177. eta=datetime.now().isoformat())
  178. print('- CREATE MESSAGE')
  179. l.event_dispatcher = mock_event_dispatcher()
  180. l.node = MockNode()
  181. print('+ UPDATE STRATEGIES')
  182. l.update_strategies()
  183. print('- UPDATE STRATEGIES')
  184. l.qos = Mock()
  185. print('+ GET ON MESSAGE')
  186. callback = self._get_on_message(l)
  187. print('- GET ON MESSAGE')
  188. print('+ CALLBACK & m.decode()')
  189. callback(m.decode(), m)
  190. print('- CALLBACK & m.decode()')
  191. self.assertTrue(m.acknowledged)
  192. @patch('celery.worker.consumer.error')
  193. def test_receive_message_InvalidTaskError(self, error):
  194. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  195. l.blueprint.state = RUN
  196. l.event_dispatcher = mock_event_dispatcher()
  197. l.steps.pop()
  198. m = create_message(Mock(), task=self.foo_task.name,
  199. args=(1, 2), kwargs='foobarbaz', id=1)
  200. l.update_strategies()
  201. l.event_dispatcher = mock_event_dispatcher()
  202. callback = self._get_on_message(l)
  203. callback(m.decode(), m)
  204. self.assertIn('Received invalid task message', error.call_args[0][0])
  205. @patch('celery.worker.consumer.crit')
  206. def test_on_decode_error(self, crit):
  207. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  208. class MockMessage(Mock):
  209. content_type = 'application/x-msgpack'
  210. content_encoding = 'binary'
  211. body = 'foobarbaz'
  212. message = MockMessage()
  213. l.on_decode_error(message, KeyError('foo'))
  214. self.assertTrue(message.ack.call_count)
  215. self.assertIn("Can't decode message body", crit.call_args[0][0])
  216. def _get_on_message(self, l):
  217. if l.qos is None:
  218. l.qos = Mock()
  219. l.event_dispatcher = mock_event_dispatcher()
  220. l.task_consumer = Mock()
  221. l.connection = Mock()
  222. l.connection.drain_events.side_effect = WorkerShutdown()
  223. with self.assertRaises(WorkerShutdown):
  224. l.loop(*l.loop_args())
  225. self.assertTrue(l.task_consumer.register_callback.called)
  226. return l.task_consumer.register_callback.call_args[0][0]
  227. def test_receieve_message(self):
  228. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  229. l.blueprint.state = RUN
  230. l.event_dispatcher = mock_event_dispatcher()
  231. m = create_message(Mock(), task=self.foo_task.name,
  232. args=[2, 4, 8], kwargs={})
  233. l.update_strategies()
  234. callback = self._get_on_message(l)
  235. callback(m.decode(), m)
  236. in_bucket = self.buffer.get_nowait()
  237. self.assertIsInstance(in_bucket, Request)
  238. self.assertEqual(in_bucket.name, self.foo_task.name)
  239. self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
  240. self.assertTrue(self.timer.empty())
  241. def test_start_channel_error(self):
  242. class MockConsumer(Consumer):
  243. iterations = 0
  244. def loop(self, *args, **kwargs):
  245. if not self.iterations:
  246. self.iterations = 1
  247. raise KeyError('foo')
  248. raise SyntaxError('bar')
  249. l = MockConsumer(self.buffer.put, timer=self.timer,
  250. send_events=False, pool=BasePool(), app=self.app)
  251. l.channel_errors = (KeyError, )
  252. with self.assertRaises(KeyError):
  253. l.start()
  254. l.timer.stop()
  255. def test_start_connection_error(self):
  256. class MockConsumer(Consumer):
  257. iterations = 0
  258. def loop(self, *args, **kwargs):
  259. if not self.iterations:
  260. self.iterations = 1
  261. raise KeyError('foo')
  262. raise SyntaxError('bar')
  263. l = MockConsumer(self.buffer.put, timer=self.timer,
  264. send_events=False, pool=BasePool(), app=self.app)
  265. l.connection_errors = (KeyError, )
  266. self.assertRaises(SyntaxError, l.start)
  267. l.timer.stop()
  268. def test_loop_ignores_socket_timeout(self):
  269. class Connection(self.app.connection().__class__):
  270. obj = None
  271. def drain_events(self, **kwargs):
  272. self.obj.connection = None
  273. raise socket.timeout(10)
  274. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  275. l.connection = Connection()
  276. l.task_consumer = Mock()
  277. l.connection.obj = l
  278. l.qos = QoS(l.task_consumer.qos, 10)
  279. l.loop(*l.loop_args())
  280. def test_loop_when_socket_error(self):
  281. class Connection(self.app.connection().__class__):
  282. obj = None
  283. def drain_events(self, **kwargs):
  284. self.obj.connection = None
  285. raise socket.error('foo')
  286. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  287. l.blueprint.state = RUN
  288. c = l.connection = Connection()
  289. l.connection.obj = l
  290. l.task_consumer = Mock()
  291. l.qos = QoS(l.task_consumer.qos, 10)
  292. with self.assertRaises(socket.error):
  293. l.loop(*l.loop_args())
  294. l.blueprint.state = CLOSE
  295. l.connection = c
  296. l.loop(*l.loop_args())
  297. def test_loop(self):
  298. class Connection(self.app.connection().__class__):
  299. obj = None
  300. def drain_events(self, **kwargs):
  301. self.obj.connection = None
  302. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  303. l.blueprint.state = RUN
  304. l.connection = Connection()
  305. l.connection.obj = l
  306. l.task_consumer = Mock()
  307. l.qos = QoS(l.task_consumer.qos, 10)
  308. l.loop(*l.loop_args())
  309. l.loop(*l.loop_args())
  310. self.assertTrue(l.task_consumer.consume.call_count)
  311. l.task_consumer.qos.assert_called_with(prefetch_count=10)
  312. self.assertEqual(l.qos.value, 10)
  313. l.qos.decrement_eventually()
  314. self.assertEqual(l.qos.value, 9)
  315. l.qos.update()
  316. self.assertEqual(l.qos.value, 9)
  317. l.task_consumer.qos.assert_called_with(prefetch_count=9)
  318. def test_ignore_errors(self):
  319. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  320. l.connection_errors = (AttributeError, KeyError, )
  321. l.channel_errors = (SyntaxError, )
  322. ignore_errors(l, Mock(side_effect=AttributeError('foo')))
  323. ignore_errors(l, Mock(side_effect=KeyError('foo')))
  324. ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
  325. with self.assertRaises(IndexError):
  326. ignore_errors(l, Mock(side_effect=IndexError('foo')))
  327. def test_apply_eta_task(self):
  328. from celery.worker import state
  329. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  330. l.qos = QoS(None, 10)
  331. task = object()
  332. qos = l.qos.value
  333. l.apply_eta_task(task)
  334. self.assertIn(task, state.reserved_requests)
  335. self.assertEqual(l.qos.value, qos - 1)
  336. self.assertIs(self.buffer.get_nowait(), task)
  337. def test_receieve_message_eta_isoformat(self):
  338. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  339. l.blueprint.state = RUN
  340. l.steps.pop()
  341. m = create_message(
  342. Mock(), task=self.foo_task.name,
  343. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  344. args=[2, 4, 8], kwargs={},
  345. )
  346. l.task_consumer = Mock()
  347. l.qos = QoS(l.task_consumer.qos, 1)
  348. current_pcount = l.qos.value
  349. l.event_dispatcher = mock_event_dispatcher()
  350. l.enabled = False
  351. l.update_strategies()
  352. callback = self._get_on_message(l)
  353. callback(m.decode(), m)
  354. l.timer.stop()
  355. l.timer.join(1)
  356. items = [entry[2] for entry in self.timer.queue]
  357. found = 0
  358. for item in items:
  359. if item.args[0].name == self.foo_task.name:
  360. found = True
  361. self.assertTrue(found)
  362. self.assertGreater(l.qos.value, current_pcount)
  363. l.timer.stop()
  364. def test_pidbox_callback(self):
  365. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  366. con = find_step(l, consumer.Control).box
  367. con.node = Mock()
  368. con.reset = Mock()
  369. con.on_message('foo', 'bar')
  370. con.node.handle_message.assert_called_with('foo', 'bar')
  371. con.node = Mock()
  372. con.node.handle_message.side_effect = KeyError('foo')
  373. con.on_message('foo', 'bar')
  374. con.node.handle_message.assert_called_with('foo', 'bar')
  375. con.node = Mock()
  376. con.node.handle_message.side_effect = ValueError('foo')
  377. con.on_message('foo', 'bar')
  378. con.node.handle_message.assert_called_with('foo', 'bar')
  379. self.assertTrue(con.reset.called)
  380. def test_revoke(self):
  381. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  382. l.blueprint.state = RUN
  383. l.steps.pop()
  384. backend = Mock()
  385. id = uuid()
  386. t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
  387. kwargs={}, id=id)
  388. from celery.worker.state import revoked
  389. revoked.add(id)
  390. callback = self._get_on_message(l)
  391. callback(t.decode(), t)
  392. self.assertTrue(self.buffer.empty())
  393. def test_receieve_message_not_registered(self):
  394. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  395. l.blueprint.state = RUN
  396. l.steps.pop()
  397. backend = Mock()
  398. m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
  399. l.event_dispatcher = mock_event_dispatcher()
  400. callback = self._get_on_message(l)
  401. self.assertFalse(callback(m.decode(), m))
  402. with self.assertRaises(Empty):
  403. self.buffer.get_nowait()
  404. self.assertTrue(self.timer.empty())
  405. @patch('celery.worker.consumer.warn')
  406. @patch('celery.worker.consumer.logger')
  407. def test_receieve_message_ack_raises(self, logger, warn):
  408. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  409. l.blueprint.state = RUN
  410. backend = Mock()
  411. m = create_message(backend, args=[2, 4, 8], kwargs={})
  412. l.event_dispatcher = mock_event_dispatcher()
  413. l.connection_errors = (socket.error, )
  414. m.reject = Mock()
  415. m.reject.side_effect = socket.error('foo')
  416. callback = self._get_on_message(l)
  417. self.assertFalse(callback(m.decode(), m))
  418. self.assertTrue(warn.call_count)
  419. with self.assertRaises(Empty):
  420. self.buffer.get_nowait()
  421. self.assertTrue(self.timer.empty())
  422. m.reject.assert_called_with(requeue=False)
  423. self.assertTrue(logger.critical.call_count)
  424. def test_receive_message_eta(self):
  425. l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  426. l.steps.pop()
  427. l.event_dispatcher = mock_event_dispatcher()
  428. backend = Mock()
  429. m = create_message(
  430. backend, task=self.foo_task.name,
  431. args=[2, 4, 8], kwargs={},
  432. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  433. )
  434. try:
  435. l.blueprint.start(l)
  436. p = l.app.conf.BROKER_CONNECTION_RETRY
  437. l.app.conf.BROKER_CONNECTION_RETRY = False
  438. l.blueprint.start(l)
  439. l.app.conf.BROKER_CONNECTION_RETRY = p
  440. l.blueprint.restart(l)
  441. l.event_dispatcher = mock_event_dispatcher()
  442. callback = self._get_on_message(l)
  443. callback(m.decode(), m)
  444. finally:
  445. l.timer.stop()
  446. try:
  447. l.timer.join()
  448. except RuntimeError:
  449. pass
  450. in_hold = l.timer.queue[0]
  451. self.assertEqual(len(in_hold), 3)
  452. eta, priority, entry = in_hold
  453. task = entry.args[0]
  454. self.assertIsInstance(task, Request)
  455. self.assertEqual(task.name, self.foo_task.name)
  456. self.assertEqual(task.execute(), 2 * 4 * 8)
  457. with self.assertRaises(Empty):
  458. self.buffer.get_nowait()
  459. def test_reset_pidbox_node(self):
  460. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  461. con = find_step(l, consumer.Control).box
  462. con.node = Mock()
  463. chan = con.node.channel = Mock()
  464. l.connection = Mock()
  465. chan.close.side_effect = socket.error('foo')
  466. l.connection_errors = (socket.error, )
  467. con.reset()
  468. chan.close.assert_called_with()
  469. def test_reset_pidbox_node_green(self):
  470. from celery.worker.pidbox import gPidbox
  471. pool = Mock()
  472. pool.is_green = True
  473. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  474. app=self.app)
  475. con = find_step(l, consumer.Control)
  476. self.assertIsInstance(con.box, gPidbox)
  477. con.start(l)
  478. l.pool.spawn_n.assert_called_with(
  479. con.box.loop, l,
  480. )
  481. def test__green_pidbox_node(self):
  482. pool = Mock()
  483. pool.is_green = True
  484. l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
  485. app=self.app)
  486. l.node = Mock()
  487. controller = find_step(l, consumer.Control)
  488. class BConsumer(Mock):
  489. def __enter__(self):
  490. self.consume()
  491. return self
  492. def __exit__(self, *exc_info):
  493. self.cancel()
  494. controller.box.node.listen = BConsumer()
  495. connections = []
  496. class Connection(object):
  497. calls = 0
  498. def __init__(self, obj):
  499. connections.append(self)
  500. self.obj = obj
  501. self.default_channel = self.channel()
  502. self.closed = False
  503. def __enter__(self):
  504. return self
  505. def __exit__(self, *exc_info):
  506. self.close()
  507. def channel(self):
  508. return Mock()
  509. def as_uri(self):
  510. return 'dummy://'
  511. def drain_events(self, **kwargs):
  512. if not self.calls:
  513. self.calls += 1
  514. raise socket.timeout()
  515. self.obj.connection = None
  516. controller.box._node_shutdown.set()
  517. def close(self):
  518. self.closed = True
  519. l.connection = Mock()
  520. l.connect = lambda: Connection(obj=l)
  521. controller = find_step(l, consumer.Control)
  522. controller.box.loop(l)
  523. self.assertTrue(controller.box.node.listen.called)
  524. self.assertTrue(controller.box.consumer)
  525. controller.box.consumer.consume.assert_called_with()
  526. self.assertIsNone(l.connection)
  527. self.assertTrue(connections[0].closed)
  528. @patch('kombu.connection.Connection._establish_connection')
  529. @patch('kombu.utils.sleep')
  530. def test_connect_errback(self, sleep, connect):
  531. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  532. from kombu.transport.memory import Transport
  533. Transport.connection_errors = (ChannelError, )
  534. def effect():
  535. if connect.call_count > 1:
  536. return
  537. raise ChannelError('error')
  538. connect.side_effect = effect
  539. l.connect()
  540. connect.assert_called_with()
  541. def test_stop_pidbox_node(self):
  542. l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
  543. cont = find_step(l, consumer.Control)
  544. cont._node_stopped = Event()
  545. cont._node_shutdown = Event()
  546. cont._node_stopped.set()
  547. cont.stop(l)
  548. def test_start__loop(self):
  549. class _QoS(object):
  550. prev = 3
  551. value = 4
  552. def update(self):
  553. self.prev = self.value
  554. class _Consumer(MyKombuConsumer):
  555. iterations = 0
  556. def reset_connection(self):
  557. if self.iterations >= 1:
  558. raise KeyError('foo')
  559. init_callback = Mock()
  560. l = _Consumer(self.buffer.put, timer=self.timer,
  561. init_callback=init_callback, app=self.app)
  562. l.task_consumer = Mock()
  563. l.broadcast_consumer = Mock()
  564. l.qos = _QoS()
  565. l.connection = Connection()
  566. l.iterations = 0
  567. def raises_KeyError(*args, **kwargs):
  568. l.iterations += 1
  569. if l.qos.prev != l.qos.value:
  570. l.qos.update()
  571. if l.iterations >= 2:
  572. raise KeyError('foo')
  573. l.loop = raises_KeyError
  574. with self.assertRaises(KeyError):
  575. l.start()
  576. self.assertEqual(l.iterations, 2)
  577. self.assertEqual(l.qos.prev, l.qos.value)
  578. init_callback.reset_mock()
  579. l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
  580. send_events=False, init_callback=init_callback)
  581. l.qos = _QoS()
  582. l.task_consumer = Mock()
  583. l.broadcast_consumer = Mock()
  584. l.connection = Connection()
  585. l.loop = Mock(side_effect=socket.error('foo'))
  586. with self.assertRaises(socket.error):
  587. l.start()
  588. self.assertTrue(l.loop.call_count)
  589. def test_reset_connection_with_no_node(self):
  590. l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
  591. l.steps.pop()
  592. self.assertEqual(None, l.pool)
  593. l.blueprint.start(l)
  594. class test_WorkController(AppCase):
  595. def setup(self):
  596. self.worker = self.create_worker()
  597. from celery import worker
  598. self._logger = worker.logger
  599. self._comp_logger = components.logger
  600. self.logger = worker.logger = Mock()
  601. self.comp_logger = components.logger = Mock()
  602. @self.app.task(shared=False)
  603. def foo_task(x, y, z):
  604. return x * y * z
  605. self.foo_task = foo_task
  606. def teardown(self):
  607. from celery import worker
  608. worker.logger = self._logger
  609. components.logger = self._comp_logger
  610. def create_worker(self, **kw):
  611. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  612. worker.blueprint.shutdown_complete.set()
  613. return worker
  614. def test_on_consumer_ready(self):
  615. self.worker.on_consumer_ready(Mock())
  616. def test_setup_queues_worker_direct(self):
  617. self.app.conf.CELERY_WORKER_DIRECT = True
  618. self.app.amqp.__dict__['queues'] = Mock()
  619. self.worker.setup_queues({})
  620. self.app.amqp.queues.select_add.assert_called_with(
  621. worker_direct(self.worker.hostname),
  622. )
  623. def test_send_worker_shutdown(self):
  624. with patch('celery.signals.worker_shutdown') as ws:
  625. self.worker._send_worker_shutdown()
  626. ws.send.assert_called_with(sender=self.worker)
  627. def test_process_shutdown_on_worker_shutdown(self):
  628. raise SkipTest('unstable test')
  629. from celery.concurrency.prefork import process_destructor
  630. from celery.concurrency.asynpool import Worker
  631. with patch('celery.signals.worker_process_shutdown') as ws:
  632. Worker._make_shortcuts = Mock()
  633. with patch('os._exit') as _exit:
  634. worker = Worker(None, None, on_exit=process_destructor)
  635. worker._do_exit(22, 3.1415926)
  636. ws.send.assert_called_with(
  637. sender=None, pid=22, exitcode=3.1415926,
  638. )
  639. _exit.assert_called_with(3.1415926)
  640. def test_process_task_revoked_release_semaphore(self):
  641. self.worker._quick_release = Mock()
  642. req = Mock()
  643. req.execute_using_pool.side_effect = TaskRevokedError
  644. self.worker._process_task(req)
  645. self.worker._quick_release.assert_called_with()
  646. delattr(self.worker, '_quick_release')
  647. self.worker._process_task(req)
  648. def test_shutdown_no_blueprint(self):
  649. self.worker.blueprint = None
  650. self.worker._shutdown()
  651. @patch('celery.platforms.create_pidlock')
  652. def test_use_pidfile(self, create_pidlock):
  653. create_pidlock.return_value = Mock()
  654. worker = self.create_worker(pidfile='pidfilelockfilepid')
  655. worker.steps = []
  656. worker.start()
  657. self.assertTrue(create_pidlock.called)
  658. worker.stop()
  659. self.assertTrue(worker.pidlock.release.called)
  660. @patch('celery.platforms.signals')
  661. @patch('celery.platforms.set_mp_process_title')
  662. def test_process_initializer(self, set_mp_process_title, _signals):
  663. with restore_logging():
  664. from celery import signals
  665. from celery._state import _tls
  666. from celery.concurrency.prefork import (
  667. process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
  668. )
  669. def on_worker_process_init(**kwargs):
  670. on_worker_process_init.called = True
  671. on_worker_process_init.called = False
  672. signals.worker_process_init.connect(on_worker_process_init)
  673. def Loader(*args, **kwargs):
  674. loader = Mock(*args, **kwargs)
  675. loader.conf = {}
  676. loader.override_backends = {}
  677. return loader
  678. with self.Celery(loader=Loader) as app:
  679. app.conf = AttributeDict(DEFAULTS)
  680. process_initializer(app, 'awesome.worker.com')
  681. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  682. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  683. self.assertTrue(app.loader.init_worker.call_count)
  684. self.assertTrue(on_worker_process_init.called)
  685. self.assertIs(_tls.current_app, app)
  686. set_mp_process_title.assert_called_with(
  687. 'celeryd', hostname='awesome.worker.com',
  688. )
  689. with patch('celery.app.trace.setup_worker_optimizations') as S:
  690. os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
  691. try:
  692. process_initializer(app, 'luke.worker.com')
  693. S.assert_called_with(app)
  694. finally:
  695. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  696. def test_attrs(self):
  697. worker = self.worker
  698. self.assertIsNotNone(worker.timer)
  699. self.assertIsInstance(worker.timer, Timer)
  700. self.assertIsNotNone(worker.pool)
  701. self.assertIsNotNone(worker.consumer)
  702. self.assertTrue(worker.steps)
  703. def test_with_embedded_beat(self):
  704. worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
  705. self.assertTrue(worker.beat)
  706. self.assertIn(worker.beat, [w.obj for w in worker.steps])
  707. def test_with_autoscaler(self):
  708. worker = self.create_worker(
  709. autoscale=[10, 3], send_events=False,
  710. timer_cls='celery.utils.timer2.Timer',
  711. )
  712. self.assertTrue(worker.autoscaler)
  713. def test_dont_stop_or_terminate(self):
  714. worker = self.app.WorkController(concurrency=1, loglevel=0)
  715. worker.stop()
  716. self.assertNotEqual(worker.blueprint.state, CLOSE)
  717. worker.terminate()
  718. self.assertNotEqual(worker.blueprint.state, CLOSE)
  719. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  720. try:
  721. worker.blueprint.state = RUN
  722. worker.stop(in_sighandler=True)
  723. self.assertNotEqual(worker.blueprint.state, CLOSE)
  724. worker.terminate(in_sighandler=True)
  725. self.assertNotEqual(worker.blueprint.state, CLOSE)
  726. finally:
  727. worker.pool.signal_safe = sigsafe
  728. def test_on_timer_error(self):
  729. worker = self.app.WorkController(concurrency=1, loglevel=0)
  730. try:
  731. raise KeyError('foo')
  732. except KeyError as exc:
  733. components.Timer(worker).on_timer_error(exc)
  734. msg, args = self.comp_logger.error.call_args[0]
  735. self.assertIn('KeyError', msg % args)
  736. def test_on_timer_tick(self):
  737. worker = self.app.WorkController(concurrency=1, loglevel=10)
  738. components.Timer(worker).on_timer_tick(30.0)
  739. xargs = self.comp_logger.debug.call_args[0]
  740. fmt, arg = xargs[0], xargs[1]
  741. self.assertEqual(30.0, arg)
  742. self.assertIn('Next eta %s secs', fmt)
  743. def test_process_task(self):
  744. worker = self.worker
  745. worker.pool = Mock()
  746. backend = Mock()
  747. m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
  748. kwargs={})
  749. task = Request(m.decode(), message=m, app=self.app)
  750. worker._process_task(task)
  751. self.assertEqual(worker.pool.apply_async.call_count, 1)
  752. worker.pool.stop()
  753. def test_process_task_raise_base(self):
  754. worker = self.worker
  755. worker.pool = Mock()
  756. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  757. backend = Mock()
  758. m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
  759. kwargs={})
  760. task = Request(m.decode(), message=m, app=self.app)
  761. worker.steps = []
  762. worker.blueprint.state = RUN
  763. with self.assertRaises(KeyboardInterrupt):
  764. worker._process_task(task)
  765. def test_process_task_raise_WorkerTerminate(self):
  766. worker = self.worker
  767. worker.pool = Mock()
  768. worker.pool.apply_async.side_effect = WorkerTerminate()
  769. backend = Mock()
  770. m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
  771. kwargs={})
  772. task = Request(m.decode(), message=m, app=self.app)
  773. worker.steps = []
  774. worker.blueprint.state = RUN
  775. with self.assertRaises(SystemExit):
  776. worker._process_task(task)
  777. def test_process_task_raise_regular(self):
  778. worker = self.worker
  779. worker.pool = Mock()
  780. worker.pool.apply_async.side_effect = KeyError('some exception')
  781. backend = Mock()
  782. m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
  783. kwargs={})
  784. task = Request(m.decode(), message=m, app=self.app)
  785. worker._process_task(task)
  786. worker.pool.stop()
  787. def test_start_catches_base_exceptions(self):
  788. worker1 = self.create_worker()
  789. worker1.blueprint.state = RUN
  790. stc = MockStep()
  791. stc.start.side_effect = WorkerTerminate()
  792. worker1.steps = [stc]
  793. worker1.start()
  794. stc.start.assert_called_with(worker1)
  795. self.assertTrue(stc.terminate.call_count)
  796. worker2 = self.create_worker()
  797. worker2.blueprint.state = RUN
  798. sec = MockStep()
  799. sec.start.side_effect = WorkerShutdown()
  800. sec.terminate = None
  801. worker2.steps = [sec]
  802. worker2.start()
  803. self.assertTrue(sec.stop.call_count)
  804. def test_state_db(self):
  805. from celery.worker import state
  806. Persistent = state.Persistent
  807. state.Persistent = Mock()
  808. try:
  809. worker = self.create_worker(state_db='statefilename')
  810. self.assertTrue(worker._persistence)
  811. finally:
  812. state.Persistent = Persistent
  813. def test_process_task_sem(self):
  814. worker = self.worker
  815. worker._quick_acquire = Mock()
  816. req = Mock()
  817. worker._process_task_sem(req)
  818. worker._quick_acquire.assert_called_with(worker._process_task, req)
  819. def test_signal_consumer_close(self):
  820. worker = self.worker
  821. worker.consumer = Mock()
  822. worker.signal_consumer_close()
  823. worker.consumer.close.assert_called_with()
  824. worker.consumer.close.side_effect = AttributeError()
  825. worker.signal_consumer_close()
  826. def test_start__stop(self):
  827. worker = self.worker
  828. worker.blueprint.shutdown_complete.set()
  829. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  830. worker.blueprint.state = RUN
  831. worker.blueprint.started = 4
  832. for w in worker.steps:
  833. w.start = Mock()
  834. w.close = Mock()
  835. w.stop = Mock()
  836. worker.start()
  837. for w in worker.steps:
  838. self.assertTrue(w.start.call_count)
  839. worker.consumer = Mock()
  840. worker.stop()
  841. for stopstep in worker.steps:
  842. self.assertTrue(stopstep.close.call_count)
  843. self.assertTrue(stopstep.stop.call_count)
  844. # Doesn't close pool if no pool.
  845. worker.start()
  846. worker.pool = None
  847. worker.stop()
  848. # test that stop of None is not attempted
  849. worker.steps[-1] = None
  850. worker.start()
  851. worker.stop()
  852. def test_step_raises(self):
  853. worker = self.worker
  854. step = Mock()
  855. worker.steps = [step]
  856. step.start.side_effect = TypeError()
  857. worker.stop = Mock()
  858. worker.start()
  859. worker.stop.assert_called_with()
  860. def test_state(self):
  861. self.assertTrue(self.worker.state)
  862. def test_start__terminate(self):
  863. worker = self.worker
  864. worker.blueprint.shutdown_complete.set()
  865. worker.blueprint.started = 5
  866. worker.blueprint.state = RUN
  867. worker.steps = [MockStep() for _ in range(5)]
  868. worker.start()
  869. for w in worker.steps[:3]:
  870. self.assertTrue(w.start.call_count)
  871. self.assertTrue(worker.blueprint.started, len(worker.steps))
  872. self.assertEqual(worker.blueprint.state, RUN)
  873. worker.terminate()
  874. for step in worker.steps:
  875. self.assertTrue(step.terminate.call_count)
  876. def test_Queues_pool_no_sem(self):
  877. w = Mock()
  878. w.pool_cls.uses_semaphore = False
  879. components.Queues(w).create(w)
  880. self.assertIs(w.process_task, w._process_task)
  881. def test_Hub_crate(self):
  882. w = Mock()
  883. x = components.Hub(w)
  884. x.create(w)
  885. self.assertTrue(w.timer.max_interval)
  886. def test_Pool_crate_threaded(self):
  887. w = Mock()
  888. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  889. w.pool_cls = Mock()
  890. w.use_eventloop = False
  891. pool = components.Pool(w)
  892. pool.create(w)
  893. def test_Pool_create(self):
  894. from kombu.async.semaphore import LaxBoundedSemaphore
  895. w = Mock()
  896. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  897. w.hub = Mock()
  898. PoolImp = Mock()
  899. poolimp = PoolImp.return_value = Mock()
  900. poolimp._pool = [Mock(), Mock()]
  901. poolimp._cache = {}
  902. poolimp._fileno_to_inq = {}
  903. poolimp._fileno_to_outq = {}
  904. from celery.concurrency.prefork import TaskPool as _TaskPool
  905. class MockTaskPool(_TaskPool):
  906. Pool = PoolImp
  907. @property
  908. def timers(self):
  909. return {Mock(): 30}
  910. w.pool_cls = MockTaskPool
  911. w.use_eventloop = True
  912. w.consumer.restart_count = -1
  913. pool = components.Pool(w)
  914. pool.create(w)
  915. pool.register_with_event_loop(w, w.hub)
  916. self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
  917. P = w.pool
  918. P.start()