test_loops.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. from __future__ import absolute_import
  2. import errno
  3. import socket
  4. from kombu.async import Hub, READ, WRITE, ERR
  5. from celery.bootsteps import CLOSE, RUN
  6. from celery.exceptions import (
  7. InvalidTaskError, WorkerLostError, WorkerShutdown, WorkerTerminate,
  8. )
  9. from celery.five import Empty
  10. from celery.platforms import EX_FAILURE
  11. from celery.worker import state
  12. from celery.worker.consumer import Consumer
  13. from celery.worker.loops import _quick_drain, asynloop, synloop
  14. from celery.tests.case import AppCase, Mock, task_message_from_sig
  15. class X(object):
  16. def __init__(self, app, heartbeat=None, on_task_message=None,
  17. transport_driver_type=None):
  18. hub = Hub()
  19. (
  20. self.obj,
  21. self.connection,
  22. self.consumer,
  23. self.blueprint,
  24. self.hub,
  25. self.qos,
  26. self.heartbeat,
  27. self.clock,
  28. ) = self.args = [Mock(name='obj'),
  29. Mock(name='connection'),
  30. Mock(name='consumer'),
  31. Mock(name='blueprint'),
  32. hub,
  33. Mock(name='qos'),
  34. heartbeat,
  35. Mock(name='clock')]
  36. self.connection.supports_heartbeats = True
  37. self.connection.get_heartbeat_interval.side_effect = (
  38. lambda: self.heartbeat
  39. )
  40. self.consumer.callbacks = []
  41. self.obj.strategies = {}
  42. self.connection.connection_errors = (socket.error,)
  43. if transport_driver_type:
  44. self.connection.transport.driver_type = transport_driver_type
  45. self.hub.readers = {}
  46. self.hub.writers = {}
  47. self.hub.consolidate = set()
  48. self.hub.timer = Mock(name='hub.timer')
  49. self.hub.timer._queue = [Mock()]
  50. self.hub.fire_timers = Mock(name='hub.fire_timers')
  51. self.hub.fire_timers.return_value = 1.7
  52. self.hub.poller = Mock(name='hub.poller')
  53. self.hub.close = Mock(name='hub.close()') # asynloop calls hub.close
  54. self.Hub = self.hub
  55. self.blueprint.state = RUN
  56. # need this for create_task_handler
  57. _consumer = Consumer(Mock(), timer=Mock(), app=app)
  58. _consumer.on_task_message = on_task_message or []
  59. self.obj.create_task_handler = _consumer.create_task_handler
  60. self.on_unknown_message = self.obj.on_unknown_message = Mock(
  61. name='on_unknown_message',
  62. )
  63. _consumer.on_unknown_message = self.on_unknown_message
  64. self.on_unknown_task = self.obj.on_unknown_task = Mock(
  65. name='on_unknown_task',
  66. )
  67. _consumer.on_unknown_task = self.on_unknown_task
  68. self.on_invalid_task = self.obj.on_invalid_task = Mock(
  69. name='on_invalid_task',
  70. )
  71. _consumer.on_invalid_task = self.on_invalid_task
  72. _consumer.strategies = self.obj.strategies
  73. def timeout_then_error(self, mock):
  74. def first(*args, **kwargs):
  75. mock.side_effect = socket.error()
  76. self.connection.more_to_read = False
  77. raise socket.timeout()
  78. mock.side_effect = first
  79. def close_then_error(self, mock=None, mod=0, exc=None):
  80. mock = Mock() if mock is None else mock
  81. def first(*args, **kwargs):
  82. if not mod or mock.call_count > mod:
  83. self.close()
  84. self.connection.more_to_read = False
  85. raise (socket.error() if exc is None else exc)
  86. mock.side_effect = first
  87. return mock
  88. def close(self, *args, **kwargs):
  89. self.blueprint.state = CLOSE
  90. def closer(self, mock=None, mod=0):
  91. mock = Mock() if mock is None else mock
  92. def closing(*args, **kwargs):
  93. if not mod or mock.call_count >= mod:
  94. self.close()
  95. mock.side_effect = closing
  96. return mock
  97. def get_task_callback(*args, **kwargs):
  98. x = X(*args, **kwargs)
  99. x.blueprint.state = CLOSE
  100. asynloop(*x.args)
  101. return x, x.consumer.on_message
  102. class test_asynloop(AppCase):
  103. def setup(self):
  104. @self.app.task(shared=False)
  105. def add(x, y):
  106. return x + y
  107. self.add = add
  108. def test_drain_after_consume(self):
  109. x, _ = get_task_callback(self.app, transport_driver_type='amqp')
  110. self.assertIn(
  111. _quick_drain, [p.fun for p in x.hub._ready],
  112. )
  113. def test_pool_did_not_start_at_startup(self):
  114. x = X(self.app)
  115. x.obj.restart_count = 0
  116. x.obj.pool.did_start_ok.return_value = False
  117. with self.assertRaises(WorkerLostError):
  118. asynloop(*x.args)
  119. def test_setup_heartbeat(self):
  120. x = X(self.app, heartbeat=10)
  121. x.hub.call_repeatedly = Mock(name='x.hub.call_repeatedly()')
  122. x.blueprint.state = CLOSE
  123. asynloop(*x.args)
  124. x.consumer.consume.assert_called_with()
  125. x.obj.on_ready.assert_called_with()
  126. x.hub.call_repeatedly.assert_called_with(
  127. 10 / 2.0, x.connection.heartbeat_check, 2.0,
  128. )
  129. def task_context(self, sig, **kwargs):
  130. x, on_task = get_task_callback(self.app, **kwargs)
  131. message = task_message_from_sig(self.app, sig)
  132. strategy = x.obj.strategies[sig.task] = Mock(name='strategy')
  133. return x, on_task, message, strategy
  134. def test_on_task_received(self):
  135. _, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
  136. on_task(msg)
  137. strategy.assert_called_with(
  138. msg, None, msg.ack_log_error, msg.reject_log_error, [],
  139. )
  140. def test_on_task_received_executes_on_task_message(self):
  141. cbs = [Mock(), Mock(), Mock()]
  142. _, on_task, msg, strategy = self.task_context(
  143. self.add.s(2, 2), on_task_message=cbs,
  144. )
  145. on_task(msg)
  146. strategy.assert_called_with(
  147. msg, None, msg.ack_log_error, msg.reject_log_error, cbs,
  148. )
  149. def test_on_task_message_missing_name(self):
  150. x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
  151. msg.headers.pop('task')
  152. on_task(msg)
  153. x.on_unknown_message.assert_called_with(msg.decode(), msg)
  154. def test_on_task_not_registered(self):
  155. x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
  156. exc = strategy.side_effect = KeyError(self.add.name)
  157. on_task(msg)
  158. x.on_invalid_task.assert_called_with(None, msg, exc)
  159. def test_on_task_InvalidTaskError(self):
  160. x, on_task, msg, strategy = self.task_context(self.add.s(2, 2))
  161. exc = strategy.side_effect = InvalidTaskError()
  162. on_task(msg)
  163. x.on_invalid_task.assert_called_with(None, msg, exc)
  164. def test_should_terminate(self):
  165. x = X(self.app)
  166. # XXX why aren't the errors propagated?!?
  167. state.should_terminate = True
  168. try:
  169. with self.assertRaises(WorkerTerminate):
  170. asynloop(*x.args)
  171. finally:
  172. state.should_terminate = None
  173. def test_should_terminate_hub_close_raises(self):
  174. x = X(self.app)
  175. # XXX why aren't the errors propagated?!?
  176. state.should_terminate = EX_FAILURE
  177. x.hub.close.side_effect = MemoryError()
  178. try:
  179. with self.assertRaises(WorkerTerminate):
  180. asynloop(*x.args)
  181. finally:
  182. state.should_terminate = None
  183. def test_should_stop(self):
  184. x = X(self.app)
  185. state.should_stop = 303
  186. try:
  187. with self.assertRaises(WorkerShutdown):
  188. asynloop(*x.args)
  189. finally:
  190. state.should_stop = None
  191. def test_updates_qos(self):
  192. x = X(self.app)
  193. x.qos.prev = 3
  194. x.qos.value = 3
  195. x.hub.on_tick.add(x.closer(mod=2))
  196. x.hub.timer._queue = [1]
  197. asynloop(*x.args)
  198. self.assertFalse(x.qos.update.called)
  199. x = X(self.app)
  200. x.qos.prev = 1
  201. x.qos.value = 6
  202. x.hub.on_tick.add(x.closer(mod=2))
  203. asynloop(*x.args)
  204. x.qos.update.assert_called_with()
  205. x.hub.fire_timers.assert_called_with(propagate=(socket.error,))
  206. def test_poll_empty(self):
  207. x = X(self.app)
  208. x.hub.readers = {6: Mock()}
  209. x.hub.timer._queue = [1]
  210. x.close_then_error(x.hub.poller.poll)
  211. x.hub.fire_timers.return_value = 33.37
  212. poller = x.hub.poller
  213. poller.poll.return_value = []
  214. with self.assertRaises(socket.error):
  215. asynloop(*x.args)
  216. poller.poll.assert_called_with(33.37)
  217. def test_poll_readable(self):
  218. x = X(self.app)
  219. reader = Mock(name='reader')
  220. x.hub.add_reader(6, reader, 6)
  221. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), mod=4))
  222. poller = x.hub.poller
  223. poller.poll.return_value = [(6, READ)]
  224. with self.assertRaises(socket.error):
  225. asynloop(*x.args)
  226. reader.assert_called_with(6)
  227. self.assertTrue(poller.poll.called)
  228. def test_poll_readable_raises_Empty(self):
  229. x = X(self.app)
  230. reader = Mock(name='reader')
  231. x.hub.add_reader(6, reader, 6)
  232. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  233. poller = x.hub.poller
  234. poller.poll.return_value = [(6, READ)]
  235. reader.side_effect = Empty()
  236. with self.assertRaises(socket.error):
  237. asynloop(*x.args)
  238. reader.assert_called_with(6)
  239. self.assertTrue(poller.poll.called)
  240. def test_poll_writable(self):
  241. x = X(self.app)
  242. writer = Mock(name='writer')
  243. x.hub.add_writer(6, writer, 6)
  244. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  245. poller = x.hub.poller
  246. poller.poll.return_value = [(6, WRITE)]
  247. with self.assertRaises(socket.error):
  248. asynloop(*x.args)
  249. writer.assert_called_with(6)
  250. self.assertTrue(poller.poll.called)
  251. def test_poll_writable_none_registered(self):
  252. x = X(self.app)
  253. writer = Mock(name='writer')
  254. x.hub.add_writer(6, writer, 6)
  255. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  256. poller = x.hub.poller
  257. poller.poll.return_value = [(7, WRITE)]
  258. with self.assertRaises(socket.error):
  259. asynloop(*x.args)
  260. self.assertTrue(poller.poll.called)
  261. def test_poll_unknown_event(self):
  262. x = X(self.app)
  263. writer = Mock(name='reader')
  264. x.hub.add_writer(6, writer, 6)
  265. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  266. poller = x.hub.poller
  267. poller.poll.return_value = [(6, 0)]
  268. with self.assertRaises(socket.error):
  269. asynloop(*x.args)
  270. self.assertTrue(poller.poll.called)
  271. def test_poll_keep_draining_disabled(self):
  272. x = X(self.app)
  273. x.hub.writers = {6: Mock()}
  274. poll = x.hub.poller.poll
  275. def se(*args, **kwargs):
  276. poll.side_effect = socket.error()
  277. poll.side_effect = se
  278. poller = x.hub.poller
  279. poll.return_value = [(6, 0)]
  280. with self.assertRaises(socket.error):
  281. asynloop(*x.args)
  282. self.assertTrue(poller.poll.called)
  283. def test_poll_err_writable(self):
  284. x = X(self.app)
  285. writer = Mock(name='writer')
  286. x.hub.add_writer(6, writer, 6, 48)
  287. x.hub.on_tick.add(x.close_then_error(Mock(), 2))
  288. poller = x.hub.poller
  289. poller.poll.return_value = [(6, ERR)]
  290. with self.assertRaises(socket.error):
  291. asynloop(*x.args)
  292. writer.assert_called_with(6, 48)
  293. self.assertTrue(poller.poll.called)
  294. def test_poll_write_generator(self):
  295. x = X(self.app)
  296. x.hub.remove = Mock(name='hub.remove()')
  297. def Gen():
  298. yield 1
  299. yield 2
  300. gen = Gen()
  301. x.hub.add_writer(6, gen)
  302. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  303. x.hub.poller.poll.return_value = [(6, WRITE)]
  304. with self.assertRaises(socket.error):
  305. asynloop(*x.args)
  306. self.assertTrue(gen.gi_frame.f_lasti != -1)
  307. self.assertFalse(x.hub.remove.called)
  308. def test_poll_write_generator_stopped(self):
  309. x = X(self.app)
  310. def Gen():
  311. raise StopIteration()
  312. yield
  313. gen = Gen()
  314. x.hub.add_writer(6, gen)
  315. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  316. x.hub.poller.poll.return_value = [(6, WRITE)]
  317. x.hub.remove = Mock(name='hub.remove()')
  318. with self.assertRaises(socket.error):
  319. asynloop(*x.args)
  320. self.assertIsNone(gen.gi_frame)
  321. def test_poll_write_generator_raises(self):
  322. x = X(self.app)
  323. def Gen():
  324. raise ValueError('foo')
  325. yield
  326. gen = Gen()
  327. x.hub.add_writer(6, gen)
  328. x.hub.remove = Mock(name='hub.remove()')
  329. x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
  330. x.hub.poller.poll.return_value = [(6, WRITE)]
  331. with self.assertRaises(ValueError):
  332. asynloop(*x.args)
  333. self.assertIsNone(gen.gi_frame)
  334. x.hub.remove.assert_called_with(6)
  335. def test_poll_err_readable(self):
  336. x = X(self.app)
  337. reader = Mock(name='reader')
  338. x.hub.add_reader(6, reader, 6, 24)
  339. x.hub.on_tick.add(x.close_then_error(Mock(), 2))
  340. poller = x.hub.poller
  341. poller.poll.return_value = [(6, ERR)]
  342. with self.assertRaises(socket.error):
  343. asynloop(*x.args)
  344. reader.assert_called_with(6, 24)
  345. self.assertTrue(poller.poll.called)
  346. def test_poll_raises_ValueError(self):
  347. x = X(self.app)
  348. x.hub.readers = {6: Mock()}
  349. poller = x.hub.poller
  350. x.close_then_error(poller.poll, exc=ValueError)
  351. asynloop(*x.args)
  352. self.assertTrue(poller.poll.called)
  353. class test_synloop(AppCase):
  354. def test_timeout_ignored(self):
  355. x = X(self.app)
  356. x.timeout_then_error(x.connection.drain_events)
  357. with self.assertRaises(socket.error):
  358. synloop(*x.args)
  359. self.assertEqual(x.connection.drain_events.call_count, 2)
  360. def test_updates_qos_when_changed(self):
  361. x = X(self.app)
  362. x.qos.prev = 2
  363. x.qos.value = 2
  364. x.timeout_then_error(x.connection.drain_events)
  365. with self.assertRaises(socket.error):
  366. synloop(*x.args)
  367. self.assertFalse(x.qos.update.called)
  368. x.qos.value = 4
  369. x.timeout_then_error(x.connection.drain_events)
  370. with self.assertRaises(socket.error):
  371. synloop(*x.args)
  372. x.qos.update.assert_called_with()
  373. def test_ignores_socket_errors_when_closed(self):
  374. x = X(self.app)
  375. x.close_then_error(x.connection.drain_events)
  376. self.assertIsNone(synloop(*x.args))
  377. class test_quick_drain(AppCase):
  378. def setup(self):
  379. self.connection = Mock(name='connection')
  380. def test_drain(self):
  381. _quick_drain(self.connection, timeout=33.3)
  382. self.connection.drain_events.assert_called_with(timeout=33.3)
  383. def test_drain_error(self):
  384. exc = KeyError()
  385. exc.errno = 313
  386. self.connection.drain_events.side_effect = exc
  387. with self.assertRaises(KeyError):
  388. _quick_drain(self.connection, timeout=33.3)
  389. def test_drain_error_EAGAIN(self):
  390. exc = KeyError()
  391. exc.errno = errno.EAGAIN
  392. self.connection.drain_events.side_effect = exc
  393. _quick_drain(self.connection, timeout=33.3)