Browse Source

moving initial_prefetch_count to consumer from tasks step

Dan 11 years ago
parent
commit
551546ceae
2 changed files with 6 additions and 5 deletions
  1. 1 1
      celery/tests/worker/test_consumer.py
  2. 5 4
      celery/worker/consumer.py

+ 1 - 1
celery/tests/worker/test_consumer.py

@@ -182,7 +182,7 @@ class test_Tasks(AppCase):
         tasks = Tasks(c)
         self.assertIsNone(c.task_consumer)
         self.assertIsNone(c.qos)
-        self.assertEqual(tasks.initial_prefetch_count, 2)
+        self.assertEqual(c.initial_prefetch_count, 2)
 
         c.task_consumer = Mock()
         tasks.stop(c)

+ 5 - 4
celery/worker/consumer.py

@@ -165,7 +165,8 @@ class Consumer(object):
                  init_callback=noop, hostname=None,
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
-                 worker_options=None, disable_rate_limits=False, **kwargs):
+                 worker_options=None, disable_rate_limits=False, 
+                 initial_prefetch_count=2, **kwargs):
         self.app = app
         self.controller = controller
         self.init_callback = init_callback
@@ -184,6 +185,7 @@ class Consumer(object):
         self.on_task_message = set()
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.disable_rate_limits = disable_rate_limits
+        self.initial_prefetch_count = initial_prefetch_count
 
         # this contains a tokenbucket for each task type by name, used for
         # rate limits, or None if rate limits are disabled for that task.
@@ -506,16 +508,15 @@ class Control(bootsteps.StartStopStep):
 class Tasks(bootsteps.StartStopStep):
     requires = (Events, )
 
-    def __init__(self, c, initial_prefetch_count=2, **kwargs):
+    def __init__(self, c, **kwargs):
         c.task_consumer = c.qos = None
-        self.initial_prefetch_count = initial_prefetch_count
 
     def start(self, c):
         c.update_strategies()
         c.task_consumer = c.app.amqp.TaskConsumer(
             c.connection, on_decode_error=c.on_decode_error,
         )
-        c.qos = QoS(c.task_consumer.qos, self.initial_prefetch_count)
+        c.qos = QoS(c.task_consumer.qos, c.initial_prefetch_count)
         c.qos.update()  # set initial prefetch count
 
     def stop(self, c):