batches.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.contrib.batches
  4. ======================
  5. Collect messages and processes them as a list.
  6. **Example**
  7. A click counter that flushes the buffer every 100 messages, and every
  8. 10 seconds.
  9. .. code-block:: python
  10. from celery.task import task
  11. from celery.contrib.batches import Batches
  12. # Flush after 100 messages, or 10 seconds.
  13. @task(base=Batches, flush_every=100, flush_interval=10)
  14. def count_click(requests):
  15. from collections import Counter
  16. count = Counter(request.kwargs['url'] for request in requests)
  17. for url, count in count.items():
  18. print('>>> Clicks: %s -> %s' % (url, count))
  19. Registering the click is done as follows:
  20. >>> count_click.delay(url='http://example.com')
  21. .. warning::
  22. For this to work you have to set
  23. :setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
  24. the final multiplied value is higher than ``flush_every``.
  25. In the future we hope to add the ability to direct batching tasks
  26. to a channel with different QoS requirements than the task channel.
  27. """
  28. from __future__ import absolute_import
  29. from itertools import count
  30. from Queue import Empty, Queue
  31. from celery.task import Task
  32. from celery.utils import timer2
  33. from celery.utils.log import get_logger
  34. from celery.worker import state
  35. logger = get_logger(__name__)
  36. def consume_queue(queue):
  37. """Iterator yielding all immediately available items in a
  38. :class:`Queue.Queue`.
  39. The iterator stops as soon as the queue raises :exc:`Queue.Empty`.
  40. *Examples*
  41. >>> q = Queue()
  42. >>> map(q.put, range(4))
  43. >>> list(consume_queue(q))
  44. [0, 1, 2, 3]
  45. >>> list(consume_queue(q))
  46. []
  47. """
  48. get = queue.get_nowait
  49. while 1:
  50. try:
  51. yield get()
  52. except Empty:
  53. break
  54. def apply_batches_task(task, args, loglevel, logfile):
  55. task.push_request(loglevel=loglevel, logfile=logfile)
  56. try:
  57. result = task(*args)
  58. except Exception, exc:
  59. result = None
  60. task.logger.error('Error: %r', exc, exc_info=True)
  61. finally:
  62. task.pop_request()
  63. return result
  64. class SimpleRequest(object):
  65. """Pickleable request."""
  66. #: task id
  67. id = None
  68. #: task name
  69. name = None
  70. #: positional arguments
  71. args = ()
  72. #: keyword arguments
  73. kwargs = {}
  74. #: message delivery information.
  75. delivery_info = None
  76. #: worker node name
  77. hostname = None
  78. def __init__(self, id, name, args, kwargs, delivery_info, hostname):
  79. self.id = id
  80. self.name = name
  81. self.args = args
  82. self.kwargs = kwargs
  83. self.delivery_info = delivery_info
  84. self.hostname = hostname
  85. @classmethod
  86. def from_request(cls, request):
  87. return cls(request.id, request.name, request.args,
  88. request.kwargs, request.delivery_info, request.hostname)
  89. class Batches(Task):
  90. abstract = True
  91. #: Maximum number of message in buffer.
  92. flush_every = 10
  93. #: Timeout in seconds before buffer is flushed anyway.
  94. flush_interval = 30
  95. def __init__(self):
  96. self._buffer = Queue()
  97. self._count = count(1).next
  98. self._tref = None
  99. self._pool = None
  100. self._logging = None
  101. def run(self, requests):
  102. raise NotImplementedError('%r must implement run(requests)' % (self, ))
  103. def flush(self, requests):
  104. return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
  105. for r in requests], ))
  106. def execute(self, request, pool, loglevel, logfile):
  107. if not self._pool: # just take pool from first task.
  108. self._pool = pool
  109. if not self._logging:
  110. self._logging = loglevel, logfile
  111. state.task_ready(request) # immediately remove from worker state.
  112. self._buffer.put(request)
  113. if self._tref is None: # first request starts flush timer.
  114. self._tref = timer2.apply_interval(self.flush_interval * 1000,
  115. self._do_flush)
  116. if not self._count() % self.flush_every:
  117. self._do_flush()
  118. def _do_flush(self):
  119. logger.debug('Batches: Wake-up to flush buffer...')
  120. requests = None
  121. if self._buffer.qsize():
  122. requests = list(consume_queue(self._buffer))
  123. if requests:
  124. logger.debug('Batches: Buffer complete: %s', len(requests))
  125. self.flush(requests)
  126. if not requests:
  127. logger.debug('Batches: Cancelling timer: Nothing in buffer.')
  128. self._tref.cancel() # cancel timer.
  129. self._tref = None
  130. def apply_buffer(self, requests, args=(), kwargs={}):
  131. acks_late = [], []
  132. [acks_late[r.task.acks_late].append(r) for r in requests]
  133. assert requests and (acks_late[True] or acks_late[False])
  134. def on_accepted(pid, time_accepted):
  135. [req.acknowledge() for req in acks_late[False]]
  136. def on_return(result):
  137. [req.acknowledge() for req in acks_late[True]]
  138. loglevel, logfile = self._logging
  139. return self._pool.apply_async(apply_batches_task,
  140. (self, args, loglevel, logfile),
  141. accept_callback=on_accepted,
  142. callback=acks_late[True] and on_return or None)