|
@@ -1,57 +1,156 @@
|
|
|
|
+"""
|
|
|
|
+celery.contrib.batches
|
|
|
|
+======================
|
|
|
|
+
|
|
|
|
+Collect messages and processes them as a list.
|
|
|
|
+
|
|
|
|
+**Example**
|
|
|
|
+
|
|
|
|
+A click counter that flushes the buffer every 100 messages, and every
|
|
|
|
+10 seconds.
|
|
|
|
+
|
|
|
|
+.. code-block:: python
|
|
|
|
+
|
|
|
|
+ from celery.task import task
|
|
|
|
+ from celery.contrib.batches import Batches
|
|
|
|
+
|
|
|
|
+ # Flush after 100 messages, or 10 seconds.
|
|
|
|
+ @task(base=Batches, flush_every=100, flush_interval=10)
|
|
|
|
+ def count_click(requests):
|
|
|
|
+ from collections import Counter
|
|
|
|
+ count = Counter(request.kwargs["url"] for request in requests)
|
|
|
|
+ for url, count in count.items():
|
|
|
|
+ print(">>> Clicks: %s -> %s" % (url, count))
|
|
|
|
+
|
|
|
|
+Registering the click is done as follows:
|
|
|
|
+
|
|
|
|
+ >>> count_click.delay(url="http://example.com")
|
|
|
|
+
|
|
|
|
+.. warning::
|
|
|
|
+
|
|
|
|
+ For this to work you have to set
|
|
|
|
+ :setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
|
|
|
|
+ the final multiplied value is higher than ``flush_every``.
|
|
|
|
+
|
|
|
|
+ In the future we hope to add the ability to direct batching tasks
|
|
|
|
+ to a channel with different QoS requirements than the task channel.
|
|
|
|
+
|
|
|
|
+:copyright: (c) 2009 - 2010 by Ask Solem.
|
|
|
|
+:license: BSD, see LICENSE for more details.
|
|
|
|
+
|
|
|
|
+"""
|
|
from itertools import count
|
|
from itertools import count
|
|
-from collections import deque
|
|
|
|
|
|
+from Queue import Queue
|
|
|
|
|
|
-from celery.task.base import Task
|
|
|
|
-from celery.utils.compat import defaultdict
|
|
|
|
|
|
+from celery.datastructures import consume_queue
|
|
|
|
+from celery.task import Task
|
|
|
|
+from celery.utils import timer2
|
|
|
|
+from celery.utils import cached_property
|
|
|
|
+from celery.worker import state
|
|
|
|
|
|
|
|
|
|
-class Batches(Task):
|
|
|
|
- abstract = True
|
|
|
|
- flush_every = 10
|
|
|
|
|
|
+class SimpleRequest(object):
|
|
|
|
+ """Pickleable request."""
|
|
|
|
|
|
- def __init__(self):
|
|
|
|
- self._buffer = deque()
|
|
|
|
- self._count = count().next
|
|
|
|
|
|
+ #: task id
|
|
|
|
+ id = None
|
|
|
|
|
|
- def execute(self, wrapper, pool, loglevel, logfile):
|
|
|
|
- self._buffer.append((wrapper, pool, loglevel, logfile))
|
|
|
|
|
|
+ #: task name
|
|
|
|
+ name = None
|
|
|
|
|
|
- if not self._count() % self.flush_every:
|
|
|
|
- self.flush(self._buffer)
|
|
|
|
- self._buffer.clear()
|
|
|
|
|
|
+ #: positional arguments
|
|
|
|
+ args = ()
|
|
|
|
+
|
|
|
|
+ #: keyword arguments
|
|
|
|
+ kwargs = {}
|
|
|
|
+
|
|
|
|
+ #: message delivery information.
|
|
|
|
+ delivery_info = None
|
|
|
|
+
|
|
|
|
+ #: worker node name
|
|
|
|
+ hostname = None
|
|
|
|
+
|
|
|
|
+ def __init__(self, id, name, args, kwargs, delivery_info, hostname):
|
|
|
|
+ self.id = id
|
|
|
|
+ self.name = name
|
|
|
|
+ self.args = args
|
|
|
|
+ self.kwargs = kwargs
|
|
|
|
+ self.delivery_info = delivery_info
|
|
|
|
+ self.hostname = hostname
|
|
|
|
|
|
- def flush(self, tasks):
|
|
|
|
- for wrapper, pool, loglevel, logfile in tasks:
|
|
|
|
- wrapper.execute_using_pool(pool, loglevel, logfile)
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def from_request(cls, request):
|
|
|
|
+ return cls(request.task_id, request.task_name, request.args,
|
|
|
|
+ request.kwargs, request.delivery_info, request.hostname)
|
|
|
|
|
|
|
|
|
|
-class Counter(Task):
|
|
|
|
|
|
+class Batches(Task):
|
|
abstract = True
|
|
abstract = True
|
|
|
|
+
|
|
|
|
+ #: Maximum number of message in buffer.
|
|
flush_every = 10
|
|
flush_every = 10
|
|
|
|
|
|
|
|
+ #: Timeout in seconds before buffer is flushed anyway.
|
|
|
|
+ flush_interval = 30
|
|
|
|
+
|
|
def __init__(self):
|
|
def __init__(self):
|
|
- self._buffer = deque()
|
|
|
|
- self._count = count().next
|
|
|
|
|
|
+ self._buffer = Queue()
|
|
|
|
+ self._count = count(1).next
|
|
|
|
+ self._tref = None
|
|
|
|
+ self._pool = None
|
|
|
|
|
|
- def execute(self, wrapper, pool, loglevel, logfile):
|
|
|
|
- self._buffer.append((wrapper.args, wrapper.kwargs))
|
|
|
|
|
|
+ def run(self, requests):
|
|
|
|
+ raise NotImplementedError("%r must implement run(requests)" % (self, ))
|
|
|
|
|
|
- if not self._count() % self.flush_every:
|
|
|
|
- self.flush(self._buffer)
|
|
|
|
- self._buffer.clear()
|
|
|
|
|
|
+ def flush(self, requests):
|
|
|
|
+ return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
|
|
|
|
+ for r in requests], ))
|
|
|
|
|
|
- def flush(self, buffer):
|
|
|
|
- raise NotImplementedError("Counters must implement 'flush'")
|
|
|
|
|
|
+ def execute(self, request, pool, loglevel, logfile, consumer):
|
|
|
|
+ if not self._pool: # just take pool from first task.
|
|
|
|
+ self._pool = pool
|
|
|
|
|
|
|
|
+ state.task_ready(request) # immediately remove from worker state.
|
|
|
|
+ self._buffer.put(request)
|
|
|
|
|
|
-class ClickCounter(Task):
|
|
|
|
- flush_every = 1000
|
|
|
|
|
|
+ if self._tref is None: # first request starts flush timer.
|
|
|
|
+ self._tref = timer2.apply_interval(self.flush_interval * 1000,
|
|
|
|
+ self._do_flush)
|
|
|
|
|
|
- def flush(self, buffer):
|
|
|
|
- urlcount = defaultdict(lambda: 0)
|
|
|
|
- for args, kwargs in buffer:
|
|
|
|
- urlcount[kwargs["url"]] += 1
|
|
|
|
-
|
|
|
|
- for url, count in urlcount.items():
|
|
|
|
- print(">>> Clicks: %s -> %s" % (url, count))
|
|
|
|
- # increment_in_db(url, n=count)
|
|
|
|
|
|
+ if not self._count() % self.flush_every:
|
|
|
|
+ self._do_flush()
|
|
|
|
+
|
|
|
|
+ def _do_flush(self):
|
|
|
|
+ self.debug("Wake-up to flush buffer...")
|
|
|
|
+ requests = None
|
|
|
|
+ if self._buffer.qsize():
|
|
|
|
+ requests = list(consume_queue(self._buffer))
|
|
|
|
+ if requests:
|
|
|
|
+ self.debug("Buffer complete: %s" % (len(requests, )))
|
|
|
|
+ self.flush(requests)
|
|
|
|
+ if not requests:
|
|
|
|
+ self.debug("Cancelling timer: Nothing in buffer.")
|
|
|
|
+ self._tref.cancel() # cancel timer.
|
|
|
|
+ self._tref = None
|
|
|
|
+
|
|
|
|
+ def apply_buffer(self, requests, args=(), kwargs={}):
|
|
|
|
+ acks_late = [], []
|
|
|
|
+ [acks_late[r.task.acks_late].append(r) for r in requests]
|
|
|
|
+ assert requests and (acks_late[True] or acks_late[False])
|
|
|
|
+
|
|
|
|
+ def on_accepted(pid, time_accepted):
|
|
|
|
+ [req.acknowledge() for req in acks_late[False]]
|
|
|
|
+
|
|
|
|
+ def on_return(result):
|
|
|
|
+ [req.acknowledge() for req in acks_late[True]]
|
|
|
|
+
|
|
|
|
+ return self._pool.apply_async(self, args,
|
|
|
|
+ accept_callback=on_accepted,
|
|
|
|
+ callbacks=acks_late[True] and [on_return] or [])
|
|
|
|
+
|
|
|
|
+ def debug(self, msg):
|
|
|
|
+ self.logger.debug("%s: %s" % (self.name, msg))
|
|
|
|
+
|
|
|
|
+ @cached_property
|
|
|
|
+ def logger(self):
|
|
|
|
+ return self.app.log.get_default_logger()
|