test_multi.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import signal
  4. import sys
  5. import pytest
  6. from case import Mock, call, patch, skip
  7. from celery.apps.multi import (Cluster, MultiParser, NamespacedOptionParser,
  8. Node, format_opt)
  9. class test_functions:
  10. def test_parse_ns_range(self):
  11. m = MultiParser()
  12. assert m._parse_ns_range('1-3', True), ['1', '2' == '3']
  13. assert m._parse_ns_range('1-3', False) == ['1-3']
  14. assert m._parse_ns_range('1-3,10,11,20', True) == [
  15. '1', '2', '3', '10', '11', '20',
  16. ]
  17. def test_format_opt(self):
  18. assert format_opt('--foo', None) == '--foo'
  19. assert format_opt('-c', 1) == '-c 1'
  20. assert format_opt('--log', 'foo') == '--log=foo'
  21. class test_NamespacedOptionParser:
  22. def test_parse(self):
  23. x = NamespacedOptionParser(['-c:1,3', '4'])
  24. x.parse()
  25. assert x.namespaces.get('1,3') == {'-c': '4'}
  26. x = NamespacedOptionParser(['-c:jerry,elaine', '5',
  27. '--loglevel:kramer=DEBUG',
  28. '--flag',
  29. '--logfile=foo', '-Q', 'bar', 'a', 'b',
  30. '--', '.disable_rate_limits=1'])
  31. x.parse()
  32. assert x.options == {
  33. '--logfile': 'foo',
  34. '-Q': 'bar',
  35. '--flag': None,
  36. }
  37. assert x.values, ['a' == 'b']
  38. assert x.namespaces.get('jerry,elaine') == {'-c': '5'}
  39. assert x.namespaces.get('kramer') == {'--loglevel': 'DEBUG'}
  40. assert x.passthrough == '-- .disable_rate_limits=1'
  41. def multi_args(p, *args, **kwargs):
  42. return MultiParser(*args, **kwargs).parse(p)
  43. class test_multi_args:
  44. @patch('celery.apps.multi.gethostname')
  45. def test_parse(self, gethostname):
  46. gethostname.return_value = 'example.com'
  47. p = NamespacedOptionParser([
  48. '-c:jerry,elaine', '5',
  49. '--loglevel:kramer=DEBUG',
  50. '--flag',
  51. '--logfile=foo', '-Q', 'bar', 'jerry',
  52. 'elaine', 'kramer',
  53. '--', '.disable_rate_limits=1',
  54. ])
  55. p.parse()
  56. it = multi_args(p, cmd='COMMAND', append='*AP*',
  57. prefix='*P*', suffix='*S*')
  58. nodes = list(it)
  59. def assert_line_in(name, args):
  60. assert name in {n.name for n in nodes}
  61. argv = None
  62. for node in nodes:
  63. if node.name == name:
  64. argv = node.argv
  65. assert argv
  66. for arg in args:
  67. assert arg in argv
  68. assert_line_in(
  69. '*P*jerry@*S*',
  70. ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
  71. '-c 5', '--flag', '--logfile=foo',
  72. '-- .disable_rate_limits=1', '*AP*'],
  73. )
  74. assert_line_in(
  75. '*P*elaine@*S*',
  76. ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
  77. '-c 5', '--flag', '--logfile=foo',
  78. '-- .disable_rate_limits=1', '*AP*'],
  79. )
  80. assert_line_in(
  81. '*P*kramer@*S*',
  82. ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
  83. '-Q bar', '--flag', '--logfile=foo',
  84. '-- .disable_rate_limits=1', '*AP*'],
  85. )
  86. expand = nodes[0].expander
  87. assert expand('%h') == '*P*jerry@*S*'
  88. assert expand('%n') == '*P*jerry'
  89. nodes2 = list(multi_args(p, cmd='COMMAND', append='',
  90. prefix='*P*', suffix='*S*'))
  91. assert nodes2[0].argv[-1] == '-- .disable_rate_limits=1'
  92. p2 = NamespacedOptionParser(['10', '-c:1', '5'])
  93. p2.parse()
  94. nodes3 = list(multi_args(p2, cmd='COMMAND'))
  95. def _args(name, *args):
  96. return args + (
  97. '--pidfile={0}.pid'.format(name),
  98. '--logfile={0}%I.log'.format(name),
  99. '--executable={0}'.format(sys.executable),
  100. '',
  101. )
  102. assert len(nodes3) == 10
  103. assert nodes3[0].name == 'celery1@example.com'
  104. assert nodes3[0].argv == (
  105. 'COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1')
  106. for i, worker in enumerate(nodes3[1:]):
  107. assert worker.name == 'celery%s@example.com' % (i + 2)
  108. node_i = 'celery%s' % (i + 2,)
  109. assert worker.argv == (
  110. 'COMMAND',
  111. '-n %s@example.com' % (node_i,)) + _args(node_i)
  112. nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
  113. assert len(nodes4) == 10
  114. assert nodes4[0].name == 'celery1@'
  115. assert nodes4[0].argv == (
  116. 'COMMAND', '-c 5', '-n celery1@') + _args('celery1')
  117. p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
  118. p3.parse()
  119. nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
  120. assert nodes5[0].name == 'foo@'
  121. assert nodes5[0].argv == (
  122. 'COMMAND', '-c 5', '-n foo@') + _args('foo')
  123. p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
  124. p4.parse()
  125. nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
  126. assert nodes6[0].name == 'foo@'
  127. assert nodes6[0].argv == (
  128. 'COMMAND', '-Q test', '-n foo@') + _args('foo')
  129. p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
  130. p5.parse()
  131. nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
  132. assert nodes7[0].name == 'foo@bar'
  133. assert nodes7[0].argv == (
  134. 'COMMAND', '-Q test', '-n foo@bar') + _args('foo')
  135. p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
  136. p6.parse()
  137. with pytest.raises(KeyError):
  138. list(multi_args(p6))
  139. def test_optmerge(self):
  140. p = NamespacedOptionParser(['foo', 'test'])
  141. p.parse()
  142. p.options = {'x': 'y'}
  143. r = p.optmerge('foo')
  144. assert r['x'] == 'y'
  145. class test_Node:
  146. def setup(self):
  147. self.p = Mock(name='p')
  148. self.p.options = {
  149. '--executable': 'python',
  150. '--logfile': 'foo.log',
  151. }
  152. self.p.namespaces = {}
  153. self.node = Node('foo@bar.com', options={'-A': 'proj'})
  154. self.expander = self.node.expander = Mock(name='expander')
  155. self.node.pid = 303
  156. def test_from_kwargs(self):
  157. n = Node.from_kwargs(
  158. 'foo@bar.com',
  159. max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
  160. )
  161. assert sorted(n.argv) == sorted([
  162. '-m celery worker --detach',
  163. '-A foo',
  164. '--executable={0}'.format(n.executable),
  165. '-O fair',
  166. '-n foo@bar.com',
  167. '--logfile=foo%I.log',
  168. '-Q q1,q2',
  169. '--max-tasks-per-child=30',
  170. '--pidfile=foo.pid',
  171. '',
  172. ])
  173. @patch('os.kill')
  174. def test_send(self, kill):
  175. assert self.node.send(9)
  176. kill.assert_called_with(self.node.pid, 9)
  177. @patch('os.kill')
  178. def test_send__ESRCH(self, kill):
  179. kill.side_effect = OSError()
  180. kill.side_effect.errno = errno.ESRCH
  181. assert not self.node.send(9)
  182. kill.assert_called_with(self.node.pid, 9)
  183. @patch('os.kill')
  184. def test_send__error(self, kill):
  185. kill.side_effect = OSError()
  186. kill.side_effect.errno = errno.ENOENT
  187. with pytest.raises(OSError):
  188. self.node.send(9)
  189. kill.assert_called_with(self.node.pid, 9)
  190. def test_alive(self):
  191. self.node.send = Mock(name='send')
  192. assert self.node.alive() is self.node.send.return_value
  193. self.node.send.assert_called_with(0)
  194. def test_start(self):
  195. self.node._waitexec = Mock(name='_waitexec')
  196. self.node.start(env={'foo': 'bar'}, kw=2)
  197. self.node._waitexec.assert_called_with(
  198. self.node.argv, path=self.node.executable,
  199. env={'foo': 'bar'}, kw=2,
  200. )
  201. @patch('celery.apps.multi.Popen')
  202. def test_waitexec(self, Popen, argv=['A', 'B']):
  203. on_spawn = Mock(name='on_spawn')
  204. on_signalled = Mock(name='on_signalled')
  205. on_failure = Mock(name='on_failure')
  206. env = Mock(name='env')
  207. self.node.handle_process_exit = Mock(name='handle_process_exit')
  208. self.node._waitexec(
  209. argv,
  210. path='python',
  211. env=env,
  212. on_spawn=on_spawn,
  213. on_signalled=on_signalled,
  214. on_failure=on_failure,
  215. )
  216. Popen.assert_called_with(
  217. self.node.prepare_argv(argv, 'python'), env=env)
  218. self.node.handle_process_exit.assert_called_with(
  219. Popen().wait(),
  220. on_signalled=on_signalled,
  221. on_failure=on_failure,
  222. )
  223. def test_handle_process_exit(self):
  224. assert self.node.handle_process_exit(0) == 0
  225. def test_handle_process_exit__failure(self):
  226. on_failure = Mock(name='on_failure')
  227. assert self.node.handle_process_exit(9, on_failure=on_failure) == 9
  228. on_failure.assert_called_with(self.node, 9)
  229. def test_handle_process_exit__signalled(self):
  230. on_signalled = Mock(name='on_signalled')
  231. assert self.node.handle_process_exit(
  232. -9, on_signalled=on_signalled) == 9
  233. on_signalled.assert_called_with(self.node, 9)
  234. def test_logfile(self):
  235. assert self.node.logfile == self.expander.return_value
  236. self.expander.assert_called_with('%n%I.log')
  237. class test_Cluster:
  238. def setup(self):
  239. self.Popen = self.patching('celery.apps.multi.Popen')
  240. self.kill = self.patching('os.kill')
  241. self.gethostname = self.patching('celery.apps.multi.gethostname')
  242. self.gethostname.return_value = 'example.com'
  243. self.Pidfile = self.patching('celery.apps.multi.Pidfile')
  244. self.cluster = Cluster(
  245. [Node('foo@example.com'),
  246. Node('bar@example.com'),
  247. Node('baz@example.com')],
  248. on_stopping_preamble=Mock(name='on_stopping_preamble'),
  249. on_send_signal=Mock(name='on_send_signal'),
  250. on_still_waiting_for=Mock(name='on_still_waiting_for'),
  251. on_still_waiting_progress=Mock(name='on_still_waiting_progress'),
  252. on_still_waiting_end=Mock(name='on_still_waiting_end'),
  253. on_node_start=Mock(name='on_node_start'),
  254. on_node_restart=Mock(name='on_node_restart'),
  255. on_node_shutdown_ok=Mock(name='on_node_shutdown_ok'),
  256. on_node_status=Mock(name='on_node_status'),
  257. on_node_signal=Mock(name='on_node_signal'),
  258. on_node_signal_dead=Mock(name='on_node_signal_dead'),
  259. on_node_down=Mock(name='on_node_down'),
  260. on_child_spawn=Mock(name='on_child_spawn'),
  261. on_child_signalled=Mock(name='on_child_signalled'),
  262. on_child_failure=Mock(name='on_child_failure'),
  263. )
  264. def test_len(self):
  265. assert len(self.cluster) == 3
  266. def test_getitem(self):
  267. assert self.cluster[0].name == 'foo@example.com'
  268. def test_start(self):
  269. self.cluster.start_node = Mock(name='start_node')
  270. self.cluster.start()
  271. self.cluster.start_node.assert_has_calls(
  272. call(node) for node in self.cluster
  273. )
  274. def test_start_node(self):
  275. self.cluster._start_node = Mock(name='_start_node')
  276. node = self.cluster[0]
  277. assert (self.cluster.start_node(node) is
  278. self.cluster._start_node.return_value)
  279. self.cluster.on_node_start.assert_called_with(node)
  280. self.cluster._start_node.assert_called_with(node)
  281. self.cluster.on_node_status.assert_called_with(
  282. node, self.cluster._start_node(),
  283. )
  284. def test__start_node(self):
  285. node = self.cluster[0]
  286. node.start = Mock(name='node.start')
  287. assert self.cluster._start_node(node) is node.start.return_value
  288. node.start.assert_called_with(
  289. self.cluster.env,
  290. on_spawn=self.cluster.on_child_spawn,
  291. on_signalled=self.cluster.on_child_signalled,
  292. on_failure=self.cluster.on_child_failure,
  293. )
  294. def test_send_all(self):
  295. nodes = [Mock(name='n1'), Mock(name='n2')]
  296. self.cluster.getpids = Mock(name='getpids')
  297. self.cluster.getpids.return_value = nodes
  298. self.cluster.send_all(15)
  299. self.cluster.on_node_signal.assert_has_calls(
  300. call(node, 'TERM') for node in nodes
  301. )
  302. for node in nodes:
  303. node.send.assert_called_with(15, self.cluster.on_node_signal_dead)
  304. @skip.if_win32()
  305. def test_kill(self):
  306. self.cluster.send_all = Mock(name='.send_all')
  307. self.cluster.kill()
  308. self.cluster.send_all.assert_called_with(signal.SIGKILL)
  309. def test_getpids(self):
  310. self.gethostname.return_value = 'e.com'
  311. self.prepare_pidfile_for_getpids(self.Pidfile)
  312. callback = Mock()
  313. p = Cluster([
  314. Node('foo@e.com'),
  315. Node('bar@e.com'),
  316. Node('baz@e.com'),
  317. ])
  318. nodes = p.getpids(on_down=callback)
  319. node_0, node_1 = nodes
  320. assert node_0.name == 'foo@e.com'
  321. assert sorted(node_0.argv) == sorted([
  322. '',
  323. '--executable={0}'.format(node_0.executable),
  324. '--logfile=foo%I.log',
  325. '--pidfile=foo.pid',
  326. '-m celery worker --detach',
  327. '-n foo@e.com',
  328. ])
  329. assert node_0.pid == 10
  330. assert node_1.name == 'bar@e.com'
  331. assert sorted(node_1.argv) == sorted([
  332. '',
  333. '--executable={0}'.format(node_1.executable),
  334. '--logfile=bar%I.log',
  335. '--pidfile=bar.pid',
  336. '-m celery worker --detach',
  337. '-n bar@e.com',
  338. ])
  339. assert node_1.pid == 11
  340. # without callback, should work
  341. nodes = p.getpids('celery worker')
  342. def prepare_pidfile_for_getpids(self, Pidfile):
  343. class pids(object):
  344. def __init__(self, path):
  345. self.path = path
  346. def read_pid(self):
  347. try:
  348. return {'foo.pid': 10,
  349. 'bar.pid': 11}[self.path]
  350. except KeyError:
  351. raise ValueError()
  352. self.Pidfile.side_effect = pids