graph.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # -*- coding: utf-8 -*-
  2. """Dependency graph implementation."""
  3. from __future__ import absolute_import, print_function, unicode_literals
  4. from collections import Counter
  5. from textwrap import dedent
  6. from kombu.utils.encoding import safe_str, bytes_to_str
  7. from celery.five import items, python_2_unicode_compatible
  8. __all__ = ['DOT', 'CycleError', 'DependencyGraph', 'GraphFormatter']
  9. class DOT:
  10. """Constants related to the dot format."""
  11. HEAD = dedent("""
  12. {IN}{type} {id} {{
  13. {INp}graph [{attrs}]
  14. """)
  15. ATTR = '{name}={value}'
  16. NODE = '{INp}"{0}" [{attrs}]'
  17. EDGE = '{INp}"{0}" {dir} "{1}" [{attrs}]'
  18. ATTRSEP = ', '
  19. DIRS = {'graph': '--', 'digraph': '->'}
  20. TAIL = '{IN}}}'
  21. class CycleError(Exception):
  22. """A cycle was detected in an acyclic graph."""
  23. @python_2_unicode_compatible
  24. class DependencyGraph(object):
  25. """A directed acyclic graph of objects and their dependencies.
  26. Supports a robust topological sort
  27. to detect the order in which they must be handled.
  28. Takes an optional iterator of ``(obj, dependencies)``
  29. tuples to build the graph from.
  30. Warning:
  31. Does not support cycle detection.
  32. """
  33. def __init__(self, it=None, formatter=None):
  34. self.formatter = formatter or GraphFormatter()
  35. self.adjacent = {}
  36. if it is not None:
  37. self.update(it)
  38. def add_arc(self, obj):
  39. """Add an object to the graph."""
  40. self.adjacent.setdefault(obj, [])
  41. def add_edge(self, A, B):
  42. """Add an edge from object ``A`` to object ``B``.
  43. I.e. ``A`` depends on ``B``.
  44. """
  45. self[A].append(B)
  46. def connect(self, graph):
  47. """Add nodes from another graph."""
  48. self.adjacent.update(graph.adjacent)
  49. def topsort(self):
  50. """Sort the graph topologically.
  51. Returns:
  52. List: of objects in the order in which they must be handled.
  53. """
  54. graph = DependencyGraph()
  55. components = self._tarjan72()
  56. NC = {
  57. node: component for component in components for node in component
  58. }
  59. for component in components:
  60. graph.add_arc(component)
  61. for node in self:
  62. node_c = NC[node]
  63. for successor in self[node]:
  64. successor_c = NC[successor]
  65. if node_c != successor_c:
  66. graph.add_edge(node_c, successor_c)
  67. return [t[0] for t in graph._khan62()]
  68. def valency_of(self, obj):
  69. """Return the valency (degree) of a vertex in the graph."""
  70. try:
  71. l = [len(self[obj])]
  72. except KeyError:
  73. return 0
  74. for node in self[obj]:
  75. l.append(self.valency_of(node))
  76. return sum(l)
  77. def update(self, it):
  78. """Update graph with data from a list of ``(obj, deps)`` tuples."""
  79. tups = list(it)
  80. for obj, _ in tups:
  81. self.add_arc(obj)
  82. for obj, deps in tups:
  83. for dep in deps:
  84. self.add_edge(obj, dep)
  85. def edges(self):
  86. """Return generator that yields for all edges in the graph."""
  87. return (obj for obj, adj in items(self) if adj)
  88. def _khan62(self):
  89. """Perform Khan's simple topological sort algorithm from '62.
  90. See https://en.wikipedia.org/wiki/Topological_sorting
  91. """
  92. count = Counter()
  93. result = []
  94. for node in self:
  95. for successor in self[node]:
  96. count[successor] += 1
  97. ready = [node for node in self if not count[node]]
  98. while ready:
  99. node = ready.pop()
  100. result.append(node)
  101. for successor in self[node]:
  102. count[successor] -= 1
  103. if count[successor] == 0:
  104. ready.append(successor)
  105. result.reverse()
  106. return result
  107. def _tarjan72(self):
  108. """Perform Tarjan's algorithm to find strongly connected components.
  109. See Also:
  110. :wikipedia:`Tarjan%27s_strongly_connected_components_algorithm`
  111. """
  112. result, stack, low = [], [], {}
  113. def visit(node):
  114. if node in low:
  115. return
  116. num = len(low)
  117. low[node] = num
  118. stack_pos = len(stack)
  119. stack.append(node)
  120. for successor in self[node]:
  121. visit(successor)
  122. low[node] = min(low[node], low[successor])
  123. if num == low[node]:
  124. component = tuple(stack[stack_pos:])
  125. stack[stack_pos:] = []
  126. result.append(component)
  127. for item in component:
  128. low[item] = len(self)
  129. for node in self:
  130. visit(node)
  131. return result
  132. def to_dot(self, fh, formatter=None):
  133. """Convert the graph to DOT format.
  134. Arguments:
  135. fh (IO): A file, or a file-like object to write the graph to.
  136. formatter (celery.utils.graph.GraphFormatter): Custom graph
  137. formatter to use.
  138. """
  139. seen = set()
  140. draw = formatter or self.formatter
  141. def P(s):
  142. print(bytes_to_str(s), file=fh)
  143. def if_not_seen(fun, obj):
  144. if draw.label(obj) not in seen:
  145. P(fun(obj))
  146. seen.add(draw.label(obj))
  147. P(draw.head())
  148. for obj, adjacent in items(self):
  149. if not adjacent:
  150. if_not_seen(draw.terminal_node, obj)
  151. for req in adjacent:
  152. if_not_seen(draw.node, obj)
  153. P(draw.edge(obj, req))
  154. P(draw.tail())
  155. def format(self, obj):
  156. return self.formatter(obj) if self.formatter else obj
  157. def __iter__(self):
  158. return iter(self.adjacent)
  159. def __getitem__(self, node):
  160. return self.adjacent[node]
  161. def __len__(self):
  162. return len(self.adjacent)
  163. def __contains__(self, obj):
  164. return obj in self.adjacent
  165. def _iterate_items(self):
  166. return items(self.adjacent)
  167. items = iteritems = _iterate_items
  168. def __repr__(self):
  169. return '\n'.join(self.repr_node(N) for N in self)
  170. def repr_node(self, obj, level=1, fmt='{0}({1})'):
  171. output = [fmt.format(obj, self.valency_of(obj))]
  172. if obj in self:
  173. for other in self[obj]:
  174. d = fmt.format(other, self.valency_of(other))
  175. output.append(' ' * level + d)
  176. output.extend(self.repr_node(other, level + 1).split('\n')[1:])
  177. return '\n'.join(output)
  178. class GraphFormatter(object):
  179. """Format dependency graphs."""
  180. _attr = DOT.ATTR.strip()
  181. _node = DOT.NODE.strip()
  182. _edge = DOT.EDGE.strip()
  183. _head = DOT.HEAD.strip()
  184. _tail = DOT.TAIL.strip()
  185. _attrsep = DOT.ATTRSEP
  186. _dirs = dict(DOT.DIRS)
  187. scheme = {
  188. 'shape': 'box',
  189. 'arrowhead': 'vee',
  190. 'style': 'filled',
  191. 'fontname': 'HelveticaNeue',
  192. }
  193. edge_scheme = {
  194. 'color': 'darkseagreen4',
  195. 'arrowcolor': 'black',
  196. 'arrowsize': 0.7,
  197. }
  198. node_scheme = {'fillcolor': 'palegreen3', 'color': 'palegreen4'}
  199. term_scheme = {'fillcolor': 'palegreen1', 'color': 'palegreen2'}
  200. graph_scheme = {'bgcolor': 'mintcream'}
  201. def __init__(self, root=None, type=None, id=None,
  202. indent=0, inw=' ' * 4, **scheme):
  203. self.id = id or 'dependencies'
  204. self.root = root
  205. self.type = type or 'digraph'
  206. self.direction = self._dirs[self.type]
  207. self.IN = inw * (indent or 0)
  208. self.INp = self.IN + inw
  209. self.scheme = dict(self.scheme, **scheme)
  210. self.graph_scheme = dict(self.graph_scheme, root=self.label(self.root))
  211. def attr(self, name, value):
  212. value = '"{0}"'.format(value)
  213. return self.FMT(self._attr, name=name, value=value)
  214. def attrs(self, d, scheme=None):
  215. d = dict(self.scheme, **dict(scheme, **d or {}) if scheme else d)
  216. return self._attrsep.join(
  217. safe_str(self.attr(k, v)) for k, v in items(d)
  218. )
  219. def head(self, **attrs):
  220. return self.FMT(
  221. self._head, id=self.id, type=self.type,
  222. attrs=self.attrs(attrs, self.graph_scheme),
  223. )
  224. def tail(self):
  225. return self.FMT(self._tail)
  226. def label(self, obj):
  227. return obj
  228. def node(self, obj, **attrs):
  229. return self.draw_node(obj, self.node_scheme, attrs)
  230. def terminal_node(self, obj, **attrs):
  231. return self.draw_node(obj, self.term_scheme, attrs)
  232. def edge(self, a, b, **attrs):
  233. return self.draw_edge(a, b, **attrs)
  234. def _enc(self, s):
  235. return s.encode('utf-8', 'ignore')
  236. def FMT(self, fmt, *args, **kwargs):
  237. return self._enc(fmt.format(
  238. *args, **dict(kwargs, IN=self.IN, INp=self.INp)
  239. ))
  240. def draw_edge(self, a, b, scheme=None, attrs=None):
  241. return self.FMT(
  242. self._edge, self.label(a), self.label(b),
  243. dir=self.direction, attrs=self.attrs(attrs, self.edge_scheme),
  244. )
  245. def draw_node(self, obj, scheme=None, attrs=None):
  246. return self.FMT(
  247. self._node, self.label(obj), attrs=self.attrs(attrs, scheme),
  248. )