test_multi.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. from __future__ import absolute_import
  2. import errno
  3. import signal
  4. import sys
  5. from celery.bin.multi import (
  6. main,
  7. MultiTool,
  8. findsig,
  9. parse_ns_range,
  10. format_opt,
  11. quote,
  12. NamespacedOptionParser,
  13. multi_args,
  14. __doc__ as doc,
  15. )
  16. from celery.tests.case import AppCase, Mock, WhateverIO, SkipTest, patch
  17. class test_functions(AppCase):
  18. def test_findsig(self):
  19. self.assertEqual(findsig(['a', 'b', 'c', '-1']), 1)
  20. self.assertEqual(findsig(['--foo=1', '-9']), 9)
  21. self.assertEqual(findsig(['-INT']), signal.SIGINT)
  22. self.assertEqual(findsig([]), signal.SIGTERM)
  23. self.assertEqual(findsig(['-s']), signal.SIGTERM)
  24. self.assertEqual(findsig(['-log']), signal.SIGTERM)
  25. def test_parse_ns_range(self):
  26. self.assertEqual(parse_ns_range('1-3', True), ['1', '2', '3'])
  27. self.assertEqual(parse_ns_range('1-3', False), ['1-3'])
  28. self.assertEqual(parse_ns_range(
  29. '1-3,10,11,20', True),
  30. ['1', '2', '3', '10', '11', '20'],
  31. )
  32. def test_format_opt(self):
  33. self.assertEqual(format_opt('--foo', None), '--foo')
  34. self.assertEqual(format_opt('-c', 1), '-c 1')
  35. self.assertEqual(format_opt('--log', 'foo'), '--log=foo')
  36. def test_quote(self):
  37. self.assertEqual(quote("the 'quick"), "'the '\\''quick'")
  38. class test_NamespacedOptionParser(AppCase):
  39. def test_parse(self):
  40. x = NamespacedOptionParser(['-c:1,3', '4'])
  41. self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
  42. x = NamespacedOptionParser(['-c:jerry,elaine', '5',
  43. '--loglevel:kramer=DEBUG',
  44. '--flag',
  45. '--logfile=foo', '-Q', 'bar', 'a', 'b',
  46. '--', '.disable_rate_limits=1'])
  47. self.assertEqual(x.options, {'--logfile': 'foo',
  48. '-Q': 'bar',
  49. '--flag': None})
  50. self.assertEqual(x.values, ['a', 'b'])
  51. self.assertEqual(x.namespaces.get('jerry,elaine'), {'-c': '5'})
  52. self.assertEqual(x.namespaces.get('kramer'), {'--loglevel': 'DEBUG'})
  53. self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
  54. class test_multi_args(AppCase):
  55. @patch('celery.bin.multi.gethostname')
  56. def test_parse(self, gethostname):
  57. gethostname.return_value = 'example.com'
  58. p = NamespacedOptionParser([
  59. '-c:jerry,elaine', '5',
  60. '--loglevel:kramer=DEBUG',
  61. '--flag',
  62. '--logfile=foo', '-Q', 'bar', 'jerry',
  63. 'elaine', 'kramer',
  64. '--', '.disable_rate_limits=1',
  65. ])
  66. it = multi_args(p, cmd='COMMAND', append='*AP*',
  67. prefix='*P*', suffix='*S*')
  68. names = list(it)
  69. def assert_line_in(name, args):
  70. self.assertIn(name, [tup[0] for tup in names])
  71. argv = None
  72. for item in names:
  73. if item[0] == name:
  74. argv = item[1]
  75. self.assertTrue(argv)
  76. for arg in args:
  77. self.assertIn(arg, argv)
  78. assert_line_in(
  79. '*P*jerry@*S*',
  80. ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
  81. '-c 5', '--flag', '--logfile=foo',
  82. '-- .disable_rate_limits=1', '*AP*'],
  83. )
  84. assert_line_in(
  85. '*P*elaine@*S*',
  86. ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
  87. '-c 5', '--flag', '--logfile=foo',
  88. '-- .disable_rate_limits=1', '*AP*'],
  89. )
  90. assert_line_in(
  91. '*P*kramer@*S*',
  92. ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
  93. '-Q bar', '--flag', '--logfile=foo',
  94. '-- .disable_rate_limits=1', '*AP*'],
  95. )
  96. expand = names[0][2]
  97. self.assertEqual(expand('%h'), '*P*jerry@*S*')
  98. self.assertEqual(expand('%n'), '*P*jerry')
  99. names2 = list(multi_args(p, cmd='COMMAND', append='',
  100. prefix='*P*', suffix='*S*'))
  101. self.assertEqual(names2[0][1][-1], '-- .disable_rate_limits=1')
  102. p2 = NamespacedOptionParser(['10', '-c:1', '5'])
  103. names3 = list(multi_args(p2, cmd='COMMAND'))
  104. self.assertEqual(len(names3), 10)
  105. self.assertEqual(
  106. names3[0][0:2],
  107. ('celery1@example.com',
  108. ['COMMAND', '-n celery1@example.com', '-c 5', '']),
  109. )
  110. for i, worker in enumerate(names3[1:]):
  111. self.assertEqual(
  112. worker[0:2],
  113. ('celery%s@example.com' % (i + 2),
  114. ['COMMAND', '-n celery%s@example.com' % (i + 2), '']),
  115. )
  116. names4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
  117. self.assertEqual(len(names4), 10)
  118. self.assertEqual(
  119. names4[0][0:2],
  120. ('celery1@',
  121. ['COMMAND', '-n celery1@', '-c 5', '']),
  122. )
  123. p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
  124. names5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
  125. self.assertEqual(
  126. names5[0][0:2],
  127. ('foo@',
  128. ['COMMAND', '-n foo@', '-c 5', '']),
  129. )
  130. class test_MultiTool(AppCase):
  131. def setup(self):
  132. self.fh = WhateverIO()
  133. self.env = {}
  134. self.t = MultiTool(env=self.env, fh=self.fh)
  135. def test_note(self):
  136. self.t.note('hello world')
  137. self.assertEqual(self.fh.getvalue(), 'hello world\n')
  138. def test_note_quiet(self):
  139. self.t.quiet = True
  140. self.t.note('hello world')
  141. self.assertFalse(self.fh.getvalue())
  142. def test_carp(self):
  143. self.t.say = Mock()
  144. self.t.carp('foo')
  145. self.t.say.assert_called_with('foo', True, self.t.stderr)
  146. def test_info(self):
  147. self.t.verbose = True
  148. self.t.info('hello info')
  149. self.assertEqual(self.fh.getvalue(), 'hello info\n')
  150. def test_info_not_verbose(self):
  151. self.t.verbose = False
  152. self.t.info('hello info')
  153. self.assertFalse(self.fh.getvalue())
  154. def test_error(self):
  155. self.t.carp = Mock()
  156. self.t.usage = Mock()
  157. self.assertEqual(self.t.error('foo'), 1)
  158. self.t.carp.assert_called_with('foo')
  159. self.t.usage.assert_called_with()
  160. self.t.carp = Mock()
  161. self.assertEqual(self.t.error(), 1)
  162. self.assertFalse(self.t.carp.called)
  163. self.assertEqual(self.t.retcode, 1)
  164. @patch('celery.bin.multi.Popen')
  165. def test_waitexec(self, Popen):
  166. self.t.note = Mock()
  167. pipe = Popen.return_value = Mock()
  168. pipe.wait.return_value = -10
  169. self.assertEqual(self.t.waitexec(['-m', 'foo'], 'path'), 10)
  170. Popen.assert_called_with(['path', '-m', 'foo'], env=self.t.env)
  171. self.t.note.assert_called_with('* Child was terminated by signal 10')
  172. pipe.wait.return_value = 2
  173. self.assertEqual(self.t.waitexec(['-m', 'foo'], 'path'), 2)
  174. self.t.note.assert_called_with(
  175. '* Child terminated with errorcode 2',
  176. )
  177. pipe.wait.return_value = 0
  178. self.assertFalse(self.t.waitexec(['-m', 'foo', 'path']))
  179. def test_nosplash(self):
  180. self.t.nosplash = True
  181. self.t.splash()
  182. self.assertFalse(self.fh.getvalue())
  183. def test_splash(self):
  184. self.t.nosplash = False
  185. self.t.splash()
  186. self.assertIn('celery multi', self.fh.getvalue())
  187. def test_usage(self):
  188. self.t.usage()
  189. self.assertTrue(self.fh.getvalue())
  190. def test_help(self):
  191. self.t.help([])
  192. self.assertIn(doc, self.fh.getvalue())
  193. def test_expand(self):
  194. self.t.expand(['foo%n', 'ask', 'klask', 'dask'])
  195. self.assertEqual(
  196. self.fh.getvalue(), 'fooask\nfooklask\nfoodask\n',
  197. )
  198. def test_restart(self):
  199. stop = self.t._stop_nodes = Mock()
  200. self.t.restart(['jerry', 'george'], 'celery worker')
  201. waitexec = self.t.waitexec = Mock()
  202. self.assertTrue(stop.called)
  203. callback = stop.call_args[1]['callback']
  204. self.assertTrue(callback)
  205. waitexec.return_value = 0
  206. callback('jerry', ['arg'], 13)
  207. waitexec.assert_called_with(['arg'], path=sys.executable)
  208. self.assertIn('OK', self.fh.getvalue())
  209. self.fh.seek(0)
  210. self.fh.truncate()
  211. waitexec.return_value = 1
  212. callback('jerry', ['arg'], 13)
  213. self.assertIn('FAILED', self.fh.getvalue())
  214. def test_stop(self):
  215. self.t.getpids = Mock()
  216. self.t.getpids.return_value = [2, 3, 4]
  217. self.t.shutdown_nodes = Mock()
  218. self.t.stop(['a', 'b', '-INT'], 'celery worker')
  219. self.t.shutdown_nodes.assert_called_with(
  220. [2, 3, 4], sig=signal.SIGINT, retry=None, callback=None,
  221. )
  222. def test_kill(self):
  223. if not hasattr(signal, 'SIGKILL'):
  224. raise SkipTest('SIGKILL not supported by this platform')
  225. self.t.getpids = Mock()
  226. self.t.getpids.return_value = [
  227. ('a', None, 10),
  228. ('b', None, 11),
  229. ('c', None, 12)
  230. ]
  231. sig = self.t.signal_node = Mock()
  232. self.t.kill(['a', 'b', 'c'], 'celery worker')
  233. sigs = sig.call_args_list
  234. self.assertEqual(len(sigs), 3)
  235. self.assertEqual(sigs[0][0], ('a', 10, signal.SIGKILL))
  236. self.assertEqual(sigs[1][0], ('b', 11, signal.SIGKILL))
  237. self.assertEqual(sigs[2][0], ('c', 12, signal.SIGKILL))
  238. def prepare_pidfile_for_getpids(self, Pidfile):
  239. class pids(object):
  240. def __init__(self, path):
  241. self.path = path
  242. def read_pid(self):
  243. try:
  244. return {'foo.pid': 10,
  245. 'bar.pid': 11}[self.path]
  246. except KeyError:
  247. raise ValueError()
  248. Pidfile.side_effect = pids
  249. @patch('celery.bin.multi.Pidfile')
  250. @patch('celery.bin.multi.gethostname')
  251. def test_getpids(self, gethostname, Pidfile):
  252. gethostname.return_value = 'e.com'
  253. self.prepare_pidfile_for_getpids(Pidfile)
  254. callback = Mock()
  255. p = NamespacedOptionParser(['foo', 'bar', 'baz'])
  256. nodes = self.t.getpids(p, 'celery worker', callback=callback)
  257. node_0, node_1 = nodes
  258. self.assertEqual(node_0[0], 'foo@e.com')
  259. self.assertEqual(
  260. sorted(node_0[1]),
  261. sorted(('celery worker', '--pidfile=foo.pid',
  262. '-n foo@e.com', '')),
  263. )
  264. self.assertEqual(node_0[2], 10)
  265. self.assertEqual(node_1[0], 'bar@e.com')
  266. self.assertEqual(
  267. sorted(node_1[1]),
  268. sorted(('celery worker', '--pidfile=bar.pid',
  269. '-n bar@e.com', '')),
  270. )
  271. self.assertEqual(node_1[2], 11)
  272. self.assertTrue(callback.called)
  273. cargs, _ = callback.call_args
  274. self.assertEqual(cargs[0], 'baz@e.com')
  275. self.assertItemsEqual(
  276. cargs[1],
  277. ['celery worker', '--pidfile=baz.pid', '-n baz@e.com', ''],
  278. )
  279. self.assertIsNone(cargs[2])
  280. self.assertIn('DOWN', self.fh.getvalue())
  281. # without callback, should work
  282. nodes = self.t.getpids(p, 'celery worker', callback=None)
  283. @patch('celery.bin.multi.Pidfile')
  284. @patch('celery.bin.multi.gethostname')
  285. @patch('celery.bin.multi.sleep')
  286. def test_shutdown_nodes(self, slepp, gethostname, Pidfile):
  287. gethostname.return_value = 'e.com'
  288. self.prepare_pidfile_for_getpids(Pidfile)
  289. self.assertIsNone(self.t.shutdown_nodes([]))
  290. self.t.signal_node = Mock()
  291. node_alive = self.t.node_alive = Mock()
  292. self.t.node_alive.return_value = False
  293. callback = Mock()
  294. self.t.stop(['foo', 'bar', 'baz'], 'celery worker', callback=callback)
  295. sigs = sorted(self.t.signal_node.call_args_list)
  296. self.assertEqual(len(sigs), 2)
  297. self.assertIn(
  298. ('foo@e.com', 10, signal.SIGTERM),
  299. [tup[0] for tup in sigs],
  300. )
  301. self.assertIn(
  302. ('bar@e.com', 11, signal.SIGTERM),
  303. [tup[0] for tup in sigs],
  304. )
  305. self.t.signal_node.return_value = False
  306. self.assertTrue(callback.called)
  307. self.t.stop(['foo', 'bar', 'baz'], 'celery worker', callback=None)
  308. def on_node_alive(pid):
  309. if node_alive.call_count > 4:
  310. return True
  311. return False
  312. self.t.signal_node.return_value = True
  313. self.t.node_alive.side_effect = on_node_alive
  314. self.t.stop(['foo', 'bar', 'baz'], 'celery worker', retry=True)
  315. @patch('os.kill')
  316. def test_node_alive(self, kill):
  317. kill.return_value = True
  318. self.assertTrue(self.t.node_alive(13))
  319. esrch = OSError()
  320. esrch.errno = errno.ESRCH
  321. kill.side_effect = esrch
  322. self.assertFalse(self.t.node_alive(13))
  323. kill.assert_called_with(13, 0)
  324. enoent = OSError()
  325. enoent.errno = errno.ENOENT
  326. kill.side_effect = enoent
  327. with self.assertRaises(OSError):
  328. self.t.node_alive(13)
  329. @patch('os.kill')
  330. def test_signal_node(self, kill):
  331. kill.return_value = True
  332. self.assertTrue(self.t.signal_node('foo', 13, 9))
  333. esrch = OSError()
  334. esrch.errno = errno.ESRCH
  335. kill.side_effect = esrch
  336. self.assertFalse(self.t.signal_node('foo', 13, 9))
  337. kill.assert_called_with(13, 9)
  338. self.assertIn('Could not signal foo', self.fh.getvalue())
  339. enoent = OSError()
  340. enoent.errno = errno.ENOENT
  341. kill.side_effect = enoent
  342. with self.assertRaises(OSError):
  343. self.t.signal_node('foo', 13, 9)
  344. def test_start(self):
  345. self.t.waitexec = Mock()
  346. self.t.waitexec.return_value = 0
  347. self.assertFalse(self.t.start(['foo', 'bar', 'baz'], 'celery worker'))
  348. self.t.waitexec.return_value = 1
  349. self.assertFalse(self.t.start(['foo', 'bar', 'baz'], 'celery worker'))
  350. def test_show(self):
  351. self.t.show(['foo', 'bar', 'baz'], 'celery worker')
  352. self.assertTrue(self.fh.getvalue())
  353. @patch('celery.bin.multi.gethostname')
  354. def test_get(self, gethostname):
  355. gethostname.return_value = 'e.com'
  356. self.t.get(['xuzzy@e.com', 'foo', 'bar', 'baz'], 'celery worker')
  357. self.assertFalse(self.fh.getvalue())
  358. self.t.get(['foo@e.com', 'foo', 'bar', 'baz'], 'celery worker')
  359. self.assertTrue(self.fh.getvalue())
  360. @patch('celery.bin.multi.gethostname')
  361. def test_names(self, gethostname):
  362. gethostname.return_value = 'e.com'
  363. self.t.names(['foo', 'bar', 'baz'], 'celery worker')
  364. self.assertIn('foo@e.com\nbar@e.com\nbaz@e.com', self.fh.getvalue())
  365. def test_execute_from_commandline(self):
  366. start = self.t.commands['start'] = Mock()
  367. self.t.error = Mock()
  368. self.t.execute_from_commandline(['multi', 'start', 'foo', 'bar'])
  369. self.assertFalse(self.t.error.called)
  370. start.assert_called_with(['foo', 'bar'], 'celery worker')
  371. self.t.error = Mock()
  372. self.t.execute_from_commandline(['multi', 'frob', 'foo', 'bar'])
  373. self.t.error.assert_called_with('Invalid command: frob')
  374. self.t.error = Mock()
  375. self.t.execute_from_commandline(['multi'])
  376. self.t.error.assert_called_with()
  377. self.t.error = Mock()
  378. self.t.execute_from_commandline(['multi', '-foo'])
  379. self.t.error.assert_called_with()
  380. self.t.execute_from_commandline(
  381. ['multi', 'start', 'foo',
  382. '--nosplash', '--quiet', '-q', '--verbose', '--no-color'],
  383. )
  384. self.assertTrue(self.t.nosplash)
  385. self.assertTrue(self.t.quiet)
  386. self.assertTrue(self.t.verbose)
  387. self.assertTrue(self.t.no_color)
  388. def test_stopwait(self):
  389. self.t._stop_nodes = Mock()
  390. self.t.stopwait(['foo', 'bar', 'baz'], 'celery worker')
  391. self.assertEqual(self.t._stop_nodes.call_args[1]['retry'], 2)
  392. @patch('celery.bin.multi.MultiTool')
  393. def test_main(self, MultiTool):
  394. m = MultiTool.return_value = Mock()
  395. with self.assertRaises(SystemExit):
  396. main()
  397. m.execute_from_commandline.assert_called_with(sys.argv)