Browse Source

Merge branch 'autoscale'

Conflicts:
	celery/apps/worker.py
	celery/conf.py
	celery/worker/__init__.py
	celery/worker/controllers.py
	celery/worker/listener.py
Ask Solem 14 years ago
parent
commit
ba8b38c848

+ 1 - 0
celery/app/defaults.py

@@ -96,6 +96,7 @@ NAMESPACES = {
         "REDIRECT_STDOUTS_LEVEL": Option("WARNING"),
         "REDIRECT_STDOUTS_LEVEL": Option("WARNING"),
     },
     },
     "CELERYD": {
     "CELERYD": {
+        "AUTOSCALER": Option("celery.worker.controllers.Autoscaler"),
         "CONCURRENCY": Option(0, type="int"),
         "CONCURRENCY": Option(0, type="int"),
         "ETA_SCHEDULER": Option("celery.utils.timer2.Timer"),
         "ETA_SCHEDULER": Option("celery.utils.timer2.Timer"),
         "ETA_SCHEDULER_PRECISION": Option(1.0, type="float"),
         "ETA_SCHEDULER_PRECISION": Option(1.0, type="float"),

+ 10 - 2
celery/apps/worker.py

@@ -6,6 +6,8 @@ import socket
 import sys
 import sys
 import warnings
 import warnings
 
 
+from carrot.utils import partition
+
 from celery import __version__
 from celery import __version__
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
@@ -40,7 +42,8 @@ class Worker(object):
             schedule=None, task_time_limit=None, task_soft_time_limit=None,
             schedule=None, task_time_limit=None, task_soft_time_limit=None,
             max_tasks_per_child=None, queues=None, events=False, db=None,
             max_tasks_per_child=None, queues=None, events=False, db=None,
             include=None, app=None, pidfile=None,
             include=None, app=None, pidfile=None,
-            redirect_stdouts=None, redirect_stdouts_level=None, **kwargs):
+            redirect_stdouts=None, redirect_stdouts_level=None,
+            autoscale=None, **kwargs):
         self.app = app = app_or_default(app)
         self.app = app = app_or_default(app)
         self.concurrency = (concurrency or
         self.concurrency = (concurrency or
                             app.conf.CELERYD_CONCURRENCY or
                             app.conf.CELERYD_CONCURRENCY or
@@ -67,6 +70,10 @@ class Worker(object):
         self.queues = None
         self.queues = None
         self.include = include or []
         self.include = include or []
         self.pidfile = pidfile
         self.pidfile = pidfile
+        self.autoscale = None
+        if autoscale:
+            max_c, _, min_c = partition(autoscale, ",")
+            self.autoscale = [int(max_c), min_c and int(min_c) or 0]
         self._isatty = sys.stdout.isatty()
         self._isatty = sys.stdout.isatty()
         self.colored = term.colored(enabled=app.conf.CELERYD_LOG_COLOR)
         self.colored = term.colored(enabled=app.conf.CELERYD_LOG_COLOR)
 
 
@@ -194,7 +201,8 @@ class Worker(object):
                                 queues=self.queues,
                                 queues=self.queues,
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 task_time_limit=self.task_time_limit,
                                 task_time_limit=self.task_time_limit,
-                                task_soft_time_limit=self.task_soft_time_limit)
+                                task_soft_time_limit=self.task_soft_time_limit,
+                                autoscale=self.autoscale)
         self.install_platform_tweaks(worker)
         self.install_platform_tweaks(worker)
         worker.start()
         worker.start()
 
 

+ 5 - 0
celery/bin/celeryd.py

@@ -150,6 +150,11 @@ class WorkerCommand(Command):
                 help="Optional file used to store the workers pid. "
                 help="Optional file used to store the workers pid. "
                      "The worker will not start if this file already exists "
                      "The worker will not start if this file already exists "
                      "and the pid is still alive."),
                      "and the pid is still alive."),
+            Option('--autoscale', default=None,
+                help="Enable autoscaling by providing "
+                     "max_concurrency,min_concurrency. Example: "
+                     "--autoscale=10,3 (always keep 3 processes, "
+                     "but grow to 10 if necessary)."),
         )
         )
 
 
 
 

+ 6 - 0
celery/concurrency/processes/__init__.py

@@ -102,6 +102,12 @@ class TaskPool(object):
                                       error_callback=on_worker_error,
                                       error_callback=on_worker_error,
                                       waitforslot=self.putlocks)
                                       waitforslot=self.putlocks)
 
 
+    def grow(self, n=1):
+        return self._pool.grow(n)
+
+    def shrink(self, n=1):
+        return self._pool.shrink(n)
+
     def on_worker_error(self, errbacks, exc):
     def on_worker_error(self, errbacks, exc):
         einfo = ExceptionInfo((exc.__class__, exc, None))
         einfo = ExceptionInfo((exc.__class__, exc, None))
         [errback(einfo) for errback in errbacks]
         [errback(einfo) for errback in errbacks]

+ 42 - 4
celery/concurrency/processes/pool.py

@@ -81,8 +81,8 @@ def soft_timeout_sighandler(signum, frame):
 
 
 
 
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
-    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     pid = os.getpid()
     pid = os.getpid()
+    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
     put = outqueue.put
     get = inqueue.get
     get = inqueue.get
 
 
@@ -108,6 +108,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     if SIG_SOFT_TIMEOUT is not None:
     if SIG_SOFT_TIMEOUT is not None:
         signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler)
         signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler)
 
 
+
     completed = 0
     completed = 0
     while maxtasks is None or (maxtasks and completed < maxtasks):
     while maxtasks is None or (maxtasks and completed < maxtasks):
         try:
         try:
@@ -137,7 +138,6 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
         completed += 1
         completed += 1
     debug('worker exiting after %d tasks' % completed)
     debug('worker exiting after %d tasks' % completed)
 
 
-
 #
 #
 # Class representing a process pool
 # Class representing a process pool
 #
 #
@@ -527,6 +527,44 @@ class Pool(object):
             return True
             return True
         return False
         return False
 
 
+    def shrink(self, n=1):
+        for i, worker in enumerate(self._iterinactive()):
+            self._processes -= 1
+            if self._putlock:
+                self._putlock._initial_value -= 1
+                self._putlock.acquire()
+            worker.terminate()
+            if i == n - 1:
+                return
+        raise ValueError("Can't shrink pool. All processes busy!")
+
+    def grow(self, n=1):
+        for i in xrange(n):
+            #assert len(self._pool) == self._processes
+            self._processes += 1
+            if self._putlock:
+                cond = self._putlock._Semaphore__cond
+                cond.acquire()
+                try:
+                    self._putlock._initial_value += 1
+                    self._putlock._Semaphore__value += 1
+                    cond.notify()
+                finally:
+                    cond.release()
+
+    def _iterinactive(self):
+        for worker in self._pool:
+            if not self._worker_active(worker):
+                yield worker
+        raise
+
+    def _worker_active(self, worker):
+        jobs = []
+        for job in self._cache.values():
+            if worker.pid in job.worker_pids():
+                return True
+        return False
+
     def _repopulate_pool(self):
     def _repopulate_pool(self):
         """Bring the number of pool processes up to the specified number,
         """Bring the number of pool processes up to the specified number,
         for use after reaping workers which have exited.
         for use after reaping workers which have exited.
@@ -541,8 +579,8 @@ class Pool(object):
     def _maintain_pool(self):
     def _maintain_pool(self):
         """"Clean up any exited workers and start replacements for them.
         """"Clean up any exited workers and start replacements for them.
         """
         """
-        if self._join_exited_workers():
-            self._repopulate_pool()
+        self._join_exited_workers()
+        self._repopulate_pool()
 
 
     def _setup_queues(self):
     def _setup_queues(self):
         from multiprocessing.queues import SimpleQueue
         from multiprocessing.queues import SimpleQueue

+ 18 - 2
celery/worker/__init__.py

@@ -120,7 +120,8 @@ class WorkController(object):
             task_soft_time_limit=None, max_tasks_per_child=None,
             task_soft_time_limit=None, max_tasks_per_child=None,
             pool_putlocks=None, db=None, prefetch_multiplier=None,
             pool_putlocks=None, db=None, prefetch_multiplier=None,
             eta_scheduler_precision=None, queues=None,
             eta_scheduler_precision=None, queues=None,
-            disable_rate_limits=None, app=None):
+            disable_rate_limits=None, autoscale=None,
+            autoscaler_cls=None, app=None):
 
 
         self.app = app_or_default(app)
         self.app = app_or_default(app)
         conf = self.app.conf
         conf = self.app.conf
@@ -138,6 +139,8 @@ class WorkController(object):
         self.mediator_cls = mediator_cls or conf.CELERYD_MEDIATOR
         self.mediator_cls = mediator_cls or conf.CELERYD_MEDIATOR
         self.eta_scheduler_cls = eta_scheduler_cls or \
         self.eta_scheduler_cls = eta_scheduler_cls or \
                                     conf.CELERYD_ETA_SCHEDULER
                                     conf.CELERYD_ETA_SCHEDULER
+        self.autoscaler_cls = autoscaler_cls or \
+                                    conf.CELERYD_AUTOSCALER
         self.schedule_filename = schedule_filename or \
         self.schedule_filename = schedule_filename or \
                                     conf.CELERYBEAT_SCHEDULE_FILENAME
                                     conf.CELERYBEAT_SCHEDULE_FILENAME
         self.hostname = hostname or socket.gethostname()
         self.hostname = hostname or socket.gethostname()
@@ -178,7 +181,13 @@ class WorkController(object):
         self.logger.debug("Instantiating thread components...")
         self.logger.debug("Instantiating thread components...")
 
 
         # Threads + Pool + Consumer
         # Threads + Pool + Consumer
-        self.pool = instantiate(self.pool_cls, self.concurrency,
+        self.autoscaler = None
+        max_concurrency = None
+        min_concurrency = concurrency
+        if autoscale:
+            max_concurrency, min_concurrency = autoscale
+
+        self.pool = instantiate(self.pool_cls, min_concurrency,
                                 logger=self.logger,
                                 logger=self.logger,
                                 initializer=process_initializer,
                                 initializer=process_initializer,
                                 initargs=(self.app, self.hostname),
                                 initargs=(self.app, self.hostname),
@@ -187,6 +196,12 @@ class WorkController(object):
                                 soft_timeout=self.task_soft_time_limit,
                                 soft_timeout=self.task_soft_time_limit,
                                 putlocks=self.pool_putlocks)
                                 putlocks=self.pool_putlocks)
 
 
+        if autoscale:
+            self.autoscaler = instantiate(self.autoscaler_cls, self.pool,
+                                          max_concurrency=max_concurrency,
+                                          min_concurrency=min_concurrency,
+                                          logger=self.logger)
+
         self.mediator = None
         self.mediator = None
         if not disable_rate_limits:
         if not disable_rate_limits:
             self.mediator = instantiate(self.mediator_cls, self.ready_queue,
             self.mediator = instantiate(self.mediator_cls, self.ready_queue,
@@ -224,6 +239,7 @@ class WorkController(object):
                                         self.mediator,
                                         self.mediator,
                                         self.scheduler,
                                         self.scheduler,
                                         self.beat,
                                         self.beat,
+                                        self.autoscaler,
                                         self.listener))
                                         self.listener))
 
 
     def start(self):
     def start(self):

+ 11 - 0
celery/worker/control/builtins.py

@@ -167,6 +167,17 @@ def ping(panel, **kwargs):
     return "pong"
     return "pong"
 
 
 
 
+@Panel.register
+def pool_grow(panel, n=1, **kwargs):
+    panel.listener.pool.grow(n)
+    return {"ok": "spawned worker processes"}
+
+@Panel.register
+def pool_shrink(panel, n=1, **kwargs):
+    panel.listener.pool.shrink(n)
+    return {"ok": "terminated worker processes"}
+
+
 @Panel.register
 @Panel.register
 def shutdown(panel, **kwargs):
 def shutdown(panel, **kwargs):
     panel.logger.critical("Got shutdown from remote.")
     panel.logger.critical("Got shutdown from remote.")

+ 66 - 0
celery/worker/controllers.py

@@ -7,10 +7,76 @@ import logging
 import sys
 import sys
 import threading
 import threading
 import traceback
 import traceback
+
+from time import sleep, time
 from Queue import Empty as QueueEmpty
 from Queue import Empty as QueueEmpty
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.utils.compat import log_with_extra
 from celery.utils.compat import log_with_extra
+from celery.worker import state
+
+
+class Autoscaler(threading.Thread):
+
+    def __init__(self, pool, max_concurrency, min_concurrency=0,
+            keepalive=30, logger=None):
+        threading.Thread.__init__(self)
+        self.pool = pool
+        self.max_concurrency = max_concurrency
+        self.min_concurrency = min_concurrency
+        self.keepalive = keepalive
+        self.logger = logger or log.get_default_logger()
+        self._last_action = None
+        self._shutdown = threading.Event()
+        self._stopped = threading.Event()
+        self.setDaemon(True)
+        self.setName(self.__class__.__name__)
+
+        assert self.keepalive, "can't scale down too fast."
+
+    def scale(self):
+        current = min(self.qty, self.max_concurrency)
+        if current > self.processes:
+            self.scale_up(current - self.processes)
+        elif current < self.processes:
+            self.scale_down((self.processes - current) - self.min_concurrency)
+        sleep(1.0)
+
+    def scale_up(self, n):
+        self.logger.info("Scaling up %s processes." % (n, ))
+        self._last_action = time()
+        return self.pool.grow(n)
+
+    def scale_down(self, n):
+        if not self._last_action or not n:
+            return
+        if time() - self._last_action > self.keepalive:
+            self.logger.info("Scaling down %s processes." % (n, ))
+            self._last_action = time()
+            try:
+                self.pool.shrink(n)
+            except Exception, exc:
+                import traceback
+                traceback.print_stack()
+                self.logger.error("Autoscaler: scale_down: %r" % (exc, ))
+
+    def run(self):
+        while not self._shutdown.isSet():
+            self.scale()
+        self._stopped.set()                 # indicate that we are stopped
+
+    def stop(self):
+        self._shutdown.set()
+        self._stopped.wait()                # block until this thread is done
+        self.join(1e100)
+
+    @property
+    def qty(self):
+        return len(state.reserved_requests)
+
+    @property
+    def processes(self):
+        return self.pool._pool._processes
 
 
 
 
 class Mediator(threading.Thread):
 class Mediator(threading.Thread):

+ 3 - 0
celery/worker/listener.py

@@ -88,6 +88,7 @@ from celery.exceptions import NotRegistered
 from celery.pidbox import BroadcastConsumer
 from celery.pidbox import BroadcastConsumer
 from celery.utils import noop, retry_over_time
 from celery.utils import noop, retry_over_time
 from celery.utils.timer2 import to_timestamp
 from celery.utils.timer2 import to_timestamp
+from celery.worker import state
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.control import ControlDispatch
 from celery.worker.control import ControlDispatch
 from celery.worker.heartbeat import Heart
 from celery.worker.heartbeat import Heart
@@ -285,9 +286,11 @@ class CarrotListener(object):
                 self.eta_schedule.apply_at(eta,
                 self.eta_schedule.apply_at(eta,
                                            self.apply_eta_task, (task, ))
                                            self.apply_eta_task, (task, ))
         else:
         else:
+            state.task_reserved(task)
             self.ready_queue.put(task)
             self.ready_queue.put(task)
 
 
     def apply_eta_task(self, task):
     def apply_eta_task(self, task):
+        state.task_reserved(task)
         self.ready_queue.put(task)
         self.ready_queue.put(task)
         self.qos.decrement_eventually()
         self.qos.decrement_eventually()
 
 

+ 6 - 0
celery/worker/state.py

@@ -24,11 +24,16 @@ Count of tasks executed by the worker, sorted by type.
 The list of currently revoked tasks. (PERSISTENT if statedb set).
 The list of currently revoked tasks. (PERSISTENT if statedb set).
 
 
 """
 """
+reserved_requests = set()
 active_requests = set()
 active_requests = set()
 total_count = defaultdict(lambda: 0)
 total_count = defaultdict(lambda: 0)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
 
 
 
+def task_reserved(request):
+    reserved_requests.add(request)
+
+
 def task_accepted(request):
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
     active_requests.add(request)
@@ -38,6 +43,7 @@ def task_accepted(request):
 def task_ready(request):
 def task_ready(request):
     """Updates global state when a task is ready."""
     """Updates global state when a task is ready."""
     active_requests.discard(request)
     active_requests.discard(request)
+    reserved_requests.discard(request)
 
 
 
 
 class Persistent(object):
 class Persistent(object):