浏览代码

celeryd: Added --autoscale=max_concurrency,min_concurrency min_concurrency is the number of workers to keep at all times (defaults to 0 if --autoscale enabled)
max_concurrency is the maximum number of workers to create. New workers will be created/terminated based on task load.

Ask Solem 14 年之前
父节点
当前提交
f8ca5ea4f5

+ 10 - 2
celery/apps/worker.py

@@ -7,6 +7,8 @@ import socket
 import sys
 import sys
 import warnings
 import warnings
 
 
+from carrot.utils import partition
+
 from celery import __version__
 from celery import __version__
 from celery import platforms
 from celery import platforms
 from celery import signals
 from celery import signals
@@ -43,7 +45,8 @@ class Worker(object):
             hostname=None, discard=False, run_clockservice=False,
             hostname=None, discard=False, run_clockservice=False,
             schedule=None, task_time_limit=None, task_soft_time_limit=None,
             schedule=None, task_time_limit=None, task_soft_time_limit=None,
             max_tasks_per_child=None, queues=None, events=False, db=None,
             max_tasks_per_child=None, queues=None, events=False, db=None,
-            include=None, defaults=None, pidfile=None, **kwargs):
+            include=None, defaults=None, pidfile=None, autoscale=None,
+            **kwargs):
         if defaults is None:
         if defaults is None:
             from celery import conf
             from celery import conf
             defaults = conf
             defaults = conf
@@ -68,6 +71,10 @@ class Worker(object):
         self.queues = queues or []
         self.queues = queues or []
         self.include = include or []
         self.include = include or []
         self.pidfile = pidfile
         self.pidfile = pidfile
+        self.autoscale = None
+        if autoscale:
+            max_c, _, min_c = partition(autoscale, ",")
+            self.autoscale = [int(max_c), min_c and int(min_c) or 0]
         self._isatty = sys.stdout.isatty()
         self._isatty = sys.stdout.isatty()
 
 
         if isinstance(self.queues, basestring):
         if isinstance(self.queues, basestring):
@@ -194,7 +201,8 @@ class Worker(object):
                                 db=self.db,
                                 db=self.db,
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 task_time_limit=self.task_time_limit,
                                 task_time_limit=self.task_time_limit,
-                                task_soft_time_limit=self.task_soft_time_limit)
+                                task_soft_time_limit=self.task_soft_time_limit,
+                                autoscale=self.autoscale)
         self.install_platform_tweaks(worker)
         self.install_platform_tweaks(worker)
         worker.start()
         worker.start()
 
 

+ 5 - 0
celery/bin/celeryd.py

@@ -148,6 +148,11 @@ class WorkerCommand(Command):
                 help="Optional file used to store the workers pid. "
                 help="Optional file used to store the workers pid. "
                      "The worker will not start if this file already exists "
                      "The worker will not start if this file already exists "
                      "and the pid is still alive."),
                      "and the pid is still alive."),
+            Option('--autoscale', default=None,
+                help="Enable autoscaling by providing "
+                     "max_concurrency,min_concurrency. Example: "
+                     "--autoscale=10,3 (always keep 3 processes, "
+                     "but grow to 10 if necessary)."),
         )
         )
 
 
 
 

+ 31 - 10
celery/concurrency/processes/pool.py

@@ -81,10 +81,23 @@ def soft_timeout_sighandler(signum, frame):
 
 
 
 
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
-    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     pid = os.getpid()
     pid = os.getpid()
+    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
     put = outqueue.put
     get = inqueue.get
     get = inqueue.get
+
+    if hasattr(inqueue, '_reader'):
+        def poll(timeout):
+            if inqueue._reader.poll(timeout):
+                return True, get()
+            return False, None
+    else:
+        def poll(timeout):
+            try:
+                return True, get(timeout=timeout)
+            except Queue.Empty:
+                return False, None
+
     if hasattr(inqueue, '_writer'):
     if hasattr(inqueue, '_writer'):
         inqueue._writer.close()
         inqueue._writer.close()
         outqueue._reader.close()
         outqueue._reader.close()
@@ -95,10 +108,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     if SIG_SOFT_TIMEOUT is not None:
     if SIG_SOFT_TIMEOUT is not None:
         signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler)
         signal.signal(SIG_SOFT_TIMEOUT, soft_timeout_sighandler)
 
 
+
     completed = 0
     completed = 0
     while maxtasks is None or (maxtasks and completed < maxtasks):
     while maxtasks is None or (maxtasks and completed < maxtasks):
         try:
         try:
-            task = get()
+            ready, task = poll(1.0)
+            if not ready:
+                continue
         except (EOFError, IOError):
         except (EOFError, IOError):
             debug('worker got EOFError or IOError -- exiting')
             debug('worker got EOFError or IOError -- exiting')
             break
             break
@@ -122,7 +138,6 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
         completed += 1
         completed += 1
     debug('worker exiting after %d tasks' % completed)
     debug('worker exiting after %d tasks' % completed)
 
 
-
 #
 #
 # Class representing a process pool
 # Class representing a process pool
 #
 #
@@ -436,7 +451,7 @@ class Pool(object):
         self._worker_handler = self.Supervisor(self)
         self._worker_handler = self.Supervisor(self)
         self._worker_handler.start()
         self._worker_handler.start()
 
 
-        self._putlock = threading.BoundedSemaphore(self._processes, True)
+        self._putlock = threading.BoundedSemaphore(self._processes)
 
 
         self._task_handler = self.TaskHandler(self._taskqueue,
         self._task_handler = self.TaskHandler(self._taskqueue,
                                               self._quick_put,
                                               self._quick_put,
@@ -517,7 +532,7 @@ class Pool(object):
             self._processes -= 1
             self._processes -= 1
             if self._putlock:
             if self._putlock:
                 self._putlock._initial_value -= 1
                 self._putlock._initial_value -= 1
-                self._putlock._Semaphore__value -= 1
+                self._putlock.acquire()
             worker.terminate()
             worker.terminate()
             if i == n - 1:
             if i == n - 1:
                 return
                 return
@@ -525,11 +540,17 @@ class Pool(object):
 
 
     def grow(self, n=1):
     def grow(self, n=1):
         for i in xrange(n):
         for i in xrange(n):
+            #assert len(self._pool) == self._processes
             self._processes += 1
             self._processes += 1
             if self._putlock:
             if self._putlock:
-                self._putlock._initial_value += 1
-                self._putlock._Semaphore__value += 1
-            self._create_worker_process()
+                cond = self._putlock._Semaphore__cond
+                cond.acquire()
+                try:
+                    self._putlock._initial_value += 1
+                    self._putlock._Semaphore__value += 1
+                    cond.notify()
+                finally:
+                    cond.release()
 
 
     def _iterinactive(self):
     def _iterinactive(self):
         for worker in self._pool:
         for worker in self._pool:
@@ -558,8 +579,8 @@ class Pool(object):
     def _maintain_pool(self):
     def _maintain_pool(self):
         """"Clean up any exited workers and start replacements for them.
         """"Clean up any exited workers and start replacements for them.
         """
         """
-        if self._join_exited_workers():
-            self._repopulate_pool()
+        self._join_exited_workers()
+        self._repopulate_pool()
 
 
     def _setup_queues(self):
     def _setup_queues(self):
         from multiprocessing.queues import SimpleQueue
         from multiprocessing.queues import SimpleQueue

+ 2 - 0
celery/conf.py

@@ -54,6 +54,7 @@ _DEFAULTS = {
     "CELERY_ACKS_LATE": False,
     "CELERY_ACKS_LATE": False,
     "CELERYD_POOL_PUTLOCKS": True,
     "CELERYD_POOL_PUTLOCKS": True,
     "CELERYD_POOL": "celery.concurrency.processes.TaskPool",
     "CELERYD_POOL": "celery.concurrency.processes.TaskPool",
+    "CELERYD_AUTOSCALER": "celery.worker.controllers.Autoscaler",
     "CELERYD_MEDIATOR": "celery.worker.controllers.Mediator",
     "CELERYD_MEDIATOR": "celery.worker.controllers.Mediator",
     "CELERYD_ETA_SCHEDULER": "celery.utils.timer2.Timer",
     "CELERYD_ETA_SCHEDULER": "celery.utils.timer2.Timer",
     "CELERYD_LISTENER": "celery.worker.listener.CarrotListener",
     "CELERYD_LISTENER": "celery.worker.listener.CarrotListener",
@@ -184,6 +185,7 @@ def prepare(m, source=settings, defaults=_DEFAULTS):
     m.CELERYD_POOL_PUTLOCKS = _get("CELERYD_POOL_PUTLOCKS")
     m.CELERYD_POOL_PUTLOCKS = _get("CELERYD_POOL_PUTLOCKS")
 
 
     m.CELERYD_POOL = _get("CELERYD_POOL")
     m.CELERYD_POOL = _get("CELERYD_POOL")
+    m.CELERYD_AUTOSCALER = _get("CELERYD_AUTOSCALER")
     m.CELERYD_LISTENER = _get("CELERYD_LISTENER")
     m.CELERYD_LISTENER = _get("CELERYD_LISTENER")
     m.CELERYD_MEDIATOR = _get("CELERYD_MEDIATOR")
     m.CELERYD_MEDIATOR = _get("CELERYD_MEDIATOR")
     m.CELERYD_ETA_SCHEDULER = _get("CELERYD_ETA_SCHEDULER")
     m.CELERYD_ETA_SCHEDULER = _get("CELERYD_ETA_SCHEDULER")

+ 17 - 2
celery/worker/__init__.py

@@ -126,7 +126,9 @@ class WorkController(object):
             max_tasks_per_child=conf.CELERYD_MAX_TASKS_PER_CHILD,
             max_tasks_per_child=conf.CELERYD_MAX_TASKS_PER_CHILD,
             pool_putlocks=conf.CELERYD_POOL_PUTLOCKS,
             pool_putlocks=conf.CELERYD_POOL_PUTLOCKS,
             disable_rate_limits=conf.DISABLE_RATE_LIMITS,
             disable_rate_limits=conf.DISABLE_RATE_LIMITS,
-            db=conf.CELERYD_STATE_DB):
+            db=conf.CELERYD_STATE_DB,
+            autoscale=None,
+            autoscaler_cls=conf.CELERYD_AUTOSCALER):
 
 
         # Options
         # Options
         self.loglevel = loglevel or self.loglevel
         self.loglevel = loglevel or self.loglevel
@@ -160,7 +162,13 @@ class WorkController(object):
         self.logger.debug("Instantiating thread components...")
         self.logger.debug("Instantiating thread components...")
 
 
         # Threads + Pool + Consumer
         # Threads + Pool + Consumer
-        self.pool = instantiate(pool_cls, self.concurrency,
+        self.autoscaler = None
+        max_concurrency = None
+        min_concurrency = concurrency
+        if autoscale:
+            max_concurrency, min_concurrency = autoscale
+
+        self.pool = instantiate(pool_cls, min_concurrency,
                                 logger=self.logger,
                                 logger=self.logger,
                                 initializer=process_initializer,
                                 initializer=process_initializer,
                                 initargs=(self.hostname, ),
                                 initargs=(self.hostname, ),
@@ -169,6 +177,12 @@ class WorkController(object):
                                 soft_timeout=self.task_soft_time_limit,
                                 soft_timeout=self.task_soft_time_limit,
                                 putlocks=self.pool_putlocks)
                                 putlocks=self.pool_putlocks)
 
 
+        if autoscale:
+            self.autoscaler = instantiate(autoscaler_cls, self.pool,
+                                          max_concurrency=max_concurrency,
+                                          min_concurrency=min_concurrency,
+                                          logger=self.logger)
+
         self.mediator = None
         self.mediator = None
         if not disable_rate_limits:
         if not disable_rate_limits:
             self.mediator = instantiate(mediator_cls, self.ready_queue,
             self.mediator = instantiate(mediator_cls, self.ready_queue,
@@ -202,6 +216,7 @@ class WorkController(object):
                                         self.mediator,
                                         self.mediator,
                                         self.scheduler,
                                         self.scheduler,
                                         self.beat,
                                         self.beat,
+                                        self.autoscaler,
                                         self.listener))
                                         self.listener))
 
 
     def start(self):
     def start(self):

+ 66 - 0
celery/worker/controllers.py

@@ -5,9 +5,75 @@ Worker Controller Threads
 """
 """
 import threading
 import threading
 import traceback
 import traceback
+
+from time import sleep, time
 from Queue import Empty as QueueEmpty
 from Queue import Empty as QueueEmpty
 
 
 from celery import log
 from celery import log
+from celery.worker import state
+
+
+class Autoscaler(threading.Thread):
+
+    def __init__(self, pool, max_concurrency, min_concurrency=0,
+            keepalive=30, logger=None):
+        threading.Thread.__init__(self)
+        self.pool = pool
+        self.max_concurrency = max_concurrency
+        self.min_concurrency = min_concurrency
+        self.keepalive = keepalive
+        self.logger = logger or log.get_default_logger()
+        self._last_action = None
+        self._shutdown = threading.Event()
+        self._stopped = threading.Event()
+        self.setDaemon(True)
+        self.setName(self.__class__.__name__)
+
+        assert self.keepalive, "can't scale down too fast."
+
+    def scale(self):
+        current = min(self.qty, self.max_concurrency)
+        if current > self.processes:
+            self.scale_up(current - self.processes)
+        elif current < self.processes:
+            self.scale_down((self.processes - current) - self.min_concurrency)
+        sleep(1.0)
+
+    def scale_up(self, n):
+        self.logger.info("Scaling up %s processes." % (n, ))
+        self._last_action = time()
+        return self.pool.grow(n)
+
+    def scale_down(self, n):
+        if not self._last_action or not n:
+            return
+        if time() - self._last_action > self.keepalive:
+            self.logger.info("Scaling down %s processes." % (n, ))
+            self._last_action = time()
+            try:
+                self.pool.shrink(n)
+            except Exception, exc:
+                import traceback
+                traceback.print_stack()
+                self.logger.error("Autoscaler: scale_down: %r" % (exc, ))
+
+    def run(self):
+        while not self._shutdown.isSet():
+            self.scale()
+        self._stopped.set()                 # indicate that we are stopped
+
+    def stop(self):
+        self._shutdown.set()
+        self._stopped.wait()                # block until this thread is done
+        self.join(1e100)
+
+    @property
+    def qty(self):
+        return len(state.reserved_requests)
+
+    @property
+    def processes(self):
+        return self.pool._pool._processes
 
 
 
 
 class Mediator(threading.Thread):
 class Mediator(threading.Thread):

+ 3 - 0
celery/worker/listener.py

@@ -82,6 +82,7 @@ from carrot.connection import AMQPConnectionException
 
 
 from celery import conf
 from celery import conf
 from celery.utils import noop, retry_over_time
 from celery.utils import noop, retry_over_time
+from celery.worker import state
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.job import TaskRequest, InvalidTaskError
 from celery.worker.control import ControlDispatch
 from celery.worker.control import ControlDispatch
 from celery.worker.heartbeat import Heart
 from celery.worker.heartbeat import Heart
@@ -273,9 +274,11 @@ class CarrotListener(object):
             self.eta_schedule.apply_at(task.eta,
             self.eta_schedule.apply_at(task.eta,
                                        self.apply_eta_task, (task, ))
                                        self.apply_eta_task, (task, ))
         else:
         else:
+            state.task_reserved(task)
             self.ready_queue.put(task)
             self.ready_queue.put(task)
 
 
     def apply_eta_task(self, task):
     def apply_eta_task(self, task):
+        state.task_reserved(task)
         self.ready_queue.put(task)
         self.ready_queue.put(task)
         self.qos.decrement_eventually()
         self.qos.decrement_eventually()
 
 

+ 6 - 0
celery/worker/state.py

@@ -24,11 +24,16 @@ Count of tasks executed by the worker, sorted by type.
 The list of currently revoked tasks. (PERSISTENT if statedb set).
 The list of currently revoked tasks. (PERSISTENT if statedb set).
 
 
 """
 """
+reserved_requests = set()
 active_requests = set()
 active_requests = set()
 total_count = defaultdict(lambda: 0)
 total_count = defaultdict(lambda: 0)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)
 
 
 
 
+def task_reserved(request):
+    reserved_requests.add(request)
+
+
 def task_accepted(request):
 def task_accepted(request):
     """Updates global state when a task has been accepted."""
     """Updates global state when a task has been accepted."""
     active_requests.add(request)
     active_requests.add(request)
@@ -38,6 +43,7 @@ def task_accepted(request):
 def task_ready(request):
 def task_ready(request):
     """Updates global state when a task is ready."""
     """Updates global state when a task is ready."""
     active_requests.discard(request)
     active_requests.discard(request)
+    reserved_requests.discard(request)
 
 
 
 
 class Persistent(object):
 class Persistent(object):