datastructures.py 11 KB


  1. from __future__ import generators
  2. import time
  3. import traceback
  4. from itertools import chain
  5. from UserList import UserList
  6. from Queue import Queue, Empty as QueueEmpty
  7. from celery.utils.compat import OrderedDict
  8. class AttributeDictMixin(object):
  9. def __getattr__(self, key):
  10. try:
  11. return self[key]
  12. except KeyError:
  13. raise AttributeError("'%s' object has no attribute '%s'" % (
  14. self.__class__.__name__, key))
  15. def __setattr__(self, key, value):
  16. self[key] = value
  17. class AttributeDict(dict, AttributeDictMixin):
  18. """Dict subclass with attribute access."""
  19. pass
  20. class DictAttribute(object):
  21. """Dict interface using attributes."""
  22. def __init__(self, obj):
  23. self.obj = obj
  24. def get(self, key, default=None):
  25. try:
  26. return self[key]
  27. except KeyError:
  28. return default
  29. def setdefault(self, key, default):
  30. try:
  31. return self[key]
  32. except KeyError:
  33. self[key] = default
  34. return default
  35. def __getitem__(self, key):
  36. try:
  37. return getattr(self.obj, key)
  38. except AttributeError:
  39. raise KeyError(key)
  40. def __setitem__(self, key, value):
  41. setattr(self.obj, key, value)
  42. def __contains__(self, key):
  43. return hasattr(self.obj, key)
  44. def iteritems(self):
  45. return vars(self.obj).iteritems()
  46. class ConfigurationView(AttributeDictMixin):
  47. changes = None
  48. defaults = None
  49. def __init__(self, changes, defaults):
  50. self.__dict__["changes"] = changes
  51. self.__dict__["defaults"] = defaults
  52. def __getitem__(self, key):
  53. for d in self.__dict__["changes"], self.__dict__["defaults"]:
  54. try:
  55. return d[key]
  56. except KeyError:
  57. pass
  58. raise KeyError(key)
  59. def __setitem__(self, key, value):
  60. self.__dict__["changes"][key] = value
  61. def get(self, key, default=None):
  62. try:
  63. return self[key]
  64. except KeyError:
  65. return default
  66. def setdefault(self, key, default):
  67. try:
  68. return self[key]
  69. except KeyError:
  70. self[key] = default
  71. return default
  72. def update(self, *args, **kwargs):
  73. return self.__dict__["changes"].update(*args, **kwargs)
  74. def __contains__(self, key):
  75. for d in self.__dict__["changes"], self.__dict__["defaults"]:
  76. if key in d:
  77. return True
  78. return False
  79. def __repr__(self):
  80. return repr(dict(iter(self)))
  81. def __iter__(self):
  82. return chain(*[d.iteritems() for d in (self.__dict__["changes"],
  83. self.__dict__["defaults"])])
  84. class PositionQueue(UserList):
  85. """A positional queue of a specific length, with slots that are either
  86. filled or unfilled. When all of the positions are filled, the queue
  87. is considered :meth:`full`.
  88. :param length: Number of items to fill.
  89. """
  90. #: The number of items required for the queue to be considered full.
  91. length = None
  92. class UnfilledPosition(object):
  93. """Describes an unfilled slot."""
  94. def __init__(self, position):
  95. # This is not used, but is an argument from xrange
  96. # so why not.
  97. self.position = position
  98. def __init__(self, length):
  99. self.length = length
  100. self.data = map(self.UnfilledPosition, xrange(length))
  101. def full(self):
  102. """Returns :const:`True` if all of the slots has been filled."""
  103. return len(self) >= self.length
  104. def __len__(self):
  105. """`len(self)` -> number of slots filled with real values."""
  106. return len(self.filled)
  107. @property
  108. def filled(self):
  109. """All filled slots as a list."""
  110. return [slot for slot in self.data
  111. if not isinstance(slot, self.UnfilledPosition)]
  112. class ExceptionInfo(object):
  113. """Exception wrapping an exception and its traceback.
  114. :param exc_info: The exception tuple info as returned by
  115. :func:`traceback.format_exception`.
  116. """
  117. #: The original exception.
  118. exception = None
  119. #: A traceback form the point when :attr:`exception` was raised.
  120. traceback = None
  121. def __init__(self, exc_info):
  122. type_, exception, tb = exc_info
  123. self.exception = exception
  124. self.traceback = ''.join(traceback.format_exception(*exc_info))
  125. def __str__(self):
  126. return self.traceback
  127. def __repr__(self):
  128. return "<%s.%s: %s>" % (
  129. self.__class__.__module__,
  130. self.__class__.__name__,
  131. str(self.exception))
  132. def consume_queue(queue):
  133. """Iterator yielding all immediately available items in a
  134. :class:`Queue.Queue`.
  135. The iterator stops as soon as the queue raises :exc:`Queue.Empty`.
  136. Example
  137. >>> q = Queue()
  138. >>> map(q.put, range(4))
  139. >>> list(consume_queue(q))
  140. [0, 1, 2, 3]
  141. >>> list(consume_queue(q))
  142. []
  143. """
  144. while 1:
  145. try:
  146. yield queue.get_nowait()
  147. except QueueEmpty:
  148. break
  149. class SharedCounter(object):
  150. """Thread-safe counter.
  151. Please note that the final value is not synchronized, this means
  152. that you should not update the value by using a previous value, the only
  153. reliable operations are increment and decrement.
  154. Example::
  155. >>> max_clients = SharedCounter(initial_value=10)
  156. # Thread one
  157. >>> max_clients += 1 # OK (safe)
  158. # Thread two
  159. >>> max_clients -= 3 # OK (safe)
  160. # Main thread
  161. >>> if client >= int(max_clients): # Max clients now at 8
  162. ... wait()
  163. >>> max_client = max_clients + 10 # NOT OK (unsafe)
  164. """
  165. def __init__(self, initial_value):
  166. self._value = initial_value
  167. self._modify_queue = Queue()
  168. def increment(self, n=1):
  169. """Increment value."""
  170. self += n
  171. return int(self)
  172. def decrement(self, n=1):
  173. """Decrement value."""
  174. self -= n
  175. return int(self)
  176. def _update_value(self):
  177. self._value += sum(consume_queue(self._modify_queue))
  178. return self._value
  179. def __iadd__(self, y):
  180. """`self += y`"""
  181. self._modify_queue.put(y * +1)
  182. return self
  183. def __isub__(self, y):
  184. """`self -= y`"""
  185. self._modify_queue.put(y * -1)
  186. return self
  187. def __int__(self):
  188. """`int(self) -> int`"""
  189. return self._update_value()
  190. def __repr__(self):
  191. return "<SharedCounter: int(%s)>" % str(int(self))
  192. class LimitedSet(object):
  193. """Kind-of Set with limitations.
  194. Good for when you need to test for membership (`a in set`),
  195. but the list might become to big, so you want to limit it so it doesn't
  196. consume too much resources.
  197. :keyword maxlen: Maximum number of members before we start
  198. deleting expired members.
  199. :keyword expires: Time in seconds, before a membership expires.
  200. """
  201. def __init__(self, maxlen=None, expires=None):
  202. self.maxlen = maxlen
  203. self.expires = expires
  204. self._data = {}
  205. def add(self, value):
  206. """Add a new member."""
  207. self._expire_item()
  208. self._data[value] = time.time()
  209. def clear(self):
  210. """Remove all members"""
  211. self._data.clear()
  212. def pop_value(self, value):
  213. """Remove membership by finding value."""
  214. self._data.pop(value, None)
  215. def _expire_item(self):
  216. """Hunt down and remove an expired item."""
  217. while 1:
  218. if self.maxlen and len(self) >= self.maxlen:
  219. value, when = self.first
  220. if not self.expires or time.time() > when + self.expires:
  221. try:
  222. self.pop_value(value)
  223. except TypeError: # pragma: no cover
  224. continue
  225. break
  226. def __contains__(self, value):
  227. return value in self._data
  228. def update(self, other):
  229. if isinstance(other, self.__class__):
  230. self._data.update(other._data)
  231. else:
  232. self._data.update(other)
  233. def as_dict(self):
  234. return self._data
  235. def __iter__(self):
  236. return iter(self._data.keys())
  237. def __len__(self):
  238. return len(self._data.keys())
  239. def __repr__(self):
  240. return "LimitedSet([%s])" % (repr(self._data.keys()))
  241. @property
  242. def chronologically(self):
  243. return sorted(self._data.items(), key=lambda (value, when): when)
  244. @property
  245. def first(self):
  246. """Get the oldest member."""
  247. return self.chronologically[0]
  248. class LocalCache(OrderedDict):
  249. """Dictionary with a finite number of keys.
  250. Older items expires first.
  251. """
  252. def __init__(self, limit=None):
  253. super(LocalCache, self).__init__()
  254. self.limit = limit
  255. def __setitem__(self, key, value):
  256. while len(self) >= self.limit:
  257. self.popitem(last=False)
  258. super(LocalCache, self).__setitem__(key, value)
  259. class TokenBucket(object):
  260. """Token Bucket Algorithm.
  261. See http://en.wikipedia.org/wiki/Token_Bucket
  262. Most of this code was stolen from an entry in the ASPN Python Cookbook:
  263. http://code.activestate.com/recipes/511490/
  264. .. admonition:: Thread safety
  265. This implementation is not thread safe.
  266. :param fill_rate: Refill rate in tokens/second.
  267. :keyword capacity: Max number of tokens. Default is 1.
  268. """
  269. #: The rate in tokens/second that the bucket will be refilled
  270. fill_rate = None
  271. #: Maximum number of tokensin the bucket.
  272. capacity = 1
  273. #: Timestamp of the last time a token was taken out of the bucket.
  274. timestamp = None
  275. def __init__(self, fill_rate, capacity=1):
  276. self.capacity = float(capacity)
  277. self._tokens = capacity
  278. self.fill_rate = float(fill_rate)
  279. self.timestamp = time.time()
  280. def can_consume(self, tokens=1):
  281. """Returns :const:`True` if `tokens` number of tokens can be consumed
  282. from the bucket."""
  283. if tokens <= self._get_tokens():
  284. self._tokens -= tokens
  285. return True
  286. return False
  287. def expected_time(self, tokens=1):
  288. """Returns the expected time in seconds when a new token should be
  289. available.
  290. .. admonition:: Warning
  291. This consumes a token from the bucket.
  292. """
  293. _tokens = self._get_tokens()
  294. tokens = max(tokens, _tokens)
  295. return (tokens - _tokens) / self.fill_rate
  296. def _get_tokens(self):
  297. if self._tokens < self.capacity:
  298. now = time.time()
  299. delta = self.fill_rate * (now - self.timestamp)
  300. self._tokens = min(self.capacity, self._tokens + delta)
  301. self.timestamp = now
  302. return self._tokens