test_prefork.py 12 KB


  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import os
  4. import pytest
  5. import socket
  6. from itertools import cycle
  7. from case import Mock, mock, patch, skip
  8. from celery.app.defaults import DEFAULTS
  9. from celery.five import range
  10. from celery.utils.collections import AttributeDict
  11. from celery.utils.functional import noop
  12. from celery.utils.objects import Bunch
  13. try:
  14. from celery.concurrency import prefork as mp
  15. from celery.concurrency import asynpool
  16. except ImportError:
  17. class _mp(object):
  18. RUN = 0x1
  19. class TaskPool(object):
  20. _pool = Mock()
  21. def __init__(self, *args, **kwargs):
  22. pass
  23. def start(self):
  24. pass
  25. def stop(self):
  26. pass
  27. def apply_async(self, *args, **kwargs):
  28. pass
  29. mp = _mp() # noqa
  30. asynpool = None # noqa
  31. class MockResult(object):
  32. def __init__(self, value, pid):
  33. self.value = value
  34. self.pid = pid
  35. def worker_pids(self):
  36. return [self.pid]
  37. def get(self):
  38. return self.value
  39. class test_process_initializer:
  40. @patch('celery.platforms.signals')
  41. @patch('celery.platforms.set_mp_process_title')
  42. def test_process_initializer(self, set_mp_process_title, _signals):
  43. with mock.restore_logging():
  44. from celery import signals
  45. from celery._state import _tls
  46. from celery.concurrency.prefork import (
  47. process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
  48. )
  49. on_worker_process_init = Mock()
  50. signals.worker_process_init.connect(on_worker_process_init)
  51. def Loader(*args, **kwargs):
  52. loader = Mock(*args, **kwargs)
  53. loader.conf = {}
  54. loader.override_backends = {}
  55. return loader
  56. with self.Celery(loader=Loader) as app:
  57. app.conf = AttributeDict(DEFAULTS)
  58. process_initializer(app, 'awesome.worker.com')
  59. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  60. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  61. assert app.loader.init_worker.call_count
  62. on_worker_process_init.assert_called()
  63. assert _tls.current_app is app
  64. set_mp_process_title.assert_called_with(
  65. 'celeryd', hostname='awesome.worker.com',
  66. )
  67. with patch('celery.app.trace.setup_worker_optimizations') as S:
  68. os.environ['FORKED_BY_MULTIPROCESSING'] = '1'
  69. try:
  70. process_initializer(app, 'luke.worker.com')
  71. S.assert_called_with(app, 'luke.worker.com')
  72. finally:
  73. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  74. os.environ['CELERY_LOG_FILE'] = 'worker%I.log'
  75. app.log.setup = Mock(name='log_setup')
  76. try:
  77. process_initializer(app, 'luke.worker.com')
  78. finally:
  79. os.environ.pop('CELERY_LOG_FILE', None)
  80. class test_process_destructor:
  81. @patch('celery.concurrency.prefork.signals')
  82. def test_process_destructor(self, signals):
  83. mp.process_destructor(13, -3)
  84. signals.worker_process_shutdown.send.assert_called_with(
  85. sender=None, pid=13, exitcode=-3,
  86. )
  87. class MockPool(object):
  88. started = False
  89. closed = False
  90. joined = False
  91. terminated = False
  92. _state = None
  93. def __init__(self, *args, **kwargs):
  94. self.started = True
  95. self._timeout_handler = Mock()
  96. self._result_handler = Mock()
  97. self.maintain_pool = Mock()
  98. self._state = mp.RUN
  99. self._processes = kwargs.get('processes')
  100. self._pool = [Bunch(pid=i, inqW_fd=1, outqR_fd=2)
  101. for i in range(self._processes)]
  102. self._current_proc = cycle(range(self._processes))
  103. def close(self):
  104. self.closed = True
  105. self._state = 'CLOSE'
  106. def join(self):
  107. self.joined = True
  108. def terminate(self):
  109. self.terminated = True
  110. def terminate_job(self, *args, **kwargs):
  111. pass
  112. def restart(self, *args, **kwargs):
  113. pass
  114. def handle_result_event(self, *args, **kwargs):
  115. pass
  116. def flush(self):
  117. pass
  118. def grow(self, n=1):
  119. self._processes += n
  120. def shrink(self, n=1):
  121. self._processes -= n
  122. def apply_async(self, *args, **kwargs):
  123. pass
  124. def register_with_event_loop(self, loop):
  125. pass
  126. class ExeMockPool(MockPool):
  127. def apply_async(self, target, args=(), kwargs={}, callback=noop):
  128. from threading import Timer
  129. res = target(*args, **kwargs)
  130. Timer(0.1, callback, (res,)).start()
  131. return MockResult(res, next(self._current_proc))
  132. class TaskPool(mp.TaskPool):
  133. Pool = BlockingPool = MockPool
  134. class ExeMockTaskPool(mp.TaskPool):
  135. Pool = BlockingPool = ExeMockPool
  136. @skip.if_win32()
  137. @skip.unless_module('multiprocessing')
  138. class test_AsynPool:
  139. def test_gen_not_started(self):
  140. def gen():
  141. yield 1
  142. yield 2
  143. g = gen()
  144. assert asynpool.gen_not_started(g)
  145. next(g)
  146. assert not asynpool.gen_not_started(g)
  147. list(g)
  148. assert not asynpool.gen_not_started(g)
  149. @patch('select.select', create=True)
  150. def test_select(self, __select):
  151. ebadf = socket.error()
  152. ebadf.errno = errno.EBADF
  153. with patch('select.poll', create=True) as poller:
  154. poll = poller.return_value = Mock(name='poll.poll')
  155. poll.return_value = {3}, set(), 0
  156. assert asynpool._select({3}, poll=poll) == ({3}, set(), 0)
  157. poll.return_value = {3}, set(), 0
  158. assert asynpool._select({3}, None, {3}, poll=poll) == (
  159. {3}, set(), 0,
  160. )
  161. eintr = socket.error()
  162. eintr.errno = errno.EINTR
  163. poll.side_effect = eintr
  164. readers = {3}
  165. assert asynpool._select(readers, poll=poll) == (set(), set(), 1)
  166. assert 3 in readers
  167. with patch('select.poll', create=True) as poller:
  168. poll = poller.return_value = Mock(name='poll.poll')
  169. poll.side_effect = ebadf
  170. with patch('select.select') as selcheck:
  171. selcheck.side_effect = ebadf
  172. readers = {3}
  173. assert asynpool._select(readers, poll=poll) == (
  174. set(), set(), 1,
  175. )
  176. assert 3 not in readers
  177. with patch('select.poll', create=True) as poller:
  178. poll = poller.return_value = Mock(name='poll.poll')
  179. poll.side_effect = MemoryError()
  180. with pytest.raises(MemoryError):
  181. asynpool._select({1}, poll=poll)
  182. with patch('select.poll', create=True) as poller:
  183. poll = poller.return_value = Mock(name='poll.poll')
  184. with patch('select.select') as selcheck:
  185. def se(*args):
  186. selcheck.side_effect = MemoryError()
  187. raise ebadf
  188. poll.side_effect = se
  189. with pytest.raises(MemoryError):
  190. asynpool._select({3}, poll=poll)
  191. with patch('select.poll', create=True) as poller:
  192. poll = poller.return_value = Mock(name='poll.poll')
  193. with patch('select.select') as selcheck:
  194. def se2(*args):
  195. selcheck.side_effect = socket.error()
  196. selcheck.side_effect.errno = 1321
  197. raise ebadf
  198. poll.side_effect = se2
  199. with pytest.raises(socket.error):
  200. asynpool._select({3}, poll=poll)
  201. with patch('select.poll', create=True) as poller:
  202. poll = poller.return_value = Mock(name='poll.poll')
  203. poll.side_effect = socket.error()
  204. poll.side_effect.errno = 34134
  205. with pytest.raises(socket.error):
  206. asynpool._select({3}, poll=poll)
  207. def test_promise(self):
  208. fun = Mock()
  209. x = asynpool.promise(fun, (1,), {'foo': 1})
  210. x()
  211. assert x.ready
  212. fun.assert_called_with(1, foo=1)
  213. def test_Worker(self):
  214. w = asynpool.Worker(Mock(), Mock())
  215. w.on_loop_start(1234)
  216. w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
  217. @skip.if_win32()
  218. @skip.unless_module('multiprocessing')
  219. class test_ResultHandler:
  220. def test_process_result(self):
  221. x = asynpool.ResultHandler(
  222. Mock(), Mock(), {}, Mock(),
  223. Mock(), Mock(), Mock(), Mock(),
  224. fileno_to_outq={},
  225. on_process_alive=Mock(),
  226. on_job_ready=Mock(),
  227. )
  228. assert x
  229. hub = Mock(name='hub')
  230. recv = x._recv_message = Mock(name='recv_message')
  231. recv.return_value = iter([])
  232. x.on_state_change = Mock()
  233. x.register_with_event_loop(hub)
  234. proc = x.fileno_to_outq[3] = Mock()
  235. reader = proc.outq._reader
  236. reader.poll.return_value = False
  237. x.handle_event(6) # KeyError
  238. x.handle_event(3)
  239. x._recv_message.assert_called_with(
  240. hub.add_reader, 3, x.on_state_change,
  241. )
  242. class test_TaskPool:
  243. def test_start(self):
  244. pool = TaskPool(10)
  245. pool.start()
  246. assert pool._pool.started
  247. assert pool._pool._state == asynpool.RUN
  248. _pool = pool._pool
  249. pool.stop()
  250. assert _pool.closed
  251. assert _pool.joined
  252. pool.stop()
  253. pool.start()
  254. _pool = pool._pool
  255. pool.terminate()
  256. pool.terminate()
  257. assert _pool.terminated
  258. def test_restart(self):
  259. pool = TaskPool(10)
  260. pool._pool = Mock(name='pool')
  261. pool.restart()
  262. pool._pool.restart.assert_called_with()
  263. pool._pool.apply_async.assert_called_with(mp.noop)
  264. def test_did_start_ok(self):
  265. pool = TaskPool(10)
  266. pool._pool = Mock(name='pool')
  267. assert pool.did_start_ok() is pool._pool.did_start_ok()
  268. def test_register_with_event_loop(self):
  269. pool = TaskPool(10)
  270. pool._pool = Mock(name='pool')
  271. loop = Mock(name='loop')
  272. pool.register_with_event_loop(loop)
  273. pool._pool.register_with_event_loop.assert_called_with(loop)
  274. def test_on_close(self):
  275. pool = TaskPool(10)
  276. pool._pool = Mock(name='pool')
  277. pool._pool._state = mp.RUN
  278. pool.on_close()
  279. pool._pool.close.assert_called_with()
  280. def test_on_close__pool_not_running(self):
  281. pool = TaskPool(10)
  282. pool._pool = Mock(name='pool')
  283. pool._pool._state = mp.CLOSE
  284. pool.on_close()
  285. pool._pool.close.assert_not_called()
  286. def test_apply_async(self):
  287. pool = TaskPool(10)
  288. pool.start()
  289. pool.apply_async(lambda x: x, (2,), {})
  290. def test_grow_shrink(self):
  291. pool = TaskPool(10)
  292. pool.start()
  293. assert pool._pool._processes == 10
  294. pool.grow()
  295. assert pool._pool._processes == 11
  296. pool.shrink(2)
  297. assert pool._pool._processes == 9
  298. def test_info(self):
  299. pool = TaskPool(10)
  300. procs = [Bunch(pid=i) for i in range(pool.limit)]
  301. class _Pool(object):
  302. _pool = procs
  303. _maxtasksperchild = None
  304. timeout = 10
  305. soft_timeout = 5
  306. def human_write_stats(self, *args, **kwargs):
  307. return {}
  308. pool._pool = _Pool()
  309. info = pool.info
  310. assert info['max-concurrency'] == pool.limit
  311. assert info['max-tasks-per-child'] == 'N/A'
  312. assert info['timeouts'] == (5, 10)
  313. def test_num_processes(self):
  314. pool = TaskPool(7)
  315. pool.start()
  316. assert pool.num_processes == 7