multi.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. from __future__ import absolute_import, unicode_literals
  2. import errno
  3. import os
  4. import shlex
  5. import signal
  6. import sys
  7. from collections import OrderedDict, defaultdict
  8. from functools import partial
  9. from subprocess import Popen
  10. from time import sleep
  11. from kombu.utils.encoding import from_utf8
  12. from kombu.utils.objects import cached_property
  13. from celery.five import UserList, items
  14. from celery.platforms import IS_WINDOWS, Pidfile, signal_name, signals
  15. from celery.utils.nodenames import (
  16. gethostname, host_format, node_format, nodesplit,
  17. )
  18. from celery.utils.saferepr import saferepr
  19. __all__ = ['Cluster', 'Node']
  20. CELERY_EXE = 'celery'
  21. def celery_exe(*args):
  22. return ' '.join((CELERY_EXE,) + args)
  23. class NamespacedOptionParser(object):
  24. def __init__(self, args):
  25. self.args = args
  26. self.options = OrderedDict()
  27. self.values = []
  28. self.passthrough = ''
  29. self.namespaces = defaultdict(lambda: OrderedDict())
  30. self.parse()
  31. def parse(self):
  32. rargs = list(self.args)
  33. pos = 0
  34. while pos < len(rargs):
  35. arg = rargs[pos]
  36. if arg == '--':
  37. self.passthrough = ' '.join(rargs[pos:])
  38. break
  39. elif arg[0] == '-':
  40. if arg[1] == '-':
  41. self.process_long_opt(arg[2:])
  42. else:
  43. value = None
  44. if len(rargs) > pos + 1 and rargs[pos + 1][0] != '-':
  45. value = rargs[pos + 1]
  46. pos += 1
  47. self.process_short_opt(arg[1:], value)
  48. else:
  49. self.values.append(arg)
  50. pos += 1
  51. def process_long_opt(self, arg, value=None):
  52. if '=' in arg:
  53. arg, value = arg.split('=', 1)
  54. self.add_option(arg, value, short=False)
  55. def process_short_opt(self, arg, value=None):
  56. self.add_option(arg, value, short=True)
  57. def optmerge(self, ns, defaults=None):
  58. if defaults is None:
  59. defaults = self.options
  60. return OrderedDict(defaults, **self.namespaces[ns])
  61. def add_option(self, name, value, short=False, ns=None):
  62. prefix = short and '-' or '--'
  63. dest = self.options
  64. if ':' in name:
  65. name, ns = name.split(':')
  66. dest = self.namespaces[ns]
  67. dest[prefix + name] = value
  68. class Node(object):
  69. def __init__(self, name, argv, expander, namespace, p):
  70. self.p = p
  71. self.name = name
  72. self.argv = tuple(argv)
  73. self.expander = expander
  74. self.namespace = namespace
  75. self._pid = None
  76. def alive(self):
  77. return self.send(0)
  78. def send(self, sig, on_error=None):
  79. pid = self.pid
  80. if pid:
  81. try:
  82. os.kill(pid, sig)
  83. except OSError as exc:
  84. if exc.errno != errno.ESRCH:
  85. raise
  86. maybe_call(on_error, self)
  87. return False
  88. return True
  89. maybe_call(on_error, self)
  90. def start(self, env=None, **kwargs):
  91. return self._waitexec(
  92. self.argv, path=self.executable, env=env, **kwargs)
  93. def _waitexec(self, argv, path=sys.executable, env=None,
  94. on_spawn=None, on_signalled=None, on_failure=None):
  95. argstr = self.prepare_argv(argv, path)
  96. maybe_call(on_spawn, self, argstr=' '.join(argstr), env=env)
  97. pipe = Popen(argstr, env=env)
  98. return self.handle_process_exit(
  99. pipe.wait(),
  100. on_signalled=on_signalled,
  101. on_failure=on_failure,
  102. )
  103. def handle_process_exit(self, retcode, on_signalled=None, on_failure=None):
  104. if retcode < 0:
  105. maybe_call(on_signalled, self, -retcode)
  106. return -retcode
  107. elif retcode > 0:
  108. maybe_call(on_failure, self, retcode)
  109. return retcode
  110. def prepare_argv(self, argv, path):
  111. args = ' '.join([path] + list(argv))
  112. return shlex.split(from_utf8(args), posix=not IS_WINDOWS)
  113. def getopt(self, *alt):
  114. try:
  115. return self._getnsopt(*alt)
  116. except KeyError:
  117. return self._getoptopt(*alt)
  118. def _getnsopt(self, *alt):
  119. return self._getopt(self.p.namespaces[self.namespace], list(alt))
  120. def _getoptopt(self, *alt):
  121. return self._getopt(self.p.options, list(alt))
  122. def _getopt(self, d, alt):
  123. for opt in alt:
  124. try:
  125. return d[opt]
  126. except KeyError:
  127. pass
  128. raise KeyError(alt[0])
  129. def __repr__(self):
  130. return '<{name}: {0.name}>'.format(self, name=type(self).__name__)
  131. @cached_property
  132. def pidfile(self):
  133. return self.expander(self.getopt('--pidfile', '-p'))
  134. @cached_property
  135. def logfile(self):
  136. return self.expander(self.getopt('--logfile', '-f'))
  137. @property
  138. def pid(self):
  139. if self._pid is not None:
  140. return self._pid
  141. try:
  142. return Pidfile(self.pidfile).read_pid()
  143. except ValueError:
  144. pass
  145. @pid.setter
  146. def pid(self, value):
  147. self._pid = value
  148. @cached_property
  149. def executable(self):
  150. return self.p.options['--executable']
  151. @cached_property
  152. def argv_with_executable(self):
  153. return (self.executable,) + self.argv
  154. def maybe_call(fun, *args, **kwargs):
  155. if fun is not None:
  156. fun(*args, **kwargs)
  157. class MultiParser(object):
  158. Node = Node
  159. def __init__(self, cmd='celery worker',
  160. append='', prefix='', suffix='',
  161. range_prefix='celery'):
  162. self.cmd = cmd
  163. self.append = append
  164. self.prefix = prefix
  165. self.suffix = suffix
  166. self.range_prefix = range_prefix
  167. def parse(self, p):
  168. names = p.values
  169. options = dict(p.options)
  170. ranges = len(names) == 1
  171. prefix = self.prefix
  172. if ranges:
  173. try:
  174. names, prefix = self._get_ranges(names), self.range_prefix
  175. except ValueError:
  176. pass
  177. cmd = options.pop('--cmd', self.cmd)
  178. append = options.pop('--append', self.append)
  179. hostname = options.pop('--hostname', options.pop('-n', gethostname()))
  180. prefix = options.pop('--prefix', prefix) or ''
  181. suffix = options.pop('--suffix', self.suffix) or hostname
  182. suffix = '' if suffix in ('""', "''") else suffix
  183. self._update_ns_opts(p, names)
  184. self._update_ns_ranges(p, ranges)
  185. return (
  186. self._args_for_node(p, name, prefix, suffix, cmd, append, options)
  187. for name in names
  188. )
  189. def _get_ranges(self, names):
  190. noderange = int(names[0])
  191. return [str(n) for n in range(1, noderange + 1)]
  192. def _update_ns_opts(self, p, names):
  193. # Numbers in args always refers to the index in the list of names.
  194. # (e.g. `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
  195. for ns_name, ns_opts in list(items(p.namespaces)):
  196. if ns_name.isdigit():
  197. ns_index = int(ns_name) - 1
  198. if ns_index < 0:
  199. raise KeyError('Indexes start at 1 got: %r' % (ns_name,))
  200. try:
  201. p.namespaces[names[ns_index]].update(ns_opts)
  202. except IndexError:
  203. raise KeyError('No node at index %r' % (ns_name,))
  204. def _update_ns_ranges(self, p, ranges):
  205. for ns_name, ns_opts in list(items(p.namespaces)):
  206. if ',' in ns_name or (ranges and '-' in ns_name):
  207. for subns in self._parse_ns_range(ns_name, ranges):
  208. p.namespaces[subns].update(ns_opts)
  209. p.namespaces.pop(ns_name)
  210. def _parse_ns_range(self, ns, ranges=False):
  211. ret = []
  212. for space in ',' in ns and ns.split(',') or [ns]:
  213. if ranges and '-' in space:
  214. start, stop = space.split('-')
  215. ret.extend(
  216. str(n) for n in range(int(start), int(stop) + 1)
  217. )
  218. else:
  219. ret.append(space)
  220. return ret
  221. def _args_for_node(self, p, name, prefix, suffix, cmd, append, options):
  222. name, nodename, expand = self._get_nodename(
  223. name, prefix, suffix, options)
  224. if nodename in p.namespaces:
  225. ns = nodename
  226. else:
  227. ns = name
  228. argv = (
  229. [expand(cmd)] +
  230. [self.format_opt(opt, expand(value))
  231. for opt, value in items(p.optmerge(ns, options))] +
  232. [p.passthrough]
  233. )
  234. if append:
  235. argv.append(expand(append))
  236. return self.Node(nodename, argv, expand, name, p)
  237. def _get_nodename(self, name, prefix, suffix, options):
  238. hostname = suffix
  239. if '@' in name:
  240. nodename = options['-n'] = host_format(name)
  241. shortname, hostname = nodesplit(nodename)
  242. name = shortname
  243. else:
  244. shortname = '%s%s' % (prefix, name)
  245. nodename = options['-n'] = host_format(
  246. '{0}@{1}'.format(shortname, hostname),
  247. )
  248. expand = partial(
  249. node_format, nodename=nodename, N=shortname, d=hostname,
  250. h=nodename, i='%i', I='%I',
  251. )
  252. return name, nodename, expand
  253. def format_opt(self, opt, value):
  254. if not value:
  255. return opt
  256. if opt.startswith('--'):
  257. return '{0}={1}'.format(opt, value)
  258. return '{0} {1}'.format(opt, value)
  259. class Cluster(UserList):
  260. MultiParser = MultiParser
  261. OptionParser = NamespacedOptionParser
  262. def __init__(self, argv, cmd=None, env=None,
  263. on_stopping_preamble=None,
  264. on_send_signal=None,
  265. on_still_waiting_for=None,
  266. on_still_waiting_progress=None,
  267. on_still_waiting_end=None,
  268. on_node_start=None,
  269. on_node_restart=None,
  270. on_node_shutdown_ok=None,
  271. on_node_status=None,
  272. on_node_signal=None,
  273. on_node_signal_dead=None,
  274. on_node_down=None,
  275. on_child_spawn=None,
  276. on_child_signalled=None,
  277. on_child_failure=None):
  278. self.argv = argv
  279. self.cmd = cmd or celery_exe('worker')
  280. self.env = env
  281. self.p = self.OptionParser(argv)
  282. self.with_detacher_default_options(self.p)
  283. self.on_stopping_preamble = on_stopping_preamble
  284. self.on_send_signal = on_send_signal
  285. self.on_still_waiting_for = on_still_waiting_for
  286. self.on_still_waiting_progress = on_still_waiting_progress
  287. self.on_still_waiting_end = on_still_waiting_end
  288. self.on_node_start = on_node_start
  289. self.on_node_restart = on_node_restart
  290. self.on_node_shutdown_ok = on_node_shutdown_ok
  291. self.on_node_status = on_node_status
  292. self.on_node_signal = on_node_signal
  293. self.on_node_signal_dead = on_node_signal_dead
  294. self.on_node_down = on_node_down
  295. self.on_child_spawn = on_child_spawn
  296. self.on_child_signalled = on_child_signalled
  297. self.on_child_failure = on_child_failure
  298. def start(self):
  299. return [self.start_node(node) for node in self]
  300. def start_node(self, node):
  301. maybe_call(self.on_node_start, node)
  302. retcode = self._start_node(node)
  303. maybe_call(self.on_node_status, node, retcode)
  304. return retcode
  305. def _start_node(self, node):
  306. return node.start(
  307. self.env,
  308. on_spawn=self.on_child_spawn,
  309. on_signalled=self.on_child_signalled,
  310. on_failure=self.on_child_failure,
  311. )
  312. def send_all(self, sig):
  313. for node in self.getpids(on_down=self.on_node_down):
  314. maybe_call(self.on_node_signal, node, signal_name(sig))
  315. node.send(sig, self.on_node_signal_dead)
  316. def kill(self):
  317. return self.send_all(signal.SIGKILL)
  318. def restart(self):
  319. retvals = []
  320. def restart_on_down(node):
  321. maybe_call(self.on_node_restart, node)
  322. retval = self._start_node(node)
  323. maybe_call(self.on_node_status, node, retval)
  324. retvals.append(retval)
  325. self._stop_nodes(retry=2, on_down=restart_on_down)
  326. return retvals
  327. def stop(self, retry=None, callback=None):
  328. return self._stop_nodes(retry=retry, on_down=callback)
  329. def stopwait(self, retry=2, callback=None):
  330. return self._stop_nodes(retry=retry, on_down=callback)
  331. def _stop_nodes(self, retry=None, on_down=None):
  332. on_down = on_down if on_down is not None else self.on_node_down
  333. restargs = self.p.args[len(self.p.values):]
  334. nodes = list(self.getpids(on_down=on_down))
  335. if nodes:
  336. for node in self.shutdown_nodes(
  337. nodes,
  338. sig=self._find_sig_argument(restargs),
  339. retry=retry):
  340. maybe_call(on_down, node)
  341. def _find_sig_argument(self, args, default=signal.SIGTERM):
  342. for arg in reversed(args):
  343. if len(arg) == 2 and arg[0] == '-':
  344. try:
  345. return int(arg[1])
  346. except ValueError:
  347. pass
  348. if arg[0] == '-':
  349. try:
  350. return signals.signum(arg[1:])
  351. except (AttributeError, TypeError):
  352. pass
  353. return default
  354. def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None):
  355. P = set(nodes)
  356. maybe_call(self.on_stopping_preamble, nodes)
  357. to_remove = set()
  358. for node in P:
  359. maybe_call(self.on_send_signal, node, signal_name(sig))
  360. if not node.send(sig, self.on_node_signal_dead):
  361. to_remove.add(node)
  362. yield node
  363. P -= to_remove
  364. if retry:
  365. maybe_call(self.on_still_waiting_for, P)
  366. its = 0
  367. while P:
  368. to_remove = set()
  369. for node in P:
  370. its += 1
  371. maybe_call(self.on_still_waiting_progress, P)
  372. if not node.alive():
  373. maybe_call(self.on_node_shutdown_ok, node)
  374. to_remove.add(node)
  375. yield node
  376. maybe_call(self.on_still_waiting_for, P)
  377. break
  378. P -= to_remove
  379. if P and not its % len(P):
  380. sleep(float(retry))
  381. maybe_call(self.on_still_waiting_end)
  382. def find(self, name):
  383. for node in self:
  384. if node.name == name:
  385. return node
  386. raise KeyError(name)
  387. def with_detacher_default_options(self, p):
  388. self._setdefaultopt(p.options, ['--pidfile', '-p'], '%n.pid')
  389. self._setdefaultopt(p.options, ['--logfile', '-f'], '%n%I.log')
  390. self._setdefaultopt(p.options, ['--executable'], sys.executable)
  391. p.options.setdefault(
  392. '--cmd',
  393. '-m {0}'.format(celery_exe('worker', '--detach')),
  394. )
  395. def _setdefaultopt(self, d, alt, value):
  396. for opt in alt[1:]:
  397. try:
  398. return d[opt]
  399. except KeyError:
  400. pass
  401. return d.setdefault(alt[0], value)
  402. def getpids(self, on_down=None):
  403. for node in self:
  404. if node.pid:
  405. yield node
  406. else:
  407. maybe_call(on_down, node)
  408. def __repr__(self):
  409. return '<{name}({0}): {1}>'.format(
  410. len(self), saferepr([n.name for n in self]),
  411. name=type(self).__name__,
  412. )
  413. @cached_property
  414. def data(self):
  415. return list(self.MultiParser(cmd=self.cmd).parse(self.p))