Browse Source

More fixes for #1570

Ask Solem 11 years ago
parent
commit
57cdafab6f

+ 2 - 1
celery/bin/celery.py

@@ -424,7 +424,8 @@ class control(_RemoteControl):
 
     def time_limit(self, method, task_name, soft, hard=None, **kwargs):
         """<task_name> <soft_secs> [hard_secs]"""
-        return self.call(method, task_name, float(soft), float(hard), reply=True, **kwargs)
+        return self.call(method, task_name,
+                         float(soft), float(hard), reply=True, **kwargs)
 
     def add_consumer(self, method, queue, exchange=None,
                      exchange_type='direct', routing_key=None, **kwargs):

+ 6 - 2
celery/worker/autoscale.py

@@ -49,7 +49,7 @@ class WorkerComponent(bootsteps.StartStopStep):
         scaler = w.autoscaler = self.instantiate(
             w.autoscaler_cls,
             w.pool, w.max_concurrency, w.min_concurrency,
-            mutex=DummyLock() if w.use_eventloop else None,
+            worker=w, mutex=DummyLock() if w.use_eventloop else None,
         )
         return scaler if not w.use_eventloop else None
 
@@ -63,7 +63,8 @@ class WorkerComponent(bootsteps.StartStopStep):
 class Autoscaler(bgThread):
 
     def __init__(self, pool, max_concurrency,
-                 min_concurrency=0, keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
+                 min_concurrency=0, worker=None,
+                 keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
         super(Autoscaler, self).__init__()
         self.pool = pool
         self.mutex = mutex or threading.Lock()
@@ -71,6 +72,7 @@ class Autoscaler(bgThread):
         self.min_concurrency = min_concurrency
         self.keepalive = keepalive
         self._last_action = None
+        self.worker = worker
 
         assert self.keepalive, 'cannot scale down too fast.'
 
@@ -133,6 +135,7 @@ class Autoscaler(bgThread):
     def _grow(self, n):
         info('Scaling up %s processes.', n)
         self.pool.grow(n)
+        self.worker.consumer.increment_prefetch_count(n, True)
 
     def _shrink(self, n):
         info('Scaling down %s processes.', n)
@@ -142,6 +145,7 @@ class Autoscaler(bgThread):
             debug("Autoscaler won't scale down: all processes busy.")
         except Exception as exc:
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
+        self.worker.consumer.decrement_prefetch_count(n, True)
 
     def info(self):
         return {'max': self.max_concurrency,

+ 5 - 1
celery/worker/components.py

@@ -212,7 +212,10 @@ class Consumer(bootsteps.StartStopStep):
     last = True
 
     def create(self, w):
-        prefetch_count = w.concurrency * w.prefetch_multiplier
+        if w.max_concurrency:
+            prefetch_count = max(w.min_concurrency, 1) * w.prefetch_multiplier
+        else:
+            prefetch_count = w.concurrency * w.prefetch_multiplier
         c = w.consumer = self.instantiate(
             w.consumer_cls, w.process_task,
             hostname=w.hostname,
@@ -226,5 +229,6 @@ class Consumer(bootsteps.StartStopStep):
             hub=w.hub,
             worker_options=w.options,
             disable_rate_limits=w.disable_rate_limits,
+            prefetch_multiplier=w.prefetch_multiplier,
         )
         return c

+ 13 - 6
celery/worker/consumer.py

@@ -166,7 +166,7 @@ class Consumer(object):
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
                  worker_options=None, disable_rate_limits=False,
-                 initial_prefetch_count=2, **kwargs):
+                 initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
         self.app = app
         self.controller = controller
         self.init_callback = init_callback
@@ -186,6 +186,7 @@ class Consumer(object):
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.disable_rate_limits = disable_rate_limits
         self.initial_prefetch_count = initial_prefetch_count
+        self.prefetch_multiplier = prefetch_multiplier
 
         # this contains a tokenbucket for each task type by name, used for
         # rate limits, or None if rate limits are disabled for that task.
@@ -225,28 +226,35 @@ class Consumer(object):
             (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
         )
 
-    def increment_prefetch_count(self, n=1):
+    def increment_prefetch_count(self, n=1, use_multiplier=False):
         """Increase the prefetch count by ``n``.
 
         This will also increase the initial value so it'll persist between
         consumer restarts.  If you want the change to be temporary,
         you can use ``self.qos.increment_eventually(n)`` instead.
 
+        :keyword use_multiplier: If True the value will be multiplied
+            using the current prefetch multiplier setting.
+
         """
+        n = n * self.prefetch_multiplier if use_multiplier else n
         # initial value must be changed for consumer restart.
         if self.initial_prefetch_count:
             # only increase if prefetch enabled (>0)
             self.initial_prefetch_count += n
         self.qos.increment_eventually(n)
 
-    def decrement_prefetch_count(self, n=1):
+    def decrement_prefetch_count(self, n=1, use_multiplier=False):
         """Decrease prefetch count by ``n``.
 
         This will also decrease the initial value so it'll persist between
         consumer restarts.  If you want the change to be temporary,
         you can use ``self.qos.decrement_eventually(n)`` instead.
 
+        :keyword use_multiplier: If True the value will be multiplied
+            using the current prefetch multiplier setting.
         """
+        n = n * self.prefetch_multiplier if use_multiplier else n
         initial = self.initial_prefetch_count
         if initial:  # was not disabled (>0)
             # must not get lower than 1, since that will disable the limit.
@@ -431,8 +439,6 @@ class Consumer(object):
         callbacks = self.on_task_message
 
         def on_task_received(body, message):
-            if callbacks:
-                [callback() for callback in callbacks]
             try:
                 name = body['task']
             except (KeyError, TypeError):
@@ -441,7 +447,8 @@ class Consumer(object):
             try:
                 strategies[name](message, body,
                                  message.ack_log_error,
-                                 message.reject_log_error)
+                                 message.reject_log_error,
+                                 callbacks)
             except KeyError as exc:
                 on_unknown_task(body, message, exc)
             except InvalidTaskError as exc:

+ 2 - 2
celery/worker/control.py

@@ -283,7 +283,7 @@ def pool_grow(state, n=1, **kwargs):
         state.consumer.controller.autoscaler.force_scale_up(n)
     else:
         state.consumer.pool.grow(n)
-    state.consumer.increment_prefetch_count(n)
+        state.consumer.increment_prefetch_count(n, True)
     return {'ok': 'pool will grow'}
 
 
@@ -293,7 +293,7 @@ def pool_shrink(state, n=1, **kwargs):
         state.consumer.controller.autoscaler.force_scale_down(n)
     else:
         state.consumer.pool.shrink(n)
-    state.consumer.decrement_prefetch_count(n)
+        state.consumer.decrement_prefetch_count(n, True)
     return {'ok': 'pool will shrink'}
 
 

+ 3 - 1
celery/worker/strategy.py

@@ -41,7 +41,7 @@ def default(task, app, consumer,
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
 
-    def task_message_handler(message, body, ack, reject,
+    def task_message_handler(message, body, ack, reject, callbacks,
                              to_timestamp=to_timestamp):
         req = Req(body, on_ack=ack, on_reject=reject,
                   app=app, hostname=hostname,
@@ -82,6 +82,8 @@ def default(task, app, consumer,
                 if bucket:
                     return limit_task(req, bucket, 1)
             task_reserved(req)
+            if callbacks:
+                [callback() for callback in callbacks]
             handle(req)
 
     return task_message_handler