coroutine.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import time
  2. from collections import deque
  3. from celery.task.base import Task
  4. class CoroutineTask(Task):
  5. abstract = True
  6. _current_gen = None
  7. def body(self):
  8. while True:
  9. args, kwargs = (yield)
  10. yield self.run(*args, *kwargs)
  11. def run(self, *args, **kwargs):
  12. try:
  13. return self._gen.send((args, kwargs))
  14. finally:
  15. self._gen.next() # Go to receive-mode
  16. @property
  17. def _gen(self):
  18. if not self._current_gen:
  19. self._current_gen = self.body()
  20. self._current_gen.next() # Go to receive-mode
  21. return self._current_gen
  22. class Aggregate(CoroutineTask):
  23. abstract = True
  24. proxied = None
  25. minlen = 100
  26. time_max = 60
  27. _time_since = None
  28. def body(self):
  29. waiting = deque()
  30. timesince = time.time()
  31. while True:
  32. argtuple = (yield)
  33. waiting.append(argtuple)
  34. if self._expired() or len(waiting) >= self.minlen:
  35. yield self.process(waiting)
  36. waiting.clear()
  37. else:
  38. yield None
  39. def process(self, jobs):
  40. """Jobs is a deque with the arguments gathered so far.
  41. Arguments is a args, kwargs tuple.
  42. """
  43. raise NotImplementedError(
  44. "Subclasses of Aggregate needs to implement process()")
  45. def _expired(self):
  46. if not self._time_since:
  47. self._time_since = time.time()
  48. return False
  49. if time.time() + self.time_max > self._time_since:
  50. self._time_since = time.time()
  51. return True
  52. return False