datastructures.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.datastructures
  4. ~~~~~~~~~~~~~~~~~~~~~
  5. Custom types and data structures.
  6. :copyright: (c) 2009 - 2012 by Ask Solem.
  7. :license: BSD, see LICENSE for more details.
  8. """
  9. from __future__ import absolute_import
  10. from __future__ import with_statement
  11. import sys
  12. import time
  13. import traceback
  14. from collections import defaultdict
  15. from itertools import chain
  16. from kombu.utils.limits import TokenBucket # noqa
  17. from .utils.functional import LRUCache, first, uniq # noqa
  18. class CycleError(Exception):
  19. """A cycle was detected in an acyclic graph."""
  20. class DependencyGraph(object):
  21. """A directed acyclic graph of objects and their dependencies.
  22. Supports a robust topological sort
  23. to detect the order in which they must be handled.
  24. Takes an optional iterator of ``(obj, dependencies)``
  25. tuples to build the graph from.
  26. .. warning::
  27. Does not support cycle detection.
  28. """
  29. def __init__(self, it=None):
  30. self.adjacent = {}
  31. if it is not None:
  32. self.update(it)
  33. def add_arc(self, obj):
  34. """Add an object to the graph."""
  35. self.adjacent.setdefault(obj, [])
  36. def add_edge(self, A, B):
  37. """Add an edge from object ``A`` to object ``B``
  38. (``A`` depends on ``B``)."""
  39. self[A].append(B)
  40. def topsort(self):
  41. """Sort the graph topologically.
  42. :returns: a list of objects in the order
  43. in which they must be handled.
  44. """
  45. graph = DependencyGraph()
  46. components = self._tarjan72()
  47. NC = dict((node, component)
  48. for component in components
  49. for node in component)
  50. for component in components:
  51. graph.add_arc(component)
  52. for node in self:
  53. node_c = NC[node]
  54. for successor in self[node]:
  55. successor_c = NC[successor]
  56. if node_c != successor_c:
  57. graph.add_edge(node_c, successor_c)
  58. return [t[0] for t in graph._khan62()]
  59. def valency_of(self, obj):
  60. """Returns the velency (degree) of a vertex in the graph."""
  61. try:
  62. l = [len(self[obj])]
  63. except KeyError:
  64. return 0
  65. for node in self[obj]:
  66. l.append(self.valency_of(node))
  67. return sum(l)
  68. def update(self, it):
  69. """Update the graph with data from a list
  70. of ``(obj, dependencies)`` tuples."""
  71. tups = list(it)
  72. for obj, _ in tups:
  73. self.add_arc(obj)
  74. for obj, deps in tups:
  75. for dep in deps:
  76. self.add_edge(obj, dep)
  77. def edges(self):
  78. """Returns generator that yields for all edges in the graph."""
  79. return (obj for obj, adj in self.iteritems() if adj)
  80. def _khan62(self):
  81. """Khans simple topological sort algorithm from '62
  82. See http://en.wikipedia.org/wiki/Topological_sorting
  83. """
  84. count = defaultdict(lambda: 0)
  85. result = []
  86. for node in self:
  87. for successor in self[node]:
  88. count[successor] += 1
  89. ready = [node for node in self if not count[node]]
  90. while ready:
  91. node = ready.pop()
  92. result.append(node)
  93. for successor in self[node]:
  94. count[successor] -= 1
  95. if count[successor] == 0:
  96. ready.append(successor)
  97. result.reverse()
  98. return result
  99. def _tarjan72(self):
  100. """Tarjan's algorithm to find strongly connected components.
  101. See http://bit.ly/vIMv3h.
  102. """
  103. result, stack, low = [], [], {}
  104. def visit(node):
  105. if node in low:
  106. return
  107. num = len(low)
  108. low[node] = num
  109. stack_pos = len(stack)
  110. stack.append(node)
  111. for successor in self[node]:
  112. visit(successor)
  113. low[node] = min(low[node], low[successor])
  114. if num == low[node]:
  115. component = tuple(stack[stack_pos:])
  116. stack[stack_pos:] = []
  117. result.append(component)
  118. for item in component:
  119. low[item] = len(self)
  120. for node in self:
  121. visit(node)
  122. return result
  123. def to_dot(self, fh, ws=" " * 4):
  124. """Convert the graph to DOT format.
  125. :param fh: A file, or a file-like object to write the graph to.
  126. """
  127. fh.write("digraph dependencies {\n")
  128. for obj, adjacent in self.iteritems():
  129. if not adjacent:
  130. fh.write(ws + '"%s"\n' % (obj, ))
  131. for req in adjacent:
  132. fh.write(ws + '"%s" -> "%s"\n' % (obj, req))
  133. fh.write("}\n")
  134. def __iter__(self):
  135. return self.adjacent.iterkeys()
  136. def __getitem__(self, node):
  137. return self.adjacent[node]
  138. def __len__(self):
  139. return len(self.adjacent)
  140. def __contains__(self, obj):
  141. return obj in self.adjacent
  142. def _iterate_items(self):
  143. return self.adjacent.iteritems()
  144. items = iteritems = _iterate_items
  145. def __repr__(self):
  146. return '\n'.join(self.repr_node(N) for N in self)
  147. def repr_node(self, obj, level=1):
  148. output = ["%s(%s)" % (obj, self.valency_of(obj))]
  149. if obj in self:
  150. for other in self[obj]:
  151. d = "%s(%s)" % (other, self.valency_of(other))
  152. output.append(' ' * level + d)
  153. output.extend(self.repr_node(other, level + 1).split('\n')[1:])
  154. return '\n'.join(output)
  155. class AttributeDictMixin(object):
  156. """Adds attribute access to mappings.
  157. `d.key -> d[key]`
  158. """
  159. def __getattr__(self, key):
  160. """`d.key -> d[key]`"""
  161. try:
  162. return self[key]
  163. except KeyError:
  164. raise AttributeError("'%s' object has no attribute '%s'" % (
  165. self.__class__.__name__, key))
  166. def __setattr__(self, key, value):
  167. """`d[key] = value -> d.key = value`"""
  168. self[key] = value
  169. class AttributeDict(dict, AttributeDictMixin):
  170. """Dict subclass with attribute access."""
  171. pass
  172. class DictAttribute(object):
  173. """Dict interface to attributes.
  174. `obj[k] -> obj.k`
  175. """
  176. def __init__(self, obj):
  177. self.obj = obj
  178. def get(self, key, default=None):
  179. try:
  180. return self[key]
  181. except KeyError:
  182. return default
  183. def setdefault(self, key, default):
  184. try:
  185. return self[key]
  186. except KeyError:
  187. self[key] = default
  188. return default
  189. def __getitem__(self, key):
  190. try:
  191. return getattr(self.obj, key)
  192. except AttributeError:
  193. raise KeyError(key)
  194. def __setitem__(self, key, value):
  195. setattr(self.obj, key, value)
  196. def __contains__(self, key):
  197. return hasattr(self.obj, key)
  198. def _iterate_items(self):
  199. return vars(self.obj).iteritems()
  200. iteritems = _iterate_items
  201. if sys.version_info >= (3, 0): # pragma: no cover
  202. items = _iterate_items
  203. else:
  204. def items(self):
  205. return list(self._iterate_items())
  206. class ConfigurationView(AttributeDictMixin):
  207. """A view over an applications configuration dicts.
  208. If the key does not exist in ``changes``, the ``defaults`` dict
  209. is consulted.
  210. :param changes: Dict containing changes to the configuration.
  211. :param defaults: Dict containing the default configuration.
  212. """
  213. changes = None
  214. defaults = None
  215. _order = None
  216. def __init__(self, changes, defaults):
  217. self.__dict__.update(changes=changes, defaults=defaults,
  218. _order=[changes] + defaults)
  219. def __getitem__(self, key):
  220. for d in self._order:
  221. try:
  222. return d[key]
  223. except KeyError:
  224. pass
  225. raise KeyError(key)
  226. def __setitem__(self, key, value):
  227. self.changes[key] = value
  228. def first(self, *keys):
  229. return first(None, (self.get(key) for key in keys))
  230. def get(self, key, default=None):
  231. try:
  232. return self[key]
  233. except KeyError:
  234. return default
  235. def setdefault(self, key, default):
  236. try:
  237. return self[key]
  238. except KeyError:
  239. self[key] = default
  240. return default
  241. def update(self, *args, **kwargs):
  242. return self.changes.update(*args, **kwargs)
  243. def __contains__(self, key):
  244. for d in self._order:
  245. if key in d:
  246. return True
  247. return False
  248. def __repr__(self):
  249. return repr(dict(self.iteritems()))
  250. def __iter__(self):
  251. return self.iterkeys()
  252. def _iter(self, op):
  253. # defaults must be first in the stream, so values in
  254. # changes takes precedence.
  255. return chain(*[op(d) for d in reversed(self._order)])
  256. def _iterate_keys(self):
  257. return uniq(self._iter(lambda d: d.iterkeys()))
  258. iterkeys = _iterate_keys
  259. def _iterate_items(self):
  260. return ((key, self[key]) for key in self)
  261. iteritems = _iterate_items
  262. def _iterate_values(self):
  263. return (self[key] for key in self)
  264. itervalues = _iterate_values
  265. def keys(self):
  266. return list(self._iterate_keys())
  267. def items(self):
  268. return list(self._iterate_items())
  269. def values(self):
  270. return list(self._iterate_values())
  271. class _Code(object):
  272. def __init__(self, code):
  273. self.co_filename = code.co_filename
  274. self.co_name = code.co_name
  275. class _Frame(object):
  276. Code = _Code
  277. def __init__(self, frame):
  278. self.f_globals = {
  279. "__file__": frame.f_globals.get("__file__", "__main__"),
  280. "__name__": frame.f_globals.get("__name__"),
  281. "__loader__": frame.f_globals.get("__loader__"),
  282. }
  283. self.f_locals = fl = {}
  284. try:
  285. fl["__traceback_hide__"] = frame.f_locals["__traceback_hide__"]
  286. except KeyError:
  287. pass
  288. self.f_code = self.Code(frame.f_code)
  289. self.f_lineno = frame.f_lineno
  290. class _Object(object):
  291. def __init__(self, **kw):
  292. [setattr(self, k, v) for k, v in kw.iteritems()]
  293. class _Truncated(object):
  294. def __init__(self):
  295. self.tb_lineno = -1
  296. self.tb_frame = _Object(
  297. f_globals={"__file__": "",
  298. "__name__": "",
  299. "__loader__": None},
  300. f_fileno=None,
  301. f_code=_Object(co_filename="...",
  302. co_name="[rest of traceback truncated]"),
  303. )
  304. self.tb_next = None
  305. class Traceback(object):
  306. Frame = _Frame
  307. tb_frame = tb_lineno = tb_next = None
  308. max_frames = sys.getrecursionlimit() / 8
  309. def __init__(self, tb, max_frames=None, depth=0):
  310. limit = self.max_frames = max_frames or self.max_frames
  311. self.tb_frame = self.Frame(tb.tb_frame)
  312. self.tb_lineno = tb.tb_lineno
  313. if tb.tb_next is not None:
  314. if depth <= limit:
  315. self.tb_next = Traceback(tb.tb_next, limit, depth + 1)
  316. else:
  317. self.tb_next = _Truncated()
  318. class ExceptionInfo(object):
  319. """Exception wrapping an exception and its traceback.
  320. :param exc_info: The exception info tuple as returned by
  321. :func:`sys.exc_info`.
  322. """
  323. #: Exception type.
  324. type = None
  325. #: Exception instance.
  326. exception = None
  327. #: Pickleable traceback instance for use with :mod:`traceback`
  328. tb = None
  329. #: String representation of the traceback.
  330. traceback = None
  331. #: Set to true if this is an internal error.
  332. internal = False
  333. def __init__(self, exc_info, internal=False):
  334. self.type, self.exception, tb = exc_info
  335. self.tb = Traceback(tb)
  336. self.traceback = ''.join(traceback.format_exception(*exc_info))
  337. self.internal = internal
  338. def __str__(self):
  339. return self.traceback
  340. def __repr__(self):
  341. return "<ExceptionInfo: %r>" % (self.exception, )
  342. @property
  343. def exc_info(self):
  344. return self.type, self.exception, self.tb
  345. class LimitedSet(object):
  346. """Kind-of Set with limitations.
  347. Good for when you need to test for membership (`a in set`),
  348. but the list might become to big, so you want to limit it so it doesn't
  349. consume too much resources.
  350. :keyword maxlen: Maximum number of members before we start
  351. evicting expired members.
  352. :keyword expires: Time in seconds, before a membership expires.
  353. """
  354. __slots__ = ("maxlen", "expires", "_data")
  355. def __init__(self, maxlen=None, expires=None):
  356. self.maxlen = maxlen
  357. self.expires = expires
  358. self._data = {}
  359. def add(self, value):
  360. """Add a new member."""
  361. self._expire_item()
  362. self._data[value] = time.time()
  363. def clear(self):
  364. """Remove all members"""
  365. self._data.clear()
  366. def pop_value(self, value):
  367. """Remove membership by finding value."""
  368. self._data.pop(value, None)
  369. def _expire_item(self):
  370. """Hunt down and remove an expired item."""
  371. while 1:
  372. if self.maxlen and len(self) >= self.maxlen:
  373. value, when = self.first
  374. if not self.expires or time.time() > when + self.expires:
  375. try:
  376. self.pop_value(value)
  377. except TypeError: # pragma: no cover
  378. continue
  379. break
  380. def __contains__(self, value):
  381. return value in self._data
  382. def update(self, other):
  383. if isinstance(other, self.__class__):
  384. self._data.update(other._data)
  385. else:
  386. for obj in other:
  387. self.add(obj)
  388. def as_dict(self):
  389. return self._data
  390. def __iter__(self):
  391. return iter(self._data.keys())
  392. def __len__(self):
  393. return len(self._data.keys())
  394. def __repr__(self):
  395. return "LimitedSet([%s])" % (repr(self._data.keys()))
  396. @property
  397. def chronologically(self):
  398. return sorted(self._data.items(), key=lambda (value, when): when)
  399. @property
  400. def first(self):
  401. """Get the oldest member."""
  402. return self.chronologically[0]