batches.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.contrib.batches
  4. ======================
  5. Collect messages and processes them as a list.
  6. **Example**
  7. A click counter that flushes the buffer every 100 messages, and every
  8. 10 seconds.
  9. .. code-block:: python
  10. from celery.task import task
  11. from celery.contrib.batches import Batches
  12. # Flush after 100 messages, or 10 seconds.
  13. @task(base=Batches, flush_every=100, flush_interval=10)
  14. def count_click(requests):
  15. from collections import Counter
  16. count = Counter(request.kwargs["url"] for request in requests)
  17. for url, count in count.items():
  18. print(">>> Clicks: %s -> %s" % (url, count))
  19. Registering the click is done as follows:
  20. >>> count_click.delay(url="http://example.com")
  21. .. warning::
  22. For this to work you have to set
  23. :setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
  24. the final multiplied value is higher than ``flush_every``.
  25. In the future we hope to add the ability to direct batching tasks
  26. to a channel with different QoS requirements than the task channel.
  27. :copyright: (c) 2009 - 2012 by Ask Solem.
  28. :license: BSD, see LICENSE for more details.
  29. """
  30. from __future__ import absolute_import
  31. from itertools import count
  32. from Queue import Empty, Queue
  33. from celery.task import Task
  34. from celery.utils import timer2
  35. from celery.utils.log import get_logger
  36. from celery.worker import state
  37. logger = get_logger(__name__)
  38. def consume_queue(queue):
  39. """Iterator yielding all immediately available items in a
  40. :class:`Queue.Queue`.
  41. The iterator stops as soon as the queue raises :exc:`Queue.Empty`.
  42. *Examples*
  43. >>> q = Queue()
  44. >>> map(q.put, range(4))
  45. >>> list(consume_queue(q))
  46. [0, 1, 2, 3]
  47. >>> list(consume_queue(q))
  48. []
  49. """
  50. get = queue.get_nowait
  51. while 1:
  52. try:
  53. yield get()
  54. except Empty:
  55. break
  56. def apply_batches_task(task, args, loglevel, logfile):
  57. task.push_request(loglevel=loglevel, logfile=logfile)
  58. try:
  59. result = task(*args)
  60. except Exception, exc:
  61. result = None
  62. task.logger.error("Error: %r", exc, exc_info=True)
  63. finally:
  64. task.pop_request()
  65. return result
  66. class SimpleRequest(object):
  67. """Pickleable request."""
  68. #: task id
  69. id = None
  70. #: task name
  71. name = None
  72. #: positional arguments
  73. args = ()
  74. #: keyword arguments
  75. kwargs = {}
  76. #: message delivery information.
  77. delivery_info = None
  78. #: worker node name
  79. hostname = None
  80. def __init__(self, id, name, args, kwargs, delivery_info, hostname):
  81. self.id = id
  82. self.name = name
  83. self.args = args
  84. self.kwargs = kwargs
  85. self.delivery_info = delivery_info
  86. self.hostname = hostname
  87. @classmethod
  88. def from_request(cls, request):
  89. return cls(request.id, request.name, request.args,
  90. request.kwargs, request.delivery_info, request.hostname)
  91. class Batches(Task):
  92. abstract = True
  93. #: Maximum number of message in buffer.
  94. flush_every = 10
  95. #: Timeout in seconds before buffer is flushed anyway.
  96. flush_interval = 30
  97. def __init__(self):
  98. self._buffer = Queue()
  99. self._count = count(1).next
  100. self._tref = None
  101. self._pool = None
  102. self._logging = None
  103. def run(self, requests):
  104. raise NotImplementedError("%r must implement run(requests)" % (self, ))
  105. def flush(self, requests):
  106. return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
  107. for r in requests], ))
  108. def execute(self, request, pool, loglevel, logfile):
  109. if not self._pool: # just take pool from first task.
  110. self._pool = pool
  111. if not self._logging:
  112. self._logging = loglevel, logfile
  113. state.task_ready(request) # immediately remove from worker state.
  114. self._buffer.put(request)
  115. if self._tref is None: # first request starts flush timer.
  116. self._tref = timer2.apply_interval(self.flush_interval * 1000,
  117. self._do_flush)
  118. if not self._count() % self.flush_every:
  119. self._do_flush()
  120. def _do_flush(self):
  121. logger.debug("Batches: Wake-up to flush buffer...")
  122. requests = None
  123. if self._buffer.qsize():
  124. requests = list(consume_queue(self._buffer))
  125. if requests:
  126. logger.debug("Batches: Buffer complete: %s", len(requests))
  127. self.flush(requests)
  128. if not requests:
  129. logger.debug("Batches: Cancelling timer: Nothing in buffer.")
  130. self._tref.cancel() # cancel timer.
  131. self._tref = None
  132. def apply_buffer(self, requests, args=(), kwargs={}):
  133. acks_late = [], []
  134. [acks_late[r.task.acks_late].append(r) for r in requests]
  135. assert requests and (acks_late[True] or acks_late[False])
  136. def on_accepted(pid, time_accepted):
  137. [req.acknowledge() for req in acks_late[False]]
  138. def on_return(result):
  139. [req.acknowledge() for req in acks_late[True]]
  140. loglevel, logfile = self._logging
  141. return self._pool.apply_async(apply_batches_task,
  142. (self, args, loglevel, logfile),
  143. accept_callback=on_accepted,
  144. callback=acks_late[True] and on_return or None)