|
@@ -15,6 +15,7 @@ from celery.conf import AMQP_CONNECTION_RETRY, AMQP_CONNECTION_MAX_RETRIES
|
|
|
from celery.log import setup_logger
|
|
|
from celery.pool import TaskPool
|
|
|
from celery.utils import retry_over_time
|
|
|
+from celery.datastructures import SharedCounter
|
|
|
from Queue import Queue
|
|
|
import traceback
|
|
|
import logging
|
|
@@ -43,12 +44,14 @@ class AMQPListener(object):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, bucket_queue, hold_queue, logger):
|
|
|
+ def __init__(self, bucket_queue, hold_queue, logger,
|
|
|
+ initial_prefetch_count=2):
|
|
|
self.amqp_connection = None
|
|
|
self.task_consumer = None
|
|
|
self.bucket_queue = bucket_queue
|
|
|
self.hold_queue = hold_queue
|
|
|
self.logger = logger
|
|
|
+ self.prefetch_count = SharedCounter(initial_prefetch_count)
|
|
|
|
|
|
def start(self):
|
|
|
"""Start the consumer.
|
|
@@ -77,6 +80,7 @@ class AMQPListener(object):
|
|
|
self.logger.debug("AMQPListener: Ready to accept tasks!")
|
|
|
|
|
|
while True:
|
|
|
+ self.task_consumer.qos(prefetch_count=int(self.prefetch_count))
|
|
|
it.next()
|
|
|
|
|
|
def stop(self):
|
|
@@ -102,12 +106,15 @@ class AMQPListener(object):
|
|
|
if eta:
|
|
|
self.logger.info("Got task from broker: %s[%s] eta:[%s]" % (
|
|
|
task.task_name, task.task_id, eta))
|
|
|
- self.hold_queue.put((task, eta))
|
|
|
+ self.hold_queue.put((task, eta, self.prefetch_count.decrement))
|
|
|
+ self.prefetch_count.increment()
|
|
|
else:
|
|
|
+ self.prefetch_count.decrement()
|
|
|
self.logger.info("Got task from broker: %s[%s]" % (
|
|
|
task.task_name, task.task_id))
|
|
|
self.bucket_queue.put(task)
|
|
|
|
|
|
+
|
|
|
def close_connection(self):
|
|
|
"""Close the AMQP connection."""
|
|
|
if self.task_consumer:
|
|
@@ -246,7 +253,8 @@ class WorkController(object):
|
|
|
self.hold_queue)
|
|
|
self.pool = TaskPool(self.concurrency, logger=self.logger)
|
|
|
self.amqp_listener = AMQPListener(self.bucket_queue, self.hold_queue,
|
|
|
- logger=self.logger)
|
|
|
+ logger=self.logger,
|
|
|
+ initial_prefetch_count=concurrency)
|
|
|
self.mediator = Mediator(self.bucket_queue, self.safe_process_task)
|
|
|
|
|
|
# The order is important here;
|