datastructures.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.datastructures
  4. ~~~~~~~~~~~~~~~~~~~~~
  5. Custom types and data structures.
  6. :copyright: (c) 2009 - 2012 by Ask Solem.
  7. :license: BSD, see LICENSE for more details.
  8. """
  9. from __future__ import absolute_import
  10. from __future__ import with_statement
  11. import sys
  12. import time
  13. import traceback
  14. from collections import defaultdict
  15. from itertools import chain
  16. from threading import RLock
  17. from kombu.utils.limits import TokenBucket # noqa
  18. from .utils import uniq
  19. from .utils.compat import UserDict, OrderedDict
  20. class CycleError(Exception):
  21. """A cycle was detected in an acyclic graph."""
  22. class DependencyGraph(object):
  23. """A directed acyclic graph of objects and their dependencies.
  24. Supports a robust topological sort
  25. to detect the order in which they must be handled.
  26. Takes an optional iterator of ``(obj, dependencies)``
  27. tuples to build the graph from.
  28. .. warning::
  29. Does not support cycle detection.
  30. """
  31. def __init__(self, it=None):
  32. self.adjacent = {}
  33. if it is not None:
  34. self.update(it)
  35. def add_arc(self, obj):
  36. """Add an object to the graph."""
  37. self.adjacent[obj] = []
  38. def add_edge(self, A, B):
  39. """Add an edge from object ``A`` to object ``B``
  40. (``A`` depends on ``B``)."""
  41. self[A].append(B)
  42. def topsort(self):
  43. """Sort the graph topologically.
  44. :returns: a list of objects in the order
  45. in which they must be handled.
  46. """
  47. graph = DependencyGraph()
  48. components = self._tarjan72()
  49. NC = dict((node, component)
  50. for component in components
  51. for node in component)
  52. for component in components:
  53. graph.add_arc(component)
  54. for node in self:
  55. node_c = NC[node]
  56. for successor in self[node]:
  57. successor_c = NC[successor]
  58. if node_c != successor_c:
  59. graph.add_edge(node_c, successor_c)
  60. return [t[0] for t in graph._khan62()]
  61. def valency_of(self, obj):
  62. """Returns the velency (degree) of a vertex in the graph."""
  63. l = [len(self[obj])]
  64. for node in self[obj]:
  65. l.append(self.valency_of(node))
  66. return sum(l)
  67. def update(self, it):
  68. """Update the graph with data from a list
  69. of ``(obj, dependencies)`` tuples."""
  70. tups = list(it)
  71. for obj, _ in tups:
  72. self.add_arc(obj)
  73. for obj, deps in tups:
  74. for dep in deps:
  75. self.add_edge(obj, dep)
  76. def edges(self):
  77. """Returns generator that yields for all edges in the graph."""
  78. return (obj for obj, adj in self.iteritems() if adj)
  79. def _khan62(self):
  80. """Khans simple topological sort algorithm from '62
  81. See http://en.wikipedia.org/wiki/Topological_sorting
  82. """
  83. count = defaultdict(lambda: 0)
  84. result = []
  85. for node in self:
  86. for successor in self[node]:
  87. count[successor] += 1
  88. ready = [node for node in self if not count[node]]
  89. while ready:
  90. node = ready.pop()
  91. result.append(node)
  92. for successor in self[node]:
  93. count[successor] -= 1
  94. if count[successor] == 0:
  95. ready.append(successor)
  96. result.reverse()
  97. return result
  98. def _tarjan72(self):
  99. """Tarjan's algorithm to find strongly connected components.
  100. See http://bit.ly/vIMv3h.
  101. """
  102. result, stack, low = [], [], {}
  103. def visit(node):
  104. if node in low:
  105. return
  106. num = len(low)
  107. low[node] = num
  108. stack_pos = len(stack)
  109. stack.append(node)
  110. for successor in self[node]:
  111. visit(successor)
  112. low[node] = min(low[node], low[successor])
  113. if num == low[node]:
  114. component = tuple(stack[stack_pos:])
  115. stack[stack_pos:] = []
  116. result.append(component)
  117. for item in component:
  118. low[item] = len(self)
  119. for node in self:
  120. visit(node)
  121. return result
  122. def to_dot(self, fh, ws=" " * 4):
  123. """Convert the graph to DOT format.
  124. :param fh: A file, or a file-like object to write the graph to.
  125. """
  126. fh.write("digraph dependencies {\n")
  127. for obj, adjacent in self.iteritems():
  128. if not adjacent:
  129. fh.write(ws + '"%s"\n' % (obj, ))
  130. for req in adjacent:
  131. fh.write(ws + '"%s" -> "%s"\n' % (obj, req))
  132. fh.write("}\n")
  133. def __iter__(self):
  134. return self.adjacent.iterkeys()
  135. def __getitem__(self, node):
  136. return self.adjacent[node]
  137. def __len__(self):
  138. return len(self.adjacent)
  139. def _iterate_items(self):
  140. return self.adjacent.iteritems()
  141. items = iteritems = _iterate_items
  142. def __repr__(self):
  143. return '\n'.join(self.repr_node(N) for N in self)
  144. def repr_node(self, obj, level=1):
  145. output = ["%s(%s)" % (obj, self.valency_of(obj))]
  146. for other in self[obj]:
  147. d = "%s(%s)" % (other, self.valency_of(other))
  148. output.append(' ' * level + d)
  149. output.extend(self.repr_node(other, level + 1).split('\n')[1:])
  150. return '\n'.join(output)
  151. class AttributeDictMixin(object):
  152. """Adds attribute access to mappings.
  153. `d.key -> d[key]`
  154. """
  155. def __getattr__(self, key):
  156. """`d.key -> d[key]`"""
  157. try:
  158. return self[key]
  159. except KeyError:
  160. raise AttributeError("'%s' object has no attribute '%s'" % (
  161. self.__class__.__name__, key))
  162. def __setattr__(self, key, value):
  163. """`d[key] = value -> d.key = value`"""
  164. self[key] = value
  165. class AttributeDict(dict, AttributeDictMixin):
  166. """Dict subclass with attribute access."""
  167. pass
  168. class DictAttribute(object):
  169. """Dict interface to attributes.
  170. `obj[k] -> obj.k`
  171. """
  172. def __init__(self, obj):
  173. self.obj = obj
  174. def get(self, key, default=None):
  175. try:
  176. return self[key]
  177. except KeyError:
  178. return default
  179. def setdefault(self, key, default):
  180. try:
  181. return self[key]
  182. except KeyError:
  183. self[key] = default
  184. return default
  185. def __getitem__(self, key):
  186. try:
  187. return getattr(self.obj, key)
  188. except AttributeError:
  189. raise KeyError(key)
  190. def __setitem__(self, key, value):
  191. setattr(self.obj, key, value)
  192. def __contains__(self, key):
  193. return hasattr(self.obj, key)
  194. def _iterate_items(self):
  195. return vars(self.obj).iteritems()
  196. iteritems = _iterate_items
  197. if sys.version_info >= (3, 0): # pragma: no cover
  198. items = _iterate_items
  199. else:
  200. def items(self):
  201. return list(self._iterate_items())
  202. class ConfigurationView(AttributeDictMixin):
  203. """A view over an applications configuration dicts.
  204. If the key does not exist in ``changes``, the ``defaults`` dict
  205. is consulted.
  206. :param changes: Dict containing changes to the configuration.
  207. :param defaults: Dict containing the default configuration.
  208. """
  209. changes = None
  210. defaults = None
  211. _order = None
  212. def __init__(self, changes, defaults):
  213. self.__dict__.update(changes=changes, defaults=defaults,
  214. _order=[changes] + defaults)
  215. def __getitem__(self, key):
  216. for d in self._order:
  217. try:
  218. return d[key]
  219. except KeyError:
  220. pass
  221. raise KeyError(key)
  222. def __setitem__(self, key, value):
  223. self.changes[key] = value
  224. def get(self, key, default=None):
  225. try:
  226. return self[key]
  227. except KeyError:
  228. return default
  229. def setdefault(self, key, default):
  230. try:
  231. return self[key]
  232. except KeyError:
  233. self[key] = default
  234. return default
  235. def update(self, *args, **kwargs):
  236. return self.changes.update(*args, **kwargs)
  237. def __contains__(self, key):
  238. for d in self._order:
  239. if key in d:
  240. return True
  241. return False
  242. def __repr__(self):
  243. return repr(dict(self.iteritems()))
  244. def __iter__(self):
  245. return self.iterkeys()
  246. def _iter(self, op):
  247. # defaults must be first in the stream, so values in
  248. # changes takes precedence.
  249. return chain(*[op(d) for d in reversed(self._order)])
  250. def _iterate_keys(self):
  251. return uniq(self._iter(lambda d: d.iterkeys()))
  252. iterkeys = _iterate_keys
  253. def _iterate_items(self):
  254. return ((key, self[key]) for key in self)
  255. iteritems = _iterate_items
  256. def _iterate_values(self):
  257. return (self[key] for key in self)
  258. itervalues = _iterate_values
  259. def keys(self):
  260. return list(self._iterate_keys())
  261. def items(self):
  262. return list(self._iterate_items())
  263. def values(self):
  264. return list(self._iterate_values())
  265. class _Code(object):
  266. def __init__(self, code):
  267. self.co_filename = code.co_filename
  268. self.co_name = code.co_name
  269. class _Frame(object):
  270. Code = _Code
  271. def __init__(self, frame):
  272. self.f_globals = {
  273. "__file__": frame.f_globals.get("__file__", "__main__"),
  274. "__name__": frame.f_globals.get("__name__"),
  275. "__loader__": frame.f_globals.get("__loader__"),
  276. }
  277. self.f_locals = fl = {}
  278. try:
  279. fl["__traceback_hide__"] = frame.f_locals["__traceback_hide__"]
  280. except KeyError:
  281. pass
  282. self.f_code = self.Code(frame.f_code)
  283. self.f_lineno = frame.f_lineno
  284. class _Object(object):
  285. def __init__(self, **kw):
  286. [setattr(self, k, v) for k, v in kw.iteritems()]
  287. class _Truncated(object):
  288. def __init__(self):
  289. self.tb_lineno = -1
  290. self.tb_frame = _Object(
  291. f_globals={"__file__": "",
  292. "__name__": "",
  293. "__loader__": None},
  294. f_fileno=None,
  295. f_code=_Object(co_filename="...",
  296. co_name="[rest of traceback truncated]"),
  297. )
  298. self.tb_next = None
  299. class Traceback(object):
  300. Frame = _Frame
  301. tb_frame = tb_lineno = tb_next = None
  302. max_frames = sys.getrecursionlimit() / 8
  303. def __init__(self, tb, max_frames=None, depth=0):
  304. limit = self.max_frames = max_frames or self.max_frames
  305. self.tb_frame = self.Frame(tb.tb_frame)
  306. self.tb_lineno = tb.tb_lineno
  307. if tb.tb_next is not None:
  308. if depth <= limit:
  309. self.tb_next = Traceback(tb.tb_next, limit, depth + 1)
  310. else:
  311. self.tb_next = _Truncated()
  312. class ExceptionInfo(object):
  313. """Exception wrapping an exception and its traceback.
  314. :param exc_info: The exception info tuple as returned by
  315. :func:`sys.exc_info`.
  316. """
  317. #: Exception type.
  318. type = None
  319. #: Exception instance.
  320. exception = None
  321. #: Pickleable traceback instance for use with :mod:`traceback`
  322. tb = None
  323. #: String representation of the traceback.
  324. traceback = None
  325. #: Set to true if this is an internal error.
  326. internal = False
  327. def __init__(self, exc_info, internal=False):
  328. self.type, self.exception, tb = exc_info
  329. self.tb = Traceback(tb)
  330. self.traceback = ''.join(traceback.format_exception(*exc_info))
  331. self.internal = internal
  332. def __str__(self):
  333. return self.traceback
  334. def __repr__(self):
  335. return "<ExceptionInfo: %r>" % (self.exception, )
  336. @property
  337. def exc_info(self):
  338. return self.type, self.exception, self.tb
  339. class LimitedSet(object):
  340. """Kind-of Set with limitations.
  341. Good for when you need to test for membership (`a in set`),
  342. but the list might become to big, so you want to limit it so it doesn't
  343. consume too much resources.
  344. :keyword maxlen: Maximum number of members before we start
  345. evicting expired members.
  346. :keyword expires: Time in seconds, before a membership expires.
  347. """
  348. __slots__ = ("maxlen", "expires", "_data")
  349. def __init__(self, maxlen=None, expires=None):
  350. self.maxlen = maxlen
  351. self.expires = expires
  352. self._data = {}
  353. def add(self, value):
  354. """Add a new member."""
  355. self._expire_item()
  356. self._data[value] = time.time()
  357. def clear(self):
  358. """Remove all members"""
  359. self._data.clear()
  360. def pop_value(self, value):
  361. """Remove membership by finding value."""
  362. self._data.pop(value, None)
  363. def _expire_item(self):
  364. """Hunt down and remove an expired item."""
  365. while 1:
  366. if self.maxlen and len(self) >= self.maxlen:
  367. value, when = self.first
  368. if not self.expires or time.time() > when + self.expires:
  369. try:
  370. self.pop_value(value)
  371. except TypeError: # pragma: no cover
  372. continue
  373. break
  374. def __contains__(self, value):
  375. return value in self._data
  376. def update(self, other):
  377. if isinstance(other, self.__class__):
  378. self._data.update(other._data)
  379. else:
  380. for obj in other:
  381. self.add(obj)
  382. def as_dict(self):
  383. return self._data
  384. def __iter__(self):
  385. return iter(self._data.keys())
  386. def __len__(self):
  387. return len(self._data.keys())
  388. def __repr__(self):
  389. return "LimitedSet([%s])" % (repr(self._data.keys()))
  390. @property
  391. def chronologically(self):
  392. return sorted(self._data.items(), key=lambda (value, when): when)
  393. @property
  394. def first(self):
  395. """Get the oldest member."""
  396. return self.chronologically[0]
  397. class LRUCache(UserDict):
  398. """LRU Cache implementation using a doubly linked list to track access.
  399. :keyword limit: The maximum number of keys to keep in the cache.
  400. When a new key is inserted and the limit has been exceeded,
  401. the *Least Recently Used* key will be discarded from the
  402. cache.
  403. """
  404. def __init__(self, limit=None):
  405. self.limit = limit
  406. self.mutex = RLock()
  407. self.data = OrderedDict()
  408. def __getitem__(self, key):
  409. with self.mutex:
  410. value = self[key] = self.data.pop(key)
  411. return value
  412. def keys(self):
  413. # userdict.keys in py3k calls __getitem__
  414. return self.data.keys()
  415. def values(self):
  416. return list(self._iterate_values())
  417. def items(self):
  418. return list(self._iterate_items())
  419. def __setitem__(self, key, value):
  420. # remove least recently used key.
  421. with self.mutex:
  422. if self.limit and len(self.data) >= self.limit:
  423. self.data.pop(iter(self.data).next())
  424. self.data[key] = value
  425. def __iter__(self):
  426. return self.data.iterkeys()
  427. def _iterate_items(self):
  428. for k in self:
  429. try:
  430. yield (k, self.data[k])
  431. except KeyError:
  432. pass
  433. iteritems = _iterate_items
  434. def _iterate_values(self):
  435. for k in self:
  436. try:
  437. yield self.data[k]
  438. except KeyError: # pragma: no cover
  439. pass
  440. itervalues = _iterate_values
  441. def incr(self, key, delta=1):
  442. with self.mutex:
  443. # this acts as memcached does- store as a string, but return a
  444. # integer as long as it exists and we can cast it
  445. newval = int(self.data.pop(key)) + delta
  446. self[key] = str(newval)
  447. return newval