test_prefork.py 13 KB


  1. from __future__ import absolute_import
  2. import errno
  3. import os
  4. import socket
  5. import sys
  6. from itertools import cycle
  7. from celery.app.defaults import DEFAULTS
  8. from celery.datastructures import AttributeDict
  9. from celery.five import items, range
  10. from celery.utils.functional import noop
  11. from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging
  12. try:
  13. from celery.concurrency import prefork as mp
  14. from celery.concurrency import asynpool
  15. except ImportError:
  16. class _mp(object):
  17. RUN = 0x1
  18. class TaskPool(object):
  19. _pool = Mock()
  20. def __init__(self, *args, **kwargs):
  21. pass
  22. def start(self):
  23. pass
  24. def stop(self):
  25. pass
  26. def apply_async(self, *args, **kwargs):
  27. pass
  28. mp = _mp() # noqa
  29. asynpool = None # noqa
  30. class Object(object): # for writeable attributes.
  31. def __init__(self, **kwargs):
  32. [setattr(self, k, v) for k, v in items(kwargs)]
  33. class MockResult(object):
  34. def __init__(self, value, pid):
  35. self.value = value
  36. self.pid = pid
  37. def worker_pids(self):
  38. return [self.pid]
  39. def get(self):
  40. return self.value
  41. class test_process_initializer(AppCase):
  42. @patch('celery.platforms.signals')
  43. @patch('celery.platforms.set_mp_process_title')
  44. def test_process_initializer(self, set_mp_process_title, _signals):
  45. with restore_logging():
  46. from celery import signals
  47. from celery._state import _tls
  48. from celery.concurrency.prefork import (
  49. process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
  50. )
  51. def on_worker_process_init(**kwargs):
  52. on_worker_process_init.called = True
  53. on_worker_process_init.called = False
  54. signals.worker_process_init.connect(on_worker_process_init)
  55. def Loader(*args, **kwargs):
  56. loader = Mock(*args, **kwargs)
  57. loader.conf = {}
  58. loader.override_backends = {}
  59. return loader
  60. with self.Celery(loader=Loader) as app:
  61. app.conf = AttributeDict(DEFAULTS)
  62. process_initializer(app, 'awesome.worker.com')
  63. _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
  64. _signals.reset.assert_any_call(*WORKER_SIGRESET)
  65. self.assertTrue(app.loader.init_worker.call_count)
  66. self.assertTrue(on_worker_process_init.called)
  67. self.assertIs(_tls.current_app, app)
  68. set_mp_process_title.assert_called_with(
  69. 'celeryd', hostname='awesome.worker.com',
  70. )
  71. with patch('celery.app.trace.setup_worker_optimizations') as S:
  72. os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
  73. try:
  74. process_initializer(app, 'luke.worker.com')
  75. S.assert_called_with(app, 'luke.worker.com')
  76. finally:
  77. os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
  78. os.environ['CELERY_LOG_FILE'] = 'worker%I.log'
  79. app.log.setup = Mock(name='log_setup')
  80. try:
  81. process_initializer(app, 'luke.worker.com')
  82. finally:
  83. os.environ.pop('CELERY_LOG_FILE', None)
  84. class test_process_destructor(AppCase):
  85. @patch('celery.concurrency.prefork.signals')
  86. def test_process_destructor(self, signals):
  87. mp.process_destructor(13, -3)
  88. signals.worker_process_shutdown.send.assert_called_with(
  89. sender=None, pid=13, exitcode=-3,
  90. )
  91. class MockPool(object):
  92. started = False
  93. closed = False
  94. joined = False
  95. terminated = False
  96. _state = None
  97. def __init__(self, *args, **kwargs):
  98. self.started = True
  99. self._timeout_handler = Mock()
  100. self._result_handler = Mock()
  101. self.maintain_pool = Mock()
  102. self._state = mp.RUN
  103. self._processes = kwargs.get('processes')
  104. self._pool = [Object(pid=i, inqW_fd=1, outqR_fd=2)
  105. for i in range(self._processes)]
  106. self._current_proc = cycle(range(self._processes))
  107. def close(self):
  108. self.closed = True
  109. self._state = 'CLOSE'
  110. def join(self):
  111. self.joined = True
  112. def terminate(self):
  113. self.terminated = True
  114. def terminate_job(self, *args, **kwargs):
  115. pass
  116. def restart(self, *args, **kwargs):
  117. pass
  118. def handle_result_event(self, *args, **kwargs):
  119. pass
  120. def flush(self):
  121. pass
  122. def grow(self, n=1):
  123. self._processes += n
  124. def shrink(self, n=1):
  125. self._processes -= n
  126. def apply_async(self, *args, **kwargs):
  127. pass
  128. def register_with_event_loop(self, loop):
  129. pass
  130. class ExeMockPool(MockPool):
  131. def apply_async(self, target, args=(), kwargs={}, callback=noop):
  132. from threading import Timer
  133. res = target(*args, **kwargs)
  134. Timer(0.1, callback, (res,)).start()
  135. return MockResult(res, next(self._current_proc))
  136. class TaskPool(mp.TaskPool):
  137. Pool = BlockingPool = MockPool
  138. class ExeMockTaskPool(mp.TaskPool):
  139. Pool = BlockingPool = ExeMockPool
  140. class PoolCase(AppCase):
  141. def setup(self):
  142. try:
  143. import multiprocessing # noqa
  144. except ImportError:
  145. raise SkipTest('multiprocessing not supported')
  146. class test_AsynPool(PoolCase):
  147. def setup(self):
  148. if sys.platform == 'win32':
  149. raise SkipTest('win32: skip')
  150. def test_gen_not_started(self):
  151. def gen():
  152. yield 1
  153. yield 2
  154. g = gen()
  155. self.assertTrue(asynpool.gen_not_started(g))
  156. next(g)
  157. self.assertFalse(asynpool.gen_not_started(g))
  158. list(g)
  159. self.assertFalse(asynpool.gen_not_started(g))
  160. @patch('select.select', create=True)
  161. def test_select(self, __select):
  162. ebadf = socket.error()
  163. ebadf.errno = errno.EBADF
  164. with patch('select.poll', create=True) as poller:
  165. poll = poller.return_value = Mock(name='poll.poll')
  166. poll.return_value = {3}, set(), 0
  167. self.assertEqual(
  168. asynpool._select({3}, poll=poll),
  169. ({3}, set(), 0),
  170. )
  171. poll.return_value = {3}, set(), 0
  172. self.assertEqual(
  173. asynpool._select({3}, None, {3}, poll=poll),
  174. ({3}, set(), 0),
  175. )
  176. eintr = socket.error()
  177. eintr.errno = errno.EINTR
  178. poll.side_effect = eintr
  179. readers = {3}
  180. self.assertEqual(
  181. asynpool._select(readers, poll=poll),
  182. (set(), set(), 1),
  183. )
  184. self.assertIn(3, readers)
  185. with patch('select.poll') as poller:
  186. poll = poller.return_value = Mock(name='poll.poll')
  187. poll.side_effect = ebadf
  188. with patch('select.select') as selcheck:
  189. selcheck.side_effect = ebadf
  190. readers = {3}
  191. self.assertEqual(
  192. asynpool._select(readers, poll=poll),
  193. (set(), set(), 1),
  194. )
  195. self.assertNotIn(3, readers)
  196. with patch('select.poll') as poller:
  197. poll = poller.return_value = Mock(name='poll.poll')
  198. poll.side_effect = MemoryError()
  199. with self.assertRaises(MemoryError):
  200. asynpool._select({1}, poll=poll)
  201. with patch('select.poll') as poller:
  202. poll = poller.return_value = Mock(name='poll.poll')
  203. with patch('select.select') as selcheck:
  204. def se(*args):
  205. selcheck.side_effect = MemoryError()
  206. raise ebadf
  207. poll.side_effect = se
  208. with self.assertRaises(MemoryError):
  209. asynpool._select({3}, poll=poll)
  210. with patch('select.poll') as poller:
  211. poll = poller.return_value = Mock(name='poll.poll')
  212. with patch('select.select') as selcheck:
  213. def se2(*args):
  214. selcheck.side_effect = socket.error()
  215. selcheck.side_effect.errno = 1321
  216. raise ebadf
  217. poll.side_effect = se2
  218. with self.assertRaises(socket.error):
  219. asynpool._select({3}, poll=poll)
  220. with patch('select.poll') as poller:
  221. poll = poller.return_value = Mock(name='poll.poll')
  222. poll.side_effect = socket.error()
  223. poll.side_effect.errno = 34134
  224. with self.assertRaises(socket.error):
  225. asynpool._select({3}, poll=poll)
  226. def test_promise(self):
  227. fun = Mock()
  228. x = asynpool.promise(fun, (1,), {'foo': 1})
  229. x()
  230. self.assertTrue(x.ready)
  231. fun.assert_called_with(1, foo=1)
  232. def test_Worker(self):
  233. w = asynpool.Worker(Mock(), Mock())
  234. w.on_loop_start(1234)
  235. w.outq.put.assert_called_with((asynpool.WORKER_UP, (1234,)))
  236. class test_ResultHandler(PoolCase):
  237. def setup(self):
  238. if sys.platform == 'win32':
  239. raise SkipTest('win32: skip')
  240. def test_process_result(self):
  241. x = asynpool.ResultHandler(
  242. Mock(), Mock(), {}, Mock(),
  243. Mock(), Mock(), Mock(), Mock(),
  244. fileno_to_outq={},
  245. on_process_alive=Mock(),
  246. on_job_ready=Mock(),
  247. )
  248. self.assertTrue(x)
  249. hub = Mock(name='hub')
  250. recv = x._recv_message = Mock(name='recv_message')
  251. recv.return_value = iter([])
  252. x.on_state_change = Mock()
  253. x.register_with_event_loop(hub)
  254. proc = x.fileno_to_outq[3] = Mock()
  255. reader = proc.outq._reader
  256. reader.poll.return_value = False
  257. x.handle_event(6) # KeyError
  258. x.handle_event(3)
  259. x._recv_message.assert_called_with(
  260. hub.add_reader, 3, x.on_state_change,
  261. )
  262. class test_TaskPool(PoolCase):
  263. def test_start(self):
  264. pool = TaskPool(10)
  265. pool.start()
  266. self.assertTrue(pool._pool.started)
  267. self.assertTrue(pool._pool._state == asynpool.RUN)
  268. _pool = pool._pool
  269. pool.stop()
  270. self.assertTrue(_pool.closed)
  271. self.assertTrue(_pool.joined)
  272. pool.stop()
  273. pool.start()
  274. _pool = pool._pool
  275. pool.terminate()
  276. pool.terminate()
  277. self.assertTrue(_pool.terminated)
  278. def test_restart(self):
  279. pool = TaskPool(10)
  280. pool._pool = Mock(name='pool')
  281. pool.restart()
  282. pool._pool.restart.assert_called_with()
  283. pool._pool.apply_async.assert_called_with(mp.noop)
  284. def test_did_start_ok(self):
  285. pool = TaskPool(10)
  286. pool._pool = Mock(name='pool')
  287. self.assertIs(pool.did_start_ok(), pool._pool.did_start_ok())
  288. def test_register_with_event_loop(self):
  289. pool = TaskPool(10)
  290. pool._pool = Mock(name='pool')
  291. loop = Mock(name='loop')
  292. pool.register_with_event_loop(loop)
  293. pool._pool.register_with_event_loop.assert_called_with(loop)
  294. def test_on_close(self):
  295. pool = TaskPool(10)
  296. pool._pool = Mock(name='pool')
  297. pool._pool._state = mp.RUN
  298. pool.on_close()
  299. pool._pool.close.assert_called_with()
  300. def test_on_close__pool_not_running(self):
  301. pool = TaskPool(10)
  302. pool._pool = Mock(name='pool')
  303. pool._pool._state = mp.CLOSE
  304. pool.on_close()
  305. self.assertFalse(pool._pool.close.called)
  306. def test_apply_async(self):
  307. pool = TaskPool(10)
  308. pool.start()
  309. pool.apply_async(lambda x: x, (2,), {})
  310. def test_grow_shrink(self):
  311. pool = TaskPool(10)
  312. pool.start()
  313. self.assertEqual(pool._pool._processes, 10)
  314. pool.grow()
  315. self.assertEqual(pool._pool._processes, 11)
  316. pool.shrink(2)
  317. self.assertEqual(pool._pool._processes, 9)
  318. def test_info(self):
  319. pool = TaskPool(10)
  320. procs = [Object(pid=i) for i in range(pool.limit)]
  321. class _Pool(object):
  322. _pool = procs
  323. _maxtasksperchild = None
  324. timeout = 10
  325. soft_timeout = 5
  326. def human_write_stats(self, *args, **kwargs):
  327. return {}
  328. pool._pool = _Pool()
  329. info = pool.info
  330. self.assertEqual(info['max-concurrency'], pool.limit)
  331. self.assertEqual(info['max-tasks-per-child'], 'N/A')
  332. self.assertEqual(info['timeouts'], (5, 10))
  333. def test_num_processes(self):
  334. pool = TaskPool(7)
  335. pool.start()
  336. self.assertEqual(pool.num_processes, 7)