Ver código fonte

Tests passing

Ask Solem 12 anos atrás
pai
commit
ccad83ee21

+ 0 - 1
celery/__init__.py

@@ -74,4 +74,3 @@ old_module, new_module = recreate_module(__name__,  # pragma: no cover
     __homepage__=__homepage__, __docformat__=__docformat__,
     VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER,
 )
-

+ 0 - 2
celery/bin/base.py

@@ -306,10 +306,8 @@ class Command(object):
 
     def find_app(self, app):
         try:
-            print('sym by name: %r' % (app, ))
             sym = self.symbol_by_name(app)
         except AttributeError:
-            print('ATTRIBUTE ERROR')
             # last part was not an attribute, but a module
             sym = import_from_cwd(app)
         if isinstance(sym, ModuleType):

+ 194 - 195
celery/bin/celery.py

@@ -15,7 +15,6 @@ import sys
 import warnings
 
 from importlib import import_module
-from itertools import imap
 from operator import itemgetter
 from pprint import pformat
 
@@ -103,6 +102,28 @@ def load_extension_commands(namespace='celery.commands'):
             command(cls, name=ep.name)
 
 
+def determine_exit_status(ret):
+    if isinstance(ret, int):
+        return ret
+    return EX_OK if ret else EX_FAILURE
+
+
+def main(argv=None):
+    # Fix for setuptools generated scripts, so that it will
+    # work with multiprocessing fork emulation.
+    # (see multiprocessing.forking.get_preparation_data())
+    try:
+        if __name__ != '__main__':  # pragma: no cover
+            sys.modules['__main__'] = sys.modules[__name__]
+        cmd = CeleryCommand()
+        cmd.maybe_patch_concurrency()
+        from billiard import freeze_support
+        freeze_support()
+        cmd.execute_from_commandline(argv)
+    except KeyboardInterrupt:
+        pass
+
+
 class Command(BaseCommand):
     help = ''
     args = ''
@@ -136,7 +157,8 @@ class Command(BaseCommand):
 
     def exit_help(self, command):
         # this never exits due to OptionParser.parse_options
-        return self.run_from_argv(self.prog_name, [command, '--help'])
+        self.run_from_argv(self.prog_name, [command, '--help'])
+        sys.exit(EX_USAGE)
 
     def error(self, s):
         self.out(s, fh=self.stderr)
@@ -467,180 +489,6 @@ class result(Command):
         self.out(self.prettify(value)[1])
 
 
-@command
-class graph(Command):
-    args = """<TYPE> [arguments]
-            .....  bootsteps [worker] [consumer]
-            .....  workers   [enumerate]
-    """
-
-    def run(self, what=None, *args, **kwargs):
-        map = {'bootsteps': self.bootsteps, 'workers': self.workers}
-        not what and self.exit_help('graph')
-        if what not in map:
-            raise Error('no graph {0} in {1}'.format(what, '|'.join(map)))
-        return map[what](*args, **kwargs)
-
-    def bootsteps(self, *args, **kwargs):
-        worker = self.app.WorkController()
-        include = set(arg.lower() for arg in args or ['worker', 'consumer'])
-        if 'worker' in include:
-            graph = worker.namespace.graph
-            if 'consumer' in include:
-                worker.namespace.connect_with(worker.consumer.namespace)
-        else:
-            graph = worker.consumer.namespace.graph
-        graph.to_dot(self.stdout)
-
-    def workers(self, *args, **kwargs):
-
-        def simplearg(arg):
-            return maybe_list(itemgetter(0, 2)(arg.partition(':')))
-
-        def maybe_list(l, sep=','):
-            return (l[0], l[1].split(sep) if sep in l[1] else l[1])
-
-        args = dict(map(simplearg, args))
-        generic = 'generic' in args
-
-        def generic_label(node):
-            return '{0} ({1}://)'.format(type(node).__name__,
-                                         node._label.split('://')[0])
-
-        class Node(object):
-            force_label = None
-            scheme = {}
-
-            def __init__(self, label, pos=None):
-                self._label = label
-                self.pos = pos
-
-            def label(self):
-                return self._label
-
-            def __str__(self):
-                return self.label()
-
-        class Thread(Node):
-            scheme = {'fillcolor': 'lightcyan4', 'fontcolor': 'yellow',
-                      'shape':'oval', 'fontsize': 10, 'width': 0.3,
-                      'color': 'black'}
-
-            def __init__(self, label, **kwargs):
-                self._label = 'thr-{0}'.format(next(tids))
-                self.real_label = label
-                self.pos = 0
-
-        class Formatter(GraphFormatter):
-
-            def label(self, obj):
-                return obj and obj.label()
-
-            def node(self, obj):
-                scheme = dict(obj.scheme, sortv=obj.pos) if obj.pos else obj.scheme
-                if isinstance(obj, Thread):
-                    scheme['label'] = obj.real_label
-                return self.draw_node(
-                    obj, dict(self.node_scheme, **scheme),
-                )
-
-            def terminal_node(self, obj):
-                return self.draw_node(
-                    obj, dict(self.term_scheme, **obj.scheme),
-                )
-
-            def edge(self, a, b, **attrs):
-                if isinstance(a, Thread):
-                    attrs.update(arrowhead='none', arrowtail='tee')
-                return self.draw_edge(a, b, self.edge_scheme, attrs)
-
-        def subscript(n):
-            S = {'0': '₀', '1': '₁', '2': '₂', '3': '₃', '4': '₄',
-                 '5': '₅', '6': '₆', '7': '₇', '8': '₈', '9': '₉'}
-            return ''.join([S[i] for i in str(n)])
-
-
-        class Worker(Node):
-            pass
-
-        class Backend(Node):
-            scheme = {'shape': 'folder', 'width': 2,
-                      'height': 1, 'color': 'black',
-                      'fillcolor': 'peachpuff3', 'color': 'peachpuff4'}
-
-            def label(self):
-                return generic_label(self) if generic else self._label
-
-        class Broker(Node):
-            scheme = {'shape': 'circle', 'fillcolor': 'cadetblue3',
-                      'color': 'cadetblue4', 'height': 1}
-
-            def label(self):
-                return generic_label(self) if generic else self._label
-
-        from itertools import count
-        tids = count(1)
-        Wmax = int(args.get('wmax', 4) or 0)
-        Tmax = int(args.get('tmax', 3) or 0)
-
-        def maybe_abbr(l, name, max=Wmax):
-            size = len(l)
-            abbr = max and size > max
-            if 'enumerate' in args:
-                l = ['{0}{1}'.format(name, subscript(i + 1))
-                        for i, obj in enumerate(l)]
-            if abbr:
-                l = l[0:max -1] + [l[size - 1]]
-                l[max - 2] = '{0}⎨…{1}⎬'.format(
-                    name[0], subscript(size - (max - 1)))
-            return l
-
-        try:
-            workers = args['nodes']
-            threads = args.get('threads') or []
-        except KeyError:
-            replies = self.app.control.inspect().stats()
-            workers, threads = [], []
-            for worker, reply in replies.iteritems():
-                workers.append(worker)
-                threads.append(reply['pool']['max-concurrency'])
-
-        wlen = len(workers)
-        backend = args.get('backend', self.app.conf.CELERY_RESULT_BACKEND)
-        threads_for = {}
-        workers = maybe_abbr(workers, 'Worker')
-        if Wmax and wlen > Wmax:
-            threads = threads[0:3] + [threads[-1]]
-        for i, threads in enumerate(threads):
-            threads_for[workers[i]] = maybe_abbr(
-                range(int(threads)), 'P', Tmax,
-            )
-
-        broker = Broker(args.get('broker', self.app.connection().as_uri()))
-        backend = Backend(backend) if backend else None
-        graph = DependencyGraph(formatter=Formatter())
-        graph.add_arc(broker)
-        if backend:
-            graph.add_arc(backend)
-        curworker = [0]
-        for i, worker in enumerate(workers):
-            worker = Worker(worker, pos=i)
-            graph.add_arc(worker)
-            graph.add_edge(worker, broker)
-            if backend:
-                graph.add_edge(worker, backend)
-            threads = threads_for.get(worker._label)
-            if threads:
-                for thread in threads:
-                    thread = Thread(thread)
-                    graph.add_arc(thread)
-                    graph.add_edge(thread, worker)
-
-            curworker[0] += 1
-
-        graph.to_dot(self.stdout)
-
-
 class _RemoteControl(Command):
     name = None
     choices = None
@@ -707,7 +555,7 @@ class _RemoteControl(Command):
         destination = kwargs.get('destination')
         timeout = kwargs.get('timeout') or self.choices[method][0]
         if destination and isinstance(destination, basestring):
-            destination = list(imap(str.strip, destination.split(',')))
+            destination = [dest.strip() for dest in destination.split(',')]
 
         try:
             handler = getattr(self, method)
@@ -1116,26 +964,177 @@ class CeleryCommand(BaseCommand):
         load_extension_commands()
 
 
-def determine_exit_status(ret):
-    if isinstance(ret, int):
-        return ret
-    return EX_OK if ret else EX_FAILURE
+@command
+class graph(Command):
+    args = """<TYPE> [arguments]
+            .....  bootsteps [worker] [consumer]
+            .....  workers   [enumerate]
+    """
 
+    def run(self, what=None, *args, **kwargs):
+        map = {'bootsteps': self.bootsteps, 'workers': self.workers}
+        not what and self.exit_help('graph')
+        if what not in map:
+            raise Error('no graph {0} in {1}'.format(what, '|'.join(map)))
+        return map[what](*args, **kwargs)
 
-def main(argv=None):
-    # Fix for setuptools generated scripts, so that it will
-    # work with multiprocessing fork emulation.
-    # (see multiprocessing.forking.get_preparation_data())
-    try:
-        if __name__ != '__main__':  # pragma: no cover
-            sys.modules['__main__'] = sys.modules[__name__]
-        cmd = CeleryCommand()
-        cmd.maybe_patch_concurrency()
-        from billiard import freeze_support
-        freeze_support()
-        cmd.execute_from_commandline(argv)
-    except KeyboardInterrupt:
-        pass
+    def bootsteps(self, *args, **kwargs):
+        worker = self.app.WorkController()
+        include = set(arg.lower() for arg in args or ['worker', 'consumer'])
+        if 'worker' in include:
+            graph = worker.namespace.graph
+            if 'consumer' in include:
+                worker.namespace.connect_with(worker.consumer.namespace)
+        else:
+            graph = worker.consumer.namespace.graph
+        graph.to_dot(self.stdout)
+
+    def workers(self, *args, **kwargs):
+
+        def simplearg(arg):
+            return maybe_list(itemgetter(0, 2)(arg.partition(':')))
+
+        def maybe_list(l, sep=','):
+            return (l[0], l[1].split(sep) if sep in l[1] else l[1])
+
+        args = dict(map(simplearg, args))
+        generic = 'generic' in args
+
+        def generic_label(node):
+            return '{0} ({1}://)'.format(type(node).__name__,
+                                         node._label.split('://')[0])
+
+        class Node(object):
+            force_label = None
+            scheme = {}
+
+            def __init__(self, label, pos=None):
+                self._label = label
+                self.pos = pos
+
+            def label(self):
+                return self._label
+
+            def __str__(self):
+                return self.label()
+
+        class Thread(Node):
+            scheme = {'fillcolor': 'lightcyan4', 'fontcolor': 'yellow',
+                      'shape': 'oval', 'fontsize': 10, 'width': 0.3,
+                      'color': 'black'}
+
+            def __init__(self, label, **kwargs):
+                self._label = 'thr-{0}'.format(next(tids))
+                self.real_label = label
+                self.pos = 0
+
+        class Formatter(GraphFormatter):
+
+            def label(self, obj):
+                return obj and obj.label()
+
+            def node(self, obj):
+                scheme = dict(obj.scheme) if obj.pos else obj.scheme
+                if isinstance(obj, Thread):
+                    scheme['label'] = obj.real_label
+                return self.draw_node(
+                    obj, dict(self.node_scheme, **scheme),
+                )
+
+            def terminal_node(self, obj):
+                return self.draw_node(
+                    obj, dict(self.term_scheme, **obj.scheme),
+                )
+
+            def edge(self, a, b, **attrs):
+                if isinstance(a, Thread):
+                    attrs.update(arrowhead='none', arrowtail='tee')
+                return self.draw_edge(a, b, self.edge_scheme, attrs)
+
+        def subscript(n):
+            S = {'0': '₀', '1': '₁', '2': '₂', '3': '₃', '4': '₄',
+                 '5': '₅', '6': '₆', '7': '₇', '8': '₈', '9': '₉'}
+            return ''.join([S[i] for i in str(n)])
+
+        class Worker(Node):
+            pass
+
+        class Backend(Node):
+            scheme = {'shape': 'folder', 'width': 2,
+                      'height': 1, 'color': 'black',
+                      'fillcolor': 'peachpuff3', 'color': 'peachpuff4'}
+
+            def label(self):
+                return generic_label(self) if generic else self._label
+
+        class Broker(Node):
+            scheme = {'shape': 'circle', 'fillcolor': 'cadetblue3',
+                      'color': 'cadetblue4', 'height': 1}
+
+            def label(self):
+                return generic_label(self) if generic else self._label
+
+        from itertools import count
+        tids = count(1)
+        Wmax = int(args.get('wmax', 4) or 0)
+        Tmax = int(args.get('tmax', 3) or 0)
+
+        def maybe_abbr(l, name, max=Wmax):
+            size = len(l)
+            abbr = max and size > max
+            if 'enumerate' in args:
+                l = ['{0}{1}'.format(name, subscript(i + 1))
+                        for i, obj in enumerate(l)]
+            if abbr:
+                l = l[0:max - 1] + [l[size - 1]]
+                l[max - 2] = '{0}⎨…{1}⎬'.format(
+                    name[0], subscript(size - (max - 1)))
+            return l
+
+        try:
+            workers = args['nodes']
+            threads = args.get('threads') or []
+        except KeyError:
+            replies = self.app.control.inspect().stats()
+            workers, threads = [], []
+            for worker, reply in replies.iteritems():
+                workers.append(worker)
+                threads.append(reply['pool']['max-concurrency'])
+
+        wlen = len(workers)
+        backend = args.get('backend', self.app.conf.CELERY_RESULT_BACKEND)
+        threads_for = {}
+        workers = maybe_abbr(workers, 'Worker')
+        if Wmax and wlen > Wmax:
+            threads = threads[0:3] + [threads[-1]]
+        for i, threads in enumerate(threads):
+            threads_for[workers[i]] = maybe_abbr(
+                range(int(threads)), 'P', Tmax,
+            )
+
+        broker = Broker(args.get('broker', self.app.connection().as_uri()))
+        backend = Backend(backend) if backend else None
+        graph = DependencyGraph(formatter=Formatter())
+        graph.add_arc(broker)
+        if backend:
+            graph.add_arc(backend)
+        curworker = [0]
+        for i, worker in enumerate(workers):
+            worker = Worker(worker, pos=i)
+            graph.add_arc(worker)
+            graph.add_edge(worker, broker)
+            if backend:
+                graph.add_edge(worker, backend)
+            threads = threads_for.get(worker._label)
+            if threads:
+                for thread in threads:
+                    thread = Thread(thread)
+                    graph.add_arc(thread)
+                    graph.add_edge(thread, worker)
+
+            curworker[0] += 1
+
+        graph.to_dot(self.stdout)
 
 
 if __name__ == '__main__':          # pragma: no cover

+ 1 - 4
celery/bootsteps.py

@@ -57,7 +57,7 @@ class StepFormatter(GraphFormatter):
     }
 
     def label(self, step):
-        return '{0}{1}'.format(self._get_prefix(step),
+        return step and '{0}{1}'.format(self._get_prefix(step),
             (step.label or _label(step)).encode('utf-8', 'ignore'),
         )
 
@@ -210,9 +210,6 @@ class Namespace(object):
                 if node.name not in self.steps:
                     steps[node.name] = node
                 stream.append(node.requires)
-        # Make sure we have all the steps
-        assert [steps[req.name] for step in steps.values()
-                    for req in step.requires]
 
     def _finalize_steps(self, steps):
         last = self._find_last()

+ 1 - 1
celery/fixups/django.py

@@ -7,7 +7,6 @@ import warnings
 from datetime import datetime
 
 from celery import signals
-from celery.utils.imports import import_from_cwd
 
 SETTINGS_MODULE = os.environ.get('DJANGO_SETTINGS_MODULE')
 
@@ -89,6 +88,7 @@ class DjangoFixup(object):
         )
 
     def install(self):
+        # Need to add project directory to path
         sys.path.append(os.getcwd())
         signals.beat_embedded_init.connect(self.close_database)
         signals.worker_ready.connect(self.on_worker_ready)

+ 3 - 2
celery/tests/bin/test_celery.py

@@ -45,9 +45,10 @@ class test_Command(AppCase):
         self.err = WhateverIO()
         self.cmd = Command(self.app, stdout=self.out, stderr=self.err)
 
-    def test_show_help(self):
+    def test_exit_help(self):
         self.cmd.run_from_argv = Mock()
-        self.assertEqual(self.cmd.show_help('foo'), EX_USAGE)
+        with self.assertRaises(SystemExit):
+            self.cmd.exit_help('foo')
         self.cmd.run_from_argv.assert_called_with(
                 self.cmd.prog_name, ['foo', '--help']
         )

+ 9 - 9
celery/tests/worker/test_worker.py

@@ -187,7 +187,7 @@ class test_QoS(Case):
 
     def test_consumer_increment_decrement(self):
         mconsumer = Mock()
-        qos = QoS(mconsumer, 10)
+        qos = QoS(mconsumer.qos, 10)
         qos.update()
         self.assertEqual(qos.value, 10)
         mconsumer.qos.assert_called_with(prefetch_count=10)
@@ -209,7 +209,7 @@ class test_QoS(Case):
 
     def test_consumer_decrement_eventually(self):
         mconsumer = Mock()
-        qos = QoS(mconsumer, 10)
+        qos = QoS(mconsumer.qos, 10)
         qos.decrement_eventually()
         self.assertEqual(qos.value, 9)
         qos.value = 0
@@ -218,7 +218,7 @@ class test_QoS(Case):
 
     def test_set(self):
         mconsumer = Mock()
-        qos = QoS(mconsumer, 10)
+        qos = QoS(mconsumer.qos, 10)
         qos.set(12)
         self.assertEqual(qos.prev, 12)
         qos.set(qos.prev)
@@ -235,7 +235,8 @@ class test_Consumer(Case):
 
     def test_info(self):
         l = MyKombuConsumer(self.ready_queue, timer=self.timer)
-        l.qos = QoS(l.task_consumer, 10)
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer.qos, 10)
         info = l.info
         self.assertEqual(info['prefetch_count'], 10)
         self.assertFalse(info['broker'])
@@ -432,7 +433,7 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.task_consumer = Mock()
         l.connection.obj = l
-        l.qos = QoS(l.task_consumer, 10)
+        l.qos = QoS(l.task_consumer.qos, 10)
         l.loop(*l.loop_args())
 
     def test_loop_when_socket_error(self):
@@ -449,7 +450,7 @@ class test_Consumer(Case):
         c = l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10)
+        l.qos = QoS(l.task_consumer.qos, 10)
         with self.assertRaises(socket.error):
             l.loop(*l.loop_args())
 
@@ -469,13 +470,12 @@ class test_Consumer(Case):
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 10)
+        l.qos = QoS(l.task_consumer.qos, 10)
 
         l.loop(*l.loop_args())
         l.loop(*l.loop_args())
         self.assertTrue(l.task_consumer.consume.call_count)
         l.task_consumer.qos.assert_called_with(prefetch_count=10)
-        l.task_consumer.qos = Mock()
         self.assertEqual(l.qos.value, 10)
         l.qos.decrement_eventually()
         self.assertEqual(l.qos.value, 9)
@@ -513,7 +513,7 @@ class test_Consumer(Case):
                            args=[2, 4, 8], kwargs={})
 
         l.task_consumer = Mock()
-        l.qos = QoS(l.task_consumer, 1)
+        l.qos = QoS(l.task_consumer.qos, 1)
         current_pcount = l.qos.value
         l.event_dispatcher = Mock()
         l.enabled = False