test_prefork.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. from __future__ import absolute_import
  2. import errno
  3. import os
  4. import socket
  5. import sys
  6. from itertools import cycle
  7. from celery.app.defaults import DEFAULTS
  8. from celery.datastructures import AttributeDict
  9. from celery.five import range
  10. from celery.utils.functional import noop
  11. from celery.utils.objects import Bunch
  12. from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
  13. try:
  14. from celery.concurrency import prefork as mp
  15. from celery.concurrency import asynpool
  16. except ImportError:
  17. class _mp(object):
  18. RUN = 0x1
  19. class TaskPool(object):
  20. _pool = Mock()
  21. def __init__(self, *args, **kwargs):
  22. pass
  23. def start(self):
  24. pass
  25. def stop(self):
  26. pass
  27. def apply_async(self, *args, **kwargs):
  28. pass
  29. mp = _mp() # noqa
  30. asynpool = None # noqa
  31. class MockResult(object):
  32. def __init__(self, value, pid):
  33. self.value = value
  34. self.pid = pid
  35. def worker_pids(self):
  36. return [self.pid]
  37. def get(self):
  38. return self.value
  39. class test_process_initializer(AppCase):
  40. @patch('celery.platforms.signals')
  41. @patch('celery.platforms.set_mp_process_title')
  42. def test_process_initializer(self, set_mp_process_title, _signals):
  43. with restore_logging():
  44. from celery import signals
  45. from celery._state import _tls
  46. from celery.concurrency.prefork import (
  47. process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
  48. )
  49. def on_worker_process_init(**kwargs):
  50. on_worker_process_init.called = True
  51. on_worker_process_init.called = False
  52. signals.worker_process_init.connect(on_worker_process_init)
  53. def Loader(*args, **kwargs):
  54. loader = Mock(*args, **kwargs)
  55. loader.conf = {}
  56. loader.override_backends = {}
  57. return loader
  58. with self.Celery(loader=Loader) as app:
  59. app.conf = AttributeDict(DEFAULTS)
  60. process_initializer(app, 'awesome.worker.com')
  61. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  62. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  63. self.assertTrue(app.loader.init_worker.call_count)
  64. self.assertTrue(on_worker_process_init.called)
  65. self.assertIs(_tls.current_app, app)
  66. set_mp_process_title.assert_called_with(
  67. 'celeryd', hostname='awesome.worker.com',
  68. )
  69. with patch('celery.app.trace.setup_worker_optimizations') as S:
  70. os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
  71. try:
  72. process_initializer(app, 'luke.worker.com')
  73. S.assert_called_with(app, 'luke.worker.com')
  74. finally:
  75. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  76. os.environ['CELERY_LOG_FILE'] = 'worker%I.log'
  77. app.log.setup = Mock(name='log_setup')
  78. try:
  79. process_initializer(app, 'luke.worker.com')
  80. finally:
  81. os.environ.pop('CELERY_LOG_FILE', None)
  82. class test_process_destructor(AppCase):
  83. @patch('celery.concurrency.prefork.signals')
  84. def test_process_destructor(self, signals):
  85. mp.process_destructor(13, -3)
  86. signals.worker_process_shutdown.send.assert_called_with(
  87. sender=None, pid=13, exitcode=-3,
  88. )
  89. class MockPool(object):
  90. started = False
  91. closed = False
  92. joined = False
  93. terminated = False
  94. _state = None
  95. def __init__(self, *args, **kwargs):
  96. self.started = True
  97. self._timeout_handler = Mock()
  98. self._result_handler = Mock()
  99. self.maintain_pool = Mock()
  100. self._state = mp.RUN
  101. self._processes = kwargs.get('processes')
  102. self._pool = [Bunch(pid=i, inqW_fd=1, outqR_fd=2)
  103. for i in range(self._processes)]
  104. self._current_proc = cycle(range(self._processes))
  105. def close(self):
  106. self.closed = True
  107. self._state = 'CLOSE'
  108. def join(self):
  109. self.joined = True
  110. def terminate(self):
  111. self.terminated = True
  112. def terminate_job(self, *args, **kwargs):
  113. pass
  114. def restart(self, *args, **kwargs):
  115. pass
  116. def handle_result_event(self, *args, **kwargs):
  117. pass
  118. def flush(self):
  119. pass
  120. def grow(self, n=1):
  121. self._processes += n
  122. def shrink(self, n=1):
  123. self._processes -= n
  124. def apply_async(self, *args, **kwargs):
  125. pass
  126. def register_with_event_loop(self, loop):
  127. pass
  128. class ExeMockPool(MockPool):
  129. def apply_async(self, target, args=(), kwargs={}, callback=noop):
  130. from threading import Timer
  131. res = target(*args, **kwargs)
  132. Timer(0.1, callback, (res,)).start()
  133. return MockResult(res, next(self._current_proc))
  134. class TaskPool(mp.TaskPool):
  135. Pool = BlockingPool = MockPool
  136. class ExeMockTaskPool(mp.TaskPool):
  137. Pool = BlockingPool = ExeMockPool
  138. class PoolCase(AppCase):
  139. def setup(self):
  140. try:
  141. import multiprocessing # noqa
  142. except ImportError:
  143. raise SkipTest('multiprocessing not supported')
  144. class test_AsynPool(PoolCase):
  145. def setup(self):
  146. if sys.platform == 'win32':
  147. raise SkipTest('win32: skip')
  148. def test_gen_not_started(self):
  149. def gen():
  150. yield 1
  151. yield 2
  152. g = gen()
  153. self.assertTrue(asynpool.gen_not_started(g))
  154. next(g)
  155. self.assertFalse(asynpool.gen_not_started(g))
  156. list(g)
  157. self.assertFalse(asynpool.gen_not_started(g))
  158. @patch('select.select', create=True)
  159. def test_select(self, __select):
  160. ebadf = socket.error()
  161. ebadf.errno = errno.EBADF
  162. with patch('select.poll', create=True) as poller:
  163. poll = poller.return_value = Mock(name='poll.poll')
  164. poll.return_value = {3}, set(), 0
  165. self.assertEqual(
  166. asynpool._select({3}, poll=poll),
  167. ({3}, set(), 0),
  168. )
  169. poll.return_value = {3}, set(), 0
  170. self.assertEqual(
  171. asynpool._select({3}, None, {3}, poll=poll),
  172. ({3}, set(), 0),
  173. )
  174. eintr = socket.error()
  175. eintr.errno = errno.EINTR
  176. poll.side_effect = eintr
  177. readers = {3}
  178. self.assertEqual(
  179. asynpool._select(readers, poll=poll),
  180. (set(), set(), 1),
  181. )
  182. self.assertIn(3, readers)
  183. with patch('select.poll') as poller:
  184. poll = poller.return_value = Mock(name='poll.poll')
  185. poll.side_effect = ebadf
  186. with patch('select.select') as selcheck:
  187. selcheck.side_effect = ebadf
  188. readers = {3}
  189. self.assertEqual(
  190. asynpool._select(readers, poll=poll),
  191. (set(), set(), 1),
  192. )
  193. self.assertNotIn(3, readers)
  194. with patch('select.poll') as poller:
  195. poll = poller.return_value = Mock(name='poll.poll')
  196. poll.side_effect = MemoryError()
  197. with self.assertRaises(MemoryError):
  198. asynpool._select({1}, poll=poll)
  199. with patch('select.poll') as poller:
  200. poll = poller.return_value = Mock(name='poll.poll')
  201. with patch('select.select') as selcheck:
  202. def se(*args):
  203. selcheck.side_effect = MemoryError()
  204. raise ebadf
  205. poll.side_effect = se
  206. with self.assertRaises(MemoryError):
  207. asynpool._select({3}, poll=poll)
  208. with patch('select.poll') as poller:
  209. poll = poller.return_value = Mock(name='poll.poll')
  210. with patch('select.select') as selcheck:
  211. def se2(*args):
  212. selcheck.side_effect = socket.error()
  213. selcheck.side_effect.errno = 1321
  214. raise ebadf
  215. poll.side_effect = se2
  216. with self.assertRaises(socket.error):
  217. asynpool._select({3}, poll=poll)
  218. with patch('select.poll') as poller:
  219. poll = poller.return_value = Mock(name='poll.poll')
  220. poll.side_effect = socket.error()
  221. poll.side_effect.errno = 34134
  222. with self.assertRaises(socket.error):
  223. asynpool._select({3}, poll=poll)
  224. def test_promise(self):
  225. fun = Mock()
  226. x = asynpool.promise(fun, (1,), {'foo': 1})
  227. x()
  228. self.assertTrue(x.ready)
  229. fun.assert_called_with(1, foo=1)
  230. def test_Worker(self):
  231. w = asynpool.Worker(Mock(), Mock())
  232. w.on_loop_start(1234)
  233. w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
  234. class test_ResultHandler(PoolCase):
  235. def setup(self):
  236. if sys.platform == 'win32':
  237. raise SkipTest('win32: skip')
  238. def test_process_result(self):
  239. x = asynpool.ResultHandler(
  240. Mock(), Mock(), {}, Mock(),
  241. Mock(), Mock(), Mock(), Mock(),
  242. fileno_to_outq={},
  243. on_process_alive=Mock(),
  244. on_job_ready=Mock(),
  245. )
  246. self.assertTrue(x)
  247. hub = Mock(name='hub')
  248. recv = x._recv_message = Mock(name='recv_message')
  249. recv.return_value = iter([])
  250. x.on_state_change = Mock()
  251. x.register_with_event_loop(hub)
  252. proc = x.fileno_to_outq[3] = Mock()
  253. reader = proc.outq._reader
  254. reader.poll.return_value = False
  255. x.handle_event(6) # KeyError
  256. x.handle_event(3)
  257. x._recv_message.assert_called_with(
  258. hub.add_reader, 3, x.on_state_change,
  259. )
  260. class test_TaskPool(PoolCase):
  261. def test_start(self):
  262. pool = TaskPool(10)
  263. pool.start()
  264. self.assertTrue(pool._pool.started)
  265. self.assertTrue(pool._pool._state == asynpool.RUN)
  266. _pool = pool._pool
  267. pool.stop()
  268. self.assertTrue(_pool.closed)
  269. self.assertTrue(_pool.joined)
  270. pool.stop()
  271. pool.start()
  272. _pool = pool._pool
  273. pool.terminate()
  274. pool.terminate()
  275. self.assertTrue(_pool.terminated)
  276. def test_restart(self):
  277. pool = TaskPool(10)
  278. pool._pool = Mock(name='pool')
  279. pool.restart()
  280. pool._pool.restart.assert_called_with()
  281. pool._pool.apply_async.assert_called_with(mp.noop)
  282. def test_did_start_ok(self):
  283. pool = TaskPool(10)
  284. pool._pool = Mock(name='pool')
  285. self.assertIs(pool.did_start_ok(), pool._pool.did_start_ok())
  286. def test_register_with_event_loop(self):
  287. pool = TaskPool(10)
  288. pool._pool = Mock(name='pool')
  289. loop = Mock(name='loop')
  290. pool.register_with_event_loop(loop)
  291. pool._pool.register_with_event_loop.assert_called_with(loop)
  292. def test_on_close(self):
  293. pool = TaskPool(10)
  294. pool._pool = Mock(name='pool')
  295. pool._pool._state = mp.RUN
  296. pool.on_close()
  297. pool._pool.close.assert_called_with()
  298. def test_on_close__pool_not_running(self):
  299. pool = TaskPool(10)
  300. pool._pool = Mock(name='pool')
  301. pool._pool._state = mp.CLOSE
  302. pool.on_close()
  303. self.assertFalse(pool._pool.close.called)
  304. def test_apply_async(self):
  305. pool = TaskPool(10)
  306. pool.start()
  307. pool.apply_async(lambda x: x, (2,), {})
  308. def test_grow_shrink(self):
  309. pool = TaskPool(10)
  310. pool.start()
  311. self.assertEqual(pool._pool._processes, 10)
  312. pool.grow()
  313. self.assertEqual(pool._pool._processes, 11)
  314. pool.shrink(2)
  315. self.assertEqual(pool._pool._processes, 9)
  316. def test_info(self):
  317. pool = TaskPool(10)
  318. procs = [Bunch(pid=i) for i in range(pool.limit)]
  319. class _Pool(object):
  320. _pool = procs
  321. _maxtasksperchild = None
  322. timeout = 10
  323. soft_timeout = 5
  324. def human_write_stats(self, *args, **kwargs):
  325. return {}
  326. pool._pool = _Pool()
  327. info = pool.info
  328. self.assertEqual(info['max-concurrency'], pool.limit)
  329. self.assertEqual(info['max-tasks-per-child'], 'N/A')
  330. self.assertEqual(info['timeouts'], (5, 10))
  331. def test_num_processes(self):
  332. pool = TaskPool(7)
  333. pool.start()
  334. self.assertEqual(pool.num_processes, 7)