Parcourir la source

Working implementation of celery.contrib.batches

Ask Solem il y a 14 ans
Parent
commit
1d1ead532f

+ 136 - 37
celery/contrib/batches.py

@@ -1,57 +1,156 @@
+"""
+celery.contrib.batches
+======================
+
+Collect messages and processes them as a list.
+
+**Example**
+
+A click counter that flushes the buffer every 100 messages, and every
+10 seconds.
+
+.. code-block:: python
+
+    from celery.task import task
+    from celery.contrib.batches import Batches
+
+    # Flush after 100 messages, or 10 seconds.
+    @task(base=Batches, flush_every=100, flush_interval=10)
+    def count_click(requests):
+        from collections import Counter
+        count = Counter(request.kwargs["url"] for request in requests)
+        for url, count in count.items():
+            print(">>> Clicks: %s -> %s" % (url, count))
+
+Registering the click is done as follows:
+
+    >>> count_click.delay(url="http://example.com")
+
+.. warning::
+
+    For this to work you have to set
+    :setting:`CELERYD_PREFETCH_MULTIPLIER` to zero, or some value where
+    the final multiplied value is higher than ``flush_every``.
+
+    In the future we hope to add the ability to direct batching tasks
+    to a channel with different QoS requirements than the task channel.
+
+:copyright: (c) 2009 - 2010 by Ask Solem.
+:license: BSD, see LICENSE for more details.
+
+"""
 from itertools import count
-from collections import deque
+from Queue import Queue
 
-from celery.task.base import Task
-from celery.utils.compat import defaultdict
+from celery.datastructures import consume_queue
+from celery.task import Task
+from celery.utils import timer2
+from celery.utils import cached_property
+from celery.worker import state
 
 
-class Batches(Task):
-    abstract = True
-    flush_every = 10
+class SimpleRequest(object):
+    """Pickleable request."""
 
-    def __init__(self):
-        self._buffer = deque()
-        self._count = count().next
+    #: task id
+    id = None
 
-    def execute(self, wrapper, pool, loglevel, logfile):
-        self._buffer.append((wrapper, pool, loglevel, logfile))
+    #: task name
+    name = None
 
-        if not self._count() % self.flush_every:
-            self.flush(self._buffer)
-            self._buffer.clear()
+    #: positional arguments
+    args = ()
+
+    #: keyword arguments
+    kwargs = {}
+
+    #: message delivery information.
+    delivery_info = None
+
+    #: worker node name
+    hostname = None
+
+    def __init__(self, id, name, args, kwargs, delivery_info, hostname):
+        self.id = id
+        self.name = name
+        self.args = args
+        self.kwargs = kwargs
+        self.delivery_info = delivery_info
+        self.hostname = hostname
 
-    def flush(self, tasks):
-        for wrapper, pool, loglevel, logfile in tasks:
-            wrapper.execute_using_pool(pool, loglevel, logfile)
+    @classmethod
+    def from_request(cls, request):
+        return cls(request.task_id, request.task_name, request.args,
+                   request.kwargs, request.delivery_info, request.hostname)
 
 
-class Counter(Task):
+class Batches(Task):
     abstract = True
+
+    #: Maximum number of message in buffer.
     flush_every = 10
 
+    #: Timeout in seconds before buffer is flushed anyway.
+    flush_interval = 30
+
     def __init__(self):
-        self._buffer = deque()
-        self._count = count().next
+        self._buffer = Queue()
+        self._count = count(1).next
+        self._tref = None
+        self._pool = None
 
-    def execute(self, wrapper, pool, loglevel, logfile):
-        self._buffer.append((wrapper.args, wrapper.kwargs))
+    def run(self, requests):
+        raise NotImplementedError("%r must implement run(requests)" % (self, ))
 
-        if not self._count() % self.flush_every:
-            self.flush(self._buffer)
-            self._buffer.clear()
+    def flush(self, requests):
+        return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
+                                                for r in requests], ))
 
-    def flush(self, buffer):
-        raise NotImplementedError("Counters must implement 'flush'")
+    def execute(self, request, pool, loglevel, logfile, consumer):
+        if not self._pool:         # just take pool from first task.
+            self._pool = pool
 
+        state.task_ready(request)  # immediately remove from worker state.
+        self._buffer.put(request)
 
-class ClickCounter(Task):
-    flush_every = 1000
+        if self._tref is None:     # first request starts flush timer.
+            self._tref = timer2.apply_interval(self.flush_interval * 1000,
+                                               self._do_flush)
 
-    def flush(self, buffer):
-        urlcount = defaultdict(lambda: 0)
-        for args, kwargs in buffer:
-            urlcount[kwargs["url"]] += 1
-
-        for url, count in urlcount.items():
-            print(">>> Clicks: %s -> %s" % (url, count))
-            # increment_in_db(url, n=count)
+        if not self._count() % self.flush_every:
+            self._do_flush()
+
+    def _do_flush(self):
+        self.debug("Wake-up to flush buffer...")
+        requests = None
+        if self._buffer.qsize():
+            requests = list(consume_queue(self._buffer))
+            if requests:
+                self.debug("Buffer complete: %s" % (len(requests, )))
+                self.flush(requests)
+        if not requests:
+            self.debug("Cancelling timer: Nothing in buffer.")
+            self._tref.cancel()  # cancel timer.
+            self._tref = None
+
+    def apply_buffer(self, requests, args=(), kwargs={}):
+        acks_late = [], []
+        [acks_late[r.task.acks_late].append(r) for r in requests]
+        assert requests and (acks_late[True] or acks_late[False])
+
+        def on_accepted(pid, time_accepted):
+            [req.acknowledge() for req in acks_late[False]]
+
+        def on_return(result):
+            [req.acknowledge() for req in acks_late[True]]
+
+        return self._pool.apply_async(self, args,
+                    accept_callback=on_accepted,
+                    callbacks=acks_late[True] and [on_return] or [])
+
+    def debug(self, msg):
+        self.logger.debug("%s: %s" % (self.name, msg))
+
+    @cached_property
+    def logger(self):
+        return self.app.log.get_default_logger()

+ 5 - 3
celery/task/base.py

@@ -665,16 +665,18 @@ class BaseTask(object):
         """
         pass
 
-    def execute(self, wrapper, pool, loglevel, logfile):
+    def execute(self, request, pool, loglevel, logfile, **kwargs):
         """The method the worker calls to execute the task.
 
-        :param wrapper: A :class:`~celery.worker.job.TaskRequest`.
+        :param request: A :class:`~celery.worker.job.TaskRequest`.
         :param pool: A task pool.
         :param loglevel: Current loglevel.
         :param logfile: Name of the currently used logfile.
 
+        :keyword consumer: The :class:`~celery.worker.consumer.Consumer`.
+
         """
-        wrapper.execute_using_pool(pool, loglevel, logfile)
+        request.execute_using_pool(pool, loglevel, logfile)
 
     def __repr__(self):
         """`repr(task)`"""

+ 4 - 3
celery/worker/__init__.py

@@ -258,11 +258,12 @@ class WorkController(object):
             self.stop()
             raise exc
 
-    def process_task(self, wrapper):
+    def process_task(self, request):
         """Process task by sending it to the pool of workers."""
         try:
-            wrapper.task.execute(wrapper, self.pool,
-                                 self.loglevel, self.logfile)
+            request.task.execute(request, self.pool,
+                                 self.loglevel, self.logfile,
+                                 consumer=self.consumer)
         except SystemTerminate:
             self.terminate()
             raise SystemExit()

+ 7 - 6
celery/worker/consumer.py

@@ -106,24 +106,25 @@ class QoS(object):
         self.logger = logger
         self.value = SharedCounter(initial_value)
 
-    def increment(self):
+    def increment(self, n=1):
         """Increment the current prefetch count value by one."""
         if int(self.value):
-            return self.set(self.value.increment())
+            return self.set(self.value.increment(n))
 
-    def decrement(self):
+    def decrement(self, n=1):
         """Decrement the current prefetch count value by one."""
         if int(self.value):
-            return self.set(self.value.decrement())
+            return self.set(self.value.decrement(n))
 
-    def decrement_eventually(self):
+    def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.
 
         The MainThread will be responsible for calling :meth:`update`
         when necessary.
 
         """
-        self.value.decrement()
+        if int(self.value):
+            self.value.decrement(n)
 
     def set(self, pcount):
         """Set channel prefetch_count setting."""

+ 12 - 0
docs/reference/celery.contrib.batches.rst

@@ -0,0 +1,12 @@
+.. currentmodule:: celery.contrib.batches
+
+.. automodule:: celery.contrib.batches
+
+    **API**
+
+    .. autoclass:: Batches
+        :members:
+        :undoc-members:
+    .. autoclass:: SimpleRequest
+        :members:
+        :undoc-members:

+ 2 - 1
docs/reference/index.rst

@@ -27,8 +27,9 @@
     celery.loaders.base
     celery.registry
     celery.states
-    celery.contrib.rdb
     celery.contrib.abortable
+    celery.contrib.batches
+    celery.contrib.rdb
     celery.events
     celery.events.state
     celery.apps.worker