Ask Solem 12 лет назад
Родитель
Сommit
8e254e0fe0
1 измененных файлов с 28 добавлено и 20 удалено
  1. 28 20
      celery/contrib/batches.py

+ 28 - 20
celery/contrib/batches.py

@@ -12,7 +12,7 @@ A click counter that flushes the buffer every 100 messages, and every
 
 .. code-block:: python
 
-    from celery.task import task
+    from celery import task
     from celery.contrib.batches import Batches
 
     # Flush after 100 messages, or 10 seconds.
@@ -43,9 +43,9 @@ from itertools import count
 from Queue import Empty, Queue
 
 from celery.task import Task
-from celery.utils import timer2
 from celery.utils.log import get_logger
 from celery.worker import state
+from celery.worker.job import Request
 
 
 logger = get_logger(__name__)
@@ -136,30 +136,39 @@ class Batches(Task):
         self._count = count(1).next
         self._tref = None
         self._pool = None
-        self._logging = None
 
     def run(self, requests):
         raise NotImplementedError('%r must implement run(requests)' % (self, ))
 
-    def flush(self, requests):
-        return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
-                                                for r in requests], ))
+    def Strategy(self, task, app, consumer):
+        self._pool = consumer.pool
+        hostname = consumer.hostname
+        eventer = consumer.event_dispatcher
+        Req = Request
+        connection_errors = consumer.connection_errors
+        timer = consumer.timer
+        put_buffer = self._buffer.put
+        flush_buffer = self._do_flush
+
+        def task_message_handler(message, body, ack):
+            request = Req(body, on_ack=ack, app=app, hostname=hostname,
+                          events=eventer, task=task,
+                          connection_errors=connection_errors,
+                          delivery_info=message.delivery_info)
+            put_buffer(request)
 
-    def execute(self, request, pool, loglevel, logfile):
-        if not self._pool:         # just take pool from first task.
-            self._pool = pool
-        if not self._logging:
-            self._logging = loglevel, logfile
+            if self._tref is None:     # first request starts flush timer.
+                self._tref = timer.apply_interval(self.flush_interval * 1000.0,
+                                                  flush_buffer)
 
-        state.task_ready(request)  # immediately remove from worker state.
-        self._buffer.put(request)
+            if not self._count() % self.flush_every:
+                flush_buffer()
 
-        if self._tref is None:     # first request starts flush timer.
-            self._tref = timer2.apply_interval(self.flush_interval * 1000,
-                                               self._do_flush)
+        return task_message_handler
 
-        if not self._count() % self.flush_every:
-            self._do_flush()
+    def flush(self, requests):
+        return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
+                                                for r in requests], ))
 
     def _do_flush(self):
         logger.debug('Batches: Wake-up to flush buffer...')
@@ -185,8 +194,7 @@ class Batches(Task):
         def on_return(result):
             [req.acknowledge() for req in acks_late[True]]
 
-        loglevel, logfile = self._logging
         return self._pool.apply_async(apply_batches_task,
-                    (self, args, loglevel, logfile),
+                    (self, args, 0, None),
                     accept_callback=on_accepted,
                     callback=acks_late[True] and on_return or None)