Преглед на файлове

Multi: Finalize Cluster and Node API

Ask Solem преди 8 години
родител
ревизия
bd01086ba7
променени са 5 файла, в които са добавени 293 реда и са изтрити 188 реда
  1. 117 123
      celery/apps/multi.py
  2. 78 20
      celery/bin/multi.py
  3. 59 27
      celery/tests/apps/test_multi.py
  4. 39 17
      celery/tests/bin/test_multi.py
  5. 0 1
      celery/utils/imports.py

+ 117 - 123
celery/apps/multi.py

@@ -15,7 +15,7 @@ from kombu.utils.encoding import from_utf8
 from kombu.utils.objects import cached_property
 
 from celery.five import UserList, items
-from celery.platforms import IS_WINDOWS, Pidfile, signal_name, signals
+from celery.platforms import IS_WINDOWS, Pidfile, signal_name
 from celery.utils.nodenames import (
     gethostname, host_format, node_format, nodesplit,
 )
@@ -30,6 +30,48 @@ def celery_exe(*args):
     return ' '.join((CELERY_EXE,) + args)
 
 
+def build_nodename(name, prefix, suffix):
+    hostname = suffix
+    if '@' in name:
+        nodename = host_format(name)
+        shortname, hostname = nodesplit(nodename)
+        name = shortname
+    else:
+        shortname = '%s%s' % (prefix, name)
+        nodename = host_format(
+            '{0}@{1}'.format(shortname, hostname),
+        )
+    return name, nodename, hostname
+
+
+def build_expander(nodename, shortname, hostname):
+    return partial(
+        node_format,
+        nodename=nodename,
+        N=shortname,
+        d=hostname,
+        h=nodename,
+        i='%i',
+        I='%I',
+    )
+
+
+def format_opt(opt, value):
+    if not value:
+        return opt
+    if opt.startswith('--'):
+        return '{0}={1}'.format(opt, value)
+    return '{0} {1}'.format(opt, value)
+
+
+def _kwargs_to_command_line(kwargs):
+    return {
+        ('--{0}'.format(k.replace('_', '-'))
+         if len(k) > 1 else '-{0}'.format(k)): '{0}'.format(v)
+        for k, v in items(kwargs)
+    }
+
+
 class NamespacedOptionParser(object):
 
     def __init__(self, args):
@@ -39,8 +81,6 @@ class NamespacedOptionParser(object):
         self.passthrough = ''
         self.namespaces = defaultdict(lambda: OrderedDict())
 
-        self.parse()
-
     def parse(self):
         rargs = list(self.args)
         pos = 0
@@ -86,14 +126,49 @@ class NamespacedOptionParser(object):
 
 class Node(object):
 
-    def __init__(self, name, argv, expander, namespace, p):
-        self.p = p
+    def __init__(self, name,
+                 cmd=None, append=None, options=None, extra_args=None):
         self.name = name
-        self.argv = tuple(argv)
-        self.expander = expander
-        self.namespace = namespace
+        self.cmd = cmd or '-m {0}'.format(celery_exe('worker', '--detach'))
+        self.append = append
+        self.extra_args = extra_args or ''
+        self.options = self._annotate_with_default_opts(
+            options or OrderedDict())
+        self.expander = self._prepare_expander()
+        self.argv = self._prepare_argv()
         self._pid = None
 
+    def _annotate_with_default_opts(self, options):
+        options['-n'] = self.name
+        self._setdefaultopt(options, ['--pidfile', '-p'], '%n.pid')
+        self._setdefaultopt(options, ['--logfile', '-f'], '%n%I.log')
+        self._setdefaultopt(options, ['--executable'], sys.executable)
+        return options
+
+    def _setdefaultopt(self, d, alt, value):
+        for opt in alt[1:]:
+            try:
+                return d[opt]
+            except KeyError:
+                pass
+        return d.setdefault(alt[0], value)
+
+    def _prepare_expander(self):
+        shortname, hostname = self.name.split('@', 1)
+        return build_expander(
+            self.name, shortname, hostname)
+
+    def _prepare_argv(self):
+        argv = tuple(
+            [self.expander(self.cmd)] +
+            [format_opt(opt, self.expander(value))
+                for opt, value in items(self.options)] +
+            [self.extra_args]
+        )
+        if self.append:
+            argv += (self.expander(self.append),)
+        return argv
+
     def alive(self):
         return self.send(0)
 
@@ -138,21 +213,9 @@ class Node(object):
         return shlex.split(from_utf8(args), posix=not IS_WINDOWS)
 
     def getopt(self, *alt):
-        try:
-            return self._getnsopt(*alt)
-        except KeyError:
-            return self._getoptopt(*alt)
-
-    def _getnsopt(self, *alt):
-        return self._getopt(self.p.namespaces[self.namespace], list(alt))
-
-    def _getoptopt(self, *alt):
-        return self._getopt(self.p.options, list(alt))
-
-    def _getopt(self, d, alt):
         for opt in alt:
             try:
-                return d[opt]
+                return self.options[opt]
             except KeyError:
                 pass
         raise KeyError(alt[0])
@@ -183,12 +246,16 @@ class Node(object):
 
     @cached_property
     def executable(self):
-        return self.p.options['--executable']
+        return self.options['--executable']
 
     @cached_property
     def argv_with_executable(self):
         return (self.executable,) + self.argv
 
+    @classmethod
+    def from_kwargs(cls, name, **kwargs):
+        return cls(name, options=_kwargs_to_command_line(kwargs))
+
 
 def maybe_call(fun, *args, **kwargs):
     if fun is not None:
@@ -212,11 +279,6 @@ class MultiParser(object):
         options = dict(p.options)
         ranges = len(names) == 1
         prefix = self.prefix
-        if ranges:
-            try:
-                names, prefix = self._get_ranges(names), self.range_prefix
-            except ValueError:
-                pass
         cmd = options.pop('--cmd', self.cmd)
         append = options.pop('--append', self.append)
         hostname = options.pop('--hostname', options.pop('-n', gethostname()))
@@ -224,20 +286,34 @@ class MultiParser(object):
         suffix = options.pop('--suffix', self.suffix) or hostname
         suffix = '' if suffix in ('""', "''") else suffix
 
+        if ranges:
+            try:
+                names, prefix = self._get_ranges(names), self.range_prefix
+            except ValueError:
+                pass
         self._update_ns_opts(p, names)
         self._update_ns_ranges(p, ranges)
+
         return (
-            self._args_for_node(p, name, prefix, suffix, cmd, append, options)
+            self._node_from_options(
+                p, name, prefix, suffix, cmd, append, options)
             for name in names
         )
 
+    def _node_from_options(self, p, name, prefix,
+                           suffix, cmd, append, options):
+        namespace, nodename, _ = build_nodename(name, prefix, suffix)
+        namespace = nodename if nodename in p.namespaces else namespace
+        return Node(nodename, cmd, append,
+                    p.optmerge(namespace, options), p.passthrough)
+
     def _get_ranges(self, names):
         noderange = int(names[0])
         return [str(n) for n in range(1, noderange + 1)]
 
     def _update_ns_opts(self, p, names):
         # Numbers in args always refers to the index in the list of names.
-        # (e.g. `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
+        # (e.g., `start foo bar baz -c:1` where 1 is foo, 2 is bar, and so on).
         for ns_name, ns_opts in list(items(p.namespaces)):
             if ns_name.isdigit():
                 ns_index = int(ns_name) - 1
@@ -267,55 +343,10 @@ class MultiParser(object):
                 ret.append(space)
         return ret
 
-    def _args_for_node(self, p, name, prefix, suffix, cmd, append, options):
-        name, nodename, expand = self._get_nodename(
-            name, prefix, suffix, options)
-
-        if nodename in p.namespaces:
-            ns = nodename
-        else:
-            ns = name
-
-        argv = (
-            [expand(cmd)] +
-            [self.format_opt(opt, expand(value))
-                for opt, value in items(p.optmerge(ns, options))] +
-            [p.passthrough]
-        )
-        if append:
-            argv.append(expand(append))
-        return self.Node(nodename, argv, expand, name, p)
-
-    def _get_nodename(self, name, prefix, suffix, options):
-        hostname = suffix
-        if '@' in name:
-            nodename = options['-n'] = host_format(name)
-            shortname, hostname = nodesplit(nodename)
-            name = shortname
-        else:
-            shortname = '%s%s' % (prefix, name)
-            nodename = options['-n'] = host_format(
-                '{0}@{1}'.format(shortname, hostname),
-            )
-        expand = partial(
-            node_format, nodename=nodename, N=shortname, d=hostname,
-            h=nodename, i='%i', I='%I',
-        )
-        return name, nodename, expand
-
-    def format_opt(self, opt, value):
-        if not value:
-            return opt
-        if opt.startswith('--'):
-            return '{0}={1}'.format(opt, value)
-        return '{0} {1}'.format(opt, value)
-
 
 class Cluster(UserList):
-    MultiParser = MultiParser
-    OptionParser = NamespacedOptionParser
 
-    def __init__(self, argv, cmd=None, env=None,
+    def __init__(self, nodes, cmd=None, env=None,
                  on_stopping_preamble=None,
                  on_send_signal=None,
                  on_still_waiting_for=None,
@@ -331,11 +362,9 @@ class Cluster(UserList):
                  on_child_spawn=None,
                  on_child_signalled=None,
                  on_child_failure=None):
-        self.argv = argv
+        self.nodes = nodes
         self.cmd = cmd or celery_exe('worker')
         self.env = env
-        self.p = self.OptionParser(argv)
-        self.with_detacher_default_options(self.p)
 
         self.on_stopping_preamble = on_stopping_preamble
         self.on_send_signal = on_send_signal
@@ -378,7 +407,7 @@ class Cluster(UserList):
     def kill(self):
         return self.send_all(signal.SIGKILL)
 
-    def restart(self):
+    def restart(self, sig=signal.SIGTERM):
         retvals = []
 
         def restart_on_down(node):
@@ -387,40 +416,22 @@ class Cluster(UserList):
             maybe_call(self.on_node_status, node, retval)
             retvals.append(retval)
 
-        self._stop_nodes(retry=2, on_down=restart_on_down)
+        self._stop_nodes(retry=2, on_down=restart_on_down, sig=sig)
         return retvals
 
-    def stop(self, retry=None, callback=None):
-        return self._stop_nodes(retry=retry, on_down=callback)
+    def stop(self, retry=None, callback=None, sig=signal.SIGTERM):
+        return self._stop_nodes(retry=retry, on_down=callback, sig=sig)
 
-    def stopwait(self, retry=2, callback=None):
-        return self._stop_nodes(retry=retry, on_down=callback)
+    def stopwait(self, retry=2, callback=None, sig=signal.SIGTERM):
+        return self._stop_nodes(retry=retry, on_down=callback, sig=sig)
 
-    def _stop_nodes(self, retry=None, on_down=None):
+    def _stop_nodes(self, retry=None, on_down=None, sig=signal.SIGTERM):
         on_down = on_down if on_down is not None else self.on_node_down
-        restargs = self.p.args[len(self.p.values):]
         nodes = list(self.getpids(on_down=on_down))
         if nodes:
-            for node in self.shutdown_nodes(
-                    nodes,
-                    sig=self._find_sig_argument(restargs),
-                    retry=retry):
+            for node in self.shutdown_nodes(nodes, sig=sig, retry=retry):
                 maybe_call(on_down, node)
 
-    def _find_sig_argument(self, args, default=signal.SIGTERM):
-        for arg in reversed(args):
-            if len(arg) == 2 and arg[0] == '-':
-                try:
-                    return int(arg[1])
-                except ValueError:
-                    pass
-            if arg[0] == '-':
-                try:
-                    return signals.signum(arg[1:])
-                except (AttributeError, TypeError):
-                    pass
-        return default
-
     def shutdown_nodes(self, nodes, sig=signal.SIGTERM, retry=None):
         P = set(nodes)
         maybe_call(self.on_stopping_preamble, nodes)
@@ -456,23 +467,6 @@ class Cluster(UserList):
                 return node
         raise KeyError(name)
 
-    def with_detacher_default_options(self, p):
-        self._setdefaultopt(p.options, ['--pidfile', '-p'], '%n.pid')
-        self._setdefaultopt(p.options, ['--logfile', '-f'], '%n%I.log')
-        self._setdefaultopt(p.options, ['--executable'], sys.executable)
-        p.options.setdefault(
-            '--cmd',
-            '-m {0}'.format(celery_exe('worker', '--detach')),
-        )
-
-    def _setdefaultopt(self, d, alt, value):
-        for opt in alt[1:]:
-            try:
-                return d[opt]
-            except KeyError:
-                pass
-        return d.setdefault(alt[0], value)
-
     def getpids(self, on_down=None):
         for node in self:
             if node.pid:
@@ -486,6 +480,6 @@ class Cluster(UserList):
             name=type(self).__name__,
         )
 
-    @cached_property
+    @property
     def data(self):
-        return list(self.MultiParser(cmd=self.cmd).parse(self.p))
+        return self.nodes

+ 78 - 20
celery/bin/multi.py

@@ -96,6 +96,7 @@ Examples
 from __future__ import absolute_import, print_function, unicode_literals
 
 import os
+import signal
 import sys
 
 from functools import wraps
@@ -103,8 +104,8 @@ from functools import wraps
 from kombu.utils.objects import cached_property
 
 from celery import VERSION_BANNER
-from celery.apps.multi import Cluster
-from celery.platforms import EX_FAILURE, EX_OK
+from celery.apps.multi import Cluster, MultiParser, NamespacedOptionParser
+from celery.platforms import EX_FAILURE, EX_OK, signals
 from celery.utils import term
 from celery.utils.text import pluralize
 
@@ -144,6 +145,24 @@ def splash(fun):
     return _inner
 
 
+def using_cluster(fun):
+
+    @wraps(fun)
+    def _inner(self, *argv, **kwargs):
+        return fun(self, self.cluster_from_argv(argv), **kwargs)
+    return _inner
+
+
+def using_cluster_and_sig(fun):
+
+    @wraps(fun)
+    def _inner(self, *argv, **kwargs):
+        p, cluster = self._cluster_from_argv(argv)
+        sig = self._find_sig_argument(p)
+        return fun(self, cluster, sig, **kwargs)
+    return _inner
+
+
 class TermLogger(object):
 
     splash_text = 'celery multi v{version}'
@@ -201,6 +220,8 @@ class TermLogger(object):
 
 
 class MultiTool(TermLogger):
+    MultiParser = MultiParser
+    OptionParser = NamespacedOptionParser
 
     reserved_options = [
         ('--nosplash', 'nosplash'),
@@ -260,56 +281,93 @@ class MultiTool(TermLogger):
         return argv
 
     @splash
-    def start(self, *argv):
+    @using_cluster
+    def start(self, cluster):
         self.note('> Starting nodes...')
-        return int(any(self.Cluster(argv).start()))
+        return int(any(cluster.start()))
 
     @splash
-    def stop(self, *argv, **kwargs):
-        return self.Cluster(argv).stop(**kwargs)
+    @using_cluster_and_sig
+    def stop(self, cluster, sig, **kwargs):
+        return cluster.stop(sig=sig, **kwargs)
 
     @splash
-    def stopwait(self, *argv, **kwargs):
-        return self.Cluster(argv).stopwait(**kwargs)
+    @using_cluster_and_sig
+    def stopwait(self, cluster, sig, **kwargs):
+        return cluster.stopwait(sig=sig, **kwargs)
     stop_verify = stopwait  # compat
 
     @splash
-    def restart(self, *argv, **kwargs):
-        return int(any(self.Cluster(argv).restart(**kwargs)))
+    @using_cluster_and_sig
+    def restart(self, cluster, sig, **kwargs):
+        return int(any(cluster.restart(sig=sig, **kwargs)))
 
-    def names(self, *argv):
-        self.say('\n'.join(n.name for n in self.Cluster(argv)))
+    @using_cluster
+    def names(self, cluster):
+        self.say('\n'.join(n.name for n in cluster))
 
     def get(self, wanted, *argv):
         try:
-            node = self.Cluster(argv).find(wanted)
+            node = self.cluster_from_argv(argv).find(wanted)
         except KeyError:
             return EX_FAILURE
         else:
             return self.ok(' '.join(node.argv))
 
-    def show(self, *argv):
+    @using_cluster
+    def show(self, cluster):
         return self.ok('\n'.join(
             ' '.join(node.argv_with_executable)
-            for node in self.Cluster(argv)
+            for node in cluster
         ))
 
     @splash
-    def kill(self, *argv):
-        return self.Cluster(argv).kill()
+    @using_cluster
+    def kill(self, cluster):
+        return cluster.kill()
 
     def expand(self, template, *argv):
         return self.ok('\n'.join(
             node.expander(template)
-            for node in self.Cluster(argv)
+            for node in self.cluster_from_argv(argv)
         ))
 
     def help(self, *argv):
         self.say(__doc__)
 
-    def Cluster(self, argv, cmd=None):
+    def _find_sig_argument(self, p, default=signal.SIGTERM):
+        args = p.args[len(p.values):]
+        for arg in reversed(args):
+            if len(arg) == 2 and arg[0] == '-':
+                try:
+                    return int(arg[1])
+                except ValueError:
+                    pass
+            if arg[0] == '-':
+                try:
+                    return signals.signum(arg[1:])
+                except (AttributeError, TypeError):
+                    pass
+        return default
+
+    def _nodes_from_argv(self, argv, cmd=None):
+        cmd = cmd if cmd is not None else self.cmd
+        p = self.OptionParser(argv)
+        p.parse()
+        return p, self.MultiParser(cmd=cmd).parse(p)
+
+    def cluster_from_argv(self, argv, cmd=None):
+        _, cluster = self._cluster_from_argv(argv, cmd=cmd)
+        return cluster
+
+    def _cluster_from_argv(self, argv, cmd=None):
+        p, nodes = self._nodes_from_argv(argv, cmd=cmd)
+        return p, self.Cluster(list(nodes), cmd=cmd)
+
+    def Cluster(self, nodes, cmd=None):
         return Cluster(
-            argv, cmd if cmd is not None else self.cmd,
+            nodes,
+            cmd=cmd,
             env=self.env,
             on_stopping_preamble=self.on_stopping_preamble,
             on_send_signal=self.on_send_signal,

+ 59 - 27
celery/tests/apps/test_multi.py

@@ -2,9 +2,10 @@ from __future__ import absolute_import, unicode_literals
 
 import errno
 import signal
+import sys
 
 from celery.apps.multi import (
-    Cluster, MultiParser, NamespacedOptionParser, Node,
+    Cluster, MultiParser, NamespacedOptionParser, Node, format_opt,
 )
 
 from celery.tests.case import AppCase, Mock, call, patch
@@ -12,15 +13,6 @@ from celery.tests.case import AppCase, Mock, call, patch
 
 class test_functions(AppCase):
 
-    def test_findsig(self):
-        m = Cluster([])
-        self.assertEqual(m._find_sig_argument(['a', 'b', 'c', '-1']), 1)
-        self.assertEqual(m._find_sig_argument(['--foo=1', '-9']), 9)
-        self.assertEqual(m._find_sig_argument(['-INT']), signal.SIGINT)
-        self.assertEqual(m._find_sig_argument([]), signal.SIGTERM)
-        self.assertEqual(m._find_sig_argument(['-s']), signal.SIGTERM)
-        self.assertEqual(m._find_sig_argument(['-log']), signal.SIGTERM)
-
     def test_parse_ns_range(self):
         m = MultiParser()
         self.assertEqual(m._parse_ns_range('1-3', True), ['1', '2', '3'])
@@ -31,22 +23,23 @@ class test_functions(AppCase):
         )
 
     def test_format_opt(self):
-        m = MultiParser()
-        self.assertEqual(m.format_opt('--foo', None), '--foo')
-        self.assertEqual(m.format_opt('-c', 1), '-c 1')
-        self.assertEqual(m.format_opt('--log', 'foo'), '--log=foo')
+        self.assertEqual(format_opt('--foo', None), '--foo')
+        self.assertEqual(format_opt('-c', 1), '-c 1')
+        self.assertEqual(format_opt('--log', 'foo'), '--log=foo')
 
 
 class test_NamespacedOptionParser(AppCase):
 
     def test_parse(self):
         x = NamespacedOptionParser(['-c:1,3', '4'])
+        x.parse()
         self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
         x = NamespacedOptionParser(['-c:jerry,elaine', '5',
                                     '--loglevel:kramer=DEBUG',
                                     '--flag',
                                     '--logfile=foo', '-Q', 'bar', 'a', 'b',
                                     '--', '.disable_rate_limits=1'])
+        x.parse()
         self.assertEqual(x.options, {'--logfile': 'foo',
                                      '-Q': 'bar',
                                      '--flag': None})
@@ -73,6 +66,7 @@ class test_multi_args(AppCase):
             'elaine', 'kramer',
             '--', '.disable_rate_limits=1',
         ])
+        p.parse()
         it = multi_args(p, cmd='COMMAND', append='*AP*',
                         prefix='*P*', suffix='*S*')
         nodes = list(it)
@@ -113,18 +107,29 @@ class test_multi_args(AppCase):
         self.assertEqual(nodes2[0].argv[-1], '-- .disable_rate_limits=1')
 
         p2 = NamespacedOptionParser(['10', '-c:1', '5'])
+        p2.parse()
         nodes3 = list(multi_args(p2, cmd='COMMAND'))
+
+        def _args(name, *args):
+            return args + (
+                '--pidfile={0}.pid'.format(name),
+                '--logfile={0}%I.log'.format(name),
+                '--executable={0}'.format(sys.executable),
+                '',
+            )
+
         self.assertEqual(len(nodes3), 10)
         self.assertEqual(nodes3[0].name, 'celery1@example.com')
         self.assertTupleEqual(
             nodes3[0].argv,
-            ('COMMAND', '-n celery1@example.com', '-c 5', ''),
+            ('COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1'),
         )
         for i, worker in enumerate(nodes3[1:]):
             self.assertEqual(worker.name, 'celery%s@example.com' % (i + 2))
             self.assertTupleEqual(
                 worker.argv,
-                ('COMMAND', '-n celery%s@example.com' % (i + 2), ''),
+                (('COMMAND', '-n celery%s@example.com' % (i + 2)) +
+                 _args('celery%s' % (i + 2))),
             )
 
         nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
@@ -132,39 +137,44 @@ class test_multi_args(AppCase):
         self.assertEqual(nodes4[0].name, 'celery1@')
         self.assertTupleEqual(
             nodes4[0].argv,
-            ('COMMAND', '-n celery1@', '-c 5', ''),
+            ('COMMAND', '-c 5', '-n celery1@') + _args('celery1'),
         )
 
         p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
+        p3.parse()
         nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
         self.assertEqual(nodes5[0].name, 'foo@')
         self.assertTupleEqual(
             nodes5[0].argv,
-            ('COMMAND', '-n foo@', '-c 5', ''),
+            ('COMMAND', '-c 5', '-n foo@') + _args('foo'),
         )
 
         p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
+        p4.parse()
         nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
         self.assertEqual(nodes6[0].name, 'foo@')
         self.assertTupleEqual(
             nodes6[0].argv,
-            ('COMMAND', '-n foo@', '-Q test', ''),
+            ('COMMAND', '-Q test', '-n foo@') + _args('foo'),
         )
 
         p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
+        p5.parse()
         nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
         self.assertEqual(nodes7[0].name, 'foo@bar')
         self.assertTupleEqual(
             nodes7[0].argv,
-            ('COMMAND', '-n foo@bar', '-Q test', ''),
+            ('COMMAND', '-Q test', '-n foo@bar') + _args('foo'),
         )
 
         p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
+        p6.parse()
         with self.assertRaises(KeyError):
             list(multi_args(p6))
 
     def test_optmerge(self):
         p = NamespacedOptionParser(['foo', 'test'])
+        p.parse()
         p.options = {'x': 'y'}
         r = p.optmerge('foo')
         self.assertEqual(r['x'], 'y')
@@ -179,12 +189,28 @@ class test_Node(AppCase):
             '--logfile': 'foo.log',
         }
         self.p.namespaces = {}
-        self.expander = Mock(name='expander')
-        self.node = Node(
-            'foo@bar.com', ['-A', 'proj'], self.expander, 'foo', self.p,
-        )
+        self.node = Node('foo@bar.com', options={'-A': 'proj'})
+        self.expander = self.node.expander = Mock(name='expander')
         self.node.pid = 303
 
+    def test_from_kwargs(self):
+        n = Node.from_kwargs(
+            'foo@bar.com',
+            max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
+        )
+        self.assertTupleEqual(n.argv, (
+            '-m celery worker --detach',
+            '-A foo',
+            '--executable={0}'.format(n.executable),
+            '-O fair',
+            '-n foo@bar.com',
+            '--logfile=foo%I.log',
+            '-Q q1,q2',
+            '--max-tasks-per-child=30',
+            '--pidfile=foo.pid',
+            '',
+        ))
+
     @patch('os.kill')
     def test_send(self, kill):
         self.assertTrue(self.node.send(9))
@@ -267,7 +293,7 @@ class test_Node(AppCase):
 
     def test_logfile(self):
         self.assertEqual(self.node.logfile, self.expander.return_value)
-        self.expander.assert_called_with('foo.log')
+        self.expander.assert_called_with('%n%I.log')
 
 
 class test_Cluster(AppCase):
@@ -279,7 +305,9 @@ class test_Cluster(AppCase):
         self.gethostname.return_value = 'example.com'
         self.Pidfile = self.patch('celery.apps.multi.Pidfile')
         self.cluster = Cluster(
-            ['foo', 'bar', 'baz'],
+            [Node('foo@example.com'),
+             Node('bar@example.com'),
+             Node('baz@example.com')],
             on_stopping_preamble=Mock(name='on_stopping_preamble'),
             on_send_signal=Mock(name='on_send_signal'),
             on_still_waiting_for=Mock(name='on_still_waiting_for'),
@@ -358,7 +386,11 @@ class test_Cluster(AppCase):
         self.prepare_pidfile_for_getpids(self.Pidfile)
         callback = Mock()
 
-        p = Cluster(['foo', 'bar', 'baz'])
+        p = Cluster([
+            Node('foo@e.com'),
+            Node('bar@e.com'),
+            Node('baz@e.com'),
+        ])
         nodes = p.getpids(on_down=callback)
         node_0, node_1 = nodes
         self.assertEqual(node_0.name, 'foo@e.com')

+ 39 - 17
celery/tests/bin/test_multi.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 
+import signal
 import sys
 
 from celery.bin.multi import (
@@ -18,6 +19,8 @@ class test_MultiTool(AppCase):
         self.fh = WhateverIO()
         self.env = {}
         self.t = MultiTool(env=self.env, fh=self.fh)
+        self.t.cluster_from_argv = Mock(name='cluster_from_argv')
+        self.t._cluster_from_argv = Mock(name='cluster_from_argv')
         self.t.Cluster = Mock(name='Cluster')
         self.t.carp = Mock(name='.carp')
         self.t.usage = Mock(name='.usage')
@@ -26,6 +29,26 @@ class test_MultiTool(AppCase):
         self.t.ok = Mock(name='.ok')
         self.cluster = self.t.Cluster.return_value
 
+        def _cluster_from_argv(argv):
+            p = self.t.OptionParser(argv)
+            p.parse()
+            return p, self.cluster
+        self.t.cluster_from_argv.return_value = self.cluster
+        self.t._cluster_from_argv.side_effect = _cluster_from_argv
+
+    def test_findsig(self):
+        self.assert_sig_argument(['a', 'b', 'c', '-1'], 1)
+        self.assert_sig_argument(['--foo=1', '-9'], 9)
+        self.assert_sig_argument(['-INT'], signal.SIGINT)
+        self.assert_sig_argument([], signal.SIGTERM)
+        self.assert_sig_argument(['-s'], signal.SIGTERM)
+        self.assert_sig_argument(['-log'], signal.SIGTERM)
+
+    def assert_sig_argument(self, args, expected):
+        p = self.t.OptionParser(args)
+        p.parse()
+        self.assertEqual(self.t._find_sig_argument(p), expected)
+
     def test_execute_from_commandline(self):
         self.t.call_command = Mock(name='call_command')
         self.t.execute_from_commandline(
@@ -67,7 +90,7 @@ class test_MultiTool(AppCase):
         self.cluster.start.return_value = [0, 0, 1, 0]
         self.assertTrue(self.t.start('10', '-A', 'proj'))
         self.t.splash.assert_called_with()
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.cluster.start.assert_called_with()
 
     def test_start__exitcodes(self):
@@ -81,26 +104,26 @@ class test_MultiTool(AppCase):
     def test_stop(self):
         self.t.stop('10', '-A', 'proj', retry=3)
         self.t.splash.assert_called_with()
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
-        self.cluster.stop.assert_called_with(retry=3)
+        self.t._cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.stop.assert_called_with(retry=3, sig=signal.SIGTERM)
 
     def test_stopwait(self):
         self.t.stopwait('10', '-A', 'proj', retry=3)
         self.t.splash.assert_called_with()
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
-        self.cluster.stopwait.assert_called_with(retry=3)
+        self.t._cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.stopwait.assert_called_with(retry=3, sig=signal.SIGTERM)
 
     def test_restart(self):
         self.cluster.restart.return_value = [0, 0, 1, 0]
         self.t.restart('10', '-A', 'proj')
         self.t.splash.assert_called_with()
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
-        self.cluster.restart.assert_called_with()
+        self.t._cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
+        self.cluster.restart.assert_called_with(sig=signal.SIGTERM)
 
     def test_names(self):
-        self.t.Cluster.return_value = [Mock(), Mock()]
-        self.t.Cluster.return_value[0].name = 'x'
-        self.t.Cluster.return_value[1].name = 'y'
+        self.t.cluster_from_argv.return_value = [Mock(), Mock()]
+        self.t.cluster_from_argv.return_value[0].name = 'x'
+        self.t.cluster_from_argv.return_value[1].name = 'y'
         self.t.names('10', '-A', 'proj')
         self.t.say.assert_called()
 
@@ -112,7 +135,7 @@ class test_MultiTool(AppCase):
             self.t.ok.return_value,
         )
         self.cluster.find.assert_called_with('wanted')
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.t.ok.assert_called_with(' '.join(node.argv))
 
     def test_get__KeyError(self):
@@ -120,7 +143,7 @@ class test_MultiTool(AppCase):
         self.assertTrue(self.t.get('wanted', '10', '-A', 'proj'))
 
     def test_show(self):
-        nodes = self.t.Cluster.return_value = [
+        nodes = self.t.cluster_from_argv.return_value = [
             Mock(name='n1'),
             Mock(name='n2'),
         ]
@@ -137,7 +160,7 @@ class test_MultiTool(AppCase):
     def test_kill(self):
         self.t.kill('10', '-A', 'proj')
         self.t.splash.assert_called_with()
-        self.t.Cluster.assert_called_with(('10', '-A', 'proj'))
+        self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.cluster.kill.assert_called_with()
 
     def test_expand(self):
@@ -145,9 +168,9 @@ class test_MultiTool(AppCase):
         node2 = Mock(name='n2')
         node1.expander.return_value = 'A'
         node2.expander.return_value = 'B'
-        nodes = self.t.Cluster.return_value = [node1, node2]
+        nodes = self.t.cluster_from_argv.return_value = [node1, node2]
         self.assertIs(self.t.expand('%p', '10'), self.t.ok.return_value)
-        self.t.Cluster.assert_called_with(('10',))
+        self.t.cluster_from_argv.assert_called_with(('10',))
         for node in nodes:
             node.expander.assert_called_with('%p')
         self.t.ok.assert_called_with('A\nB')
@@ -172,8 +195,7 @@ class test_MultiTool(AppCase):
 
     def test_Cluster(self):
         m = MultiTool()
-        c = m.Cluster(['A', 'B', 'C'])
-        self.assertListEqual(c.argv, ['A', 'B', 'C'])
+        c = m.cluster_from_argv(['A', 'B', 'C'])
         self.assertIs(c.env, m.env)
         self.assertEqual(c.cmd, 'celery worker')
         self.assertEqual(c.on_stopping_preamble, m.on_stopping_preamble)

+ 0 - 1
celery/utils/imports.py

@@ -158,4 +158,3 @@ def load_extension_classes(namespace):
                     namespace, class_name, exc))
         else:
             yield name, cls
-