test_worker.py 35 KB


  1. from __future__ import absolute_import, print_function, unicode_literals
  2. import os
  3. import socket
  4. import sys
  5. from collections import deque
  6. from datetime import datetime, timedelta
  7. from functools import partial
  8. from threading import Event
  9. import pytest
  10. from amqp import ChannelError
  11. from case import Mock, patch, skip
  12. from kombu import Connection
  13. from kombu.common import QoS, ignore_errors
  14. from kombu.transport.base import Message
  15. from kombu.transport.memory import Transport
  16. from kombu.utils.uuid import uuid
  17. from celery.bootsteps import CLOSE, RUN, TERMINATE, StartStopStep
  18. from celery.concurrency.base import BasePool
  19. from celery.exceptions import (ImproperlyConfigured, InvalidTaskError,
  20. TaskRevokedError, WorkerShutdown,
  21. WorkerTerminate)
  22. from celery.five import Empty
  23. from celery.five import Queue as FastQueue
  24. from celery.five import range
  25. from celery.platforms import EX_FAILURE
  26. from celery.utils.nodenames import worker_direct
  27. from celery.utils.serialization import pickle
  28. from celery.utils.timer2 import Timer
  29. from celery.worker import components, consumer, state
  30. from celery.worker import worker as worker_module
  31. from celery.worker.consumer import Consumer
  32. from celery.worker.pidbox import gPidbox
  33. from celery.worker.request import Request
  34. def MockStep(step=None):
  35. if step is None:
  36. step = Mock(name='step')
  37. else:
  38. step.blueprint = Mock(name='step.blueprint')
  39. step.blueprint.name = 'MockNS'
  40. step.name = 'MockStep(%s)' % (id(step),)
  41. return step
  42. def mock_event_dispatcher():
  43. evd = Mock(name='event_dispatcher')
  44. evd.groups = ['worker']
  45. evd._outbound_buffer = deque()
  46. return evd
  47. def find_step(obj, typ):
  48. return obj.blueprint.steps[typ.name]
  49. def create_message(channel, **data):
  50. data.setdefault('id', uuid())
  51. m = Message(body=pickle.dumps(dict(**data)),
  52. channel=channel,
  53. content_type='application/x-python-serialize',
  54. content_encoding='binary',
  55. delivery_info={'consumer_tag': 'mock'})
  56. m.accept = ['application/x-python-serialize']
  57. return m
  58. class ConsumerCase:
  59. def create_task_message(self, channel, *args, **kwargs):
  60. m = self.TaskMessage(*args, **kwargs)
  61. m.channel = channel
  62. m.delivery_info = {'consumer_tag': 'mock'}
  63. return m
  64. class test_Consumer(ConsumerCase):
  65. def setup(self):
  66. self.buffer = FastQueue()
  67. self.timer = Timer()
  68. @self.app.task(shared=False)
  69. def foo_task(x, y, z):
  70. return x * y * z
  71. self.foo_task = foo_task
  72. def teardown(self):
  73. self.timer.stop()
  74. def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
  75. without_mingle=True, without_gossip=True,
  76. without_heartbeat=True, **kwargs):
  77. if controller is None:
  78. controller = Mock(name='.controller')
  79. buffer = buffer if buffer is not None else self.buffer.put
  80. timer = timer if timer is not None else self.timer
  81. app = app if app is not None else self.app
  82. c = Consumer(
  83. buffer,
  84. timer=timer,
  85. app=app,
  86. controller=controller,
  87. without_mingle=without_mingle,
  88. without_gossip=without_gossip,
  89. without_heartbeat=without_heartbeat,
  90. **kwargs
  91. )
  92. c.task_consumer = Mock(name='.task_consumer')
  93. c.qos = QoS(c.task_consumer.qos, 10)
  94. c.connection = Mock(name='.connection')
  95. c.controller = c.app.WorkController()
  96. c.heart = Mock(name='.heart')
  97. c.controller.consumer = c
  98. c.pool = c.controller.pool = Mock(name='.controller.pool')
  99. c.node = Mock(name='.node')
  100. c.event_dispatcher = mock_event_dispatcher()
  101. return c
  102. def NoopConsumer(self, *args, **kwargs):
  103. c = self.LoopConsumer(*args, **kwargs)
  104. c.loop = Mock(name='.loop')
  105. return c
  106. def test_info(self):
  107. c = self.NoopConsumer()
  108. c.connection.info.return_value = {'foo': 'bar'}
  109. c.controller.pool.info.return_value = [Mock(), Mock()]
  110. info = c.controller.stats()
  111. assert info['prefetch_count'] == 10
  112. assert info['broker']
  113. def test_start_when_closed(self):
  114. c = self.NoopConsumer()
  115. c.blueprint.state = CLOSE
  116. c.start()
  117. def test_connection(self):
  118. c = self.NoopConsumer()
  119. c.blueprint.start(c)
  120. assert isinstance(c.connection, Connection)
  121. c.blueprint.state = RUN
  122. c.event_dispatcher = None
  123. c.blueprint.restart(c)
  124. assert c.connection
  125. c.blueprint.state = RUN
  126. c.shutdown()
  127. assert c.connection is None
  128. assert c.task_consumer is None
  129. c.blueprint.start(c)
  130. assert isinstance(c.connection, Connection)
  131. c.blueprint.restart(c)
  132. c.stop()
  133. c.shutdown()
  134. assert c.connection is None
  135. assert c.task_consumer is None
  136. def test_close_connection(self):
  137. c = self.NoopConsumer()
  138. c.blueprint.state = RUN
  139. step = find_step(c, consumer.Connection)
  140. connection = c.connection
  141. step.shutdown(c)
  142. connection.close.assert_called()
  143. assert c.connection is None
  144. def test_close_connection__heart_shutdown(self):
  145. c = self.NoopConsumer()
  146. event_dispatcher = c.event_dispatcher
  147. heart = c.heart
  148. c.event_dispatcher.enabled = True
  149. c.blueprint.state = RUN
  150. Events = find_step(c, consumer.Events)
  151. Events.shutdown(c)
  152. Heart = find_step(c, consumer.Heart)
  153. Heart.shutdown(c)
  154. event_dispatcher.close.assert_called()
  155. heart.stop.assert_called_with()
  156. @patch('celery.worker.consumer.consumer.warn')
  157. def test_receive_message_unknown(self, warn):
  158. c = self.LoopConsumer()
  159. c.blueprint.state = RUN
  160. c.steps.pop()
  161. channel = Mock(name='.channeol')
  162. m = create_message(channel, unknown={'baz': '!!!'})
  163. callback = self._get_on_message(c)
  164. callback(m)
  165. warn.assert_called()
  166. @patch('celery.worker.strategy.to_timestamp')
  167. def test_receive_message_eta_OverflowError(self, to_timestamp):
  168. to_timestamp.side_effect = OverflowError()
  169. c = self.LoopConsumer()
  170. c.blueprint.state = RUN
  171. c.steps.pop()
  172. m = self.create_task_message(
  173. Mock(), self.foo_task.name,
  174. args=('2, 2'), kwargs={},
  175. eta=datetime.now().isoformat(),
  176. )
  177. c.update_strategies()
  178. callback = self._get_on_message(c)
  179. callback(m)
  180. assert m.acknowledged
  181. @patch('celery.worker.consumer.consumer.error')
  182. def test_receive_message_InvalidTaskError(self, error):
  183. c = self.LoopConsumer()
  184. c.blueprint.state = RUN
  185. c.steps.pop()
  186. m = self.create_task_message(
  187. Mock(), self.foo_task.name,
  188. args=(1, 2), kwargs='foobarbaz', id=1)
  189. c.update_strategies()
  190. strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
  191. strat.side_effect = InvalidTaskError()
  192. callback = self._get_on_message(c)
  193. callback(m)
  194. error.assert_called()
  195. assert 'Received invalid task message' in error.call_args[0][0]
  196. @patch('celery.worker.consumer.consumer.crit')
  197. def test_on_decode_error(self, crit):
  198. c = self.LoopConsumer()
  199. class MockMessage(Mock):
  200. content_type = 'application/x-msgpack'
  201. content_encoding = 'binary'
  202. body = 'foobarbaz'
  203. message = MockMessage()
  204. c.on_decode_error(message, KeyError('foo'))
  205. assert message.ack.call_count
  206. assert "Can't decode message body" in crit.call_args[0][0]
  207. def _get_on_message(self, c):
  208. if c.qos is None:
  209. c.qos = Mock()
  210. c.task_consumer = Mock()
  211. c.event_dispatcher = mock_event_dispatcher()
  212. c.connection = Mock(name='.connection')
  213. c.connection.get_heartbeat_interval.return_value = 0
  214. c.connection.drain_events.side_effect = WorkerShutdown()
  215. with pytest.raises(WorkerShutdown):
  216. c.loop(*c.loop_args())
  217. assert c.task_consumer.on_message
  218. return c.task_consumer.on_message
  219. def test_receieve_message(self):
  220. c = self.LoopConsumer()
  221. c.blueprint.state = RUN
  222. m = self.create_task_message(
  223. Mock(), self.foo_task.name,
  224. args=[2, 4, 8], kwargs={},
  225. )
  226. c.update_strategies()
  227. callback = self._get_on_message(c)
  228. callback(m)
  229. in_bucket = self.buffer.get_nowait()
  230. assert isinstance(in_bucket, Request)
  231. assert in_bucket.name == self.foo_task.name
  232. assert in_bucket.execute() == 2 * 4 * 8
  233. assert self.timer.empty()
  234. def test_start_channel_error(self):
  235. c = self.NoopConsumer(task_events=False, pool=BasePool())
  236. c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
  237. c.channel_errors = (KeyError,)
  238. try:
  239. with pytest.raises(KeyError):
  240. c.start()
  241. finally:
  242. c.timer and c.timer.stop()
  243. def test_start_connection_error(self):
  244. c = self.NoopConsumer(task_events=False, pool=BasePool())
  245. c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
  246. c.connection_errors = (KeyError,)
  247. try:
  248. with pytest.raises(SyntaxError):
  249. c.start()
  250. finally:
  251. c.timer and c.timer.stop()
  252. def test_loop_ignores_socket_timeout(self):
  253. class Connection(self.app.connection_for_read().__class__):
  254. obj = None
  255. def drain_events(self, **kwargs):
  256. self.obj.connection = None
  257. raise socket.timeout(10)
  258. c = self.NoopConsumer()
  259. c.connection = Connection(self.app.conf.broker_url)
  260. c.connection.obj = c
  261. c.qos = QoS(c.task_consumer.qos, 10)
  262. c.loop(*c.loop_args())
  263. def test_loop_when_socket_error(self):
  264. class Connection(self.app.connection_for_read().__class__):
  265. obj = None
  266. def drain_events(self, **kwargs):
  267. self.obj.connection = None
  268. raise socket.error('foo')
  269. c = self.LoopConsumer()
  270. c.blueprint.state = RUN
  271. conn = c.connection = Connection(self.app.conf.broker_url)
  272. c.connection.obj = c
  273. c.qos = QoS(c.task_consumer.qos, 10)
  274. with pytest.raises(socket.error):
  275. c.loop(*c.loop_args())
  276. c.blueprint.state = CLOSE
  277. c.connection = conn
  278. c.loop(*c.loop_args())
  279. def test_loop(self):
  280. class Connection(self.app.connection_for_read().__class__):
  281. obj = None
  282. def drain_events(self, **kwargs):
  283. self.obj.connection = None
  284. @property
  285. def supports_heartbeats(self):
  286. return False
  287. c = self.LoopConsumer()
  288. c.blueprint.state = RUN
  289. c.connection = Connection(self.app.conf.broker_url)
  290. c.connection.obj = c
  291. c.connection.get_heartbeat_interval = Mock(return_value=None)
  292. c.qos = QoS(c.task_consumer.qos, 10)
  293. c.loop(*c.loop_args())
  294. c.loop(*c.loop_args())
  295. assert c.task_consumer.consume.call_count
  296. c.task_consumer.qos.assert_called_with(prefetch_count=10)
  297. assert c.qos.value == 10
  298. c.qos.decrement_eventually()
  299. assert c.qos.value == 9
  300. c.qos.update()
  301. assert c.qos.value == 9
  302. c.task_consumer.qos.assert_called_with(prefetch_count=9)
  303. def test_ignore_errors(self):
  304. c = self.NoopConsumer()
  305. c.connection_errors = (AttributeError, KeyError,)
  306. c.channel_errors = (SyntaxError,)
  307. ignore_errors(c, Mock(side_effect=AttributeError('foo')))
  308. ignore_errors(c, Mock(side_effect=KeyError('foo')))
  309. ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
  310. with pytest.raises(IndexError):
  311. ignore_errors(c, Mock(side_effect=IndexError('foo')))
  312. def test_apply_eta_task(self):
  313. c = self.NoopConsumer()
  314. c.qos = QoS(None, 10)
  315. task = Mock(name='task', id='1234213')
  316. qos = c.qos.value
  317. c.apply_eta_task(task)
  318. assert task in state.reserved_requests
  319. assert c.qos.value == qos - 1
  320. assert self.buffer.get_nowait() is task
  321. def test_receieve_message_eta_isoformat(self):
  322. c = self.LoopConsumer()
  323. c.blueprint.state = RUN
  324. c.steps.pop()
  325. m = self.create_task_message(
  326. Mock(), self.foo_task.name,
  327. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  328. args=[2, 4, 8], kwargs={},
  329. )
  330. c.qos = QoS(c.task_consumer.qos, 1)
  331. current_pcount = c.qos.value
  332. c.event_dispatcher.enabled = False
  333. c.update_strategies()
  334. callback = self._get_on_message(c)
  335. callback(m)
  336. c.timer.stop()
  337. c.timer.join(1)
  338. items = [entry[2] for entry in self.timer.queue]
  339. found = 0
  340. for item in items:
  341. if item.args[0].name == self.foo_task.name:
  342. found = True
  343. assert found
  344. assert c.qos.value > current_pcount
  345. c.timer.stop()
  346. def test_pidbox_callback(self):
  347. c = self.NoopConsumer()
  348. con = find_step(c, consumer.Control).box
  349. con.node = Mock()
  350. con.reset = Mock()
  351. con.on_message('foo', 'bar')
  352. con.node.handle_message.assert_called_with('foo', 'bar')
  353. con.node = Mock()
  354. con.node.handle_message.side_effect = KeyError('foo')
  355. con.on_message('foo', 'bar')
  356. con.node.handle_message.assert_called_with('foo', 'bar')
  357. con.node = Mock()
  358. con.node.handle_message.side_effect = ValueError('foo')
  359. con.on_message('foo', 'bar')
  360. con.node.handle_message.assert_called_with('foo', 'bar')
  361. con.reset.assert_called()
  362. def test_revoke(self):
  363. c = self.LoopConsumer()
  364. c.blueprint.state = RUN
  365. c.steps.pop()
  366. channel = Mock(name='channel')
  367. id = uuid()
  368. t = self.create_task_message(
  369. channel, self.foo_task.name,
  370. args=[2, 4, 8], kwargs={}, id=id,
  371. )
  372. state.revoked.add(id)
  373. callback = self._get_on_message(c)
  374. callback(t)
  375. assert self.buffer.empty()
  376. def test_receieve_message_not_registered(self):
  377. c = self.LoopConsumer()
  378. c.blueprint.state = RUN
  379. c.steps.pop()
  380. channel = Mock(name='channel')
  381. m = self.create_task_message(
  382. channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
  383. )
  384. callback = self._get_on_message(c)
  385. assert not callback(m)
  386. with pytest.raises(Empty):
  387. self.buffer.get_nowait()
  388. assert self.timer.empty()
  389. @patch('celery.worker.consumer.consumer.warn')
  390. @patch('celery.worker.consumer.consumer.logger')
  391. def test_receieve_message_ack_raises(self, logger, warn):
  392. c = self.LoopConsumer()
  393. c.blueprint.state = RUN
  394. channel = Mock(name='channel')
  395. m = self.create_task_message(
  396. channel, self.foo_task.name,
  397. args=[2, 4, 8], kwargs={},
  398. )
  399. m.headers = None
  400. c.update_strategies()
  401. c.connection_errors = (socket.error,)
  402. m.reject = Mock()
  403. m.reject.side_effect = socket.error('foo')
  404. callback = self._get_on_message(c)
  405. assert not callback(m)
  406. warn.assert_called()
  407. with pytest.raises(Empty):
  408. self.buffer.get_nowait()
  409. assert self.timer.empty()
  410. m.reject_log_error.assert_called_with(logger, c.connection_errors)
  411. def test_receive_message_eta(self):
  412. if os.environ.get('C_DEBUG_TEST'):
  413. pp = partial(print, file=sys.__stderr__)
  414. else:
  415. def pp(*args, **kwargs):
  416. pass
  417. pp('TEST RECEIVE MESSAGE ETA')
  418. pp('+CREATE MYKOMBUCONSUMER')
  419. c = self.LoopConsumer()
  420. pp('-CREATE MYKOMBUCONSUMER')
  421. c.steps.pop()
  422. channel = Mock(name='channel')
  423. pp('+ CREATE MESSAGE')
  424. m = self.create_task_message(
  425. channel, self.foo_task.name,
  426. args=[2, 4, 8], kwargs={},
  427. eta=(datetime.now() + timedelta(days=1)).isoformat(),
  428. )
  429. pp('- CREATE MESSAGE')
  430. try:
  431. pp('+ BLUEPRINT START 1')
  432. c.blueprint.start(c)
  433. pp('- BLUEPRINT START 1')
  434. p = c.app.conf.broker_connection_retry
  435. c.app.conf.broker_connection_retry = False
  436. pp('+ BLUEPRINT START 2')
  437. c.blueprint.start(c)
  438. pp('- BLUEPRINT START 2')
  439. c.app.conf.broker_connection_retry = p
  440. pp('+ BLUEPRINT RESTART')
  441. c.blueprint.restart(c)
  442. pp('- BLUEPRINT RESTART')
  443. pp('+ GET ON MESSAGE')
  444. callback = self._get_on_message(c)
  445. pp('- GET ON MESSAGE')
  446. pp('+ CALLBACK')
  447. callback(m)
  448. pp('- CALLBACK')
  449. finally:
  450. pp('+ STOP TIMER')
  451. c.timer.stop()
  452. pp('- STOP TIMER')
  453. try:
  454. pp('+ JOIN TIMER')
  455. c.timer.join()
  456. pp('- JOIN TIMER')
  457. except RuntimeError:
  458. pass
  459. in_hold = c.timer.queue[0]
  460. assert len(in_hold) == 3
  461. eta, priority, entry = in_hold
  462. task = entry.args[0]
  463. assert isinstance(task, Request)
  464. assert task.name == self.foo_task.name
  465. assert task.execute() == 2 * 4 * 8
  466. with pytest.raises(Empty):
  467. self.buffer.get_nowait()
  468. def test_reset_pidbox_node(self):
  469. c = self.NoopConsumer()
  470. con = find_step(c, consumer.Control).box
  471. con.node = Mock()
  472. chan = con.node.channel = Mock()
  473. chan.close.side_effect = socket.error('foo')
  474. c.connection_errors = (socket.error,)
  475. con.reset()
  476. chan.close.assert_called_with()
  477. def test_reset_pidbox_node_green(self):
  478. c = self.NoopConsumer(pool=Mock(is_green=True))
  479. con = find_step(c, consumer.Control)
  480. assert isinstance(con.box, gPidbox)
  481. con.start(c)
  482. c.pool.spawn_n.assert_called_with(con.box.loop, c)
  483. def test_green_pidbox_node(self):
  484. pool = Mock()
  485. pool.is_green = True
  486. c = self.NoopConsumer(pool=Mock(is_green=True))
  487. controller = find_step(c, consumer.Control)
  488. class BConsumer(Mock):
  489. def __enter__(self):
  490. self.consume()
  491. return self
  492. def __exit__(self, *exc_info):
  493. self.cancel()
  494. controller.box.node.listen = BConsumer()
  495. connections = []
  496. class Connection(object):
  497. calls = 0
  498. def __init__(self, obj):
  499. connections.append(self)
  500. self.obj = obj
  501. self.default_channel = self.channel()
  502. self.closed = False
  503. def __enter__(self):
  504. return self
  505. def __exit__(self, *exc_info):
  506. self.close()
  507. def channel(self):
  508. return Mock()
  509. def as_uri(self):
  510. return 'dummy://'
  511. def drain_events(self, **kwargs):
  512. if not self.calls:
  513. self.calls += 1
  514. raise socket.timeout()
  515. self.obj.connection = None
  516. controller.box._node_shutdown.set()
  517. def close(self):
  518. self.closed = True
  519. c.connection_for_read = lambda: Connection(obj=c)
  520. controller = find_step(c, consumer.Control)
  521. controller.box.loop(c)
  522. controller.box.node.listen.assert_called()
  523. assert controller.box.consumer
  524. controller.box.consumer.consume.assert_called_with()
  525. assert c.connection is None
  526. assert connections[0].closed
  527. @patch('kombu.connection.Connection._establish_connection')
  528. @patch('kombu.utils.functional.sleep')
  529. def test_connect_errback(self, sleep, connect):
  530. c = self.NoopConsumer()
  531. Transport.connection_errors = (ChannelError,)
  532. connect.on_nth_call_do(ChannelError('error'), n=1)
  533. c.connect()
  534. connect.assert_called_with()
  535. def test_stop_pidbox_node(self):
  536. c = self.NoopConsumer()
  537. cont = find_step(c, consumer.Control)
  538. cont._node_stopped = Event()
  539. cont._node_shutdown = Event()
  540. cont._node_stopped.set()
  541. cont.stop(c)
  542. def test_start__loop(self):
  543. class _QoS(object):
  544. prev = 3
  545. value = 4
  546. def update(self):
  547. self.prev = self.value
  548. init_callback = Mock(name='init_callback')
  549. c = self.NoopConsumer(init_callback=init_callback)
  550. c.qos = _QoS()
  551. c.connection = Connection(self.app.conf.broker_url)
  552. c.connection.get_heartbeat_interval = Mock(return_value=None)
  553. c.iterations = 0
  554. def raises_KeyError(*args, **kwargs):
  555. c.iterations += 1
  556. if c.qos.prev != c.qos.value:
  557. c.qos.update()
  558. if c.iterations >= 2:
  559. raise KeyError('foo')
  560. c.loop = raises_KeyError
  561. with pytest.raises(KeyError):
  562. c.start()
  563. assert c.iterations == 2
  564. assert c.qos.prev == c.qos.value
  565. init_callback.reset_mock()
  566. c = self.NoopConsumer(task_events=False, init_callback=init_callback)
  567. c.qos = _QoS()
  568. c.connection = Connection(self.app.conf.broker_url)
  569. c.connection.get_heartbeat_interval = Mock(return_value=None)
  570. c.loop = Mock(side_effect=socket.error('foo'))
  571. with pytest.raises(socket.error):
  572. c.start()
  573. c.loop.assert_called()
  574. def test_reset_connection_with_no_node(self):
  575. c = self.NoopConsumer()
  576. c.steps.pop()
  577. c.blueprint.start(c)
  578. class test_WorkController(ConsumerCase):
  579. def setup(self):
  580. self.worker = self.create_worker()
  581. self._logger = worker_module.logger
  582. self._comp_logger = components.logger
  583. self.logger = worker_module.logger = Mock()
  584. self.comp_logger = components.logger = Mock()
  585. @self.app.task(shared=False)
  586. def foo_task(x, y, z):
  587. return x * y * z
  588. self.foo_task = foo_task
  589. def teardown(self):
  590. worker_module.logger = self._logger
  591. components.logger = self._comp_logger
  592. def create_worker(self, **kw):
  593. worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
  594. worker.blueprint.shutdown_complete.set()
  595. return worker
  596. def test_on_consumer_ready(self):
  597. self.worker.on_consumer_ready(Mock())
  598. def test_setup_queues_worker_direct(self):
  599. self.app.conf.worker_direct = True
  600. self.app.amqp.__dict__['queues'] = Mock()
  601. self.worker.setup_queues({})
  602. self.app.amqp.queues.select_add.assert_called_with(
  603. worker_direct(self.worker.hostname),
  604. )
  605. def test_setup_queues__missing_queue(self):
  606. self.app.amqp.queues.select = Mock(name='select')
  607. self.app.amqp.queues.deselect = Mock(name='deselect')
  608. self.app.amqp.queues.select.side_effect = KeyError()
  609. self.app.amqp.queues.deselect.side_effect = KeyError()
  610. with pytest.raises(ImproperlyConfigured):
  611. self.worker.setup_queues('x,y', exclude='foo,bar')
  612. self.app.amqp.queues.select = Mock(name='select')
  613. with pytest.raises(ImproperlyConfigured):
  614. self.worker.setup_queues('x,y', exclude='foo,bar')
  615. def test_send_worker_shutdown(self):
  616. with patch('celery.signals.worker_shutdown') as ws:
  617. self.worker._send_worker_shutdown()
  618. ws.send.assert_called_with(sender=self.worker)
  619. @skip.todo('unstable test')
  620. def test_process_shutdown_on_worker_shutdown(self):
  621. from celery.concurrency.prefork import process_destructor
  622. from celery.concurrency.asynpool import Worker
  623. with patch('celery.signals.worker_process_shutdown') as ws:
  624. with patch('os._exit') as _exit:
  625. worker = Worker(None, None, on_exit=process_destructor)
  626. worker._do_exit(22, 3.1415926)
  627. ws.send.assert_called_with(
  628. sender=None, pid=22, exitcode=3.1415926,
  629. )
  630. _exit.assert_called_with(3.1415926)
  631. def test_process_task_revoked_release_semaphore(self):
  632. self.worker._quick_release = Mock()
  633. req = Mock()
  634. req.execute_using_pool.side_effect = TaskRevokedError
  635. self.worker._process_task(req)
  636. self.worker._quick_release.assert_called_with()
  637. delattr(self.worker, '_quick_release')
  638. self.worker._process_task(req)
  639. def test_shutdown_no_blueprint(self):
  640. self.worker.blueprint = None
  641. self.worker._shutdown()
  642. @patch('celery.worker.worker.create_pidlock')
  643. def test_use_pidfile(self, create_pidlock):
  644. create_pidlock.return_value = Mock()
  645. worker = self.create_worker(pidfile='pidfilelockfilepid')
  646. worker.steps = []
  647. worker.start()
  648. create_pidlock.assert_called()
  649. worker.stop()
  650. worker.pidlock.release.assert_called()
  651. def test_attrs(self):
  652. worker = self.worker
  653. assert worker.timer is not None
  654. assert isinstance(worker.timer, Timer)
  655. assert worker.pool is not None
  656. assert worker.consumer is not None
  657. assert worker.steps
  658. def test_with_embedded_beat(self):
  659. worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
  660. assert worker.beat
  661. assert worker.beat in [w.obj for w in worker.steps]
  662. def test_with_autoscaler(self):
  663. worker = self.create_worker(
  664. autoscale=[10, 3], send_events=False,
  665. timer_cls='celery.utils.timer2.Timer',
  666. )
  667. assert worker.autoscaler
  668. def test_dont_stop_or_terminate(self):
  669. worker = self.app.WorkController(concurrency=1, loglevel=0)
  670. worker.stop()
  671. assert worker.blueprint.state != CLOSE
  672. worker.terminate()
  673. assert worker.blueprint.state != CLOSE
  674. sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
  675. try:
  676. worker.blueprint.state = RUN
  677. worker.stop(in_sighandler=True)
  678. assert worker.blueprint.state != CLOSE
  679. worker.terminate(in_sighandler=True)
  680. assert worker.blueprint.state != CLOSE
  681. finally:
  682. worker.pool.signal_safe = sigsafe
  683. def test_on_timer_error(self):
  684. worker = self.app.WorkController(concurrency=1, loglevel=0)
  685. try:
  686. raise KeyError('foo')
  687. except KeyError as exc:
  688. components.Timer(worker).on_timer_error(exc)
  689. msg, args = self.comp_logger.error.call_args[0]
  690. assert 'KeyError' in msg % args
  691. def test_on_timer_tick(self):
  692. worker = self.app.WorkController(concurrency=1, loglevel=10)
  693. components.Timer(worker).on_timer_tick(30.0)
  694. xargs = self.comp_logger.debug.call_args[0]
  695. fmt, arg = xargs[0], xargs[1]
  696. assert arg == 30.0
  697. assert 'Next ETA %s secs' in fmt
  698. def test_process_task(self):
  699. worker = self.worker
  700. worker.pool = Mock()
  701. channel = Mock()
  702. m = self.create_task_message(
  703. channel, self.foo_task.name,
  704. args=[4, 8, 10], kwargs={},
  705. )
  706. task = Request(m, app=self.app)
  707. worker._process_task(task)
  708. assert worker.pool.apply_async.call_count == 1
  709. worker.pool.stop()
  710. def test_process_task_raise_base(self):
  711. worker = self.worker
  712. worker.pool = Mock()
  713. worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
  714. channel = Mock()
  715. m = self.create_task_message(
  716. channel, self.foo_task.name,
  717. args=[4, 8, 10], kwargs={},
  718. )
  719. task = Request(m, app=self.app)
  720. worker.steps = []
  721. worker.blueprint.state = RUN
  722. with pytest.raises(KeyboardInterrupt):
  723. worker._process_task(task)
  724. def test_process_task_raise_WorkerTerminate(self):
  725. worker = self.worker
  726. worker.pool = Mock()
  727. worker.pool.apply_async.side_effect = WorkerTerminate()
  728. channel = Mock()
  729. m = self.create_task_message(
  730. channel, self.foo_task.name,
  731. args=[4, 8, 10], kwargs={},
  732. )
  733. task = Request(m, app=self.app)
  734. worker.steps = []
  735. worker.blueprint.state = RUN
  736. with pytest.raises(SystemExit):
  737. worker._process_task(task)
  738. def test_process_task_raise_regular(self):
  739. worker = self.worker
  740. worker.pool = Mock()
  741. worker.pool.apply_async.side_effect = KeyError('some exception')
  742. channel = Mock()
  743. m = self.create_task_message(
  744. channel, self.foo_task.name,
  745. args=[4, 8, 10], kwargs={},
  746. )
  747. task = Request(m, app=self.app)
  748. with pytest.raises(KeyError):
  749. worker._process_task(task)
  750. worker.pool.stop()
  751. def test_start_catches_base_exceptions(self):
  752. worker1 = self.create_worker()
  753. worker1.blueprint.state = RUN
  754. stc = MockStep()
  755. stc.start.side_effect = WorkerTerminate()
  756. worker1.steps = [stc]
  757. worker1.start()
  758. stc.start.assert_called_with(worker1)
  759. assert stc.terminate.call_count
  760. worker2 = self.create_worker()
  761. worker2.blueprint.state = RUN
  762. sec = MockStep()
  763. sec.start.side_effect = WorkerShutdown()
  764. sec.terminate = None
  765. worker2.steps = [sec]
  766. worker2.start()
  767. assert sec.stop.call_count
  768. def test_statedb(self):
  769. from celery.worker import state
  770. Persistent = state.Persistent
  771. state.Persistent = Mock()
  772. try:
  773. worker = self.create_worker(statedb='statefilename')
  774. assert worker._persistence
  775. finally:
  776. state.Persistent = Persistent
  777. def test_process_task_sem(self):
  778. worker = self.worker
  779. worker._quick_acquire = Mock()
  780. req = Mock()
  781. worker._process_task_sem(req)
  782. worker._quick_acquire.assert_called_with(worker._process_task, req)
  783. def test_signal_consumer_close(self):
  784. worker = self.worker
  785. worker.consumer = Mock()
  786. worker.signal_consumer_close()
  787. worker.consumer.close.assert_called_with()
  788. worker.consumer.close.side_effect = AttributeError()
  789. worker.signal_consumer_close()
  790. def test_rusage__no_resource(self):
  791. from celery.worker import worker
  792. prev, worker.resource = worker.resource, None
  793. try:
  794. self.worker.pool = Mock(name='pool')
  795. with pytest.raises(NotImplementedError):
  796. self.worker.rusage()
  797. self.worker.stats()
  798. finally:
  799. worker.resource = prev
  800. def test_repr(self):
  801. assert repr(self.worker)
  802. def test_str(self):
  803. assert str(self.worker) == self.worker.hostname
  804. def test_start__stop(self):
  805. worker = self.worker
  806. worker.blueprint.shutdown_complete.set()
  807. worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
  808. worker.blueprint.state = RUN
  809. worker.blueprint.started = 4
  810. for w in worker.steps:
  811. w.start = Mock()
  812. w.close = Mock()
  813. w.stop = Mock()
  814. worker.start()
  815. for w in worker.steps:
  816. w.start.assert_called()
  817. worker.consumer = Mock()
  818. worker.stop(exitcode=3)
  819. for stopstep in worker.steps:
  820. stopstep.close.assert_called()
  821. stopstep.stop.assert_called()
  822. # Doesn't close pool if no pool.
  823. worker.start()
  824. worker.pool = None
  825. worker.stop()
  826. # test that stop of None is not attempted
  827. worker.steps[-1] = None
  828. worker.start()
  829. worker.stop()
  830. def test_start__KeyboardInterrupt(self):
  831. worker = self.worker
  832. worker.blueprint = Mock(name='blueprint')
  833. worker.blueprint.start.side_effect = KeyboardInterrupt()
  834. worker.stop = Mock(name='stop')
  835. worker.start()
  836. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  837. def test_register_with_event_loop(self):
  838. worker = self.worker
  839. hub = Mock(name='hub')
  840. worker.blueprint = Mock(name='blueprint')
  841. worker.register_with_event_loop(hub)
  842. worker.blueprint.send_all.assert_called_with(
  843. worker, 'register_with_event_loop', args=(hub,),
  844. description='hub.register',
  845. )
  846. def test_step_raises(self):
  847. worker = self.worker
  848. step = Mock()
  849. worker.steps = [step]
  850. step.start.side_effect = TypeError()
  851. worker.stop = Mock()
  852. worker.start()
  853. worker.stop.assert_called_with(exitcode=EX_FAILURE)
  854. def test_state(self):
  855. assert self.worker.state
  856. def test_start__terminate(self):
  857. worker = self.worker
  858. worker.blueprint.shutdown_complete.set()
  859. worker.blueprint.started = 5
  860. worker.blueprint.state = RUN
  861. worker.steps = [MockStep() for _ in range(5)]
  862. worker.start()
  863. for w in worker.steps[:3]:
  864. w.start.assert_called()
  865. assert worker.blueprint.started == len(worker.steps)
  866. assert worker.blueprint.state == RUN
  867. worker.terminate()
  868. for step in worker.steps:
  869. step.terminate.assert_called()
  870. worker.blueprint.state = TERMINATE
  871. worker.terminate()
  872. def test_Hub_create(self):
  873. w = Mock()
  874. x = components.Hub(w)
  875. x.create(w)
  876. assert w.timer.max_interval
  877. def test_Pool_create_threaded(self):
  878. w = Mock()
  879. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  880. w.pool_cls = Mock()
  881. w.use_eventloop = False
  882. pool = components.Pool(w)
  883. pool.create(w)
  884. def test_Pool_pool_no_sem(self):
  885. w = Mock()
  886. w.pool_cls.uses_semaphore = False
  887. components.Pool(w).create(w)
  888. assert w.process_task is w._process_task
  889. def test_Pool_create(self):
  890. from kombu.asynchronous.semaphore import LaxBoundedSemaphore
  891. w = Mock()
  892. w._conninfo.connection_errors = w._conninfo.channel_errors = ()
  893. w.hub = Mock()
  894. PoolImp = Mock()
  895. poolimp = PoolImp.return_value = Mock()
  896. poolimp._pool = [Mock(), Mock()]
  897. poolimp._cache = {}
  898. poolimp._fileno_to_inq = {}
  899. poolimp._fileno_to_outq = {}
  900. from celery.concurrency.prefork import TaskPool as _TaskPool
  901. class MockTaskPool(_TaskPool):
  902. Pool = PoolImp
  903. @property
  904. def timers(self):
  905. return {Mock(): 30}
  906. w.pool_cls = MockTaskPool
  907. w.use_eventloop = True
  908. w.consumer.restart_count = -1
  909. pool = components.Pool(w)
  910. pool.create(w)
  911. pool.register_with_event_loop(w, w.hub)
  912. if sys.platform != 'win32':
  913. assert isinstance(w.semaphore, LaxBoundedSemaphore)
  914. P = w.pool
  915. P.start()