datastructures.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.datastructures
  4. ~~~~~~~~~~~~~~~~~~~~~
  5. Custom types and data structures.
  6. """
  7. from __future__ import absolute_import
  8. from __future__ import with_statement
  9. import sys
  10. import time
  11. from collections import defaultdict
  12. from itertools import chain
  13. try:
  14. from collections import MutableMapping
  15. except ImportError: # pragma: no cover
  16. MutableMapping = None # noqa
  17. from billiard.einfo import ExceptionInfo # noqa
  18. from kombu.utils.limits import TokenBucket # noqa
  19. from .utils.functional import LRUCache, first, uniq # noqa
  20. class CycleError(Exception):
  21. """A cycle was detected in an acyclic graph."""
  22. class DependencyGraph(object):
  23. """A directed acyclic graph of objects and their dependencies.
  24. Supports a robust topological sort
  25. to detect the order in which they must be handled.
  26. Takes an optional iterator of ``(obj, dependencies)``
  27. tuples to build the graph from.
  28. .. warning::
  29. Does not support cycle detection.
  30. """
  31. def __init__(self, it=None):
  32. self.adjacent = {}
  33. if it is not None:
  34. self.update(it)
  35. def add_arc(self, obj):
  36. """Add an object to the graph."""
  37. self.adjacent.setdefault(obj, [])
  38. def add_edge(self, A, B):
  39. """Add an edge from object ``A`` to object ``B``
  40. (``A`` depends on ``B``)."""
  41. self[A].append(B)
  42. def topsort(self):
  43. """Sort the graph topologically.
  44. :returns: a list of objects in the order
  45. in which they must be handled.
  46. """
  47. graph = DependencyGraph()
  48. components = self._tarjan72()
  49. NC = dict((node, component)
  50. for component in components
  51. for node in component)
  52. for component in components:
  53. graph.add_arc(component)
  54. for node in self:
  55. node_c = NC[node]
  56. for successor in self[node]:
  57. successor_c = NC[successor]
  58. if node_c != successor_c:
  59. graph.add_edge(node_c, successor_c)
  60. return [t[0] for t in graph._khan62()]
  61. def valency_of(self, obj):
  62. """Returns the velency (degree) of a vertex in the graph."""
  63. try:
  64. l = [len(self[obj])]
  65. except KeyError:
  66. return 0
  67. for node in self[obj]:
  68. l.append(self.valency_of(node))
  69. return sum(l)
  70. def update(self, it):
  71. """Update the graph with data from a list
  72. of ``(obj, dependencies)`` tuples."""
  73. tups = list(it)
  74. for obj, _ in tups:
  75. self.add_arc(obj)
  76. for obj, deps in tups:
  77. for dep in deps:
  78. self.add_edge(obj, dep)
  79. def edges(self):
  80. """Returns generator that yields for all edges in the graph."""
  81. return (obj for obj, adj in self.iteritems() if adj)
  82. def _khan62(self):
  83. """Khans simple topological sort algorithm from '62
  84. See http://en.wikipedia.org/wiki/Topological_sorting
  85. """
  86. count = defaultdict(lambda: 0)
  87. result = []
  88. for node in self:
  89. for successor in self[node]:
  90. count[successor] += 1
  91. ready = [node for node in self if not count[node]]
  92. while ready:
  93. node = ready.pop()
  94. result.append(node)
  95. for successor in self[node]:
  96. count[successor] -= 1
  97. if count[successor] == 0:
  98. ready.append(successor)
  99. result.reverse()
  100. return result
  101. def _tarjan72(self):
  102. """Tarjan's algorithm to find strongly connected components.
  103. See http://bit.ly/vIMv3h.
  104. """
  105. result, stack, low = [], [], {}
  106. def visit(node):
  107. if node in low:
  108. return
  109. num = len(low)
  110. low[node] = num
  111. stack_pos = len(stack)
  112. stack.append(node)
  113. for successor in self[node]:
  114. visit(successor)
  115. low[node] = min(low[node], low[successor])
  116. if num == low[node]:
  117. component = tuple(stack[stack_pos:])
  118. stack[stack_pos:] = []
  119. result.append(component)
  120. for item in component:
  121. low[item] = len(self)
  122. for node in self:
  123. visit(node)
  124. return result
  125. def to_dot(self, fh, ws=' ' * 4):
  126. """Convert the graph to DOT format.
  127. :param fh: A file, or a file-like object to write the graph to.
  128. """
  129. fh.write('digraph dependencies {\n')
  130. for obj, adjacent in self.iteritems():
  131. if not adjacent:
  132. fh.write(ws + '"%s"\n' % (obj, ))
  133. for req in adjacent:
  134. fh.write(ws + '"%s" -> "%s"\n' % (obj, req))
  135. fh.write('}\n')
  136. def __iter__(self):
  137. return iter(self.adjacent)
  138. def __getitem__(self, node):
  139. return self.adjacent[node]
  140. def __len__(self):
  141. return len(self.adjacent)
  142. def __contains__(self, obj):
  143. return obj in self.adjacent
  144. def _iterate_items(self):
  145. return self.adjacent.iteritems()
  146. items = iteritems = _iterate_items
  147. def __repr__(self):
  148. return '\n'.join(self.repr_node(N) for N in self)
  149. def repr_node(self, obj, level=1):
  150. output = ['%s(%s)' % (obj, self.valency_of(obj))]
  151. if obj in self:
  152. for other in self[obj]:
  153. d = '%s(%s)' % (other, self.valency_of(other))
  154. output.append(' ' * level + d)
  155. output.extend(self.repr_node(other, level + 1).split('\n')[1:])
  156. return '\n'.join(output)
  157. class AttributeDictMixin(object):
  158. """Adds attribute access to mappings.
  159. `d.key -> d[key]`
  160. """
  161. def __getattr__(self, k):
  162. """`d.key -> d[key]`"""
  163. try:
  164. return self[k]
  165. except KeyError:
  166. raise AttributeError(
  167. "'%s' object has no attribute '%s'" % (type(self).__name__, k))
  168. def __setattr__(self, key, value):
  169. """`d[key] = value -> d.key = value`"""
  170. self[key] = value
  171. class AttributeDict(dict, AttributeDictMixin):
  172. """Dict subclass with attribute access."""
  173. pass
  174. class DictAttribute(object):
  175. """Dict interface to attributes.
  176. `obj[k] -> obj.k`
  177. """
  178. obj = None
  179. def __init__(self, obj):
  180. object.__setattr__(self, 'obj', obj)
  181. def __getattr__(self, key):
  182. return getattr(self.obj, key)
  183. def __setattr__(self, key, value):
  184. return setattr(self.obj, key, value)
  185. def get(self, key, default=None):
  186. try:
  187. return self[key]
  188. except KeyError:
  189. return default
  190. def setdefault(self, key, default):
  191. try:
  192. return self[key]
  193. except KeyError:
  194. self[key] = default
  195. return default
  196. def __getitem__(self, key):
  197. try:
  198. return getattr(self.obj, key)
  199. except AttributeError:
  200. raise KeyError(key)
  201. def __setitem__(self, key, value):
  202. setattr(self.obj, key, value)
  203. def __contains__(self, key):
  204. return hasattr(self.obj, key)
  205. def _iterate_keys(self):
  206. return iter(dir(self.obj))
  207. iterkeys = _iterate_keys
  208. def __iter__(self):
  209. return self._iterate_keys()
  210. def _iterate_items(self):
  211. for key in self._iterate_keys():
  212. yield key, getattr(self.obj, key)
  213. iteritems = _iterate_items
  214. if sys.version_info[0] == 3: # pragma: no cover
  215. items = _iterate_items
  216. keys = _iterate_keys
  217. else:
  218. def keys(self):
  219. return list(self)
  220. def items(self):
  221. return list(self._iterate_items())
  222. class ConfigurationView(AttributeDictMixin):
  223. """A view over an applications configuration dicts.
  224. If the key does not exist in ``changes``, the ``defaults`` dicts
  225. are consulted.
  226. :param changes: Dict containing changes to the configuration.
  227. :param defaults: List of dicts containing the default configuration.
  228. """
  229. changes = None
  230. defaults = None
  231. _order = None
  232. def __init__(self, changes, defaults):
  233. self.__dict__.update(changes=changes, defaults=defaults,
  234. _order=[changes] + defaults)
  235. def add_defaults(self, d):
  236. self.defaults.insert(0, d)
  237. self._order.insert(1, d)
  238. def __getitem__(self, key):
  239. for d in self._order:
  240. try:
  241. return d[key]
  242. except KeyError:
  243. pass
  244. raise KeyError(key)
  245. def __setitem__(self, key, value):
  246. self.changes[key] = value
  247. def first(self, *keys):
  248. return first(None, (self.get(key) for key in keys))
  249. def get(self, key, default=None):
  250. try:
  251. return self[key]
  252. except KeyError:
  253. return default
  254. def setdefault(self, key, default):
  255. try:
  256. return self[key]
  257. except KeyError:
  258. self[key] = default
  259. return default
  260. def update(self, *args, **kwargs):
  261. return self.changes.update(*args, **kwargs)
  262. def __contains__(self, key):
  263. for d in self._order:
  264. if key in d:
  265. return True
  266. return False
  267. def __repr__(self):
  268. return repr(dict(self.iteritems()))
  269. def __iter__(self):
  270. return self._iterate_keys()
  271. def __len__(self):
  272. # The logic for iterating keys includes uniq(),
  273. # so to be safe we count by explicitly iterating
  274. return len(self.keys())
  275. def _iter(self, op):
  276. # defaults must be first in the stream, so values in
  277. # changes takes precedence.
  278. return chain(*[op(d) for d in reversed(self._order)])
  279. def _iterate_keys(self):
  280. return uniq(self._iter(lambda d: d))
  281. iterkeys = _iterate_keys
  282. def _iterate_items(self):
  283. return ((key, self[key]) for key in self)
  284. iteritems = _iterate_items
  285. def _iterate_values(self):
  286. return (self[key] for key in self)
  287. itervalues = _iterate_values
  288. def keys(self):
  289. return list(self._iterate_keys())
  290. def items(self):
  291. return list(self._iterate_items())
  292. def values(self):
  293. return list(self._iterate_values())
  294. if MutableMapping:
  295. MutableMapping.register(ConfigurationView)
  296. class LimitedSet(object):
  297. """Kind-of Set with limitations.
  298. Good for when you need to test for membership (`a in set`),
  299. but the list might become to big, so you want to limit it so it doesn't
  300. consume too much resources.
  301. :keyword maxlen: Maximum number of members before we start
  302. evicting expired members.
  303. :keyword expires: Time in seconds, before a membership expires.
  304. """
  305. __slots__ = ('maxlen', 'expires', '_data', '__len__')
  306. def __init__(self, maxlen=None, expires=None):
  307. self.maxlen = maxlen
  308. self.expires = expires
  309. self._data = {}
  310. self.__len__ = self._data.__len__
  311. def add(self, value):
  312. """Add a new member."""
  313. self._expire_item()
  314. self._data[value] = time.time()
  315. def clear(self):
  316. """Remove all members"""
  317. self._data.clear()
  318. def pop_value(self, value):
  319. """Remove membership by finding value."""
  320. self._data.pop(value, None)
  321. def _expire_item(self):
  322. """Hunt down and remove an expired item."""
  323. while 1:
  324. if self.maxlen and len(self) >= self.maxlen:
  325. value, when = self.first
  326. if not self.expires or time.time() > when + self.expires:
  327. try:
  328. self.pop_value(value)
  329. except TypeError: # pragma: no cover
  330. continue
  331. break
  332. def __contains__(self, value):
  333. return value in self._data
  334. def update(self, other):
  335. if isinstance(other, self.__class__):
  336. self._data.update(other._data)
  337. else:
  338. for obj in other:
  339. self.add(obj)
  340. def as_dict(self):
  341. return self._data
  342. def __iter__(self):
  343. return iter(self._data)
  344. def __repr__(self):
  345. return 'LimitedSet(%r)' % (list(self._data), )
  346. @property
  347. def chronologically(self):
  348. return sorted(self._data.items(), key=lambda (value, when): when)
  349. @property
  350. def first(self):
  351. """Get the oldest member."""
  352. return self.chronologically[0]