Bläddra i källkod

Hub<->Component refactor + non-threaded autoscale

Ask Solem 13 år sedan
förälder
incheckning
9d5aa7c1c1

+ 1 - 1
celery/apps/worker.py

@@ -173,7 +173,7 @@ class Worker(configurated):
         app = self.app
         concurrency = self.concurrency
         if self.autoscale:
-            concurrency = "{min=%s, max=%s}" % self.autoscale
+            concurrency = "{min=%s, max=%s}" % tuple(self.autoscale)
         return BANNER % {
             "hostname": self.hostname,
             "version": __version__,

+ 3 - 0
celery/concurrency/base.py

@@ -76,6 +76,9 @@ class BasePool(object):
     def on_hard_timeout(self, job):
         pass
 
+    def maintain_pool(self, *args, **kwargs):
+        pass
+
     def terminate_job(self, pid):
         raise NotImplementedError(
                 "%s does not implement kill_job" % (self.__class__, ))

+ 4 - 1
celery/concurrency/processes/__init__.py

@@ -122,6 +122,9 @@ class TaskPool(BasePool):
     def on_hard_timeout(self, job):
         self._pool._timeout_handler.on_hard_timeout(job)
 
+    def maintain_pool(self, *args, **kwargs):
+        self._pool.maintain_pool(*args, **kwargs)
+
     @property
     def num_processes(self):
         return self._pool._processes
@@ -136,4 +139,4 @@ class TaskPool(BasePool):
 
     @property
     def timers(self):
-        return {self._pool.maintain_pool: 30}
+        return {self.maintain_pool: 30.0}

+ 0 - 1
celery/task/sets.py

@@ -6,7 +6,6 @@ from celery.state import get_current_task
 from celery.app import app_or_default
 from celery.canvas import subtask, maybe_subtask  # noqa
 from celery.utils import uuid
-from celery.utils.compat import UserList
 
 
 class TaskSet(list):

+ 1 - 1
celery/utils/compat.py

@@ -161,7 +161,7 @@ if sys.version_info >= (2, 7):
     def format_d(i):
         return format(i, ',d')
 else:
-    def format_d(i):
+    def format_d(i):  # noqa
         s = '%d' % i
         groups = []
         while s and s[-1].isdigit():

+ 86 - 24
celery/worker/__init__.py

@@ -18,9 +18,13 @@ import atexit
 import logging
 import socket
 import sys
+import time
 import traceback
 
+from functools import partial
+
 from billiard import forking_enable
+from billiard.exceptions import WorkerLostError
 from kombu.syn import detect_environment
 from kombu.utils.finalize import Finalize
 
@@ -38,7 +42,7 @@ from celery.utils.timer2 import Schedule
 from . import abstract
 from . import state
 from .buckets import TaskBucket, FastQueue
-from .hub import BoundedSemaphore
+from .hub import Hub, BoundedSemaphore
 
 RUN = 0x1
 CLOSE = 0x2
@@ -91,6 +95,50 @@ class Pool(abstract.StartStopComponent):
         if w.autoscale:
             w.max_concurrency, w.min_concurrency = w.autoscale
 
+    def on_poll_init(self, pool, hub):
+        apply_after = hub.timer.apply_after
+        apply_at = hub.timer.apply_at
+        on_soft_timeout = pool.on_soft_timeout
+        on_hard_timeout = pool.on_hard_timeout
+        maintain_pool = pool.maintain_pool
+        add_reader = hub.add_reader
+        remove = hub.remove
+        now = time.time
+
+        if not pool.did_start_ok():
+            raise WorkerLostError("Could not start worker processes")
+
+        hub.update_readers(pool.readers)
+        for handler, interval in pool.timers.iteritems():
+            hub.timer.apply_interval(interval * 1000.0, handler)
+
+        def on_timeout_set(R, soft, hard):
+
+            def _on_soft_timeout():
+                if hard:
+                    R._tref = apply_at(now() + (hard - soft),
+                                       on_hard_timeout, (R, ))
+                    on_soft_timeout(R)
+            if soft:
+                R._tref = apply_after(soft * 1000.0, _on_soft_timeout)
+            elif hard:
+                R._tref = apply_after(hard * 1000.0,
+                                      on_hard_timeout, (R, ))
+
+        def on_timeout_cancel(result):
+            try:
+                result._tref.cancel()
+                delattr(result, "_tref")
+            except AttributeError:
+                pass
+
+        pool.init_callbacks(
+            on_process_up=lambda w: add_reader(w.sentinel, maintain_pool),
+            on_process_down=lambda w: remove(w.sentinel),
+            on_timeout_set=on_timeout_set,
+            on_timeout_cancel=on_timeout_cancel,
+        )
+
     def create(self, w, semaphore=None, max_restarts=None):
         threaded = not w.use_eventloop
         forking_enable(not threaded or (w.no_execv or not w.force_execv))
@@ -105,12 +153,11 @@ class Pool(abstract.StartStopComponent):
                             soft_timeout=w.task_soft_time_limit,
                             putlocks=w.pool_putlocks and threaded,
                             lost_worker_timeout=w.worker_lost_wait,
-                            with_task_thread=threaded,
-                            with_result_thread=threaded,
-                            with_supervisor_thread=threaded,
-                            with_timeout_thread=threaded,
+                            threads=threaded,
                             max_restarts=max_restarts,
                             semaphore=semaphore)
+        if w.hub:
+            w.hub.on_init.append(partial(self.on_poll_init, pool))
         return pool
 
 
@@ -139,6 +186,7 @@ class Queues(abstract.Component):
     """This component initializes the internal queues
     used by the worker."""
     name = "worker.queues"
+    requires = ("ev", )
 
     def create(self, w):
         if not w.pool_cls.rlimit_safe:
@@ -157,26 +205,38 @@ class Queues(abstract.Component):
             w.ready_queue = TaskBucket(task_registry=w.app.tasks)
 
 
+class EvLoop(abstract.StartStopComponent):
+    name = "worker.ev"
+
+    def __init__(self, w, **kwargs):
+        w.hub = None
+
+    def include_if(self, w):
+        return w.use_eventloop
+
+    def create(self, w):
+        w.timer = Schedule(max_interval=10)
+        hub = w.hub = Hub(w.timer)
+        return hub
+
+
 class Timers(abstract.Component):
     """This component initializes the internal timers used by the worker."""
     name = "worker.timers"
     requires = ("pool", )
 
-    def create(self, w):
-        options = {"on_error": self.on_timer_error,
-                   "on_tick": self.on_timer_tick}
+    def include_if(self, w):
+        return not w.use_eventloop
 
-        if w.use_eventloop:
-            # the timers are fired by the hub, so don't use the Timer thread.
-            w.timer = Schedule(max_interval=10, **options)
-        else:
-            if not w.timer_cls:
-                # Default Timer is set by the pool, as e.g. eventlet
-                # needs a custom implementation.
-                w.timer_cls = w.pool.Timer
-            w.timer = self.instantiate(w.pool.Timer,
-                                       max_interval=w.timer_precision,
-                                       **options)
+    def create(self, w):
+        if not w.timer_cls:
+            # Default Timer is set by the pool, as e.g. eventlet
+            # needs a custom implementation.
+            w.timer_cls = w.pool.Timer
+        w.timer = self.instantiate(w.pool.Timer,
+                                   max_interval=w.timer_precision,
+                                   on_timer_error=self.on_timer_error,
+                                   on_timer_tick=self.on_timer_tick)
 
     def on_timer_error(self, exc):
         logger.error("Timer error: %r", exc, exc_info=True)
@@ -271,7 +331,8 @@ class WorkController(configurated):
             for i, component in enumerate(self.components):
                 logger.debug("Starting %s...", qualname(component))
                 self._running = i + 1
-                component.start()
+                if component:
+                    component.start()
                 logger.debug("%s OK!", qualname(component))
         except SystemTerminate:
             self.terminate()
@@ -339,10 +400,11 @@ class WorkController(configurated):
 
         for component in reversed(self.components):
             logger.debug("%s %s...", what, qualname(component))
-            stop = component.stop
-            if not warm:
-                stop = getattr(component, "terminate", None) or stop
-            stop()
+            if component:
+                stop = component.stop
+                if not warm:
+                    stop = getattr(component, "terminate", None) or stop
+                stop()
 
         self.timer.stop()
         self.consumer.close_connection()

+ 37 - 11
celery/worker/autoscale.py

@@ -19,6 +19,7 @@ from __future__ import with_statement
 
 import threading
 
+from functools import partial
 from time import sleep, time
 
 from celery.utils.log import get_logger
@@ -26,6 +27,7 @@ from celery.utils.threads import bgThread
 
 from . import state
 from .abstract import StartStopComponent
+from .hub import DummyLock
 
 logger = get_logger(__name__)
 debug, info, error = logger.debug, logger.info, logger.error
@@ -39,18 +41,33 @@ class WorkerComponent(StartStopComponent):
         self.enabled = w.autoscale
         w.autoscaler = None
 
-    def create(self, w):
-        scaler = w.autoscaler = self.instantiate(
-            w.autoscaler_cls, w.pool, w.max_concurrency, w.min_concurrency)
+    def create_threaded(self, w):
+        scaler = w.autoscaler = self.instantiate(w.autoscaler_cls,
+            w.pool, w.max_concurrency, w.min_concurrency)
         return scaler
 
+    def on_poll_init(self, scaler, hub):
+        hub.on_task.append(scaler.maybe_scale)
+        hub.timer.apply_interval(scaler.keepalive * 1000.0, scaler.maybe_scale)
+
+    def create_ev(self, w):
+        scaler = w.autoscaler = self.instantiate(w.autoscaler_cls,
+            w.pool, w.max_concurrency, w.min_concurrency,
+            mutex=DummyLock())
+        w.hub.on_init.append(partial(self.on_poll_init, scaler))
+
+    def create(self, w):
+        return (self.create_ev if w.use_eventloop
+                               else self.create_threaded)(w)
+
 
 class Autoscaler(bgThread):
 
-    def __init__(self, pool, max_concurrency, min_concurrency=0, keepalive=30):
+    def __init__(self, pool, max_concurrency, min_concurrency=0, keepalive=30,
+            mutex=None):
         super(Autoscaler, self).__init__()
         self.pool = pool
-        self.mutex = threading.Lock()
+        self.mutex = mutex or threading.Lock()
         self.max_concurrency = max_concurrency
         self.min_concurrency = min_concurrency
         self.keepalive = keepalive
@@ -60,14 +77,23 @@ class Autoscaler(bgThread):
 
     def body(self):
         with self.mutex:
-            procs = self.processes
-            cur = min(self.qty, self.max_concurrency)
-            if cur > procs:
-                self.scale_up(cur - procs)
-            elif cur < procs:
-                self.scale_down((procs - cur) - self.min_concurrency)
+            self.maybe_scale()
         sleep(1.0)
 
+    def _maybe_scale(self):
+        procs = self.processes
+        cur = min(self.qty, self.max_concurrency)
+        if cur > procs:
+            self.scale_up(cur - procs)
+            return True
+        elif cur < procs:
+            self.scale_down((procs - cur) - self.min_concurrency)
+            return True
+
+    def maybe_scale(self):
+        if self._maybe_scale():
+            self.pool.maintain_pool()
+
     def update(self, max=None, min=None):
         with self.mutex:
             if max is not None:

+ 34 - 60
celery/worker/consumer.py

@@ -80,10 +80,9 @@ import logging
 import socket
 import threading
 
-from time import sleep, time
+from time import sleep
 from Queue import Empty
 
-from billiard.exceptions import WorkerLostError
 from kombu.utils.encoding import safe_repr
 
 from celery.app import app_or_default
@@ -97,7 +96,6 @@ from . import state
 from .abstract import StartStopComponent
 from .control import Panel
 from .heartbeat import Heart
-from .hub import Hub
 
 RUN = 0x1
 CLOSE = 0x2
@@ -160,7 +158,7 @@ class Component(StartStopComponent):
 
     def Consumer(self, w):
         return (w.consumer_cls or
-                Consumer if w.use_eventloop else BlockingConsumer)
+                Consumer if w.hub else BlockingConsumer)
 
     def create(self, w):
         prefetch_count = w.concurrency * w.prefetch_multiplier
@@ -174,7 +172,7 @@ class Component(StartStopComponent):
                 timer=w.timer,
                 app=w.app,
                 controller=w,
-                use_eventloop=w.use_eventloop)
+                hub=w.hub)
         return c
 
 
@@ -305,7 +303,7 @@ class Consumer(object):
     def __init__(self, ready_queue,
             init_callback=noop, send_events=False, hostname=None,
             initial_prefetch_count=2, pool=None, app=None,
-            timer=None, controller=None, use_eventloop=False, **kwargs):
+            timer=None, controller=None, hub=None, **kwargs):
         self.app = app_or_default(app)
         self.connection = None
         self.task_consumer = None
@@ -320,7 +318,6 @@ class Consumer(object):
         self.heart = None
         self.pool = pool
         self.timer = timer or timer2.default_timer
-        self.use_eventloop = use_eventloop
         pidbox_state = AttributeDict(app=self.app,
                                      hostname=self.hostname,
                                      listener=self,     # pre 2.2
@@ -334,8 +331,9 @@ class Consumer(object):
 
         self._does_info = logger.isEnabledFor(logging.INFO)
         self.strategies = {}
-        if self.use_eventloop:
-            self.hub = Hub(self.timer)
+        if hub:
+            hub.on_init.append(self.on_poll_init)
+        self.hub = hub
 
     def update_strategies(self):
         S = self.strategies
@@ -361,6 +359,10 @@ class Consumer(object):
             except self.connection_errors + self.channel_errors:
                 error(RETRY_CONNECTION, exc_info=True)
 
+    def on_poll_init(self, hub):
+        hub.update_readers(self.connection.eventmap)
+        self.connection.transport.on_poll_init(hub.poller)
+
     def consume_messages(self, sleep=sleep, min=min, Empty=Empty):
         """Consume messages forever (or until an exception is raised)."""
 
@@ -371,13 +373,12 @@ class Consumer(object):
             fdmap = hub.fdmap
             poll = hub.poller.poll
             fire_timers = hub.fire_timers
-            apply_interval = hub.timer.apply_interval
-            apply_after = hub.timer.apply_after
-            apply_at = hub.timer.apply_at
             scheduled = hub.timer._queue
-            transport = self.connection.transport
-            on_poll_start = transport.on_poll_start
+            on_poll_start = self.connection.transport.on_poll_start
             strategies = self.strategies
+            connection = self.connection
+            drain_nowait = connection.drain_nowait
+            on_task_callbacks = hub.on_task
             buffer = []
 
             def flush_buffer():
@@ -391,54 +392,25 @@ class Consumer(object):
                 buffer[:] = []
 
             def on_task_received(body, message):
+                if on_task_callbacks:
+                    [callback() for callback in on_task_callbacks]
                 try:
                     name = body["task"]
                 except (KeyError, TypeError):
                     return self.handle_unknown_message(body, message)
-                bufferlen = len(buffer)
-                buffer.append((name, body, message))
-                if bufferlen + 1 >= 4:
-                    flush_buffer()
-                if bufferlen:
-                    fire_timers()
-
-            if not self.pool.did_start_ok():
-                raise WorkerLostError("Could not start worker processes")
-
-            update_readers(self.connection.eventmap, self.pool.readers)
-            for handler, interval in self.pool.timers.iteritems():
-                apply_interval(interval * 1000.0, handler)
-
-            def on_timeout_set(R, soft, hard):
-
-                def on_soft_timeout():
-                    if hard:
-                        R._tref = apply_at(time() + (hard - soft),
-                                        self.pool.on_hard_timeout, (R, ))
-                    self.pool.on_soft_timeout(R)
-                if soft:
-                    R._tref = apply_after(soft * 1000.0, on_soft_timeout)
-                elif hard:
-                    R._tref = apply_after(hard * 1000.0,
-                            self.pool_on_hard_timeout, (R, ))
-
-
-            def on_timeout_cancel(result):
                 try:
-                    result._tref.cancel()
-                    delattr(result, "_tref")
-                except AttributeError:
-                    pass
-
-            self.pool.init_callbacks(
-                on_process_up=lambda w: hub.add_reader(w.sentinel,
-                    self.pool._pool.maintain_pool),
-                on_process_down=lambda w: hub.remove(w.sentinel),
-                on_timeout_set=on_timeout_set,
-                on_timeout_cancel=on_timeout_cancel,
-            )
-
-            transport.on_poll_init(hub.poller)
+                    strategies[name](message, body, message.ack_log_error)
+                except KeyError, exc:
+                    self.handle_unknown_task(body, message, exc)
+                except InvalidTaskError, exc:
+                    self.handle_invalid_task(body, message, exc)
+                #bufferlen = len(buffer)
+                #buffer.append((name, body, message))
+                #if bufferlen + 1 >= 4:
+                #    flush_buffer()
+                #if bufferlen:
+                #    fire_timers()
+
             self.task_consumer.callbacks = [on_task_received]
             self.task_consumer.consume()
 
@@ -460,7 +432,8 @@ class Consumer(object):
 
                 update_readers(on_poll_start())
                 if fdmap:
-                    #for timeout in (time_to_sleep, 0.001):
+                    connection.more_to_read = True
+                    while connection.more_to_read:
                         for fileno, event in poll(time_to_sleep) or ():
                             try:
                                 fdmap[fileno](fileno, event)
@@ -469,8 +442,9 @@ class Consumer(object):
                             except socket.error:
                                 if self._state != CLOSE:  # pragma: no cover
                                     raise
-                        #if buffer:
-                        #    flush_buffer()
+                        if connection.more_to_read:
+                            drain_nowait()
+                            time_to_sleep = 0
                 else:
                     sleep(min(time_to_sleep, 0.1))
 

+ 27 - 6
celery/worker/hub.py

@@ -6,6 +6,15 @@ from kombu.utils.eventio import poll, READ, WRITE, ERR
 from celery.utils.timer2 import Schedule
 
 
+class DummyLock(object):
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc_info):
+        pass
+
+
 class BoundedSemaphore(object):
 
     def __init__(self, value=1):
@@ -14,6 +23,7 @@ class BoundedSemaphore(object):
 
     def grow(self):
         self.initial_value += 1
+        self.release()
 
     def shrink(self):
         self.initial_value -= 1
@@ -42,15 +52,27 @@ class Hub(object):
 
     def __init__(self, timer=None):
         self.fdmap = {}
-        self.poller = poll()
         self.timer = Schedule() if timer is None else timer
+        self.on_init = []
+        self.on_task = []
+
+    def start(self):
+        self.poller = poll()
+
+    def stop(self):
+        self.poller.close()
 
     def __enter__(self):
+        self.init()
         return self
 
     def __exit__(self, *exc_info):
         return self.close()
 
+    def init(self):
+        for callback in self.on_init:
+            callback(self)
+
     def fire_timers(self, min_delay=1, max_delay=10, max_timers=10):
         delay = None
         if self.timer._queue:
@@ -75,11 +97,11 @@ class Hub(object):
     def add_writer(self, fd, callback):
         return self.add(fd, callback, WRITE)
 
-    def update_readers(self, *maps):
-        [self.add_reader(*x) for row in maps for x in row.iteritems()]
+    def update_readers(self, map):
+        [self.add_reader(*x) for x in map.iteritems()]
 
-    def update_writers(self, *maps):
-        [self.add_writer(*x) for row in maps for x in row.iteritems()]
+    def update_writers(self, map):
+        [self.add_writer(*x) for x in map.iteritems()]
 
     def remove(self, fd):
         try:
@@ -89,7 +111,6 @@ class Hub(object):
 
     def close(self):
         [self.remove(fd) for fd in self.fdmap.keys()]
-        self.poller.close()
 
     @cached_property
     def scheduler(self):

+ 25 - 10
celery/worker/state.py

@@ -76,7 +76,7 @@ if os.environ.get("CELERY_BENCH"):  # pragma: no cover
 
     all_count = 0
     bench_first = None
-    bench_mem_first = None
+    bench_mem_sample = []
     bench_start = None
     bench_last = None
     bench_every = int(os.environ.get("CELERY_BENCH_EVERY", 1000))
@@ -97,33 +97,46 @@ if os.environ.get("CELERY_BENCH"):  # pragma: no cover
 
     def mem_rss():
         p = ps()
-        if p is None:
-            return "(psutil not installed)"
-        return "%s MB" % (format_d(p.get_memory_info().rss // 1024), )
+        if p is not None:
+            return "%sMB" % (format_d(p.get_memory_info().rss // 1024), )
+
+    def sample(x, n=10, k=0):
+        j = len(x) // n
+        for _ in xrange(n):
+            yield x[k]
+            k += j
 
     if current_process()._name == 'MainProcess':
         @atexit.register
         def on_shutdown():
             if bench_first is not None and bench_last is not None:
-                print("\n- Time spent in benchmark: %r" % (
+                print("- Time spent in benchmark: %r" % (
                     bench_last - bench_first))
                 print("- Avg: %s" % (sum(bench_sample) / len(bench_sample)))
-                print("- RSS: %s --> %s" % (bench_mem_first, mem_rss()))
+                if filter(None, bench_mem_sample):
+                    print("- rss (sample):")
+                    for mem in sample(bench_mem_sample):
+                        print("-    > %s," % mem)
+                    bench_mem_sample[:] = []
+                    bench_sample[:] = []
+                    import gc
+                    gc.collect()
+                    print("- rss (shutdown): %s." % (mem_rss()))
+                else:
+                    print("- rss: (psutil not installed).")
 
     def task_reserved(request):  # noqa
         global bench_start
         global bench_first
-        global bench_mem_first
         now = None
         if bench_start is None:
             bench_start = now = time()
         if bench_first is None:
             bench_first = now
-        if bench_mem_first is None:
-            bench_mem_first = mem_rss()
 
         return __reserved(request)
 
+    import sys
     def task_ready(request):  # noqa
         global all_count
         global bench_start
@@ -134,8 +147,10 @@ if os.environ.get("CELERY_BENCH"):  # pragma: no cover
             diff = now - bench_start
             print("- Time spent processing %s tasks (since first "
                     "task received): ~%.4fs\n" % (bench_every, diff))
-            bench_start, bench_last = None, now
+            sys.stdout.flush()
+            bench_start = bench_last = now
             bench_sample.append(diff)
+            bench_mem_sample.append(mem_rss())
 
         return __ready(request)