Przeglądaj źródła

Multi: Cluster/Node now gives programmatic access to Multi functionality

Ask Solem 8 lat temu
rodzic
commit
a2d4c655f7

+ 471 - 0
celery/apps/multi.py

@@ -0,0 +1,471 @@
+from __future__ import absolute_import, unicode_literals
+
+import errno
+import os
+import shlex
+import signal
+import sys
+
+from collections import OrderedDict, defaultdict
+from functools import partial
+from subprocess import Popen
+from time import sleep
+
+from kombu.utils.encoding import from_utf8
+from kombu.utils.objects import cached_property
+
+from celery.five import UserList, items
+from celery.platforms import IS_WINDOWS, Pidfile, signal_name, signals
+from celery.utils.nodenames import (
+    gethostname, host_format, node_format, nodesplit,
+)
+
+__all__ = ['Cluster', 'Node']
+
+CELERY_EXE = 'celery'
+
+
+def celery_exe(*args):
+    return ' '.join((CELERY_EXE,) + args)
+
+
+class NamespacedOptionParser(object):
+
+    def __init__(self, args):
+        self.args = args
+        self.options = OrderedDict()
+        self.values = []
+        self.passthrough = ''
+        self.namespaces = defaultdict(lambda: OrderedDict())
+
+        self.parse()
+
+    def parse(self):
+        rargs = list(self.args)
+        pos = 0
+        while pos < len(rargs):
+            arg = rargs[pos]
+            if arg == '--':
+                self.passthrough = ' '.join(rargs[pos:])
+                break
+            elif arg[0] == '-':
+                if arg[1] == '-':
+                    self.process_long_opt(arg[2:])
+                else:
+                    value = None
+                    if len(rargs) > pos + 1 and rargs[pos + 1][0] != '-':
+                        value = rargs[pos + 1]
+                        pos += 1
+                    self.process_short_opt(arg[1:], value)
+            else:
+                self.values.append(arg)
+            pos += 1
+
+    def process_long_opt(self, arg, value=None):
+        if '=' in arg:
+            arg, value = arg.split('=', 1)
+        self.add_option(arg, value, short=False)
+
+    def process_short_opt(self, arg, value=None):
+        self.add_option(arg, value, short=True)
+
+    def optmerge(self, ns, defaults=None):
+        if defaults is None:
+            defaults = self.options
+        return OrderedDict(defaults, **self.namespaces[ns])
+
+    def add_option(self, name, value, short=False, ns=None):
+        prefix = short and '-' or '--'
+        dest = self.options
+        if ':' in name:
+            name, ns = name.split(':')
+            dest = self.namespaces[ns]
+        dest[prefix + name] = value
+
+
+class Node(object):
+
+    def __init__(self, name, argv, expander, namespace, p):
+        self.p = p
+        self.name = name
+        self.argv = tuple(argv)
+        self.expander = expander
+        self.namespace = namespace
+
+    def alive(self):
+        return self.send(0)
+
+    def send(self, sig, on_error=None):
+        try:
+            os.kill(self.pid, sig)
+        except OSError as exc:
+            if exc.errno != errno.ESRCH:
+                raise
+            maybe_call(on_error, self)
+            return False
+        return True
+
+    def start(self, env=None, **kwargs):
+        return self._waitexec(
+            self.argv, path=self.executable, env=env, **kwargs)
+
+    def _waitexec(self, argv, path=sys.executable, env=None,
+                  on_spawn=None, on_signalled=None, on_failure=None):
+        argstr = self.prepare_argv(argv, path)
+        maybe_call(on_spawn, self, argstr=' '.join(argstr), env=env)
+        pipe = Popen(argstr, env=env)
+        return self.handle_process_exit(
+            pipe.wait(),
+            on_signalled=on_signalled,
+            on_failure=on_failure,
+        )
+
+    def handle_process_exit(self, retcode, on_signalled=None, on_failure=None):
+        if retcode < 0:
+            maybe_call(on_signalled, self, -retcode)
+            return -retcode
+        elif retcode > 0:
+            maybe_call(on_failure, self, retcode)
+        return retcode
+
+    def prepare_argv(self, argv, path):
+        args = ' '.join([path] + list(argv))
+        return shlex.split(from_utf8(args), posix=not IS_WINDOWS)
+
+    def getopt(self, *alt):
+        try:
+            return self._getnsopt(*alt)
+        except KeyError:
+            return self._getoptopt(*alt)
+
+    def _getnsopt(self, *alt):
+        return self._getopt(self.p.namespaces[self.namespace], list(alt))
+
+    def _getoptopt(self, *alt):
+        return self._getopt(self.p.options, list(alt))
+
+    def _getopt(self, d, alt):
+        for opt in alt:
+            try:
+                return d[opt]
+            except KeyError:
+                pass
+        raise KeyError(alt[0])
+
+    @cached_property
+    def pidfile(self):
+        return self.expander(self.getopt('--pidfile', '-p'))
+
+    @cached_property
+    def logfile(self):
+        return self.expander(self.getopt('--logfile', '-f'))
+
+    @cached_property
+    def pid(self):
+        try:
+            return Pidfile(self.pidfile).read_pid()
+        except ValueError:
+            pass
+
+    @cached_property
+    def executable(self):
+        return self.p.options['--executable']
+
+    @cached_property
+    def argv_with_executable(self):
+        return (self.executable,) + self.argv
+
+
+def maybe_call(fun, *args, **kwargs):
+    if fun is not None:
+        fun(*args, **kwargs)
+
+
+class MultiParser(object):
+    Node = Node
+
+    def __init__(self, cmd='celery worker',
+                 append='', prefix='', suffix='',
+                 range_prefix='celery'):
+        self.cmd = cmd
+        self.append = append
+        self.prefix = prefix
+        self.suffix = suffix
+        self.range_prefix = range_prefix
+
+    def parse(self, p):
+        names = p.values
+        options = dict(p.options)
+        ranges = len(names) == 1
+        prefix = self.prefix
+        if ranges:
+            try:
+                names, prefix = self._get_ranges(names), self.range_prefix
+            except ValueError:
+                pass
+        cmd = options.pop('--cmd', self.cmd)
+        append = options.pop('--append', self.append)
+        hostname = options.pop('--hostname', options.pop('-n', gethostname()))
+        prefix = options.pop('--prefix', prefix) or ''
+        suffix = options.pop('--suffix', self.suffix) or hostname
+        suffix = '' if suffix in ('""', "''") else suffix
+
+        self._update_ns_opts(p, names)
+        self._update_ns_ranges(p, ranges)
+        return (
+            self._args_for_node(p, name, prefix, suffix, cmd, append, options)
+            for name in names
+        )
+
+    def _get_ranges(self, names):
+        noderange = int(names[0])
+        return [str(n) for n in range(1, noderange + 1)]
+
+    def _update_ns_opts(self, p, names):
+        # Numbers in args always refers to the index in the list of names.
+        # (e.g. `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
+        for ns_name, ns_opts in list(items(p.namespaces)):
+            if ns_name.isdigit():
+                ns_index = int(ns_name) - 1
+                if ns_index < 0:
+                    raise KeyError('Indexes start at 1 got: %r' % (ns_name,))
+                try:
+                    p.namespaces[names[ns_index]].update(ns_opts)
+                except IndexError:
+                    raise KeyError('No node at index %r' % (ns_name,))
+
+    def _update_ns_ranges(self, p, ranges):
+        for ns_name, ns_opts in list(items(p.namespaces)):
+            if ',' in ns_name or (ranges and '-' in ns_name):
+                for subns in self._parse_ns_range(ns_name, ranges):
+                    p.namespaces[subns].update(ns_opts)
+                p.namespaces.pop(ns_name)
+
+    def _parse_ns_range(self, ns, ranges=False):
+        ret = []
+        for space in ',' in ns and ns.split(',') or [ns]:
+            if ranges and '-' in space:
+                start, stop = space.split('-')
+                ret.extend(
+                    str(n) for n in range(int(start), int(stop) + 1)
+                )
+            else:
+                ret.append(space)
+        return ret
+
+    def _args_for_node(self, p, name, prefix, suffix, cmd, append, options):
+        name, nodename, expand = self._get_nodename(
+            name, prefix, suffix, options)
+
+        if nodename in p.namespaces:
+            ns = nodename
+        else:
+            ns = name
+
+        argv = (
+            [expand(cmd)] +
+            [self.format_opt(opt, expand(value))
+                for opt, value in items(p.optmerge(ns, options))] +
+            [p.passthrough]
+        )
+        if append:
+            argv.append(expand(append))
+        return self.Node(nodename, argv, expand, name, p)
+
+    def _get_nodename(self, name, prefix, suffix, options):
+        hostname = suffix
+        if '@' in name:
+            nodename = options['-n'] = host_format(name)
+            shortname, hostname = nodesplit(nodename)
+            name = shortname
+        else:
+            shortname = '%s%s' % (prefix, name)
+            nodename = options['-n'] = host_format(
+                '{0}@{1}'.format(shortname, hostname),
+            )
+        expand = partial(
+            node_format, nodename=nodename, N=shortname, d=hostname,
+            h=nodename, i='%i', I='%I',
+        )
+        return name, nodename, expand
+
+    def format_opt(self, opt, value):
+        if not value:
+            return opt
+        if opt.startswith('--'):
+            return '{0}={1}'.format(opt, value)
+        return '{0} {1}'.format(opt, value)
+
+
+class Cluster(UserList):
+    MultiParser = MultiParser
+    OptionParser = NamespacedOptionParser
+
+    def __init__(self, argv, cmd=None, env=None,
+                 on_stopping_preamble=None,
+                 on_send_signal=None,
+                 on_still_waiting_for=None,
+                 on_still_waiting_progress=None,
+                 on_still_waiting_end=None,
+                 on_node_start=None,
+                 on_node_restart=None,
+                 on_node_shutdown_ok=None,
+                 on_node_status=None,
+                 on_node_signal=None,
+                 on_node_signal_dead=None,
+                 on_node_down=None,
+                 on_child_spawn=None,
+                 on_child_signalled=None,
+                 on_child_failure=None):
+        self.argv = argv
+        self.cmd = cmd or celery_exe('worker')
+        self.env = env
+        self.p = self.OptionParser(argv)
+        self.with_detacher_default_options(self.p)
+
+        self.on_stopping_preamble = on_stopping_preamble
+        self.on_send_signal = on_send_signal
+        self.on_still_waiting_for = on_still_waiting_for
+        self.on_still_waiting_progress = on_still_waiting_progress
+        self.on_still_waiting_end = on_still_waiting_end
+        self.on_node_start = on_node_start
+        self.on_node_restart = on_node_restart
+        self.on_node_shutdown_ok = on_node_shutdown_ok
+        self.on_node_status = on_node_status
+        self.on_node_signal = on_node_signal
+        self.on_node_signal_dead = on_node_signal_dead
+        self.on_node_down = on_node_down
+        self.on_child_spawn = on_child_spawn
+        self.on_child_signalled = on_child_signalled
+        self.on_child_failure = on_child_failure
+
+    def start(self):
+        return [self.start_node(node) for node in self]
+
+    def start_node(self, node):
+        maybe_call(self.on_node_start, node)
+        retcode = self._start_node(node)
+        maybe_call(self.on_node_status, node, retcode)
+        return retcode
+
+    def _start_node(self, node):
+        return node.start(
+            self.env,
+            on_spawn=self.on_child_spawn,
+            on_signalled=self.on_child_signalled,
+            on_failure=self.on_child_failure,
+        )
+
+    def send_all(self, sig):
+        for node in self.getpids(on_down=self.on_node_down):
+            maybe_call(self.on_node_signal, node, signal_name(sig))
+            node.send(sig, self.on_node_signal_dead)
+
+    def kill(self):
+        return self.send_all(signal.SIGKILL)
+
+    def restart(self):
+        retvals = []
+
+        def restart_on_down(node):
+            maybe_call(self.on_node_restart, node)
+            retval = self._start_node(node)
+            maybe_call(self.on_node_status, node, retval)
+            retvals.append(retval)
+
+        self._stop_nodes(retry=2, on_down=restart_on_down)
+        return retvals
+
+    def stop(self, retry=None, callback=None):
+        return self._stop_nodes(retry=retry, on_down=callback)
+
+    def stopwait(self, retry=2, callback=None):
+        return self._stop_nodes(retry=retry, on_down=callback)
+
+    def _stop_nodes(self, retry=None, on_down=None):
+        on_down = on_down if on_down is not None else self.on_node_down
+        restargs = self.p.args[len(self.p.values):]
+        nodes = list(self.getpids(on_down=on_down))
+        if nodes:
+            for node in self.shutdown_nodes(
+                    nodes,
+                    sig=self._find_sig_argument(restargs),
+                    retry=retry):
+                maybe_call(on_down, node)
+
+    def _find_sig_argument(self, args, default=signal.SIGTERM):
+        for arg in reversed(args):
+            if len(arg) == 2 and arg[0] == '-':
+                try:
+                    return int(arg[1])
+                except ValueError:
+                    pass
+            if arg[0] == '-':
+                try:
+                    return signals.signum(arg[1:])
+                except (AttributeError, TypeError):
+                    pass
+        return default
+
+    def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None):
+        P = set(nodes)
+        maybe_call(self.on_stopping_preamble, nodes)
+        to_remove = set()
+        for node in P:
+            maybe_call(self.on_send_signal, node, signal_name(sig))
+            if not node.send(sig, self.on_node_signal_dead):
+                to_remove.add(node)
+                yield node
+        P -= to_remove
+        if retry:
+            maybe_call(self.on_still_waiting_for, P)
+            its = 0
+            while P:
+                to_remove = set()
+                for node in P:
+                    its += 1
+                    maybe_call(self.on_still_waiting_progress, P)
+                    if not node.alive():
+                        maybe_call(self.on_node_shutdown_ok, node)
+                        to_remove.add(node)
+                        yield node
+                        maybe_call(self.on_still_waiting_for, P)
+                        break
+                P -= to_remove
+                if P and not its % len(P):
+                    sleep(float(retry))
+            maybe_call(self.on_still_waiting_end)
+
+    def find(self, name):
+        for node in self:
+            if node.name == name:
+                return node
+        raise KeyError(name)
+
+    def with_detacher_default_options(self, p):
+        self._setdefaultopt(p.options, ['--pidfile', '-p'], '%n.pid')
+        self._setdefaultopt(p.options, ['--logfile', '-f'], '%n%I.log')
+        self._setdefaultopt(p.options, ['--executable'], sys.executable)
+        p.options.setdefault(
+            '--cmd',
+            '-m {0}'.format(celery_exe('worker', '--detach')),
+        )
+
+    def _setdefaultopt(self, d, alt, value):
+        for opt in alt[1:]:
+            try:
+                return d[opt]
+            except KeyError:
+                pass
+        return d.setdefault(alt[0], value)
+
+    def getpids(self, on_down=None):
+        for node in self:
+            if node.pid:
+                yield node
+            else:
+                maybe_call(on_down, node)
+
+    @cached_property
+    def data(self):
+        return list(self.MultiParser(cmd=self.cmd).parse(self.p))

+ 197 - 458
celery/bin/multi.py

@@ -95,35 +95,21 @@ Examples
 """
 from __future__ import absolute_import, print_function, unicode_literals
 
-import errno
 import os
-import shlex
-import signal
 import sys
 
-from collections import OrderedDict, defaultdict, namedtuple
-from functools import partial
-from subprocess import Popen
-from time import sleep
+from functools import wraps
 
-from kombu.utils.encoding import from_utf8
 from kombu.utils.objects import cached_property
 
 from celery import VERSION_BANNER
-from celery.five import items
-from celery.platforms import Pidfile, IS_WINDOWS
+from celery.apps.multi import Cluster
+from celery.platforms import EX_FAILURE, EX_OK
 from celery.utils import term
-from celery.utils.nodenames import (
-    gethostname, host_format, node_format, nodesplit,
-)
 from celery.utils.text import pluralize
 
 __all__ = ['MultiTool']
 
-SIGNAMES = {sig for sig in dir(signal)
-            if sig.startswith('SIG') and '_' not in sig}
-SIGMAP = {getattr(signal, name): name for name in SIGNAMES}
-
 USAGE = """\
 usage: {prog_name} start <node1 node2 nodeN|range> [worker options]
        {prog_name} stop <n1 n2 nN|range> [-SIG (default: -TERM)]
@@ -143,34 +129,94 @@ additional options (must appear after command name):
     * --verbose:    Show more output.
     * --no-color:   Don't display colors.
 """
-CELERY_EXE = 'celery'
-
-multi_args_t = namedtuple(
-    'multi_args_t', ('name', 'argv', 'expander', 'namespace'),
-)
 
 
 def main():
     sys.exit(MultiTool().execute_from_commandline(sys.argv))
 
 
-def celery_exe(*args):
-    return ' '.join((CELERY_EXE,) + args)
+def splash(fun):
 
+    @wraps(fun)
+    def _inner(self, *args, **kwargs):
+        self.splash()
+        return fun(self, *args, **kwargs)
+    return _inner
 
-class MultiTool(object):
-    retcode = 0  # Final exit code.
 
-    def __init__(self, env=None, fh=None, quiet=False, verbose=False,
-                 no_color=False, nosplash=False, stdout=None, stderr=None):
-        """fh is an old alias to stdout."""
-        self.stdout = self.fh = stdout or fh or sys.stdout
+class TermLogger(object):
+
+    splash_text = 'celery multi v{version}'
+    splash_context = {'version': VERSION_BANNER}
+
+    #: Final exit code.
+    retcode = 0
+
+    def setup_terminal(self, stdout, stderr,
+                       nosplash=False, quiet=False, verbose=False,
+                       no_color=False, **kwargs):
+        self.stdout = stdout or sys.stdout
         self.stderr = stderr or sys.stderr
-        self.env = env
         self.nosplash = nosplash
         self.quiet = quiet
         self.verbose = verbose
         self.no_color = no_color
+
+    def ok(self, m, newline=True, file=None):
+        self.say(m, newline=newline, file=file)
+        return EX_OK
+
+    def say(self, m, newline=True, file=None):
+        print(m, file=file or self.stdout, end='\n' if newline else '')
+
+    def carp(self, m, newline=True, file=None):
+        return self.say(m, newline, file or self.stderr)
+
+    def error(self, msg=None):
+        if msg:
+            self.carp(msg)
+        self.usage()
+        return EX_FAILURE
+
+    def info(self, msg, newline=True):
+        if self.verbose:
+            self.note(msg, newline=newline)
+
+    def note(self, msg, newline=True):
+        if not self.quiet:
+            self.say(str(msg), newline=newline)
+
+    @splash
+    def usage(self):
+        self.say(USAGE.format(prog_name=self.prog_name))
+
+    def splash(self):
+        if not self.nosplash:
+            self.note(self.colored.cyan(
+                self.splash_text.format(**self.splash_context)))
+
+    @cached_property
+    def colored(self):
+        return term.colored(enabled=not self.no_color)
+
+
+class MultiTool(TermLogger):
+
+    reserved_options = [
+        ('--nosplash', 'nosplash'),
+        ('--quiet', 'quiet'),
+        ('-q', 'quiet'),
+        ('--verbose', 'verbose'),
+        ('--no-color', 'no_color'),
+    ]
+
+    def __init__(self, env=None, cmd=None,
+                 fh=None, stdout=None, stderr=None, **kwargs):
+        """fh is an old alias to stdout."""
+        self.env = env
+        self.cmd = cmd
+        self.setup_terminal(stdout or fh, stderr, **kwargs)
+        self.fh = self.stdout
         self.prog_name = 'celery multi'
         self.commands = {
             'start': self.start,
@@ -186,263 +232,156 @@ class MultiTool(object):
             'help': self.help,
         }
 
-    def execute_from_commandline(self, argv, cmd='celery worker'):
-        argv = list(argv)   # don't modify callers argv.
-
+    def execute_from_commandline(self, argv, cmd=None):
         # Reserve the --nosplash|--quiet|-q/--verbose options.
-        if '--nosplash' in argv:
-            self.nosplash = argv.pop(argv.index('--nosplash'))
-        if '--quiet' in argv:
-            self.quiet = argv.pop(argv.index('--quiet'))
-        if '-q' in argv:
-            self.quiet = argv.pop(argv.index('-q'))
-        if '--verbose' in argv:
-            self.verbose = argv.pop(argv.index('--verbose'))
-        if '--no-color' in argv:
-            self.no_color = argv.pop(argv.index('--no-color'))
-
+        argv = self._handle_reserved_options(argv)
+        self.cmd = cmd if cmd is not None else self.cmd
         self.prog_name = os.path.basename(argv.pop(0))
-        if not argv or argv[0][0] == '-':
+
+        if not self.validate_arguments(argv):
             return self.error()
 
+        return self.call_command(argv[0], argv[1:])
+
+    def validate_arguments(self, argv):
+        return argv and argv[0][0] != '-'
+
+    def call_command(self, command, argv):
         try:
-            self.commands[argv[0]](argv[1:], cmd)
+            return self.commands[command](*argv) or EX_OK
         except KeyError:
-            self.error('Invalid command: {0}'.format(argv[0]))
+            return self.error('Invalid command: {0}'.format(command))
 
-        return self.retcode
+    def _handle_reserved_options(self, argv):
+        argv = list(argv)  # don't modify callers argv.
+        for arg, attr in self.reserved_options:
+            if arg in argv:
+                setattr(self, attr, bool(argv.pop(argv.index(arg))))
+        return argv
 
-    def say(self, m, newline=True, file=None):
-        print(m, file=file or self.stdout, end='\n' if newline else '')
+    @splash
+    def start(self, *argv):
+        self.note('> Starting nodes...')
+        return int(any(self.Cluster(argv).start()))
 
-    def carp(self, m, newline=True, file=None):
-        return self.say(m, newline, file or self.stderr)
+    @splash
+    def stop(self, *argv, **kwargs):
+        return self.Cluster(argv).stop(**kwargs)
 
-    def names(self, argv, cmd):
-        p = NamespacedOptionParser(argv)
-        self.say('\n'.join(
-            n.name for n in multi_args(p, cmd)),
-        )
+    @splash
+    def stopwait(self, *argv, **kwargs):
+        return self.Cluster(argv).stopwait(**kwargs)
+    stop_verify = stopwait  # compat
 
-    def get(self, argv, cmd):
-        wanted = argv[0]
-        p = NamespacedOptionParser(argv[1:])
-        for node in multi_args(p, cmd):
-            if node.name == wanted:
-                self.say(' '.join(node.argv))
-                return
-
-    def show(self, argv, cmd):
-        p = NamespacedOptionParser(argv)
-        self.with_detacher_default_options(p)
-        self.say('\n'.join(
-            ' '.join([sys.executable] + n.argv) for n in multi_args(p, cmd)),
-        )
+    @splash
+    def restart(self, *argv, **kwargs):
+        return int(any(self.Cluster(argv).restart(**kwargs)))
 
-    def start(self, argv, cmd):
-        self.splash()
-        p = NamespacedOptionParser(argv)
-        self.with_detacher_default_options(p)
-        retcodes = []
-        self.note('> Starting nodes...')
-        for node in multi_args(p, cmd):
-            self.note('\t> {0}: '.format(node.name), newline=False)
-            retcode = self.waitexec(node.argv, path=p.options['--executable'])
-            self.note(retcode and self.FAILED or self.OK)
-            retcodes.append(retcode)
-        self.retcode = int(any(retcodes))
-
-    def with_detacher_default_options(self, p):
-        _setdefaultopt(p.options, ['--pidfile', '-p'], '%n.pid')
-        _setdefaultopt(p.options, ['--logfile', '-f'], '%n%I.log')
-        p.options.setdefault(
-            '--cmd',
-            '-m {0}'.format(celery_exe('worker', '--detach')),
-        )
-        _setdefaultopt(p.options, ['--executable'], sys.executable)
+    def names(self, *argv):
+        self.say('\n'.join(n.name for n in self.Cluster(argv)))
 
-    def signal_node(self, nodename, pid, sig):
-        try:
-            os.kill(pid, sig)
-        except OSError as exc:
-            if exc.errno != errno.ESRCH:
-                raise
-            self.note('Could not signal {0} ({1}): No such process'.format(
-                nodename, pid))
-            return False
-        return True
-
-    def node_alive(self, pid):
+    def get(self, wanted, *argv):
         try:
-            os.kill(pid, 0)
-        except OSError as exc:
-            if exc.errno == errno.ESRCH:
-                return False
-            raise
-        return True
-
-    def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None,
-                       callback=None):
-        if not nodes:
-            return
-        P = set(nodes)
-
-        def on_down(node):
-            P.discard(node)
-            if callback:
-                callback(*node)
+            node = self.Cluster(argv).find(wanted)
+        except KeyError:
+            return EX_FAILURE
+        else:
+            return self.ok(' '.join(node.argv))
 
-        self.note(self.colored.blue('> Stopping nodes...'))
-        for node in list(P):
-            if node in P:
-                nodename, _, pid = node
-                self.note('\t> {0}: {1} -> {2}'.format(
-                    nodename, SIGMAP[sig][3:], pid))
-                if not self.signal_node(nodename, pid, sig):
-                    on_down(node)
-
-        def note_waiting():
-            left = len(P)
-            if left:
-                pids = ', '.join(str(pid) for _, _, pid in P)
-                self.note(self.colored.blue(
-                    '> Waiting for {0} {1} -> {2}...'.format(
-                        left, pluralize(left, 'node'), pids)), newline=False)
-
-        if retry:
-            note_waiting()
-            its = 0
-            while P:
-                for node in P:
-                    its += 1
-                    self.note('.', newline=False)
-                    nodename, _, pid = node
-                    if not self.node_alive(pid):
-                        self.note('\n\t> {0}: {1}'.format(nodename, self.OK))
-                        on_down(node)
-                        note_waiting()
-                        break
-                if P and not its % len(P):
-                    sleep(float(retry))
-            self.note('')
-
-    def getpids(self, p, cmd, callback=None):
-        _setdefaultopt(p.options, ['--pidfile', '-p'], '%n.pid')
-
-        nodes = []
-        for node in multi_args(p, cmd):
-            try:
-                pidfile_template = _getopt(
-                    p.namespaces[node.namespace], ['--pidfile', '-p'],
-                )
-            except KeyError:
-                pidfile_template = _getopt(p.options, ['--pidfile', '-p'])
-            pid = None
-            pidfile = node.expander(pidfile_template)
-            try:
-                pid = Pidfile(pidfile).read_pid()
-            except ValueError:
-                pass
-            if pid:
-                nodes.append((node.name, tuple(node.argv), pid))
-            else:
-                self.note('> {0.name}: {1}'.format(node, self.DOWN))
-                if callback:
-                    callback(node.name, node.argv, pid)
-
-        return nodes
-
-    def kill(self, argv, cmd):
-        self.splash()
-        p = NamespacedOptionParser(argv)
-        for nodename, _, pid in self.getpids(p, cmd):
-            self.note('Killing node {0} ({1})'.format(nodename, pid))
-            self.signal_node(nodename, pid, signal.SIGKILL)
+    def show(self, *argv):
+        return self.ok('\n'.join(
+            ' '.join(node.argv_with_executable)
+            for node in self.Cluster(argv)
+        ))
 
-    def stop(self, argv, cmd, retry=None, callback=None):
-        self.splash()
-        p = NamespacedOptionParser(argv)
-        return self._stop_nodes(p, cmd, retry=retry, callback=callback)
+    @splash
+    def kill(self, *argv):
+        return self.Cluster(argv).kill()
 
-    def _stop_nodes(self, p, cmd, retry=None, callback=None):
-        restargs = p.args[len(p.values):]
-        self.shutdown_nodes(self.getpids(p, cmd, callback=callback),
-                            sig=findsig(restargs),
-                            retry=retry,
-                            callback=callback)
+    def expand(self, template, *argv):
+        return self.ok('\n'.join(
+            node.expander(template)
+            for node in self.Cluster(argv)
+        ))
 
-    def restart(self, argv, cmd):
-        self.splash()
-        p = NamespacedOptionParser(argv)
-        self.with_detacher_default_options(p)
-        retvals = []
+    def help(self, *argv):
+        self.say(__doc__)
+
+    def Cluster(self, argv, cmd=None):
+        return Cluster(
+            argv, cmd if cmd is not None else self.cmd,
+            env=self.env,
+            on_stopping_preamble=self.on_stopping_preamble,
+            on_send_signal=self.on_send_signal,
+            on_still_waiting_for=self.on_still_waiting_for,
+            on_still_waiting_progress=self.on_still_waiting_progress,
+            on_still_waiting_end=self.on_still_waiting_end,
+            on_node_start=self.on_node_start,
+            on_node_restart=self.on_node_restart,
+            on_node_shutdown_ok=self.on_node_shutdown_ok,
+            on_node_status=self.on_node_status,
+            on_node_signal_dead=self.on_node_signal_dead,
+            on_node_signal=self.on_node_signal,
+            on_node_down=self.on_node_down,
+            on_child_spawn=self.on_child_spawn,
+            on_child_signalled=self.on_child_signalled,
+            on_child_failure=self.on_child_failure,
+        )
+
+    def on_stopping_preamble(self, nodes):
+        self.note(self.colored.blue('> Stopping nodes...'))
+
+    def on_send_signal(self, node, sig):
+        self.note('\t> {0.name}: {1} -> {0.pid}'.format(node, sig))
 
-        def on_node_shutdown(nodename, argv, pid):
+    def on_still_waiting_for(self, nodes):
+        num_left = len(nodes)
+        if num_left:
             self.note(self.colored.blue(
-                '> Restarting node {0}: '.format(nodename)), newline=False)
-            retval = self.waitexec(argv, path=p.options['--executable'])
-            self.note(retval and self.FAILED or self.OK)
-            retvals.append(retval)
+                '> Waiting for {0} {1} -> {2}...'.format(
+                    num_left, pluralize(num_left, 'node'),
+                    ', '.join(str(node.pid) for node in nodes)),
+            ), newline=False)
 
-        self._stop_nodes(p, cmd, retry=2, callback=on_node_shutdown)
-        self.retval = int(any(retvals))
+    def on_still_waiting_progress(self, nodes):
+        self.note('.', newline=False)
 
-    def stopwait(self, argv, cmd):
-        self.splash()
-        p = NamespacedOptionParser(argv)
-        self.with_detacher_default_options(p)
-        return self._stop_nodes(p, cmd, retry=2)
-    stop_verify = stopwait  # compat
+    def on_still_waiting_end(self):
+        self.note('')
 
-    def expand(self, argv, cmd=None):
-        template = argv[0]
-        p = NamespacedOptionParser(argv[1:])
-        for node in multi_args(p, cmd):
-            self.say(node.expander(template))
+    def on_node_signal_dead(self, node):
+        self.note(
+            'Could not signal {0.name} ({0.pid}): No such process'.format(
+                node))
 
-    def help(self, argv, cmd=None):
-        self.say(__doc__)
+    def on_node_start(self, node):
+        self.note('\t> {0.name}: '.format(node), newline=False)
 
-    def usage(self):
-        self.splash()
-        self.say(USAGE.format(prog_name=self.prog_name))
+    def on_node_restart(self, node):
+        self.note(self.colored.blue(
+            '> Restarting node {0.name}: '.format(node)), newline=False)
 
-    def splash(self):
-        if not self.nosplash:
-            c = self.colored
-            self.note(c.cyan('celery multi v{0}'.format(VERSION_BANNER)))
-
-    def waitexec(self, argv, path=sys.executable):
-        args = ' '.join([path] + list(argv))
-        argstr = shlex.split(from_utf8(args), posix=not IS_WINDOWS)
-        pipe = Popen(argstr, env=self.env)
-        self.info('  {0}'.format(' '.join(argstr)))
-        retcode = pipe.wait()
-        if retcode < 0:
-            self.note('* Child was terminated by signal {0}'.format(-retcode))
-            return -retcode
-        elif retcode > 0:
-            self.note('* Child terminated with errorcode {0}'.format(retcode))
-        return retcode
+    def on_node_down(self, node):
+        self.note('> {0.name}: {1.DOWN}'.format(node, self))
 
-    def error(self, msg=None):
-        if msg:
-            self.carp(msg)
-        self.usage()
-        self.retcode = 1
-        return 1
+    def on_node_shutdown_ok(self, node):
+        self.note('\n\t> {0.name}: {1.OK}'.format(node, self))
 
-    def info(self, msg, newline=True):
-        if self.verbose:
-            self.note(msg, newline=newline)
+    def on_node_status(self, node, retval):
+        self.note(retval and self.FAILED or self.OK)
 
-    def note(self, msg, newline=True):
-        if not self.quiet:
-            self.say(str(msg), newline=newline)
+    def on_node_signal(self, node, sig):
+        self.note('Sending {sig} to node {0.name} ({0.pid})'.format(
+            node, sig=sig))
 
-    @cached_property
-    def colored(self):
-        return term.colored(enabled=not self.no_color)
+    def on_child_spawn(self, node, argstr, env):
+        self.info('  {0}'.format(argstr))
+
+    def on_child_signalled(self, node, signum):
+        self.note('* Child was terminated by signal {0}'.format(signum))
+
+    def on_child_failure(self, node, retcode):
+        self.note('* Child terminated with exit code {0}'.format(retcode))
 
     @cached_property
     def OK(self):
@@ -456,205 +395,5 @@ class MultiTool(object):
     def DOWN(self):
         return str(self.colored.magenta('DOWN'))
 
-
-def _args_for_node(p, name, prefix, suffix, cmd, append, options):
-    name, nodename, expand = _get_nodename(
-        name, prefix, suffix, options)
-
-    if nodename in p.namespaces:
-        ns = nodename
-    else:
-        ns = name
-
-    argv = ([expand(cmd)] +
-            [format_opt(opt, expand(value))
-                for opt, value in items(p.optmerge(ns, options))] +
-            [p.passthrough])
-    if append:
-        argv.append(expand(append))
-    return multi_args_t(nodename, argv, expand, name)
-
-
-def multi_args(p, cmd='celery worker', append='', prefix='', suffix=''):
-    names = p.values
-    options = dict(p.options)
-    ranges = len(names) == 1
-    if ranges:
-        try:
-            names, prefix = _get_ranges(names)
-        except ValueError:
-            pass
-    cmd = options.pop('--cmd', cmd)
-    append = options.pop('--append', append)
-    hostname = options.pop('--hostname',
-                           options.pop('-n', gethostname()))
-    prefix = options.pop('--prefix', prefix) or ''
-    suffix = options.pop('--suffix', suffix) or hostname
-    suffix = '' if suffix in ('""', "''") else suffix
-
-    _update_ns_opts(p, names)
-    _update_ns_ranges(p, ranges)
-    return (_args_for_node(p, name, prefix, suffix, cmd, append, options)
-            for name in names)
-
-
-def _get_ranges(names):
-    noderange = int(names[0])
-    names = [str(n) for n in range(1, noderange + 1)]
-    prefix = 'celery'
-    return names, prefix
-
-
-def _update_ns_opts(p, names):
-    # Numbers in args always refers to the index in the list of names.
-    # (e.g. `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
-    for ns_name, ns_opts in list(items(p.namespaces)):
-        if ns_name.isdigit():
-            ns_index = int(ns_name) - 1
-            if ns_index < 0:
-                raise KeyError('Indexes start at 1 got: %r' % (ns_name,))
-            try:
-                p.namespaces[names[ns_index]].update(ns_opts)
-            except IndexError:
-                raise KeyError('No node at index %r' % (ns_name,))
-
-
-def _update_ns_ranges(p, ranges):
-    for ns_name, ns_opts in list(items(p.namespaces)):
-        if ',' in ns_name or (ranges and '-' in ns_name):
-            for subns in parse_ns_range(ns_name, ranges):
-                p.namespaces[subns].update(ns_opts)
-            p.namespaces.pop(ns_name)
-
-
-def _get_nodename(name, prefix, suffix, options):
-        hostname = suffix
-        if '@' in name:
-            nodename = options['-n'] = host_format(name)
-            shortname, hostname = nodesplit(nodename)
-            name = shortname
-        else:
-            shortname = '%s%s' % (prefix, name)
-            nodename = options['-n'] = host_format(
-                '{0}@{1}'.format(shortname, hostname),
-            )
-        expand = partial(
-            node_format, nodename=nodename, N=shortname, d=hostname,
-            h=nodename, i='%i', I='%I',
-        )
-        return name, nodename, expand
-
-
-class NamespacedOptionParser(object):
-
-    def __init__(self, args):
-        self.args = args
-        self.options = OrderedDict()
-        self.values = []
-        self.passthrough = ''
-        self.namespaces = defaultdict(lambda: OrderedDict())
-
-        self.parse()
-
-    def parse(self):
-        rargs = list(self.args)
-        pos = 0
-        while pos < len(rargs):
-            arg = rargs[pos]
-            if arg == '--':
-                self.passthrough = ' '.join(rargs[pos:])
-                break
-            elif arg[0] == '-':
-                if arg[1] == '-':
-                    self.process_long_opt(arg[2:])
-                else:
-                    value = None
-                    if len(rargs) > pos + 1 and rargs[pos + 1][0] != '-':
-                        value = rargs[pos + 1]
-                        pos += 1
-                    self.process_short_opt(arg[1:], value)
-            else:
-                self.values.append(arg)
-            pos += 1
-
-    def process_long_opt(self, arg, value=None):
-        if '=' in arg:
-            arg, value = arg.split('=', 1)
-        self.add_option(arg, value, short=False)
-
-    def process_short_opt(self, arg, value=None):
-        self.add_option(arg, value, short=True)
-
-    def optmerge(self, ns, defaults=None):
-        if defaults is None:
-            defaults = self.options
-        return OrderedDict(defaults, **self.namespaces[ns])
-
-    def add_option(self, name, value, short=False, ns=None):
-        prefix = short and '-' or '--'
-        dest = self.options
-        if ':' in name:
-            name, ns = name.split(':')
-            dest = self.namespaces[ns]
-        dest[prefix + name] = value
-
-
-def quote(v):
-    return "\\'".join("'" + p + "'" for p in v.split("'"))
-
-
-def format_opt(opt, value):
-    if not value:
-        return opt
-    if opt.startswith('--'):
-        return '{0}={1}'.format(opt, value)
-    return '{0} {1}'.format(opt, value)
-
-
-def parse_ns_range(ns, ranges=False):
-    ret = []
-    for space in ',' in ns and ns.split(',') or [ns]:
-        if ranges and '-' in space:
-            start, stop = space.split('-')
-            ret.extend(
-                str(n) for n in range(int(start), int(stop) + 1)
-            )
-        else:
-            ret.append(space)
-    return ret
-
-
-def findsig(args, default=signal.SIGTERM):
-    for arg in reversed(args):
-        if len(arg) == 2 and arg[0] == '-':
-            try:
-                return int(arg[1])
-            except ValueError:
-                pass
-        if arg[0] == '-':
-            maybe_sig = 'SIG' + arg[1:]
-            if maybe_sig in SIGNAMES:
-                return getattr(signal, maybe_sig)
-    return default
-
-
-def _getopt(d, alt):
-    for opt in alt:
-        try:
-            return d[opt]
-        except KeyError:
-            pass
-    raise KeyError(alt[0])
-
-
-def _setdefaultopt(d, alt, value):
-    for opt in alt[1:]:
-        try:
-            return d[opt]
-        except KeyError:
-            pass
-    return d.setdefault(alt[0], value)
-
-
 if __name__ == '__main__':              # pragma: no cover
     main()

+ 16 - 7
celery/platforms.py

@@ -37,13 +37,12 @@ mputil = try_import('multiprocessing.util')
 
 __all__ = [
     'EX_OK', 'EX_FAILURE', 'EX_UNAVAILABLE', 'EX_USAGE', 'SYSTEM',
-    'IS_macOS', 'IS_WINDOWS', 'pyimplementation', 'LockFailed',
-    'get_fdmax', 'Pidfile', 'create_pidlock',
-    'close_open_fds', 'DaemonContext', 'detached', 'parse_uid',
-    'parse_gid', 'setgroups', 'initgroups', 'setgid', 'setuid',
-    'maybe_drop_privileges', 'signals', 'set_process_title',
-    'set_mp_process_title', 'get_errno_name', 'ignore_errno',
-    'fd_by_path', 'isatty',
+    'IS_macOS', 'IS_WINDOWS', 'SIGMAP', 'pyimplementation', 'LockFailed',
+    'get_fdmax', 'Pidfile', 'create_pidlock', 'close_open_fds',
+    'DaemonContext', 'detached', 'parse_uid', 'parse_gid', 'setgroups',
+    'initgroups', 'setgid', 'setuid', 'maybe_drop_privileges', 'signals',
+    'signal_name', 'set_process_title', 'set_mp_process_title',
+    'get_errno_name', 'ignore_errno', 'fd_by_path', 'isatty',
 ]
 
 # exitcodes
@@ -88,6 +87,12 @@ Please specify a different user using the -u option.
 User information: uid={uid} euid={euid} gid={gid} egid={egid}
 """
 
+SIGNAMES = {
+    sig for sig in dir(_signal)
+    if sig.startswith('SIG') and '_' not in sig
+}
+SIGMAP = {getattr(_signal, name): name for name in SIGNAMES}
+
 
 def isatty(fh):
     try:
@@ -668,6 +673,10 @@ reset_signal = signals.reset                  # compat
 ignore_signal = signals.ignore                # compat
 
 
+def signal_name(signum):
+    return SIGMAP[signum][3:]
+
+
 def strargv(argv):
     arg_start = 2 if 'manage' in argv[0] else 1
     if len(argv) > arg_start:

+ 0 - 0
celery/tests/apps/__init__.py


+ 407 - 0
celery/tests/apps/test_multi.py

@@ -0,0 +1,407 @@
+from __future__ import absolute_import, unicode_literals
+
+import errno
+import signal
+
+from celery.apps.multi import (
+    Cluster, MultiParser, NamespacedOptionParser, Node,
+)
+
+from celery.tests.case import AppCase, Mock, call, patch
+
+
+class test_functions(AppCase):
+
+    def test_findsig(self):
+        m = Cluster([])
+        self.assertEqual(m._find_sig_argument(['a', 'b', 'c', '-1']), 1)
+        self.assertEqual(m._find_sig_argument(['--foo=1', '-9']), 9)
+        self.assertEqual(m._find_sig_argument(['-INT']), signal.SIGINT)
+        self.assertEqual(m._find_sig_argument([]), signal.SIGTERM)
+        self.assertEqual(m._find_sig_argument(['-s']), signal.SIGTERM)
+        self.assertEqual(m._find_sig_argument(['-log']), signal.SIGTERM)
+
+    def test_parse_ns_range(self):
+        m = MultiParser()
+        self.assertEqual(m._parse_ns_range('1-3', True), ['1', '2', '3'])
+        self.assertEqual(m._parse_ns_range('1-3', False), ['1-3'])
+        self.assertEqual(m._parse_ns_range(
+            '1-3,10,11,20', True),
+            ['1', '2', '3', '10', '11', '20'],
+        )
+
+    def test_format_opt(self):
+        m = MultiParser()
+        self.assertEqual(m.format_opt('--foo', None), '--foo')
+        self.assertEqual(m.format_opt('-c', 1), '-c 1')
+        self.assertEqual(m.format_opt('--log', 'foo'), '--log=foo')
+
+
+class test_NamespacedOptionParser(AppCase):
+
+    def test_parse(self):
+        x = NamespacedOptionParser(['-c:1,3', '4'])
+        self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
+        x = NamespacedOptionParser(['-c:jerry,elaine', '5',
+                                    '--loglevel:kramer=DEBUG',
+                                    '--flag',
+                                    '--logfile=foo', '-Q', 'bar', 'a', 'b',
+                                    '--', '.disable_rate_limits=1'])
+        self.assertEqual(x.options, {'--logfile': 'foo',
+                                     '-Q': 'bar',
+                                     '--flag': None})
+        self.assertEqual(x.values, ['a', 'b'])
+        self.assertEqual(x.namespaces.get('jerry,elaine'), {'-c': '5'})
+        self.assertEqual(x.namespaces.get('kramer'), {'--loglevel': 'DEBUG'})
+        self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
+
+
+def multi_args(p, *args, **kwargs):
+    return MultiParser(*args, **kwargs).parse(p)
+
+
+class test_multi_args(AppCase):
+
+    @patch('celery.apps.multi.gethostname')
+    def test_parse(self, gethostname):
+        gethostname.return_value = 'example.com'
+        p = NamespacedOptionParser([
+            '-c:jerry,elaine', '5',
+            '--loglevel:kramer=DEBUG',
+            '--flag',
+            '--logfile=foo', '-Q', 'bar', 'jerry',
+            'elaine', 'kramer',
+            '--', '.disable_rate_limits=1',
+        ])
+        it = multi_args(p, cmd='COMMAND', append='*AP*',
+                        prefix='*P*', suffix='*S*')
+        nodes = list(it)
+
+        def assert_line_in(name, args):
+            self.assertIn(name, {n.name for n in nodes})
+            argv = None
+            for node in nodes:
+                if node.name == name:
+                    argv = node.argv
+            self.assertTrue(argv)
+            for arg in args:
+                self.assertIn(arg, argv)
+
+        assert_line_in(
+            '*P*jerry@*S*',
+            ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
+             '-c 5', '--flag', '--logfile=foo',
+             '-- .disable_rate_limits=1', '*AP*'],
+        )
+        assert_line_in(
+            '*P*elaine@*S*',
+            ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
+             '-c 5', '--flag', '--logfile=foo',
+             '-- .disable_rate_limits=1', '*AP*'],
+        )
+        assert_line_in(
+            '*P*kramer@*S*',
+            ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
+             '-Q bar', '--flag', '--logfile=foo',
+             '-- .disable_rate_limits=1', '*AP*'],
+        )
+        expand = nodes[0].expander
+        self.assertEqual(expand('%h'), '*P*jerry@*S*')
+        self.assertEqual(expand('%n'), '*P*jerry')
+        nodes2 = list(multi_args(p, cmd='COMMAND', append='',
+                      prefix='*P*', suffix='*S*'))
+        self.assertEqual(nodes2[0].argv[-1], '-- .disable_rate_limits=1')
+
+        p2 = NamespacedOptionParser(['10', '-c:1', '5'])
+        nodes3 = list(multi_args(p2, cmd='COMMAND'))
+        self.assertEqual(len(nodes3), 10)
+        self.assertEqual(nodes3[0].name, 'celery1@example.com')
+        self.assertTupleEqual(
+            nodes3[0].argv,
+            ('COMMAND', '-n celery1@example.com', '-c 5', ''),
+        )
+        for i, worker in enumerate(nodes3[1:]):
+            self.assertEqual(worker.name, 'celery%s@example.com' % (i + 2))
+            self.assertTupleEqual(
+                worker.argv,
+                ('COMMAND', '-n celery%s@example.com' % (i + 2), ''),
+            )
+
+        nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
+        self.assertEqual(len(nodes4), 10)
+        self.assertEqual(nodes4[0].name, 'celery1@')
+        self.assertTupleEqual(
+            nodes4[0].argv,
+            ('COMMAND', '-n celery1@', '-c 5', ''),
+        )
+
+        p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
+        nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
+        self.assertEqual(nodes5[0].name, 'foo@')
+        self.assertTupleEqual(
+            nodes5[0].argv,
+            ('COMMAND', '-n foo@', '-c 5', ''),
+        )
+
+        p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
+        nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
+        self.assertEqual(nodes6[0].name, 'foo@')
+        self.assertTupleEqual(
+            nodes6[0].argv,
+            ('COMMAND', '-n foo@', '-Q test', ''),
+        )
+
+        p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
+        nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
+        self.assertEqual(nodes7[0].name, 'foo@bar')
+        self.assertTupleEqual(
+            nodes7[0].argv,
+            ('COMMAND', '-n foo@bar', '-Q test', ''),
+        )
+
+        p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
+        with self.assertRaises(KeyError):
+            list(multi_args(p6))
+
+    def test_optmerge(self):
+        p = NamespacedOptionParser(['foo', 'test'])
+        p.options = {'x': 'y'}
+        r = p.optmerge('foo')
+        self.assertEqual(r['x'], 'y')
+
+
+class test_Node(AppCase):
+
+    def setup(self):
+        self.p = Mock(name='p')
+        self.p.options = {
+            '--executable': 'python',
+            '--logfile': 'foo.log',
+        }
+        self.p.namespaces = {}
+        self.expander = Mock(name='expander')
+        self.node = Node(
+            'foo@bar.com', ['-A', 'proj'], self.expander, 'foo', self.p,
+        )
+        self.node.pid = 303
+
+    @patch('os.kill')
+    def test_send(self, kill):
+        self.assertTrue(self.node.send(9))
+        kill.assert_called_with(self.node.pid, 9)
+
+    @patch('os.kill')
+    def test_send__ESRCH(self, kill):
+        kill.side_effect = OSError()
+        kill.side_effect.errno = errno.ESRCH
+        self.assertFalse(self.node.send(9))
+        kill.assert_called_with(self.node.pid, 9)
+
+    @patch('os.kill')
+    def test_send__error(self, kill):
+        kill.side_effect = OSError()
+        kill.side_effect.errno = errno.ENOENT
+        with self.assertRaises(OSError):
+            self.node.send(9)
+        kill.assert_called_with(self.node.pid, 9)
+
+    def test_alive(self):
+        self.node.send = Mock(name='send')
+        self.assertIs(self.node.alive(), self.node.send.return_value)
+        self.node.send.assert_called_with(0)
+
+    def test_start(self):
+        self.node._waitexec = Mock(name='_waitexec')
+        self.node.start(env={'foo': 'bar'}, kw=2)
+        self.node._waitexec.assert_called_with(
+            self.node.argv, path=self.node.executable,
+            env={'foo': 'bar'}, kw=2,
+        )
+
+    @patch('celery.apps.multi.Popen')
+    def test_waitexec(self, Popen, argv=['A', 'B']):
+        on_spawn = Mock(name='on_spawn')
+        on_signalled = Mock(name='on_signalled')
+        on_failure = Mock(name='on_failure')
+        env = Mock(name='env')
+        self.node.handle_process_exit = Mock(name='handle_process_exit')
+
+        self.node._waitexec(
+            argv,
+            path='python',
+            env=env,
+            on_spawn=on_spawn,
+            on_signalled=on_signalled,
+            on_failure=on_failure,
+        )
+
+        Popen.assert_called_with(
+            self.node.prepare_argv(argv, 'python'), env=env)
+        self.node.handle_process_exit.assert_called_with(
+            Popen().wait(),
+            on_signalled=on_signalled,
+            on_failure=on_failure,
+        )
+
+    def test_handle_process_exit(self):
+        self.assertEqual(
+            self.node.handle_process_exit(0),
+            0,
+        )
+
+    def test_handle_process_exit__failure(self):
+        on_failure = Mock(name='on_failure')
+        self.assertEqual(
+            self.node.handle_process_exit(9, on_failure=on_failure),
+            9,
+        )
+        on_failure.assert_called_with(self.node, 9)
+
+    def test_handle_process_exit__signalled(self):
+        on_signalled = Mock(name='on_signalled')
+        self.assertEqual(
+            self.node.handle_process_exit(-9, on_signalled=on_signalled),
+            9,
+        )
+        on_signalled.assert_called_with(self.node, 9)
+
+    def test_logfile(self):
+        self.assertEqual(self.node.logfile, self.expander.return_value)
+        self.expander.assert_called_with('foo.log')
+
+
+class test_Cluster(AppCase):
+
+    def setup(self):
+        self.Popen = self.patch('celery.apps.multi.Popen')
+        self.kill = self.patch('os.kill')
+        self.gethostname = self.patch('celery.apps.multi.gethostname')
+        self.gethostname.return_value = 'example.com'
+        self.Pidfile = self.patch('celery.apps.multi.Pidfile')
+        self.cluster = Cluster(
+            ['foo', 'bar', 'baz'],
+            on_stopping_preamble=Mock(name='on_stopping_preamble'),
+            on_send_signal=Mock(name='on_send_signal'),
+            on_still_waiting_for=Mock(name='on_still_waiting_for'),
+            on_still_waiting_progress=Mock(name='on_still_waiting_progress'),
+            on_still_waiting_end=Mock(name='on_still_waiting_end'),
+            on_node_start=Mock(name='on_node_start'),
+            on_node_restart=Mock(name='on_node_restart'),
+            on_node_shutdown_ok=Mock(name='on_node_shutdown_ok'),
+            on_node_status=Mock(name='on_node_status'),
+            on_node_signal=Mock(name='on_node_signal'),
+            on_node_signal_dead=Mock(name='on_node_signal_dead'),
+            on_node_down=Mock(name='on_node_down'),
+            on_child_spawn=Mock(name='on_child_spawn'),
+            on_child_signalled=Mock(name='on_child_signalled'),
+            on_child_failure=Mock(name='on_child_failure'),
+        )
+
+    def test_len(self):
+        self.assertEqual(len(self.cluster), 3)
+
+    def test_getitem(self):
+        self.assertEqual(self.cluster[0].name, 'foo@example.com')
+
+    def test_start(self):
+        self.cluster.start_node = Mock(name='start_node')
+        self.cluster.start()
+        self.cluster.start_node.assert_has_calls(
+            call(node) for node in self.cluster
+        )
+
+    def test_start_node(self):
+        self.cluster._start_node = Mock(name='_start_node')
+        node = self.cluster[0]
+        self.assertIs(
+            self.cluster.start_node(node),
+            self.cluster._start_node.return_value,
+        )
+        self.cluster.on_node_start.assert_called_with(node)
+        self.cluster._start_node.assert_called_with(node)
+        self.cluster.on_node_status.assert_called_with(
+            node, self.cluster._start_node(),
+        )
+
+    def test__start_node(self):
+        node = self.cluster[0]
+        node.start = Mock(name='node.start')
+        self.assertIs(
+            self.cluster._start_node(node),
+            node.start.return_value,
+        )
+        node.start.assert_called_with(
+            self.cluster.env,
+            on_spawn=self.cluster.on_child_spawn,
+            on_signalled=self.cluster.on_child_signalled,
+            on_failure=self.cluster.on_child_failure,
+        )
+
+    def test_send_all(self):
+        nodes = [Mock(name='n1'), Mock(name='n2')]
+        self.cluster.getpids = Mock(name='getpids')
+        self.cluster.getpids.return_value = nodes
+        self.cluster.send_all(15)
+        self.cluster.on_node_signal.assert_has_calls(
+            call(node, 'TERM') for node in nodes
+        )
+        for node in nodes:
+            node.send.assert_called_with(15, self.cluster.on_node_signal_dead)
+
+    def test_kill(self):
+        self.cluster.send_all = Mock(name='.send_all')
+        self.cluster.kill()
+        self.cluster.send_all.assert_called_with(signal.SIGKILL)
+
+    def test_getpids(self):
+        self.gethostname.return_value = 'e.com'
+        self.prepare_pidfile_for_getpids(self.Pidfile)
+        callback = Mock()
+
+        p = Cluster(['foo', 'bar', 'baz'])
+        nodes = p.getpids(on_down=callback)
+        node_0, node_1 = nodes
+        self.assertEqual(node_0.name, 'foo@e.com')
+        self.assertEqual(
+            sorted(node_0.argv),
+            sorted([
+                '',
+                '--executable={0}'.format(node_0.executable),
+                '--logfile=foo%I.log',
+                '--pidfile=foo.pid',
+                '-m celery worker --detach',
+                '-n foo@e.com',
+            ]),
+        )
+        self.assertEqual(node_0.pid, 10)
+
+        self.assertEqual(node_1.name, 'bar@e.com')
+        self.assertEqual(
+            sorted(node_1.argv),
+            sorted([
+                '',
+                '--executable={0}'.format(node_1.executable),
+                '--logfile=bar%I.log',
+                '--pidfile=bar.pid',
+                '-m celery worker --detach',
+                '-n bar@e.com',
+            ]),
+        )
+        self.assertEqual(node_1.pid, 11)
+
+        # without callback, should work
+        nodes = p.getpids('celery worker')
+
+    def prepare_pidfile_for_getpids(self, Pidfile):
+        class pids(object):
+
+            def __init__(self, path):
+                self.path = path
+
+            def read_pid(self):
+                try:
+                    return {'foo.pid': 10,
+                            'bar.pid': 11}[self.path]
+                except KeyError:
+                    raise ValueError()
+        self.Pidfile.side_effect = pids

+ 235 - 351
celery/tests/bin/test_multi.py

@@ -1,172 +1,260 @@
 from __future__ import absolute_import, unicode_literals
 
-import errno
-import signal
 import sys
 
 from celery.bin.multi import (
     main,
     MultiTool,
-    findsig,
-    parse_ns_range,
-    format_opt,
-    quote,
-    NamespacedOptionParser,
-    multi_args,
     __doc__ as doc,
 )
 from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, patch, skip
+from celery.tests.case import AppCase, Mock, patch
 
 
-class test_functions(AppCase):
-
-    def test_findsig(self):
-        self.assertEqual(findsig(['a', 'b', 'c', '-1']), 1)
-        self.assertEqual(findsig(['--foo=1', '-9']), 9)
-        self.assertEqual(findsig(['-INT']), signal.SIGINT)
-        self.assertEqual(findsig([]), signal.SIGTERM)
-        self.assertEqual(findsig(['-s']), signal.SIGTERM)
-        self.assertEqual(findsig(['-log']), signal.SIGTERM)
+class test_MultiTool(AppCase):
 
-    def test_parse_ns_range(self):
-        self.assertEqual(parse_ns_range('1-3', True), ['1', '2', '3'])
-        self.assertEqual(parse_ns_range('1-3', False), ['1-3'])
-        self.assertEqual(parse_ns_range(
-            '1-3,10,11,20', True),
-            ['1', '2', '3', '10', '11', '20'],
-        )
+    def setup(self):
+        self.fh = WhateverIO()
+        self.env = {}
+        self.t = MultiTool(env=self.env, fh=self.fh)
+        self.t.Cluster = Mock(name='Cluster')
+        self.t.carp = Mock(name='.carp')
+        self.t.usage = Mock(name='.usage')
+        self.t.splash = Mock(name='.splash')
+        self.t.say = Mock(name='.say')
+        self.t.ok = Mock(name='.ok')
+        self.cluster = self.t.Cluster.return_value
 
-    def test_format_opt(self):
-        self.assertEqual(format_opt('--foo', None), '--foo')
-        self.assertEqual(format_opt('-c', 1), '-c 1')
-        self.assertEqual(format_opt('--log', 'foo'), '--log=foo')
-
-    def test_quote(self):
-        self.assertEqual(quote("the 'quick"), "'the '\\''quick'")
-
-
-class test_NamespacedOptionParser(AppCase):
-
-    def test_parse(self):
-        x = NamespacedOptionParser(['-c:1,3', '4'])
-        self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
-        x = NamespacedOptionParser(['-c:jerry,elaine', '5',
-                                    '--loglevel:kramer=DEBUG',
-                                    '--flag',
-                                    '--logfile=foo', '-Q', 'bar', 'a', 'b',
-                                    '--', '.disable_rate_limits=1'])
-        self.assertEqual(x.options, {'--logfile': 'foo',
-                                     '-Q': 'bar',
-                                     '--flag': None})
-        self.assertEqual(x.values, ['a', 'b'])
-        self.assertEqual(x.namespaces.get('jerry,elaine'), {'-c': '5'})
-        self.assertEqual(x.namespaces.get('kramer'), {'--loglevel': 'DEBUG'})
-        self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
-
-
-class test_multi_args(AppCase):
-
-    @patch('celery.bin.multi.gethostname')
-    def test_parse(self, gethostname):
-        gethostname.return_value = 'example.com'
-        p = NamespacedOptionParser([
-            '-c:jerry,elaine', '5',
-            '--loglevel:kramer=DEBUG',
-            '--flag',
-            '--logfile=foo', '-Q', 'bar', 'jerry',
-            'elaine', 'kramer',
-            '--', '.disable_rate_limits=1',
-        ])
-        it = multi_args(p, cmd='COMMAND', append='*AP*',
-                        prefix='*P*', suffix='*S*')
-        names = list(it)
-
-        def assert_line_in(name, args):
-            self.assertIn(name, {tup[0] for tup in names})
-            argv = None
-            for item in names:
-                if item[0] == name:
-                    argv = item[1]
-            self.assertTrue(argv)
-            for arg in args:
-                self.assertIn(arg, argv)
-
-        assert_line_in(
-            '*P*jerry@*S*',
-            ['COMMAND', '-n *P*jerry@*S*', '-Q bar',
-             '-c 5', '--flag', '--logfile=foo',
-             '-- .disable_rate_limits=1', '*AP*'],
-        )
-        assert_line_in(
-            '*P*elaine@*S*',
-            ['COMMAND', '-n *P*elaine@*S*', '-Q bar',
-             '-c 5', '--flag', '--logfile=foo',
-             '-- .disable_rate_limits=1', '*AP*'],
+    def test_execute_from_commandline(self):
+        self.t.call_command = Mock(name='call_command')
+        self.t.execute_from_commandline(
+            'multi start --verbose 10 --foo'.split(),
+            cmd='X',
         )
-        assert_line_in(
-            '*P*kramer@*S*',
-            ['COMMAND', '--loglevel=DEBUG', '-n *P*kramer@*S*',
-             '-Q bar', '--flag', '--logfile=foo',
-             '-- .disable_rate_limits=1', '*AP*'],
+        self.assertEqual(self.t.cmd, 'X')
+        self.assertEqual(self.t.prog_name, 'multi')
+        self.t.call_command.assert_called_with('start', ['10', '--foo'])
+
+    def test_execute_from_commandline__arguments(self):
+        self.assertTrue(self.t.execute_from_commandline('multi'.split()))
+        self.assertTrue(self.t.execute_from_commandline('multi -bar'.split()))
+
+    def test_call_command(self):
+        cmd = self.t.commands['foo'] = Mock(name='foo')
+        self.t.retcode = 303
+        self.assertIs(
+            self.t.call_command('foo', ['1', '2', '--foo=3']),
+            cmd.return_value,
         )
-        expand = names[0][2]
-        self.assertEqual(expand('%h'), '*P*jerry@*S*')
-        self.assertEqual(expand('%n'), '*P*jerry')
-        names2 = list(multi_args(p, cmd='COMMAND', append='',
-                      prefix='*P*', suffix='*S*'))
-        self.assertEqual(names2[0][1][-1], '-- .disable_rate_limits=1')
-
-        p2 = NamespacedOptionParser(['10', '-c:1', '5'])
-        names3 = list(multi_args(p2, cmd='COMMAND'))
-        self.assertEqual(len(names3), 10)
+        cmd.assert_called_with('1', '2', '--foo=3')
+
+    def test_call_command__error(self):
         self.assertEqual(
-            names3[0][0:2],
-            ('celery1@example.com',
-             ['COMMAND', '-n celery1@example.com', '-c 5', '']),
+            self.t.call_command('asdqwewqe', ['1', '2']),
+            1,
         )
-        for i, worker in enumerate(names3[1:]):
-            self.assertEqual(
-                worker[0:2],
-                ('celery%s@example.com' % (i + 2),
-                 ['COMMAND', '-n celery%s@example.com' % (i + 2), '']),
-            )
-
-        names4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
-        self.assertEqual(len(names4), 10)
-        self.assertEqual(
-            names4[0][0:2],
-            ('celery1@',
-             ['COMMAND', '-n celery1@', '-c 5', '']),
+        self.t.carp.assert_called()
+
+    def test_handle_reserved_options(self):
+        self.assertListEqual(
+            self.t._handle_reserved_options(
+                ['a', '-q', 'b', '--no-color', 'c']),
+            ['a', 'b', 'c'],
         )
 
-        p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
-        names5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
-        self.assertEqual(
-            names5[0][0:2],
-            ('foo@',
-             ['COMMAND', '-n foo@', '-c 5', '']),
+    def test_start(self):
+        self.cluster.start.return_value = [0, 0, 1, 0]
+        self.assertTrue(self.t.start('10', '-A', 'proj'))
+        self.t.splash.assert_called_with()
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.start.assert_called_with()
+
+    def test_start__exitcodes(self):
+        self.cluster.start.return_value = [0, 0, 0]
+        self.assertFalse(self.t.start('foo', 'bar', 'baz'))
+        self.cluster.start.assert_called_with()
+
+        self.cluster.start.return_value = [0, 1, 0]
+        self.assertTrue(self.t.start('foo', 'bar', 'baz'))
+
+    def test_stop(self):
+        self.t.stop('10', '-A', 'proj', retry=3)
+        self.t.splash.assert_called_with()
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.stop.assert_called_with(retry=3)
+
+    def test_stopwait(self):
+        self.t.stopwait('10', '-A', 'proj', retry=3)
+        self.t.splash.assert_called_with()
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.stopwait.assert_called_with(retry=3)
+
+    def test_restart(self):
+        self.cluster.restart.return_value = [0, 0, 1, 0]
+        self.t.restart('10', '-A', 'proj')
+        self.t.splash.assert_called_with()
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.restart.assert_called_with()
+
+    def test_names(self):
+        self.t.Cluster.return_value = [Mock(), Mock()]
+        self.t.Cluster.return_value[0].name = 'x'
+        self.t.Cluster.return_value[1].name = 'y'
+        self.t.names('10', '-A', 'proj')
+        self.t.say.assert_called()
+
+    def test_get(self):
+        node = self.cluster.find.return_value = Mock(name='node')
+        node.argv = ['A', 'B', 'C']
+        self.assertIs(
+            self.t.get('wanted', '10', '-A', 'proj'),
+            self.t.ok.return_value,
         )
+        self.cluster.find.assert_called_with('wanted')
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.t.ok.assert_called_with(' '.join(node.argv))
 
-        p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
-        names6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
-        self.assertEqual(
-            names6[0][0:2],
-            ('foo@',
-             ['COMMAND', '-n foo@', '-Q test', '']),
+    def test_get__KeyError(self):
+        self.cluster.find.side_effect = KeyError()
+        self.assertTrue(self.t.get('wanted', '10', '-A', 'proj'))
+
+    def test_show(self):
+        nodes = self.t.Cluster.return_value = [
+            Mock(name='n1'),
+            Mock(name='n2'),
+        ]
+        nodes[0].argv_with_executable = ['python', 'foo', 'bar']
+        nodes[1].argv_with_executable = ['python', 'xuzzy', 'baz']
+
+        self.assertIs(
+            self.t.show('10', '-A', 'proj'),
+            self.t.ok.return_value,
         )
+        self.t.ok.assert_called_with(
+            '\n'.join(' '.join(node.argv_with_executable) for node in nodes))
+
+    def test_kill(self):
+        self.t.kill('10', '-A', 'proj')
+        self.t.splash.assert_called_with()
+        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.kill.assert_called_with()
 
-        p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
-        names7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
+    def test_expand(self):
+        node1 = Mock(name='n1')
+        node2 = Mock(name='n2')
+        node1.expander.return_value = 'A'
+        node2.expander.return_value = 'B'
+        nodes = self.t.Cluster.return_value = [node1, node2]
+        self.assertIs(self.t.expand('%p', '10'), self.t.ok.return_value)
+        self.t.Cluster.assert_called_with(('10',))
+        for node in nodes:
+            node.expander.assert_called_with('%p')
+        self.t.ok.assert_called_with('A\nB')
+
+    def test_note(self):
+        self.t.quiet = True
+        self.t.note('foo')
+        self.t.say.assert_not_called()
+        self.t.quiet = False
+        self.t.note('foo')
+        self.t.say.assert_called_with('foo', newline=True)
+
+    def test_splash(self):
+        x = MultiTool()
+        x.note = Mock()
+        x.nosplash = True
+        x.splash()
+        x.note.assert_not_called()
+        x.nosplash = False
+        x.splash()
+        x.note.assert_called()
+
+    def test_Cluster(self):
+        m = MultiTool()
+        c = m.Cluster(['A', 'B', 'C'])
+        self.assertListEqual(c.argv, ['A', 'B', 'C'])
+        self.assertIs(c.env, m.env)
+        self.assertEqual(c.cmd, 'celery worker')
+        self.assertEqual(c.on_stopping_preamble, m.on_stopping_preamble)
+        self.assertEqual(c.on_send_signal, m.on_send_signal)
+        self.assertEqual(c.on_still_waiting_for, m.on_still_waiting_for)
         self.assertEqual(
-            names7[0][0:2],
-            ('foo@bar',
-             ['COMMAND', '-n foo@bar', '-Q test', '']),
+            c.on_still_waiting_progress,
+            m.on_still_waiting_progress,
         )
+        self.assertEqual(c.on_still_waiting_end, m.on_still_waiting_end)
+        self.assertEqual(c.on_node_start, m.on_node_start)
+        self.assertEqual(c.on_node_restart, m.on_node_restart)
+        self.assertEqual(c.on_node_shutdown_ok, m.on_node_shutdown_ok)
+        self.assertEqual(c.on_node_status, m.on_node_status)
+        self.assertEqual(c.on_node_signal_dead, m.on_node_signal_dead)
+        self.assertEqual(c.on_node_signal, m.on_node_signal)
+        self.assertEqual(c.on_node_down, m.on_node_down)
+        self.assertEqual(c.on_child_spawn, m.on_child_spawn)
+        self.assertEqual(c.on_child_signalled, m.on_child_signalled)
+        self.assertEqual(c.on_child_failure, m.on_child_failure)
 
+    def test_on_stopping_preamble(self):
+        self.t.on_stopping_preamble([])
 
-class test_MultiTool(AppCase):
+    def test_on_send_signal(self):
+        self.t.on_send_signal(Mock(), Mock())
+
+    def test_on_still_waiting_for(self):
+        self.t.on_still_waiting_for([Mock(), Mock()])
+
+    def test_on_still_waiting_for__empty(self):
+        self.t.on_still_waiting_for([])
+
+    def test_on_still_waiting_progress(self):
+        self.t.on_still_waiting_progress([])
+
+    def test_on_still_waiting_end(self):
+        self.t.on_still_waiting_end()
+
+    def test_on_node_signal_dead(self):
+        self.t.on_node_signal_dead(Mock())
+
+    def test_on_node_start(self):
+        self.t.on_node_start(Mock())
+
+    def test_on_node_restart(self):
+        self.t.on_node_restart(Mock())
+
+    def test_on_node_down(self):
+        self.t.on_node_down(Mock())
+
+    def test_on_node_shutdown_ok(self):
+        self.t.on_node_shutdown_ok(Mock())
+
+    def test_on_node_status__FAIL(self):
+        self.t.on_node_status(Mock(), 1)
+        self.t.say.assert_called_with(self.t.FAILED, newline=True)
+
+    def test_on_node_status__OK(self):
+        self.t.on_node_status(Mock(), 0)
+        self.t.say.assert_called_with(self.t.OK, newline=True)
+
+    def test_on_node_signal(self):
+        self.t.on_node_signal(Mock(), Mock())
+
+    def test_on_child_spawn(self):
+        self.t.on_child_spawn(Mock(), Mock(), Mock())
+
+    def test_on_child_signalled(self):
+        self.t.on_child_signalled(Mock(), Mock())
+
+    def test_on_child_failure(self):
+        self.t.on_child_failure(Mock(), Mock())
+
+    def test_constant_strings(self):
+        self.assertTrue(self.t.OK)
+        self.assertTrue(self.t.DOWN)
+        self.assertTrue(self.t.FAILED)
+
+
+class test_MultiTool_functional(AppCase):
 
     def setup(self):
         self.fh = WhateverIO()
@@ -208,26 +296,6 @@ class test_MultiTool(AppCase):
         self.assertEqual(self.t.error(), 1)
         self.t.carp.assert_not_called()
 
-        self.assertEqual(self.t.retcode, 1)
-
-    @patch('celery.bin.multi.Popen')
-    def test_waitexec(self, Popen):
-        self.t.note = Mock()
-        pipe = Popen.return_value = Mock()
-        pipe.wait.return_value = -10
-        self.assertEqual(self.t.waitexec(['-m', 'foo'], 'path'), 10)
-        Popen.assert_called_with(['path', '-m', 'foo'], env=self.t.env)
-        self.t.note.assert_called_with('* Child was terminated by signal 10')
-
-        pipe.wait.return_value = 2
-        self.assertEqual(self.t.waitexec(['-m', 'foo'], 'path'), 2)
-        self.t.note.assert_called_with(
-            '* Child terminated with errorcode 2',
-        )
-
-        pipe.wait.return_value = 0
-        self.assertFalse(self.t.waitexec(['-m', 'foo', 'path']))
-
     def test_nosplash(self):
         self.t.nosplash = True
         self.t.splash()
@@ -247,202 +315,23 @@ class test_MultiTool(AppCase):
         self.assertIn(doc, self.fh.getvalue())
 
     def test_expand(self):
-        self.t.expand(['foo%n', 'ask', 'klask', 'dask'])
+        self.t.expand('foo%n', 'ask', 'klask', 'dask')
         self.assertEqual(
             self.fh.getvalue(), 'fooask\nfooklask\nfoodask\n',
         )
 
-    def test_restart(self):
-        stop = self.t._stop_nodes = Mock()
-        self.t.restart(['jerry', 'george'], 'celery worker')
-        waitexec = self.t.waitexec = Mock()
-        stop.assert_called()
-        callback = stop.call_args[1]['callback']
-        self.assertTrue(callback)
-
-        waitexec.return_value = 0
-        callback('jerry', ['arg'], 13)
-        waitexec.assert_called_with(['arg'], path=sys.executable)
-        self.assertIn('OK', self.fh.getvalue())
-        self.fh.seek(0)
-        self.fh.truncate()
-
-        waitexec.return_value = 1
-        callback('jerry', ['arg'], 13)
-        self.assertIn('FAILED', self.fh.getvalue())
-
-    def test_stop(self):
-        self.t.getpids = Mock()
-        self.t.getpids.return_value = [2, 3, 4]
-        self.t.shutdown_nodes = Mock()
-        self.t.stop(['a', 'b', '-INT'], 'celery worker')
-        self.t.shutdown_nodes.assert_called_with(
-            [2, 3, 4], sig=signal.SIGINT, retry=None, callback=None,
-
-        )
-
-    @skip.unless_symbol('signal.SIGKILL')
-    def test_kill(self):
-        self.t.getpids = Mock()
-        self.t.getpids.return_value = [
-            ('a', None, 10),
-            ('b', None, 11),
-            ('c', None, 12)
-        ]
-        sig = self.t.signal_node = Mock()
-
-        self.t.kill(['a', 'b', 'c'], 'celery worker')
-
-        sigs = sig.call_args_list
-        self.assertEqual(len(sigs), 3)
-        self.assertEqual(sigs[0][0], ('a', 10, signal.SIGKILL))
-        self.assertEqual(sigs[1][0], ('b', 11, signal.SIGKILL))
-        self.assertEqual(sigs[2][0], ('c', 12, signal.SIGKILL))
-
-    def prepare_pidfile_for_getpids(self, Pidfile):
-        class pids(object):
-
-            def __init__(self, path):
-                self.path = path
-
-            def read_pid(self):
-                try:
-                    return {'foo.pid': 10,
-                            'bar.pid': 11}[self.path]
-                except KeyError:
-                    raise ValueError()
-        Pidfile.side_effect = pids
-
-    @patch('celery.bin.multi.Pidfile')
-    @patch('celery.bin.multi.gethostname')
-    def test_getpids(self, gethostname, Pidfile):
-        gethostname.return_value = 'e.com'
-        self.prepare_pidfile_for_getpids(Pidfile)
-        callback = Mock()
-
-        p = NamespacedOptionParser(['foo', 'bar', 'baz'])
-        nodes = self.t.getpids(p, 'celery worker', callback=callback)
-        node_0, node_1 = nodes
-        self.assertEqual(node_0[0], 'foo@e.com')
-        self.assertEqual(
-            sorted(node_0[1]),
-            sorted(('celery worker', '--pidfile=foo.pid',
-                    '-n foo@e.com', '')),
-        )
-        self.assertEqual(node_0[2], 10)
-
-        self.assertEqual(node_1[0], 'bar@e.com')
-        self.assertEqual(
-            sorted(node_1[1]),
-            sorted(('celery worker', '--pidfile=bar.pid',
-                    '-n bar@e.com', '')),
-        )
-        self.assertEqual(node_1[2], 11)
-        callback.assert_called()
-        cargs, _ = callback.call_args
-        self.assertEqual(cargs[0], 'baz@e.com')
-        self.assertItemsEqual(
-            cargs[1],
-            ['celery worker', '--pidfile=baz.pid', '-n baz@e.com', ''],
-        )
-        self.assertIsNone(cargs[2])
-        self.assertIn('DOWN', self.fh.getvalue())
-
-        # without callback, should work
-        nodes = self.t.getpids(p, 'celery worker', callback=None)
-
-    @patch('celery.bin.multi.Pidfile')
-    @patch('celery.bin.multi.gethostname')
-    @patch('celery.bin.multi.sleep')
-    def test_shutdown_nodes(self, slepp, gethostname, Pidfile):
-        gethostname.return_value = 'e.com'
-        self.prepare_pidfile_for_getpids(Pidfile)
-        self.assertIsNone(self.t.shutdown_nodes([]))
-        self.t.signal_node = Mock()
-        node_alive = self.t.node_alive = Mock()
-        self.t.node_alive.return_value = False
-
-        callback = Mock()
-        self.t.stop(['foo', 'bar', 'baz'], 'celery worker', callback=callback)
-        sigs = sorted(self.t.signal_node.call_args_list)
-        self.assertEqual(len(sigs), 2)
-        self.assertIn(
-            ('foo@e.com', 10, signal.SIGTERM),
-            {tup[0] for tup in sigs},
-        )
-        self.assertIn(
-            ('bar@e.com', 11, signal.SIGTERM),
-            {tup[0] for tup in sigs},
-        )
-        self.t.signal_node.return_value = False
-        callback.assert_called()
-        self.t.stop(['foo', 'bar', 'baz'], 'celery worker', callback=None)
-
-        def on_node_alive(pid):
-            if node_alive.call_count > 4:
-                return True
-            return False
-        self.t.signal_node.return_value = True
-        self.t.node_alive.side_effect = on_node_alive
-        self.t.stop(['foo', 'bar', 'baz'], 'celery worker', retry=True)
-
-    @patch('os.kill')
-    def test_node_alive(self, kill):
-        kill.return_value = True
-        self.assertTrue(self.t.node_alive(13))
-        esrch = OSError()
-        esrch.errno = errno.ESRCH
-        kill.side_effect = esrch
-        self.assertFalse(self.t.node_alive(13))
-        kill.assert_called_with(13, 0)
-
-        enoent = OSError()
-        enoent.errno = errno.ENOENT
-        kill.side_effect = enoent
-        with self.assertRaises(OSError):
-            self.t.node_alive(13)
-
-    @patch('os.kill')
-    def test_signal_node(self, kill):
-        kill.return_value = True
-        self.assertTrue(self.t.signal_node('foo', 13, 9))
-        esrch = OSError()
-        esrch.errno = errno.ESRCH
-        kill.side_effect = esrch
-        self.assertFalse(self.t.signal_node('foo', 13, 9))
-        kill.assert_called_with(13, 9)
-        self.assertIn('Could not signal foo', self.fh.getvalue())
-
-        enoent = OSError()
-        enoent.errno = errno.ENOENT
-        kill.side_effect = enoent
-        with self.assertRaises(OSError):
-            self.t.signal_node('foo', 13, 9)
-
-    def test_start(self):
-        self.t.waitexec = Mock()
-        self.t.waitexec.return_value = 0
-        self.assertFalse(self.t.start(['foo', 'bar', 'baz'], 'celery worker'))
-
-        self.t.waitexec.return_value = 1
-        self.assertFalse(self.t.start(['foo', 'bar', 'baz'], 'celery worker'))
-
-    def test_show(self):
-        self.t.show(['foo', 'bar', 'baz'], 'celery worker')
-        self.assertTrue(self.fh.getvalue())
-
-    @patch('celery.bin.multi.gethostname')
+    @patch('celery.apps.multi.gethostname')
     def test_get(self, gethostname):
         gethostname.return_value = 'e.com'
-        self.t.get(['xuzzy@e.com', 'foo', 'bar', 'baz'], 'celery worker')
+        self.t.get('xuzzy@e.com', 'foo', 'bar', 'baz')
         self.assertFalse(self.fh.getvalue())
-        self.t.get(['foo@e.com', 'foo', 'bar', 'baz'], 'celery worker')
+        self.t.get('foo@e.com', 'foo', 'bar', 'baz')
         self.assertTrue(self.fh.getvalue())
 
-    @patch('celery.bin.multi.gethostname')
+    @patch('celery.apps.multi.gethostname')
     def test_names(self, gethostname):
         gethostname.return_value = 'e.com'
-        self.t.names(['foo', 'bar', 'baz'], 'celery worker')
+        self.t.names('foo', 'bar', 'baz')
         self.assertIn('foo@e.com\nbar@e.com\nbaz@e.com', self.fh.getvalue())
 
     def test_execute_from_commandline(self):
@@ -450,7 +339,7 @@ class test_MultiTool(AppCase):
         self.t.error = Mock()
         self.t.execute_from_commandline(['multi', 'start', 'foo', 'bar'])
         self.t.error.assert_not_called()
-        start.assert_called_with(['foo', 'bar'], 'celery worker')
+        start.assert_called_with('foo', 'bar')
 
         self.t.error = Mock()
         self.t.execute_from_commandline(['multi', 'frob', 'foo', 'bar'])
@@ -473,11 +362,6 @@ class test_MultiTool(AppCase):
         self.assertTrue(self.t.verbose)
         self.assertTrue(self.t.no_color)
 
-    def test_stopwait(self):
-        self.t._stop_nodes = Mock()
-        self.t.stopwait(['foo', 'bar', 'baz'], 'celery worker')
-        self.assertEqual(self.t._stop_nodes.call_args[1]['retry'], 2)
-
     @patch('celery.bin.multi.MultiTool')
     def test_main(self, MultiTool):
         m = MultiTool.return_value = Mock()

+ 5 - 0
celery/utils/functional.py

@@ -20,6 +20,7 @@ __all__ = [
     'LRUCache', 'is_list', 'maybe_list', 'memoize', 'mlazy', 'noop',
     'first', 'firstmethod', 'chunks', 'padlist', 'mattrgetter', 'uniq',
     'regen', 'dictfilter', 'lazy', 'maybe_evaluate', 'head_from_fun',
+    'maybe',
 ]
 
 IS_PY3 = sys.version_info[0] == 3
@@ -263,3 +264,7 @@ def fun_takes_argument(name, fun, position=None):
         spec.varkw or spec.varargs or
         (len(spec.args) >= position if position else name in spec.args)
     )
+
+
+def maybe(typ, val):
+    return typ(val) if val is not None else val

+ 3 - 3
celery/worker/request.py

@@ -22,7 +22,7 @@ from celery.exceptions import (
 )
 from celery.five import python_2_unicode_compatible, string
 from celery.platforms import signals as _signals
-from celery.utils.functional import noop
+from celery.utils.functional import maybe, noop
 from celery.utils.log import get_logger
 from celery.utils.nodenames import gethostname
 from celery.utils.timeutils import maybe_iso8601, timezone, maybe_make_aware
@@ -192,7 +192,7 @@ class Request(object):
             correlation_id=task_id,
         )
         # cannot create weakref to None
-        self._apply_result = ref(result) if result is not None else result
+        self._apply_result = maybe(ref, result)
         return result
 
     def execute(self, loglevel=None, logfile=None):
@@ -513,7 +513,7 @@ def create_request_cls(base, task, pool, hostname, eventer,
                 correlation_id=task_id,
             )
             # cannot create weakref to None
-            self._apply_result = ref(result) if result is not None else result
+            self._apply_result = maybe(ref, result)
             return result
 
         def on_success(self, failed__retval__runtime, **kwargs):

+ 11 - 0
docs/reference/celery.apps.multi.rst

@@ -0,0 +1,11 @@
+=======================================
+ ``celery.apps.multi``
+=======================================
+
+.. contents::
+    :local:
+.. currentmodule:: celery.apps.multi
+
+.. automodule:: celery.apps.multi
+    :members:
+    :undoc-members: