Browse Source

Better graph output

Ask Solem 12 years ago
parent
commit
3b32986a69

+ 4 - 11
celery/bin/celery.py

@@ -18,6 +18,8 @@ from importlib import import_module
 from itertools import imap
 from pprint import pformat
 
+from kombu.utils.encoding import safe_str
+
 from celery.platforms import EX_OK, EX_FAILURE, EX_UNAVAILABLE, EX_USAGE
 from celery.utils import term
 from celery.utils import text
@@ -27,11 +29,6 @@ from celery.utils.timeutils import maybe_iso8601
 
 from celery.bin.base import Command as BaseCommand, Option
 
-try:
-    # print_statement does not work with io.StringIO
-    from io import BytesIO as PrintIO
-except ImportError:
-    from StringIO import StringIO as PrintIO  # noqa
 
 HELP = """
 ---- -- - - ---- Commands- -------------- --- ------------
@@ -475,9 +472,7 @@ class worker_graph(Command):
 
     def run(self, **kwargs):
         worker = self.app.WorkController()
-        out = PrintIO()
-        worker.namespace.graph.to_dot(out)
-        self.out(out.getvalue())
+        worker.namespace.graph.to_dot(self.stdout)
 
 
 @command
@@ -485,9 +480,7 @@ class consumer_graph(Command):
 
     def run(self, **kwargs):
         worker = self.app.WorkController()
-        out = PrintIO()
-        worker.consumer.namespace.graph.to_dot(out)
-        self.out(out.getvalue())
+        worker.consumer.namespace.graph.to_dot(self.stdout)
 
 
 class _RemoteControl(Command):

+ 62 - 20
celery/bootsteps.py

@@ -6,7 +6,7 @@
     The bootsteps!
 
 """
-from __future__ import absolute_import
+from __future__ import absolute_import, unicode_literals
 
 from collections import deque
 from importlib import import_module
@@ -15,7 +15,7 @@ from threading import Event
 from kombu.common import ignore_errors
 from kombu.utils import symbol_by_name
 
-from .datastructures import DependencyGraph
+from .datastructures import DependencyGraph, GraphFormatter
 from .utils.imports import instantiate, qualname
 from .utils.log import get_logger
 from .utils.threads import default_socket_timeout
@@ -42,10 +42,40 @@ def _pre(ns, fmt):
     return '| {0}: {1}'.format(ns.alias, fmt)
 
 
-def _maybe_name(s):
-    if not isinstance(s, basestring):
-        return s.name
-    return s
+def _label(s):
+    return s.name.rsplit('.', 1)[-1]
+
+
+class StepFormatter(GraphFormatter):
+
+    namespace_prefix = '⧉'
+    conditional_prefix = '∘'
+    namespace_scheme = {
+        'shape': 'parallelogram',
+        'color': 'slategray4',
+        'fillcolor': 'slategray3',
+    }
+
+    def label(self, step):
+        return '{0}{1}'.format(self._get_prefix(step),
+            (step.label or _label(step)).encode('utf-8', 'ignore'),
+        )
+
+    def _get_prefix(self, step):
+        if step.last:
+            return self.namespace_prefix
+        if step.conditional:
+            return self.conditional_prefix
+        return ''
+
+    def node(self, obj, **attrs):
+        scheme = self.namespace_scheme if obj.last else self.node_scheme
+        return self.draw_node(obj, scheme, attrs)
+
+    def edge(self, a, b, **attrs):
+        if a.last:
+            attrs.update(arrowhead='none', color='darkseagreen3')
+        return self.draw_edge(a, b, self.edge_scheme, attrs)
 
 
 class Namespace(object):
@@ -59,6 +89,8 @@ class Namespace(object):
     :keyword on_stopped: Optional callback applied after namespace stopped.
 
     """
+    GraphFormatter = StepFormatter
+
     name = None
     state = None
     started = 0
@@ -145,8 +177,9 @@ class Namespace(object):
         steps = self.steps = self.claim_steps()
 
         self._debug('Building graph...')
-        for name in self._finalize_steps(steps):
-            step = steps[name] = steps[name](parent, **kwargs)
+        for S in self._finalize_steps(steps):
+            step = S(parent, **kwargs)
+            steps[step.name] = step
             order.append(step)
         self._debug('New boot order: {%s}',
                     ', '.join(s.alias for s in self.order))
@@ -173,20 +206,21 @@ class Namespace(object):
                 if node.name not in self.steps:
                     steps[node.name] = node
                 stream.append(node.requires)
-        for node in steps.itervalues():
-            node.requires = [_maybe_name(n) for n in node.requires]
-        for step in steps.values():
-            [steps[n] for n in step.requires]
+        # Make sure we have all the steps
+        assert [steps[req.name] for step in steps.values()
+                    for req in step.requires]
 
     def _finalize_steps(self, steps):
-        self._firstpass(steps)
-        G = self.graph = DependencyGraph((C.name, C.requires)
-                            for C in steps.itervalues())
         last = self._find_last()
+        self._firstpass(steps)
+        it = ((C, C.requires) for C in steps.itervalues())
+        G = self.graph = DependencyGraph(it,
+            formatter=self.GraphFormatter(root=last),
+        )
         if last:
             for obj in G:
-                if obj != last.name:
-                    G.add_edge(last.name, obj)
+                if obj != last:
+                    G.add_edge(last, obj)
         try:
             return G.topsort()
         except KeyError as exc:
@@ -196,7 +230,6 @@ class Namespace(object):
         return dict(self.load_step(step) for step in self._all_steps())
 
     def _all_steps(self):
-        print('My NAME IS: %r' % (self.name, ))
         return self.types | self.app.steps[self.name.lower()]
 
     def load_step(self, step):
@@ -208,7 +241,7 @@ class Namespace(object):
 
     @property
     def alias(self):
-        return self.name.rsplit('.', 1)[-1]
+        return _label(self)
 
 
 class StepType(type):
@@ -224,6 +257,9 @@ class StepType(type):
         )
         return super(StepType, cls).__new__(cls, name, bases, attrs)
 
+    def __str__(self):
+        return self.name
+
     def __repr__(self):
         return 'step:{0.name}{{{0.requires!r}}}'.format(self)
 
@@ -242,6 +278,12 @@ class Step(object):
     #: Optional step name, will use qualname if not specified.
     name = None
 
+    #: Optional short name used for graph outputs and in logs.
+    label = None
+
+    #: Set this to true if the step is enabled based on some condition.
+    conditional = False
+
     #: List of other steps that that must be started before this step.
     #: Note that all dependencies must be in the same namespace.
     requires = ()
@@ -283,7 +325,7 @@ class Step(object):
 
     @property
     def alias(self):
-        return self.name.rsplit('.', 1)[-1]
+        return self.label or _label(self)
 
 
 class StartStopStep(Step):

+ 121 - 7
celery/datastructures.py

@@ -6,7 +6,7 @@
     Custom types and data structures.
 
 """
-from __future__ import absolute_import, print_function
+from __future__ import absolute_import, print_function, unicode_literals
 
 import sys
 import time
@@ -20,6 +20,106 @@ from kombu.utils.limits import TokenBucket  # noqa
 
 from .utils.functional import LRUCache, first, uniq  # noqa
 
+DOT_HEAD = """
+{IN}{type} {id} {{
+{INp}graph [{attrs}]
+"""
+DOT_ATTR = '{name}={value}'
+DOT_NODE = '{INp}"{0}" [{attrs}]'
+DOT_EDGE = '{INp}"{0}" {dir} "{1}" [{attrs}]'
+DOT_ATTRSEP = ', '
+DOT_DIRS = {'graph': '--', 'digraph': '->'}
+DOT_TAIL = '{IN}}}'
+
+
+class GraphFormatter(object):
+    _attr = DOT_ATTR.strip()
+    _node = DOT_NODE.strip()
+    _edge = DOT_EDGE.strip()
+    _head = DOT_HEAD.strip()
+    _tail = DOT_TAIL.strip()
+    _attrsep = DOT_ATTRSEP
+    _dirs = dict(DOT_DIRS)
+
+    scheme = {
+        'shape': 'box',
+        'arrowhead': 'vee',
+        'style': 'filled',
+        'fontname': 'Helvetica Neue',
+    }
+    node_scheme = {
+        'fillcolor': 'palegreen3',
+        'color': 'palegreen4',
+    }
+    term_scheme = {
+        'fillcolor': 'palegreen1',
+        'color': 'palegreen2',
+    }
+    edge_scheme = {
+        'color': 'darkseagreen4',
+        'arrowcolor': 'black',
+        'arrowsize': 0.7,
+    }
+    graph_scheme = {'bgcolor': 'mintcream'}
+
+    def __init__(self, root=None, type=None, id=None, indent=0, inw=' ' * 4):
+        self.id = id or 'dependencies'
+        self.root = root
+        self.type = type or 'digraph'
+        self.direction = self._dirs[self.type]
+        self.IN = inw * (indent or 0)
+        self.INp = self.IN + inw
+        #self.graph_scheme = dict(self.graph_scheme, root=self.label(self.root))
+
+    def attr(self, name, value):
+        value = '"{0}"'.format(str(value))
+        return self.FMT(self._attr, name=name, value=value)
+
+    def attrs(self, d, scheme=None):
+        d = dict(self.scheme, **dict(scheme, **d or {}) if scheme else d)
+        return self._attrsep.join(self.attr(k, v) for k, v in d.iteritems())
+
+    def head(self, **attrs):
+        return self.FMT(self._head, id=self.id, type=self.type,
+            attrs=self.attrs(attrs, self.graph_scheme),
+        )
+
+    def tail(self):
+        return self.FMT(self._tail)
+
+    def label(self, obj):
+        return obj
+
+    def node(self, obj, **attrs):
+        return self.draw_node(obj, self.node_scheme, attrs)
+
+    def terminal_node(self, obj, **attrs):
+        return self.draw_node(obj, self.term_scheme, attrs)
+
+    def edge(self, a, b, **attrs):
+        return self.draw_edge(a, b, **attrs)
+
+    def _enc(self, s):
+        return s.encode('utf-8', 'ignore')
+
+    def FMT(self, fmt, *args, **kwargs):
+        return self._enc(fmt.format(
+            *args, **dict(kwargs, IN=self.IN, INp=self.INp)
+        ))
+
+    def draw_edge(self, a, b, scheme=None, attrs=None):
+        return self.FMT(self._edge, self.label(a), self.label(b),
+            dir=self.direction, attrs=self.attrs(attrs, self.edge_scheme),
+        )
+
+    def draw_node(self, obj, scheme=None, attrs=None):
+        return self.FMT(self._node, self.label(obj),
+            attrs=self.attrs(attrs, scheme),
+        )
+
+
+
+
 
 class CycleError(Exception):
     """A cycle was detected in an acyclic graph."""
@@ -40,7 +140,8 @@ class DependencyGraph(object):
 
     """
 
-    def __init__(self, it=None):
+    def __init__(self, it=None, formatter=None):
+        self.formatter = formatter or GraphFormatter()
         self.adjacent = {}
         if it is not None:
             self.update(it)
@@ -158,20 +259,33 @@ class DependencyGraph(object):
 
         return result
 
-    def to_dot(self, fh, ws=' ' * 4):
+    def to_dot(self, fh, formatter=None):
         """Convert the graph to DOT format.
 
         :param fh: A file, or a file-like object to write the graph to.
 
         """
+        seen = set()
+        draw = formatter or self.formatter
         P = partial(print, file=fh)
-        P('digraph dependencies {')
+
+        def if_not_seen(fun, obj):
+            if draw.label(obj) not in seen:
+                P(fun(obj))
+                seen.add(draw.label(obj))
+
+        P(draw.head())
         for obj, adjacent in self.iteritems():
+            objl = draw.label(obj)
             if not adjacent:
-                P(ws + '"{0}"'.format(obj))
+                if_not_seen(draw.terminal_node, obj)
             for req in adjacent:
-                P(ws + '"{0}" -> "{1}"'.format(obj, req))
-        P('}')
+                if_not_seen(draw.node, obj)
+                P(draw.edge(obj, req))
+        P(draw.tail())
+
+    def format(self, obj):
+        return self.formatter(obj) if self.formatter else obj
 
     def __iter__(self):
         return iter(self.adjacent)

+ 8 - 8
celery/tests/worker/test_worker.py

@@ -25,7 +25,7 @@ from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
-from celery.worker.components import Queues, Timers, Hub, Pool
+from celery.worker import components
 from celery.worker.buckets import FastQueue
 from celery.worker.job import Request
 from celery.worker import consumer
@@ -922,14 +922,14 @@ class test_WorkController(AppCase):
         try:
             raise KeyError('foo')
         except KeyError as exc:
-            Timers(worker).on_timer_error(exc)
+            components.Timer(worker).on_timer_error(exc)
             msg, args = self.comp_logger.error.call_args[0]
             self.assertIn('KeyError', msg % args)
 
     def test_on_timer_tick(self):
         worker = WorkController(concurrency=1, loglevel=10)
 
-        Timers(worker).on_timer_tick(30.0)
+        components.Timer(worker).on_timer_tick(30.0)
         xargs = self.comp_logger.debug.call_args[0]
         fmt, arg = xargs[0], xargs[1]
         self.assertEqual(30.0, arg)
@@ -1105,18 +1105,18 @@ class test_WorkController(AppCase):
     def test_Queues_pool_not_rlimit_safe(self):
         w = Mock()
         w.pool_cls.rlimit_safe = False
-        Queues(w).create(w)
+        components.Queues(w).create(w)
         self.assertTrue(w.disable_rate_limits)
 
     def test_Queues_pool_no_sem(self):
         w = Mock()
         w.pool_cls.uses_semaphore = False
-        Queues(w).create(w)
+        components.Queues(w).create(w)
         self.assertIs(w.ready_queue.put, w.process_task)
 
     def test_Hub_crate(self):
         w = Mock()
-        x = Hub(w)
+        x = components.Hub(w)
         hub = x.create(w)
         self.assertTrue(w.timer.max_interval)
         self.assertIs(w.hub, hub)
@@ -1125,7 +1125,7 @@ class test_WorkController(AppCase):
         w = Mock()
         w.pool_cls = Mock()
         w.use_eventloop = False
-        pool = Pool(w)
+        pool = components.Pool(w)
         pool.create(w)
 
     def test_Pool_create(self):
@@ -1137,7 +1137,7 @@ class test_WorkController(AppCase):
         P = w.pool_cls.return_value = Mock()
         P.timers = {Mock(): 30}
         w.use_eventloop = True
-        pool = Pool(w)
+        pool = components.Pool(w)
         pool.create(w)
         self.assertIsInstance(w.semaphore, BoundedSemaphore)
         self.assertTrue(w.hub.on_init)

+ 1 - 1
celery/worker/__init__.py

@@ -86,7 +86,7 @@ class WorkController(configurated):
             'celery.worker.components:Queues',
             'celery.worker.components:Pool',
             'celery.worker.components:Beat',
-            'celery.worker.components:Timers',
+            'celery.worker.components:Timer',
             'celery.worker.components:StateDB',
             'celery.worker.components:Consumer',
             'celery.worker.autoscale:WorkerComponent',

+ 2 - 0
celery/worker/autoreload.py

@@ -37,6 +37,8 @@ logger = get_logger(__name__)
 
 
 class WorkerComponent(bootsteps.StartStopStep):
+    label = 'Autoreloader'
+    conditional = True
     requires = (Pool, )
 
     def __init__(self, w, autoreload=None, **kwargs):

+ 2 - 0
celery/worker/autoscale.py

@@ -31,6 +31,8 @@ debug, info, error = logger.debug, logger.info, logger.error
 
 
 class WorkerComponent(bootsteps.StartStopStep):
+    label = 'Autoscaler'
+    conditional = True
     requires = (Pool, )
 
     def __init__(self, w, **kwargs):

+ 5 - 2
celery/worker/components.py

@@ -40,6 +40,7 @@ class Hub(bootsteps.StartStopStep):
 class Queues(bootsteps.Step):
     """This bootstep initializes the internal queues
     used by the worker."""
+    label = 'Queues (intra)'
     requires = (Hub, )
 
     def __init__(self, w, **kwargs):
@@ -185,6 +186,8 @@ class Beat(bootsteps.StartStopStep):
     argument is set.
 
     """
+    label = 'Beat'
+    conditional = True
 
     def __init__(self, w, beat=False, **kwargs):
         self.enabled = w.beat = beat
@@ -198,8 +201,8 @@ class Beat(bootsteps.StartStopStep):
         return b
 
 
-class Timers(bootsteps.Step):
-    """This step initializes the internal timers used by the worker."""
+class Timer(bootsteps.Step):
+    """This step initializes the internal timer used by the worker."""
     requires = (Pool, )
 
     def include_if(self, w):

+ 2 - 1
celery/worker/consumer.py

@@ -441,7 +441,7 @@ class Tasks(bootsteps.StartStopStep):
         c.task_consumer = c.app.amqp.TaskConsumer(
             c.connection, on_decode_error=c.on_decode_error,
         )
-        c.qos = QoS(c.task_consumer, self.initial_prefetch_count)
+        c.qos = QoS(c.task_consumer.qos, self.initial_prefetch_count)
         c.qos.update()  # set initial prefetch count
 
     def stop(self, c):
@@ -469,6 +469,7 @@ class Agent(bootsteps.StartStopStep):
 
 
 class Evloop(bootsteps.StartStopStep):
+    label = 'event loop'
     last = True
 
     def start(self, c):

+ 2 - 0
celery/worker/mediator.py

@@ -30,6 +30,8 @@ logger = get_logger(__name__)
 
 
 class WorkerComponent(StartStopStep):
+    label = 'Mediator'
+    conditional = True
     requires = (components.Pool, components.Queues, )
 
     def __init__(self, w, **kwargs):