Browse Source

Merge branch 'autoscale'

Conflicts:
	celery/apps/worker.py
	celery/conf.py
	celery/worker/__init__.py
	celery/worker/controllers.py
	celery/worker/listener.py
Ask Solem 14 years ago
parent
commit
ba8b38c848

+ 1 - 0
celery/app/defaults.py

@@ -96,6 +96,7 @@ NAMESPACES = {
         "REDIRECT_STDOUTS_LEVEL": Option("WARNING"),
     },
     "CELERYD": {
+        "AUTOSCALER": Option("celery.worker.controllers.Autoscaler"),
         "CONCURRENCY": Option(0, type="int"),
         "ETA_SCHEDULER": Option("celery.utils.timer2.Timer"),
         "ETA_SCHEDULER_PRECISION": Option(1.0, type="float"),

+ 10 - 2
celery/apps/worker.py

@@ -6,6 +6,8 @@ import socket
 import sys
 import warnings
 
+from carrot.utils import partition
+
 from celery import __version__
 from celery import platforms
 from celery import signals
@@ -40,7 +42,8 @@ class Worker(object):
             schedule=None, task_time_limit=None, task_soft_time_limit=None,
             max_tasks_per_child=None, queues=None, events=False, db=None,
             include=None, app=None, pidfile=None,
-            redirect_stdouts=None, redirect_stdouts_level=None, **kwargs):
+            redirect_stdouts=None, redirect_stdouts_level=None,
+            autoscale=None, **kwargs):
         self.app = app = app_or_default(app)
         self.concurrency = (concurrency or
                             app.conf.CELERYD_CONCURRENCY or
@@ -67,6 +70,10 @@ class Worker(object):
         self.queues = None
         self.include = include or []
         self.pidfile = pidfile
+        self.autoscale = None
+        if autoscale:
+            max_c, _, min_c = partition(autoscale, ",")
+            self.autoscale = [int(max_c), min_c and int(min_c) or 0]
         self._isatty = sys.stdout.isatty()
         self.colored = term.colored(enabled=app.conf.CELERYD_LOG_COLOR)
 
@@ -194,7 +201,8 @@ class Worker(object):
                                 queues=self.queues,
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 task_time_limit=self.task_time_limit,
-                                task_soft_time_limit=self.task_soft_time_limit)
+                                task_soft_time_limit=self.task_soft_time_limit,
+                                autoscale=self.autoscale)
         self.install_platform_tweaks(worker)
         worker.start()
 

+ 5 - 0
celery/bin/celeryd.py

@@ -150,6 +150,11 @@ class WorkerCommand(Command):
                 help="Optional file used to store the workers pid. "
                      "The worker will not start if this file already exists "
                      "and the pid is still alive."),
+            Option('--autoscale', default=None,
+                help="Enable autoscaling by providing "
+                     "max_concurrency,min_concurrency. Example: "
+                     "--autoscale=10,3 (always keep 3 processes, "
+                     "but grow to 10 if necessary)."),
         )
 
 

+ 6 - 0
celery/concurrency/processes/__init__.py

@@ -102,6 +102,12 @@ class TaskPool(object):
                                       error_callback=on_worker_error,
                                       waitforslot=self.putlocks)
 
+    def grow(self, n=1):
+        return self._pool.grow(n)
+
+    def shrink(self, n=1):
+        return self._pool.shrink(n)
+
     def on_worker_error(self, errbacks, exc):
         einfo = ExceptionInfo((exc.__class__, exc, None))
         [errback(einfo) for errback in errbacks]

+ 42 - 4
celery/concurrency/processes/pool.py

@@ -81,8 +81,8 @@ def soft_timeout_sighandler(signum, frame):
 
 
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
-    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     pid = os.getpid()
+    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
     get = inqueue.get
 
@@ -108,6 +108,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     if SIG_SOFT_TIMEOUT is not None:
         signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler)
 
+
     completed = 0
     while maxtasks is None or (maxtasks and completed < maxtasks):
         try:
@@ -137,7 +138,6 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
         completed += 1
     debug('worker exiting after %d tasks' % completed)
 
-
 #
 # Class representing a process pool
 #
@@ -527,6 +527,44 @@ class Pool(object):
             return True
         return False
 
+    def shrink(self, n=1):
+        for i, worker in enumerate(self._iterinactive()):
+            self._processes -= 1
+            if self._putlock:
+                self._putlock._initial_value -= 1
+                self._putlock.acquire()
+            worker.terminate()
+            if i == n - 1:
+                return
+        raise ValueError("Can't shrink pool. All processes busy!")
+
+    def grow(self, n=1):
+        for i in xrange(n):
+            #assert len(self._pool) == self._processes
+            self._processes += 1
+            if self._putlock:
+                cond = self._putlock._Semaphore__cond
+                cond.acquire()
+                try:
+                    self._putlock._initial_value += 1
+                    self._putlock._Semaphore__value += 1
+                    cond.notify()
+                finally:
+                    cond.release()
+
+    def _iterinactive(self):
+        for worker in self._pool:
+            if not self._worker_active(worker):
+                yield worker
+        raise
+
+    def _worker_active(self, worker):
+        jobs = []
+        for job in self._cache.values():
+            if worker.pid in job.worker_pids():
+                return True
+        return False
+
     def _repopulate_pool(self):
         """Bring the number of pool processes up to the specified number,
         for use after reaping workers which have exited.
@@ -541,8 +579,8 @@ class Pool(object):
     def _maintain_pool(self):
         """"Clean up any exited workers and start replacements for them.
         """
-        if self._join_exited_workers():
-            self._repopulate_pool()
+        self._join_exited_workers()
+        self._repopulate_pool()
 
     def _setup_queues(self):
         from multiprocessing.queues import SimpleQueue

+ 18 - 2
celery/worker/__init__.py

@@ -120,7 +120,8 @@ class WorkController(object):
             task_soft_time_limit=None, max_tasks_per_child=None,
             pool_putlocks=None, db=None, prefetch_multiplier=None,
             eta_scheduler_precision=None, queues=None,
-            disable_rate_limits=None, app=None):
+            disable_rate_limits=None, autoscale=None,
+            autoscaler_cls=None, app=None):
 
         self.app = app_or_default(app)
         conf = self.app.conf
@@ -138,6 +139,8 @@ class WorkController(object):
         self.mediator_cls = mediator_cls or conf.CELERYD_MEDIATOR
         self.eta_scheduler_cls = eta_scheduler_cls or \
                                     conf.CELERYD_ETA_SCHEDULER
+        self.autoscaler_cls = autoscaler_cls or \
+                                    conf.CELERYD_AUTOSCALER
         self.schedule_filename = schedule_filename or \
                                     conf.CELERYBEAT_SCHEDULE_FILENAME
         self.hostname = hostname or socket.gethostname()
@@ -178,7 +181,13 @@ class WorkController(object):
         self.logger.debug("Instantiating thread components...")
 
         # Threads + Pool + Consumer
-        self.pool = instantiate(self.pool_cls, self.concurrency,
+        self.autoscaler = None
+        max_concurrency = None
+        min_concurrency = concurrency
+        if autoscale:
+            max_concurrency, min_concurrency = autoscale
+
+        self.pool = instantiate(self.pool_cls, min_concurrency,
                                 logger=self.logger,
                                 initializer=process_initializer,
                                 initargs=(self.app, self.hostname),
@@ -187,6 +196,12 @@ class WorkController(object):
                                 soft_timeout=self.task_soft_time_limit,
                                 putlocks=self.pool_putlocks)
 
+        if autoscale:
+            self.autoscaler = instantiate(self.autoscaler_cls, self.pool,
+                                          max_concurrency=max_concurrency,
+                                          min_concurrency=min_concurrency,
+                                          logger=self.logger)
+
         self.mediator = None
         if not disable_rate_limits:
             self.mediator = instantiate(self.mediator_cls, self.ready_queue,
@@ -224,6 +239,7 @@ class WorkController(object):
                                         self.mediator,
                                         self.scheduler,
                                         self.beat,
+                                        self.autoscaler,
                                         self.listener))
 
     def start(self):

+ 11 - 0
celery/worker/control/builtins.py

@@ -167,6 +167,17 @@ def ping(panel, **kwargs):
     return "pong"
 
 
+@Panel.register
+def pool_grow(panel, n=1, **kwargs):
+    panel.listener.pool.grow(n)
+    return {"ok": "spawned worker processes"}
+
+@Panel.register
+def pool_shrink(panel, n=1, **kwargs):
+    panel.listener.pool.shrink(n)
+    return {"ok": "terminated worker processes"}
+
+
 @Panel.register
 def shutdown(panel, **kwargs):
     panel.logger.critical("Got shutdown from remote.")

+ 66 - 0
celery/worker/controllers.py

@@ -7,10 +7,76 @@ import logging
 import sys
 import threading
 import traceback
+
+from time import sleep, time
 from Queue import Empty as QueueEmpty
 
 from celery.app import app_or_default
 from celery.utils.compat import log_with_extra
+from celery.worker import state
+
+
+class Autoscaler(threading.Thread):
+
+    def __init__(self, pool, max_concurrency, min_concurrency=0,
+            keepalive=30, logger=None):
+        threading.Thread.__init__(self)
+        self.pool = pool
+        self.max_concurrency = max_concurrency
+        self.min_concurrency = min_concurrency
+        self.keepalive = keepalive
+        self.logger = logger or log.get_default_logger()
+        self._last_action = None
+        self._shutdown = threading.Event()
+        self._stopped = threading.Event()
+        self.setDaemon(True)
+        self.setName(self.__class__.__name__)
+
+        assert self.keepalive, "can't scale down too fast."
+
+    def scale(self):
+        current = min(self.qty, self.max_concurrency)
+        if current > self.processes:
+            self.scale_up(current - self.processes)
+        elif current < self.processes:
+            self.scale_down((self.processes - current) - self.min_concurrency)
+        sleep(1.0)
+
+    def scale_up(self, n):
+        self.logger.info("Scaling up %s processes." % (n, ))
+        self._last_action = time()
+        return self.pool.grow(n)
+
+    def scale_down(self, n):
+        if not self._last_action or not n:
+            return
+        if time() - self._last_action > self.keepalive:
+            self.logger.info("Scaling down %s processes." % (n, ))
+            self._last_action = time()
+            try:
+                self.pool.shrink(n)
+            except Exception, exc:
+                import traceback
+                traceback.print_stack()
+                self.logger.error("Autoscaler: scale_down: %r" % (exc, ))
+
+    def run(self):
+        while not self._shutdown.isSet():
+            self.scale()
+        self._stopped.set()                 # indicate that we are stopped
+
+    def stop(self):
+        self._shutdown.set()
+        self._stopped.wait()                # block until this thread is done
+        self.join(1e100)
+
+    @property
+    def qty(self):
+        return len(state.reserved_requests)
+
+    @property
+    def processes(self):
+        return self.pool._pool._processes
 
 
 class Mediator(threading.Thread):

+ 3 - 0
celery/worker/listener.py

@@ -88,6 +88,7 @@ from celery.exceptions import NotRegistered
 from celery.pidbox import BroadcastConsumer
 from celery.utils import noop, retry_over_time
 from celery.utils.timer2 import to_timestamp
+from celery.worker import state
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.control import ControlDispatch
 from celery.worker.heartbeat import Heart
@@ -285,9 +286,11 @@ class CarrotListener(object):
                 self.eta_schedule.apply_at(eta,
                                            self.apply_eta_task, (task, ))
         else:
+            state.task_reserved(task)
             self.ready_queue.put(task)
 
     def apply_eta_task(self, task):
+        state.task_reserved(task)
         self.ready_queue.put(task)
         self.qos.decrement_eventually()
 

+ 6 - 0
celery/worker/state.py

@@ -24,11 +24,16 @@ Count of tasks executed by the worker, sorted by type.
 The list of currently revoked tasks. (PERSISTENT if statedb set).
 
 """
+reserved_requests = set()
 active_requests = set()
 total_count = defaultdict(lambda: 0)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
 
+def task_reserved(request):
+    reserved_requests.add(request)
+
+
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
@@ -38,6 +43,7 @@ def task_accepted(request):
 def task_ready(request):
     """Updates global state when a task is ready."""
     active_requests.discard(request)
+    reserved_requests.discard(request)
 
 
 class Persistent(object):