test_prefork.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import errno
  2. import os
  3. import pytest
  4. import socket
  5. from itertools import cycle
  6. from case import Mock, mock, patch, skip
  7. from celery.app.defaults import DEFAULTS
  8. from celery.utils.collections import AttributeDict
  9. from celery.utils.functional import noop
  10. from celery.utils.objects import Bunch
  11. try:
  12. from celery.concurrency import prefork as mp
  13. from celery.concurrency import asynpool
  14. except ImportError:
  15. class _mp:
  16. RUN = 0x1
  17. class TaskPool:
  18. _pool = Mock()
  19. def __init__(self, *args, **kwargs):
  20. pass
  21. def start(self):
  22. pass
  23. def stop(self):
  24. pass
  25. def apply_async(self, *args, **kwargs):
  26. pass
  27. mp = _mp() # noqa
  28. asynpool = None # noqa
  29. class MockResult:
  30. def __init__(self, value, pid):
  31. self.value = value
  32. self.pid = pid
  33. def worker_pids(self):
  34. return [self.pid]
  35. def get(self):
  36. return self.value
  37. class test_process_initializer:
  38. @patch('celery.platforms.signals')
  39. @patch('celery.platforms.set_mp_process_title')
  40. def test_process_initializer(self, set_mp_process_title, _signals):
  41. with mock.restore_logging():
  42. from celery import signals
  43. from celery._state import _tls
  44. from celery.concurrency.prefork import (
  45. process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
  46. )
  47. on_worker_process_init = Mock()
  48. signals.worker_process_init.connect(on_worker_process_init)
  49. def Loader(*args, **kwargs):
  50. loader = Mock(*args, **kwargs)
  51. loader.conf = {}
  52. loader.override_backends = {}
  53. return loader
  54. with self.Celery(loader=Loader) as app:
  55. app.conf = AttributeDict(DEFAULTS)
  56. process_initializer(app, 'awesome.worker.com')
  57. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  58. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  59. assert app.loader.init_worker.call_count
  60. on_worker_process_init.assert_called()
  61. assert _tls.current_app is app
  62. set_mp_process_title.assert_called_with(
  63. 'celeryd', hostname='awesome.worker.com',
  64. )
  65. with patch('celery.app.trace.setup_worker_optimizations') as S:
  66. os.environ['FORKED_BY_MULTIPROCESSING'] = '1'
  67. try:
  68. process_initializer(app, 'luke.worker.com')
  69. S.assert_called_with(app, 'luke.worker.com')
  70. finally:
  71. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  72. os.environ['CELERY_LOG_FILE'] = 'worker%I.log'
  73. app.log.setup = Mock(name='log_setup')
  74. try:
  75. process_initializer(app, 'luke.worker.com')
  76. finally:
  77. os.environ.pop('CELERY_LOG_FILE', None)
  78. class test_process_destructor:
  79. @patch('celery.concurrency.prefork.signals')
  80. def test_process_destructor(self, signals):
  81. mp.process_destructor(13, -3)
  82. signals.worker_process_shutdown.send.assert_called_with(
  83. sender=None, pid=13, exitcode=-3,
  84. )
  85. class MockPool:
  86. started = False
  87. closed = False
  88. joined = False
  89. terminated = False
  90. _state = None
  91. def __init__(self, *args, **kwargs):
  92. self.started = True
  93. self._timeout_handler = Mock()
  94. self._result_handler = Mock()
  95. self.maintain_pool = Mock()
  96. self._state = mp.RUN
  97. self._processes = kwargs.get('processes')
  98. self._pool = [Bunch(pid=i, inqW_fd=1, outqR_fd=2)
  99. for i in range(self._processes)]
  100. self._current_proc = cycle(range(self._processes))
  101. def close(self):
  102. self.closed = True
  103. self._state = 'CLOSE'
  104. def join(self):
  105. self.joined = True
  106. def terminate(self):
  107. self.terminated = True
  108. def terminate_job(self, *args, **kwargs):
  109. pass
  110. def restart(self, *args, **kwargs):
  111. pass
  112. def handle_result_event(self, *args, **kwargs):
  113. pass
  114. def flush(self):
  115. pass
  116. def grow(self, n=1):
  117. self._processes += n
  118. def shrink(self, n=1):
  119. self._processes -= n
  120. def apply_async(self, *args, **kwargs):
  121. pass
  122. def register_with_event_loop(self, loop):
  123. pass
  124. class ExeMockPool(MockPool):
  125. def apply_async(self, target, args=(), kwargs={}, callback=noop):
  126. from threading import Timer
  127. res = target(*args, **kwargs)
  128. Timer(0.1, callback, (res,)).start()
  129. return MockResult(res, next(self._current_proc))
  130. class TaskPool(mp.TaskPool):
  131. Pool = BlockingPool = MockPool
  132. class ExeMockTaskPool(mp.TaskPool):
  133. Pool = BlockingPool = ExeMockPool
  134. @skip.if_win32()
  135. @skip.unless_module('multiprocessing')
  136. class test_AsynPool:
  137. def test_gen_not_started(self):
  138. def gen():
  139. yield 1
  140. yield 2
  141. g = gen()
  142. assert asynpool.gen_not_started(g)
  143. next(g)
  144. assert not asynpool.gen_not_started(g)
  145. list(g)
  146. assert not asynpool.gen_not_started(g)
  147. @patch('select.select', create=True)
  148. def test_select(self, __select):
  149. ebadf = socket.error()
  150. ebadf.errno = errno.EBADF
  151. with patch('select.poll', create=True) as poller:
  152. poll = poller.return_value = Mock(name='poll.poll')
  153. poll.return_value = {3}, set(), 0
  154. assert asynpool._select({3}, poll=poll) == ({3}, set(), 0)
  155. poll.return_value = {3}, set(), 0
  156. assert asynpool._select({3}, None, {3}, poll=poll) == (
  157. {3}, set(), 0,
  158. )
  159. eintr = socket.error()
  160. eintr.errno = errno.EINTR
  161. poll.side_effect = eintr
  162. readers = {3}
  163. assert asynpool._select(readers, poll=poll) == (set(), set(), 1)
  164. assert 3 in readers
  165. with patch('select.poll', create=True) as poller:
  166. poll = poller.return_value = Mock(name='poll.poll')
  167. poll.side_effect = ebadf
  168. with patch('select.select') as selcheck:
  169. selcheck.side_effect = ebadf
  170. readers = {3}
  171. assert asynpool._select(readers, poll=poll) == (
  172. set(), set(), 1,
  173. )
  174. assert 3 not in readers
  175. with patch('select.poll', create=True) as poller:
  176. poll = poller.return_value = Mock(name='poll.poll')
  177. poll.side_effect = MemoryError()
  178. with pytest.raises(MemoryError):
  179. asynpool._select({1}, poll=poll)
  180. with patch('select.poll', create=True) as poller:
  181. poll = poller.return_value = Mock(name='poll.poll')
  182. with patch('select.select') as selcheck:
  183. def se(*args):
  184. selcheck.side_effect = MemoryError()
  185. raise ebadf
  186. poll.side_effect = se
  187. with pytest.raises(MemoryError):
  188. asynpool._select({3}, poll=poll)
  189. with patch('select.poll', create=True) as poller:
  190. poll = poller.return_value = Mock(name='poll.poll')
  191. with patch('select.select') as selcheck:
  192. def se2(*args):
  193. selcheck.side_effect = socket.error()
  194. selcheck.side_effect.errno = 1321
  195. raise ebadf
  196. poll.side_effect = se2
  197. with pytest.raises(socket.error):
  198. asynpool._select({3}, poll=poll)
  199. with patch('select.poll', create=True) as poller:
  200. poll = poller.return_value = Mock(name='poll.poll')
  201. poll.side_effect = socket.error()
  202. poll.side_effect.errno = 34134
  203. with pytest.raises(socket.error):
  204. asynpool._select({3}, poll=poll)
  205. def test_promise(self):
  206. fun = Mock()
  207. x = asynpool.promise(fun, (1,), {'foo': 1})
  208. x()
  209. assert x.ready
  210. fun.assert_called_with(1, foo=1)
  211. def test_Worker(self):
  212. w = asynpool.Worker(Mock(), Mock())
  213. w.on_loop_start(1234)
  214. w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
  215. @skip.if_win32()
  216. @skip.unless_module('multiprocessing')
  217. class test_ResultHandler:
  218. def test_process_result(self):
  219. x = asynpool.ResultHandler(
  220. Mock(), Mock(), {}, Mock(),
  221. Mock(), Mock(), Mock(), Mock(),
  222. fileno_to_outq={},
  223. on_process_alive=Mock(),
  224. on_job_ready=Mock(),
  225. )
  226. assert x
  227. hub = Mock(name='hub')
  228. recv = x._recv_message = Mock(name='recv_message')
  229. recv.return_value = iter([])
  230. x.on_state_change = Mock()
  231. x.register_with_event_loop(hub)
  232. proc = x.fileno_to_outq[3] = Mock()
  233. reader = proc.outq._reader
  234. reader.poll.return_value = False
  235. x.handle_event(6) # KeyError
  236. x.handle_event(3)
  237. x._recv_message.assert_called_with(
  238. hub.add_reader, 3, x.on_state_change,
  239. )
  240. class test_TaskPool:
  241. def test_start(self):
  242. pool = TaskPool(10)
  243. pool.start()
  244. assert pool._pool.started
  245. assert pool._pool._state == asynpool.RUN
  246. _pool = pool._pool
  247. pool.stop()
  248. assert _pool.closed
  249. assert _pool.joined
  250. pool.stop()
  251. pool.start()
  252. _pool = pool._pool
  253. pool.terminate()
  254. pool.terminate()
  255. assert _pool.terminated
  256. def test_restart(self):
  257. pool = TaskPool(10)
  258. pool._pool = Mock(name='pool')
  259. pool.restart()
  260. pool._pool.restart.assert_called_with()
  261. pool._pool.apply_async.assert_called_with(mp.noop)
  262. def test_did_start_ok(self):
  263. pool = TaskPool(10)
  264. pool._pool = Mock(name='pool')
  265. assert pool.did_start_ok() is pool._pool.did_start_ok()
  266. def test_register_with_event_loop(self):
  267. pool = TaskPool(10)
  268. pool._pool = Mock(name='pool')
  269. loop = Mock(name='loop')
  270. pool.register_with_event_loop(loop)
  271. pool._pool.register_with_event_loop.assert_called_with(loop)
  272. def test_on_close(self):
  273. pool = TaskPool(10)
  274. pool._pool = Mock(name='pool')
  275. pool._pool._state = mp.RUN
  276. pool.on_close()
  277. pool._pool.close.assert_called_with()
  278. def test_on_close__pool_not_running(self):
  279. pool = TaskPool(10)
  280. pool._pool = Mock(name='pool')
  281. pool._pool._state = mp.CLOSE
  282. pool.on_close()
  283. pool._pool.close.assert_not_called()
  284. def test_apply_async(self):
  285. pool = TaskPool(10)
  286. pool.start()
  287. pool.apply_async(lambda x: x, (2,), {})
  288. def test_grow_shrink(self):
  289. pool = TaskPool(10)
  290. pool.start()
  291. assert pool._pool._processes == 10
  292. pool.grow()
  293. assert pool._pool._processes == 11
  294. pool.shrink(2)
  295. assert pool._pool._processes == 9
  296. def test_info(self):
  297. pool = TaskPool(10)
  298. procs = [Bunch(pid=i) for i in range(pool.limit)]
  299. class _Pool:
  300. _pool = procs
  301. _maxtasksperchild = None
  302. timeout = 10
  303. soft_timeout = 5
  304. def human_write_stats(self, *args, **kwargs):
  305. return {}
  306. pool._pool = _Pool()
  307. info = pool.info
  308. assert info['max-concurrency'] == pool.limit
  309. assert info['max-tasks-per-child'] == 'N/A'
  310. assert info['timeouts'] == (5, 10)
  311. def test_num_processes(self):
  312. pool = TaskPool(7)
  313. pool.start()
  314. assert pool.num_processes == 7