test_multi.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import pytest
  4. import signal
  5. import sys
  6. from case import Mock, call, patch, skip
  7. from celery.apps.multi import (
  8. Cluster, MultiParser, NamespacedOptionParser, Node, format_opt,
  9. )
  10. class test_functions:
  11. def test_parse_ns_range(self):
  12. m = MultiParser()
  13. assert m._parse_ns_range('1-3', True), ['1', '2' == '3']
  14. assert m._parse_ns_range('1-3', False) == ['1-3']
  15. assert m._parse_ns_range('1-3,10,11,20', True) == [
  16. '1', '2', '3', '10', '11', '20',
  17. ]
  18. def test_format_opt(self):
  19. assert format_opt('--foo', None) == '--foo'
  20. assert format_opt('-c', 1) == '-c 1'
  21. assert format_opt('--log', 'foo') == '--log=foo'
  22. class test_NamespacedOptionParser:
  23. def test_parse(self):
  24. x = NamespacedOptionParser(['-c:1,3', '4'])
  25. x.parse()
  26. assert x.namespaces.get('1,3') == {'-c': '4'}
  27. x = NamespacedOptionParser(['-c:jerry,elaine', '5',
  28. '--loglevel:kramer=DEBUG',
  29. '--flag',
  30. '--logfile=foo', '-Q', 'bar', 'a', 'b',
  31. '--', '.disable_rate_limits=1'])
  32. x.parse()
  33. assert x.options == {
  34. '--logfile': 'foo',
  35. '-Q': 'bar',
  36. '--flag': None,
  37. }
  38. assert x.values, ['a' == 'b']
  39. assert x.namespaces.get('jerry,elaine') == {'-c': '5'}
  40. assert x.namespaces.get('kramer') == {'--loglevel': 'DEBUG'}
  41. assert x.passthrough == '-- .disable_rate_limits=1'
  42. def multi_args(p, *args, **kwargs):
  43. return MultiParser(*args, **kwargs).parse(p)
  44. class test_multi_args:
  45. @patch('celery.apps.multi.gethostname')
  46. def test_parse(self, gethostname):
  47. gethostname.return_value = 'example.com'
  48. p = NamespacedOptionParser([
  49. '-c:jerry,elaine', '5',
  50. '--loglevel:kramer=DEBUG',
  51. '--flag',
  52. '--logfile=foo', '-Q', 'bar', 'jerry',
  53. 'elaine', 'kramer',
  54. '--', '.disable_rate_limits=1',
  55. ])
  56. p.parse()
  57. it = multi_args(p, cmd='COMMAND', append='*AP*',
  58. prefix='*P*', suffix='*S*')
  59. nodes = list(it)
  60. def assert_line_in(name, args):
  61. assert name in {n.name for n in nodes}
  62. argv = None
  63. for node in nodes:
  64. if node.name == name:
  65. argv = node.argv
  66. assert argv
  67. for arg in args:
  68. assert arg in argv
  69. assert_line_in(
  70. '*P*jerry@*S*',
  71. ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
  72. '-c 5', '--flag', '--logfile=foo',
  73. '-- .disable_rate_limits=1', '*AP*'],
  74. )
  75. assert_line_in(
  76. '*P*elaine@*S*',
  77. ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
  78. '-c 5', '--flag', '--logfile=foo',
  79. '-- .disable_rate_limits=1', '*AP*'],
  80. )
  81. assert_line_in(
  82. '*P*kramer@*S*',
  83. ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
  84. '-Q bar', '--flag', '--logfile=foo',
  85. '-- .disable_rate_limits=1', '*AP*'],
  86. )
  87. expand = nodes[0].expander
  88. assert expand('%h') == '*P*jerry@*S*'
  89. assert expand('%n') == '*P*jerry'
  90. nodes2 = list(multi_args(p, cmd='COMMAND', append='',
  91. prefix='*P*', suffix='*S*'))
  92. assert nodes2[0].argv[-1] == '-- .disable_rate_limits=1'
  93. p2 = NamespacedOptionParser(['10', '-c:1', '5'])
  94. p2.parse()
  95. nodes3 = list(multi_args(p2, cmd='COMMAND'))
  96. def _args(name, *args):
  97. return args + (
  98. '--pidfile={0}.pid'.format(name),
  99. '--logfile={0}%I.log'.format(name),
  100. '--executable={0}'.format(sys.executable),
  101. '',
  102. )
  103. assert len(nodes3) == 10
  104. assert nodes3[0].name == 'celery1@example.com'
  105. assert nodes3[0].argv == (
  106. 'COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1')
  107. for i, worker in enumerate(nodes3[1:]):
  108. assert worker.name == 'celery%s@example.com' % (i + 2)
  109. node_i = 'celery%s' % (i + 2,)
  110. assert worker.argv == (
  111. 'COMMAND',
  112. '-n %s@example.com' % (node_i,)) + _args(node_i)
  113. nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
  114. assert len(nodes4) == 10
  115. assert nodes4[0].name == 'celery1@'
  116. assert nodes4[0].argv == (
  117. 'COMMAND', '-c 5', '-n celery1@') + _args('celery1')
  118. p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
  119. p3.parse()
  120. nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
  121. assert nodes5[0].name == 'foo@'
  122. assert nodes5[0].argv == (
  123. 'COMMAND', '-c 5', '-n foo@') + _args('foo')
  124. p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
  125. p4.parse()
  126. nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
  127. assert nodes6[0].name == 'foo@'
  128. assert nodes6[0].argv == (
  129. 'COMMAND', '-Q test', '-n foo@') + _args('foo')
  130. p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
  131. p5.parse()
  132. nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
  133. assert nodes7[0].name == 'foo@bar'
  134. assert nodes7[0].argv == (
  135. 'COMMAND', '-Q test', '-n foo@bar') + _args('foo')
  136. p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
  137. p6.parse()
  138. with pytest.raises(KeyError):
  139. list(multi_args(p6))
  140. def test_optmerge(self):
  141. p = NamespacedOptionParser(['foo', 'test'])
  142. p.parse()
  143. p.options = {'x': 'y'}
  144. r = p.optmerge('foo')
  145. assert r['x'] == 'y'
  146. class test_Node:
  147. def setup(self):
  148. self.p = Mock(name='p')
  149. self.p.options = {
  150. '--executable': 'python',
  151. '--logfile': 'foo.log',
  152. }
  153. self.p.namespaces = {}
  154. self.node = Node('foo@bar.com', options={'-A': 'proj'})
  155. self.expander = self.node.expander = Mock(name='expander')
  156. self.node.pid = 303
  157. def test_from_kwargs(self):
  158. n = Node.from_kwargs(
  159. 'foo@bar.com',
  160. max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
  161. )
  162. assert sorted(n.argv) == sorted([
  163. '-m celery worker --detach',
  164. '-A foo',
  165. '--executable={0}'.format(n.executable),
  166. '-O fair',
  167. '-n foo@bar.com',
  168. '--logfile=foo%I.log',
  169. '-Q q1,q2',
  170. '--max-tasks-per-child=30',
  171. '--pidfile=foo.pid',
  172. '',
  173. ])
  174. @patch('os.kill')
  175. def test_send(self, kill):
  176. assert self.node.send(9)
  177. kill.assert_called_with(self.node.pid, 9)
  178. @patch('os.kill')
  179. def test_send__ESRCH(self, kill):
  180. kill.side_effect = OSError()
  181. kill.side_effect.errno = errno.ESRCH
  182. assert not self.node.send(9)
  183. kill.assert_called_with(self.node.pid, 9)
  184. @patch('os.kill')
  185. def test_send__error(self, kill):
  186. kill.side_effect = OSError()
  187. kill.side_effect.errno = errno.ENOENT
  188. with pytest.raises(OSError):
  189. self.node.send(9)
  190. kill.assert_called_with(self.node.pid, 9)
  191. def test_alive(self):
  192. self.node.send = Mock(name='send')
  193. assert self.node.alive() is self.node.send.return_value
  194. self.node.send.assert_called_with(0)
  195. def test_start(self):
  196. self.node._waitexec = Mock(name='_waitexec')
  197. self.node.start(env={'foo': 'bar'}, kw=2)
  198. self.node._waitexec.assert_called_with(
  199. self.node.argv, path=self.node.executable,
  200. env={'foo': 'bar'}, kw=2,
  201. )
  202. @patch('celery.apps.multi.Popen')
  203. def test_waitexec(self, Popen, argv=['A', 'B']):
  204. on_spawn = Mock(name='on_spawn')
  205. on_signalled = Mock(name='on_signalled')
  206. on_failure = Mock(name='on_failure')
  207. env = Mock(name='env')
  208. self.node.handle_process_exit = Mock(name='handle_process_exit')
  209. self.node._waitexec(
  210. argv,
  211. path='python',
  212. env=env,
  213. on_spawn=on_spawn,
  214. on_signalled=on_signalled,
  215. on_failure=on_failure,
  216. )
  217. Popen.assert_called_with(
  218. self.node.prepare_argv(argv, 'python'), env=env)
  219. self.node.handle_process_exit.assert_called_with(
  220. Popen().wait(),
  221. on_signalled=on_signalled,
  222. on_failure=on_failure,
  223. )
  224. def test_handle_process_exit(self):
  225. assert self.node.handle_process_exit(0) == 0
  226. def test_handle_process_exit__failure(self):
  227. on_failure = Mock(name='on_failure')
  228. assert self.node.handle_process_exit(9, on_failure=on_failure) == 9
  229. on_failure.assert_called_with(self.node, 9)
  230. def test_handle_process_exit__signalled(self):
  231. on_signalled = Mock(name='on_signalled')
  232. assert self.node.handle_process_exit(
  233. -9, on_signalled=on_signalled) == 9
  234. on_signalled.assert_called_with(self.node, 9)
  235. def test_logfile(self):
  236. assert self.node.logfile == self.expander.return_value
  237. self.expander.assert_called_with('%n%I.log')
  238. class test_Cluster:
  239. def setup(self):
  240. self.Popen = self.patching('celery.apps.multi.Popen')
  241. self.kill = self.patching('os.kill')
  242. self.gethostname = self.patching('celery.apps.multi.gethostname')
  243. self.gethostname.return_value = 'example.com'
  244. self.Pidfile = self.patching('celery.apps.multi.Pidfile')
  245. self.cluster = Cluster(
  246. [Node('foo@example.com'),
  247. Node('bar@example.com'),
  248. Node('baz@example.com')],
  249. on_stopping_preamble=Mock(name='on_stopping_preamble'),
  250. on_send_signal=Mock(name='on_send_signal'),
  251. on_still_waiting_for=Mock(name='on_still_waiting_for'),
  252. on_still_waiting_progress=Mock(name='on_still_waiting_progress'),
  253. on_still_waiting_end=Mock(name='on_still_waiting_end'),
  254. on_node_start=Mock(name='on_node_start'),
  255. on_node_restart=Mock(name='on_node_restart'),
  256. on_node_shutdown_ok=Mock(name='on_node_shutdown_ok'),
  257. on_node_status=Mock(name='on_node_status'),
  258. on_node_signal=Mock(name='on_node_signal'),
  259. on_node_signal_dead=Mock(name='on_node_signal_dead'),
  260. on_node_down=Mock(name='on_node_down'),
  261. on_child_spawn=Mock(name='on_child_spawn'),
  262. on_child_signalled=Mock(name='on_child_signalled'),
  263. on_child_failure=Mock(name='on_child_failure'),
  264. )
  265. def test_len(self):
  266. assert len(self.cluster) == 3
  267. def test_getitem(self):
  268. assert self.cluster[0].name == 'foo@example.com'
  269. def test_start(self):
  270. self.cluster.start_node = Mock(name='start_node')
  271. self.cluster.start()
  272. self.cluster.start_node.assert_has_calls(
  273. call(node) for node in self.cluster
  274. )
  275. def test_start_node(self):
  276. self.cluster._start_node = Mock(name='_start_node')
  277. node = self.cluster[0]
  278. assert (self.cluster.start_node(node) is
  279. self.cluster._start_node.return_value)
  280. self.cluster.on_node_start.assert_called_with(node)
  281. self.cluster._start_node.assert_called_with(node)
  282. self.cluster.on_node_status.assert_called_with(
  283. node, self.cluster._start_node(),
  284. )
  285. def test__start_node(self):
  286. node = self.cluster[0]
  287. node.start = Mock(name='node.start')
  288. assert self.cluster._start_node(node) is node.start.return_value
  289. node.start.assert_called_with(
  290. self.cluster.env,
  291. on_spawn=self.cluster.on_child_spawn,
  292. on_signalled=self.cluster.on_child_signalled,
  293. on_failure=self.cluster.on_child_failure,
  294. )
  295. def test_send_all(self):
  296. nodes = [Mock(name='n1'), Mock(name='n2')]
  297. self.cluster.getpids = Mock(name='getpids')
  298. self.cluster.getpids.return_value = nodes
  299. self.cluster.send_all(15)
  300. self.cluster.on_node_signal.assert_has_calls(
  301. call(node, 'TERM') for node in nodes
  302. )
  303. for node in nodes:
  304. node.send.assert_called_with(15, self.cluster.on_node_signal_dead)
  305. @skip.if_win32()
  306. def test_kill(self):
  307. self.cluster.send_all = Mock(name='.send_all')
  308. self.cluster.kill()
  309. self.cluster.send_all.assert_called_with(signal.SIGKILL)
  310. def test_getpids(self):
  311. self.gethostname.return_value = 'e.com'
  312. self.prepare_pidfile_for_getpids(self.Pidfile)
  313. callback = Mock()
  314. p = Cluster([
  315. Node('foo@e.com'),
  316. Node('bar@e.com'),
  317. Node('baz@e.com'),
  318. ])
  319. nodes = p.getpids(on_down=callback)
  320. node_0, node_1 = nodes
  321. assert node_0.name == 'foo@e.com'
  322. assert sorted(node_0.argv) == sorted([
  323. '',
  324. '--executable={0}'.format(node_0.executable),
  325. '--logfile=foo%I.log',
  326. '--pidfile=foo.pid',
  327. '-m celery worker --detach',
  328. '-n foo@e.com',
  329. ])
  330. assert node_0.pid == 10
  331. assert node_1.name == 'bar@e.com'
  332. assert sorted(node_1.argv) == sorted([
  333. '',
  334. '--executable={0}'.format(node_1.executable),
  335. '--logfile=bar%I.log',
  336. '--pidfile=bar.pid',
  337. '-m celery worker --detach',
  338. '-n bar@e.com',
  339. ])
  340. assert node_1.pid == 11
  341. # without callback, should work
  342. nodes = p.getpids('celery worker')
  343. def prepare_pidfile_for_getpids(self, Pidfile):
  344. class pids(object):
  345. def __init__(self, path):
  346. self.path = path
  347. def read_pid(self):
  348. try:
  349. return {'foo.pid': 10,
  350. 'bar.pid': 11}[self.path]
  351. except KeyError:
  352. raise ValueError()
  353. self.Pidfile.side_effect = pids