Browse Source

More fixes for #1570

Ask Solem 11 years ago
parent
commit
57cdafab6f

+ 2 - 1
celery/bin/celery.py

@@ -424,7 +424,8 @@ class control(_RemoteControl):
 
 
     def time_limit(self, method, task_name, soft, hard=None, **kwargs):
     def time_limit(self, method, task_name, soft, hard=None, **kwargs):
         """<task_name> <soft_secs> [hard_secs]"""
         """<task_name> <soft_secs> [hard_secs]"""
-        return self.call(method, task_name, float(soft), float(hard), reply=True, **kwargs)
+        return self.call(method, task_name,
+                         float(soft), float(hard), reply=True, **kwargs)
 
 
     def add_consumer(self, method, queue, exchange=None,
     def add_consumer(self, method, queue, exchange=None,
                      exchange_type='direct', routing_key=None, **kwargs):
                      exchange_type='direct', routing_key=None, **kwargs):

+ 6 - 2
celery/worker/autoscale.py

@@ -49,7 +49,7 @@ class WorkerComponent(bootsteps.StartStopStep):
         scaler = w.autoscaler = self.instantiate(
         scaler = w.autoscaler = self.instantiate(
             w.autoscaler_cls,
             w.autoscaler_cls,
             w.pool, w.max_concurrency, w.min_concurrency,
             w.pool, w.max_concurrency, w.min_concurrency,
-            mutex=DummyLock() if w.use_eventloop else None,
+            worker=w, mutex=DummyLock() if w.use_eventloop else None,
         )
         )
         return scaler if not w.use_eventloop else None
         return scaler if not w.use_eventloop else None
 
 
@@ -63,7 +63,8 @@ class WorkerComponent(bootsteps.StartStopStep):
 class Autoscaler(bgThread):
 class Autoscaler(bgThread):
 
 
     def __init__(self, pool, max_concurrency,
     def __init__(self, pool, max_concurrency,
-                 min_concurrency=0, keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
+                 min_concurrency=0, worker=None,
+                 keepalive=AUTOSCALE_KEEPALIVE, mutex=None):
         super(Autoscaler, self).__init__()
         super(Autoscaler, self).__init__()
         self.pool = pool
         self.pool = pool
         self.mutex = mutex or threading.Lock()
         self.mutex = mutex or threading.Lock()
@@ -71,6 +72,7 @@ class Autoscaler(bgThread):
         self.min_concurrency = min_concurrency
         self.min_concurrency = min_concurrency
         self.keepalive = keepalive
         self.keepalive = keepalive
         self._last_action = None
         self._last_action = None
+        self.worker = worker
 
 
         assert self.keepalive, 'cannot scale down too fast.'
         assert self.keepalive, 'cannot scale down too fast.'
 
 
@@ -133,6 +135,7 @@ class Autoscaler(bgThread):
     def _grow(self, n):
     def _grow(self, n):
         info('Scaling up %s processes.', n)
         info('Scaling up %s processes.', n)
         self.pool.grow(n)
         self.pool.grow(n)
+        self.worker.consumer.increment_prefetch_count(n, True)
 
 
     def _shrink(self, n):
     def _shrink(self, n):
         info('Scaling down %s processes.', n)
         info('Scaling down %s processes.', n)
@@ -142,6 +145,7 @@ class Autoscaler(bgThread):
             debug("Autoscaler won't scale down: all processes busy.")
             debug("Autoscaler won't scale down: all processes busy.")
         except Exception as exc:
         except Exception as exc:
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
             error('Autoscaler: scale_down: %r', exc, exc_info=True)
+        self.worker.consumer.decrement_prefetch_count(n, True)
 
 
     def info(self):
     def info(self):
         return {'max': self.max_concurrency,
         return {'max': self.max_concurrency,

+ 5 - 1
celery/worker/components.py

@@ -212,7 +212,10 @@ class Consumer(bootsteps.StartStopStep):
     last = True
     last = True
 
 
     def create(self, w):
     def create(self, w):
-        prefetch_count = w.concurrency * w.prefetch_multiplier
+        if w.max_concurrency:
+            prefetch_count = max(w.min_concurrency, 1) * w.prefetch_multiplier
+        else:
+            prefetch_count = w.concurrency * w.prefetch_multiplier
         c = w.consumer = self.instantiate(
         c = w.consumer = self.instantiate(
             w.consumer_cls, w.process_task,
             w.consumer_cls, w.process_task,
             hostname=w.hostname,
             hostname=w.hostname,
@@ -226,5 +229,6 @@ class Consumer(bootsteps.StartStopStep):
             hub=w.hub,
             hub=w.hub,
             worker_options=w.options,
             worker_options=w.options,
             disable_rate_limits=w.disable_rate_limits,
             disable_rate_limits=w.disable_rate_limits,
+            prefetch_multiplier=w.prefetch_multiplier,
         )
         )
         return c
         return c

+ 13 - 6
celery/worker/consumer.py

@@ -166,7 +166,7 @@ class Consumer(object):
                  pool=None, app=None,
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
                  worker_options=None, disable_rate_limits=False,
                  worker_options=None, disable_rate_limits=False,
-                 initial_prefetch_count=2, **kwargs):
+                 initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
         self.app = app
         self.app = app
         self.controller = controller
         self.controller = controller
         self.init_callback = init_callback
         self.init_callback = init_callback
@@ -186,6 +186,7 @@ class Consumer(object):
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.disable_rate_limits = disable_rate_limits
         self.disable_rate_limits = disable_rate_limits
         self.initial_prefetch_count = initial_prefetch_count
         self.initial_prefetch_count = initial_prefetch_count
+        self.prefetch_multiplier = prefetch_multiplier
 
 
         # this contains a tokenbucket for each task type by name, used for
         # this contains a tokenbucket for each task type by name, used for
         # rate limits, or None if rate limits are disabled for that task.
         # rate limits, or None if rate limits are disabled for that task.
@@ -225,28 +226,35 @@ class Consumer(object):
             (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
             (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
         )
         )
 
 
-    def increment_prefetch_count(self, n=1):
+    def increment_prefetch_count(self, n=1, use_multiplier=False):
         """Increase the prefetch count by ``n``.
         """Increase the prefetch count by ``n``.
 
 
         This will also increase the initial value so it'll persist between
         This will also increase the initial value so it'll persist between
         consumer restarts.  If you want the change to be temporary,
         consumer restarts.  If you want the change to be temporary,
         you can use ``self.qos.increment_eventually(n)`` instead.
         you can use ``self.qos.increment_eventually(n)`` instead.
 
 
+        :keyword use_multiplier: If True the value will be multiplied
+            using the current prefetch multiplier setting.
+
         """
         """
+        n = n * self.prefetch_multiplier if use_multiplier else n
         # initial value must be changed for consumer restart.
         # initial value must be changed for consumer restart.
         if self.initial_prefetch_count:
         if self.initial_prefetch_count:
             # only increase if prefetch enabled (>0)
             # only increase if prefetch enabled (>0)
             self.initial_prefetch_count += n
             self.initial_prefetch_count += n
         self.qos.increment_eventually(n)
         self.qos.increment_eventually(n)
 
 
-    def decrement_prefetch_count(self, n=1):
+    def decrement_prefetch_count(self, n=1, use_multiplier=False):
         """Decrease prefetch count by ``n``.
         """Decrease prefetch count by ``n``.
 
 
         This will also decrease the initial value so it'll persist between
         This will also decrease the initial value so it'll persist between
         consumer restarts.  If you want the change to be temporary,
         consumer restarts.  If you want the change to be temporary,
         you can use ``self.qos.decrement_eventually(n)`` instead.
         you can use ``self.qos.decrement_eventually(n)`` instead.
 
 
+        :keyword use_multiplier: If True the value will be multiplied
+            using the current prefetch multiplier setting.
         """
         """
+        n = n * self.prefetch_multiplier if use_multiplier else n
         initial = self.initial_prefetch_count
         initial = self.initial_prefetch_count
         if initial:  # was not disabled (>0)
         if initial:  # was not disabled (>0)
             # must not get lower than 1, since that will disable the limit.
             # must not get lower than 1, since that will disable the limit.
@@ -431,8 +439,6 @@ class Consumer(object):
         callbacks = self.on_task_message
         callbacks = self.on_task_message
 
 
         def on_task_received(body, message):
         def on_task_received(body, message):
-            if callbacks:
-                [callback() for callback in callbacks]
             try:
             try:
                 name = body['task']
                 name = body['task']
             except (KeyError, TypeError):
             except (KeyError, TypeError):
@@ -441,7 +447,8 @@ class Consumer(object):
             try:
             try:
                 strategies[name](message, body,
                 strategies[name](message, body,
                                  message.ack_log_error,
                                  message.ack_log_error,
-                                 message.reject_log_error)
+                                 message.reject_log_error,
+                                 callbacks)
             except KeyError as exc:
             except KeyError as exc:
                 on_unknown_task(body, message, exc)
                 on_unknown_task(body, message, exc)
             except InvalidTaskError as exc:
             except InvalidTaskError as exc:

+ 2 - 2
celery/worker/control.py

@@ -283,7 +283,7 @@ def pool_grow(state, n=1, **kwargs):
         state.consumer.controller.autoscaler.force_scale_up(n)
         state.consumer.controller.autoscaler.force_scale_up(n)
     else:
     else:
         state.consumer.pool.grow(n)
         state.consumer.pool.grow(n)
-    state.consumer.increment_prefetch_count(n)
+        state.consumer.increment_prefetch_count(n, True)
     return {'ok': 'pool will grow'}
     return {'ok': 'pool will grow'}
 
 
 
 
@@ -293,7 +293,7 @@ def pool_shrink(state, n=1, **kwargs):
         state.consumer.controller.autoscaler.force_scale_down(n)
         state.consumer.controller.autoscaler.force_scale_down(n)
     else:
     else:
         state.consumer.pool.shrink(n)
         state.consumer.pool.shrink(n)
-    state.consumer.decrement_prefetch_count(n)
+        state.consumer.decrement_prefetch_count(n, True)
     return {'ok': 'pool will shrink'}
     return {'ok': 'pool will shrink'}
 
 
 
 

+ 3 - 1
celery/worker/strategy.py

@@ -41,7 +41,7 @@ def default(task, app, consumer,
     handle = consumer.on_task_request
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
     limit_task = consumer._limit_task
 
 
-    def task_message_handler(message, body, ack, reject,
+    def task_message_handler(message, body, ack, reject, callbacks,
                              to_timestamp=to_timestamp):
                              to_timestamp=to_timestamp):
         req = Req(body, on_ack=ack, on_reject=reject,
         req = Req(body, on_ack=ack, on_reject=reject,
                   app=app, hostname=hostname,
                   app=app, hostname=hostname,
@@ -82,6 +82,8 @@ def default(task, app, consumer,
                 if bucket:
                 if bucket:
                     return limit_task(req, bucket, 1)
                     return limit_task(req, bucket, 1)
             task_reserved(req)
             task_reserved(req)
+            if callbacks:
+                [callback() for callback in callbacks]
             handle(req)
             handle(req)
 
 
     return task_message_handler
     return task_message_handler