Browse Source

Merge branch 'eventlet2'

Ask Solem 14 years ago
parent
commit
3b4f5699ab

+ 2 - 1
celery/app/defaults.py

@@ -87,6 +87,7 @@ NAMESPACES = {
         "STORE_ERRORS_EVEN_IF_IGNORED": Option(False, type="bool"),
         "TASK_RESULT_EXPIRES": Option(timedelta(days=1), type="int"),
         "AMQP_TASK_RESULT_EXPIRES": Option(type="int"),
+        "AMQP_TASK_RESULT_CONNECTION_MAX": Option(type="int", default=1),
         "TASK_ERROR_WHITELIST": Option((), type="tuple"),
         "TASK_SERIALIZER": Option("pickle"),
         "TRACK_STARTED": Option(False, type="bool"),
@@ -96,7 +97,7 @@ NAMESPACES = {
     "CELERYD": {
         "AUTOSCALER": Option("celery.worker.controllers.Autoscaler"),
         "CONCURRENCY": Option(0, type="int"),
-        "ETA_SCHEDULER": Option("celery.utils.timer2.Timer"),
+        "ETA_SCHEDULER": Option(None, type="str"),
         "ETA_SCHEDULER_PRECISION": Option(1.0, type="float"),
         "FORCE_HIJACK_ROOT_LOGGER": Option(False, type="bool"),
         "CONSUMER": Option("celery.worker.consumer.Consumer"),

+ 10 - 8
celery/apps/worker.py

@@ -12,7 +12,7 @@ from celery import __version__
 from celery import platforms
 from celery import signals
 from celery.app import app_or_default
-from celery.exceptions import ImproperlyConfigured
+from celery.exceptions import ImproperlyConfigured, SystemTerminate
 from celery.utils import get_full_cls_name, LOG_LEVELS, isatty
 from celery.utils import term
 from celery.worker import WorkController
@@ -43,7 +43,7 @@ class Worker(object):
             max_tasks_per_child=None, queues=None, events=False, db=None,
             include=None, app=None, pidfile=None,
             redirect_stdouts=None, redirect_stdouts_level=None,
-            autoscale=None, scheduler_cls=None, **kwargs):
+            autoscale=None, scheduler_cls=None, pool=None, **kwargs):
         self.app = app = app_or_default(app)
         self.concurrency = (concurrency or
                             app.conf.CELERYD_CONCURRENCY or
@@ -68,6 +68,7 @@ class Worker(object):
                                  app.conf.CELERY_REDIRECT_STDOUTS)
         self.redirect_stdouts_level = (redirect_stdouts_level or
                                        app.conf.CELERY_REDIRECT_STDOUTS_LEVEL)
+        self.pool = (pool or app.conf.CELERYD_POOL)
         self.db = db
         self.use_queues = queues or []
         self.queues = None
@@ -207,7 +208,8 @@ class Worker(object):
                                 max_tasks_per_child=self.max_tasks_per_child,
                                 task_time_limit=self.task_time_limit,
                                 task_soft_time_limit=self.task_soft_time_limit,
-                                autoscale=self.autoscale)
+                                autoscale=self.autoscale,
+                                pool_cls=self.pool)
         self.install_platform_tweaks(worker)
         worker.start()
 
@@ -257,7 +259,7 @@ def install_worker_int_handler(worker):
             install_worker_int_again_handler(worker)
             worker.logger.warn("celeryd: Warm shutdown (%s)" % (
                 process_name))
-            worker.stop()
+            worker.stop(in_sighandler=True)
         raise SystemExit()
 
     platforms.install_signal_handler("SIGINT", _stop)
@@ -270,8 +272,8 @@ def install_worker_int_again_handler(worker):
         if process_name == "MainProcess":
             worker.logger.warn("celeryd: Cold shutdown (%s)" % (
                 process_name))
-            worker.terminate()
-        raise SystemExit()
+            worker.terminate(in_sighandler=True)
+        raise SystemTerminate()
 
     platforms.install_signal_handler("SIGINT", _stop)
 
@@ -283,7 +285,7 @@ def install_worker_term_handler(worker):
         if process_name == "MainProcess":
             worker.logger.warn("celeryd: Warm shutdown (%s)" % (
                 process_name))
-            worker.stop()
+            worker.stop(in_sighandler=True)
         raise SystemExit()
 
     platforms.install_signal_handler("SIGTERM", _stop)
@@ -295,7 +297,7 @@ def install_worker_restart_handler(worker):
         """Signal handler restarting the current python program."""
         worker.logger.warn("Restarting celeryd (%s)" % (
             " ".join(sys.argv)))
-        worker.stop()
+        worker.stop(in_sighandler=True)
         os.execv(sys.executable, [sys.executable] + sys.argv)
 
     platforms.install_signal_handler("SIGHUP", restart_worker_sig_handler)

+ 89 - 72
celery/backends/amqp.py

@@ -34,12 +34,11 @@ class AMQPBackend(BaseDictBackend):
 
     """
 
-    _connection = None
-    _channel = None
+    _pool = None
 
     def __init__(self, connection=None, exchange=None, exchange_type=None,
             persistent=None, serializer=None, auto_delete=True,
-            expires=None, **kwargs):
+            expires=None, connection_max=None, **kwargs):
         super(AMQPBackend, self).__init__(**kwargs)
         conf = self.app.conf
         self._connection = connection
@@ -68,6 +67,8 @@ class AMQPBackend(BaseDictBackend):
             # x-expires must be a signed-int, or long describing
             # the expiry time in milliseconds.
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
+        self.connection_max = (connection_max or
+                               conf.CELERY_AMQP_TASK_RESULT_CONNECTION_MAX)
 
     def _create_binding(self, task_id):
         name = task_id.replace("-", "")
@@ -77,16 +78,16 @@ class AMQPBackend(BaseDictBackend):
                      durable=self.persistent,
                      auto_delete=self.auto_delete)
 
-    def _create_producer(self, task_id):
+    def _create_producer(self, task_id, channel):
         binding = self._create_binding(task_id)
-        binding(self.channel).declare()
+        binding(channel).declare()
 
-        return Producer(self.channel, exchange=self.exchange,
+        return Producer(channel, exchange=self.exchange,
                         routing_key=task_id.replace("-", ""),
                         serializer=self.serializer)
 
-    def _create_consumer(self, bindings):
-        return Consumer(self.channel, bindings, no_ack=True)
+    def _create_consumer(self, bindings, channel):
+        return Consumer(channel, bindings, no_ack=True)
 
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
@@ -99,17 +100,22 @@ class AMQPBackend(BaseDictBackend):
                 "traceback": traceback}
 
         for i in range(max_retries + 1):
+            conn = self.pool.acquire(block=True)
+            channel = conn.channel()
             try:
-                self._create_producer(task_id).publish(meta)
-            except Exception, exc:
-                if not max_retries:
-                    raise
-                self._channel = None
-                self._connection = None
-                warnings.warn(AMQResultWarning(
-                    "Error sending result %s: %r" % (task_id, exc)))
-                time.sleep(retry_delay)
-            break
+                try:
+                    self._create_producer(task_id, channel).publish(meta)
+                except Exception, exc:
+                    if not max_retries:
+                        raise
+                    warnings.warn(AMQResultWarning(
+                        "Error sending result %s: %r" % (task_id, exc)))
+                    time.sleep(retry_delay)
+                else:
+                    break
+            finally:
+                channel.close()
+                conn.release()
 
         return result
 
@@ -138,18 +144,28 @@ class AMQPBackend(BaseDictBackend):
             return self.wait_for(task_id, timeout, cache)
 
     def poll(self, task_id):
-        binding = self._create_binding(task_id)(self.channel)
-        result = binding.get()
-        if result:
-            binding.delete(if_unused=True, if_empty=True, nowait=True)
-            payload = self._cache[task_id] = result.payload
-            return payload
-        elif task_id in self._cache:
-            return self._cache[task_id]     # use previously received state.
-        return {"status": states.PENDING, "result": None}
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
+        try:
+            binding = self._create_binding(task_id)(channel)
+            result = binding.get()
+            if result:
+                try:
+                    binding.delete(if_unused=True, if_empty=True, nowait=True)
+                except conn.channel_errors:
+                    pass
+                payload = self._cache[task_id] = result.payload
+                return payload
+            elif task_id in self._cache:
+                # use previously received state.
+                return self._cache[task_id]
+            return {"status": states.PENDING, "result": None}
+        finally:
+            channel.close()
+            conn.release()
 
     def drain_events(self, consumer, timeout=None):
-        wait = self.connection.drain_events
+        wait = consumer.channel.connection.drain_events
         results = {}
 
         def callback(meta, message):
@@ -173,60 +189,61 @@ class AMQPBackend(BaseDictBackend):
         return results
 
     def consume(self, task_id, timeout=None):
-        binding = self._create_binding(task_id)
-        consumer = self._create_consumer(binding)
-        consumer.consume()
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
         try:
-            return self.drain_events(consumer, timeout=timeout).values()[0]
+            binding = self._create_binding(task_id)
+            consumer = self._create_consumer(binding, channel)
+            consumer.consume()
+            try:
+                return self.drain_events(consumer, timeout=timeout).values()[0]
+            finally:
+                consumer.cancel()
         finally:
-            consumer.cancel()
+            channel.release()
+            conn.release()
 
     def get_many(self, task_ids, timeout=None):
-        bindings = [self._create_binding(task_id) for task_id in task_ids]
-        consumer = self._create_consumer(bindings)
-        consumer.consume()
-        ids = set(task_ids)
-        results = {}
-        cached_ids = set()
-        for task_id in ids:
-            try:
-                cached = self._cache[task_id]
-            except KeyError:
-                pass
-            else:
-                if cached["status"] in states.READY_STATES:
-                    results[task_id] = cached
-                    cached_ids.add(task_id)
-        ids ^= cached_ids
+        conn = self.pool.acquire(block=True)
+        channel = conn.channel()
         try:
-            while ids:
-                r = self.drain_events(consumer, timeout=timeout)
-                results.update(r)
-                ids ^= set(r.keys())
+            bindings = [self._create_binding(task_id) for task_id in task_ids]
+            consumer = self._create_consumer(bindings, channel)
+            consumer.consume()
+            ids = set(task_ids)
+            cached_ids = set()
+            for task_id in ids:
+                try:
+                    cached = self._cache[task_id]
+                except KeyError:
+                    pass
+                else:
+                    if cached["status"] in states.READY_STATES:
+                        yield task_id, cached
+                        cached_ids.add(task_id)
+            ids ^= cached_ids
+            try:
+                while ids:
+                    r = self.drain_events(consumer, timeout=timeout)
+                    ids ^= set(r.keys())
+                    for ready_id, ready_meta in r.items():
+                        yield ready_id, ready_meta
+            finally:
+                consumer.cancel()
         finally:
-            consumer.cancel()
-
-        return results
+            channel.close()
+            conn.release()
 
     def close(self):
-        if self._channel is not None:
-            self._channel.close()
-            self._channel = None
-        if self._connection is not None:
-            self._connection.close()
-            self._connection = None
-
-    @property
-    def connection(self):
-        if not self._connection:
-            self._connection = self.app.broker_connection()
-        return self._connection
+        if self._pool is not None:
+            self._pool.close()
+            self._pool = None
 
     @property
-    def channel(self):
-        if not self._channel:
-            self._channel = self.connection.channel()
-        return self._channel
+    def pool(self):
+        if not self._pool:
+            self._pool = self.app.broker_connection().Pool(self.connection_max)
+        return self._pool
 
     def reload_task_result(self, task_id):
         raise NotImplementedError(

+ 11 - 1
celery/bin/celeryd.py

@@ -83,6 +83,11 @@ class WorkerCommand(Command):
 
     def run(self, *args, **kwargs):
         kwargs.pop("app", None)
+        # Pools like eventlet/gevent needs to patch libs as early
+        # as possible.
+        from celery import concurrency
+        kwargs["pool"] = concurrency.get_implementation(
+                    kwargs.get("pool") or self.app.conf.CELERYD_POOL)
         return self.app.Worker(**kwargs).run()
 
     def get_options(self):
@@ -91,7 +96,12 @@ class WorkerCommand(Command):
             Option('-c', '--concurrency',
                 default=conf.CELERYD_CONCURRENCY,
                 action="store", dest="concurrency", type="int",
-                help="Number of child processes processing the queue."),
+                help="Number of worker threads/processes"),
+            Option('-P', '--pool',
+                default=conf.CELERYD_POOL,
+                action="store", dest="pool", type="str",
+                help="Pool implementation: "
+                     "processes (default), eventlet or gevent."),
             Option('--purge', '--discard', default=False,
                 action="store_true", dest="discard",
                 help="Discard all waiting tasks before the server is"

+ 1 - 1
celery/bin/celeryd_detach.py

@@ -1,7 +1,7 @@
 import os
 import sys
 
-from optparse import OptionParser, BadOptionError, make_option as Option
+from optparse import OptionParser, BadOptionError
 
 from celery import __version__
 from celery.bin.base import daemon_options

+ 11 - 0
celery/concurrency/__init__.py

@@ -0,0 +1,11 @@
+from celery.utils import get_cls_by_name
+
+ALIASES = {
+    "processes": "celery.concurrency.processes.TaskPool",
+    "eventlet": "celery.concurrency.evlet.TaskPool",
+    "gevent": "celery.concurrency.evg.TaskPool",
+}
+
+
+def get_implementation(cls):
+    return get_cls_by_name(cls, ALIASES)

+ 122 - 0
celery/concurrency/base.py

@@ -0,0 +1,122 @@
+import sys
+import traceback
+
+from celery import log
+from celery.datastructures import ExceptionInfo
+from celery.utils.functional import partial
+from celery.utils import timer2
+
+
+def apply_target(target, args=(), kwargs={}, callback=None,
+        accept_callback=None):
+    if accept_callback:
+        accept_callback()
+    callback(target(*args, **kwargs))
+
+
+class BasePool(object):
+    RUN = 0x1
+    CLOSE = 0x2
+    TERMINATE = 0x3
+
+    Timer = timer2.Timer
+
+    signal_safe = True
+
+    _state = None
+    _pool = None
+
+    def __init__(self, limit=None, putlocks=True, logger=None, **options):
+        self.limit = limit
+        self.putlocks = putlocks
+        self.logger = logger or log.get_default_logger()
+        self.options = options
+
+    def on_start(self):
+        pass
+
+    def on_stop(self):
+        pass
+
+    def on_apply(self, *args, **kwargs):
+        pass
+
+    def stop(self):
+        self._state = self.CLOSE
+        self.on_stop()
+        self._state = self.TERMINATE
+
+    def terminate(self):
+        self._state = self.TERMINATE
+        self.on_terminate()
+
+    def start(self):
+        self.on_start()
+        self._state = self.RUN
+
+    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
+            errbacks=None, accept_callback=None, timeout_callback=None,
+            **compat):
+        """Equivalent of the :func:`apply` built-in function.
+
+        All `callbacks` and `errbacks` should complete immediately since
+        otherwise the thread which handles the result will get blocked.
+
+        """
+        args = args or []
+        kwargs = kwargs or {}
+        callbacks = callbacks or []
+        errbacks = errbacks or []
+
+        on_ready = partial(self.on_ready, callbacks, errbacks)
+        on_worker_error = partial(self.on_worker_error, errbacks)
+
+        self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
+            target, args, kwargs))
+
+        return self.on_apply(target, args, kwargs,
+                             callback=on_ready,
+                             accept_callback=accept_callback,
+                             timeout_callback=timeout_callback,
+                             error_callback=on_worker_error,
+                             waitforslot=self.putlocks)
+
+    def on_ready(self, callbacks, errbacks, ret_value):
+        """What to do when a worker task is ready and its return value has
+        been collected."""
+
+        if isinstance(ret_value, ExceptionInfo):
+            if isinstance(ret_value.exception, (
+                    SystemExit, KeyboardInterrupt)):
+                raise ret_value.exception
+            [self.safe_apply_callback(errback, ret_value)
+                    for errback in errbacks]
+        else:
+            [self.safe_apply_callback(callback, ret_value)
+                    for callback in callbacks]
+
+    def on_worker_error(self, errbacks, exc):
+        einfo = ExceptionInfo((exc.__class__, exc, None))
+        [errback(einfo) for errback in errbacks]
+
+    def safe_apply_callback(self, fun, *args):
+        try:
+            fun(*args)
+        except:
+            self.logger.error("Pool callback raised exception: %s" % (
+                traceback.format_exc(), ),
+                exc_info=sys.exc_info())
+
+    def blocking(self, fun, *args, **kwargs):
+        return fun(*args, **kwargs)
+
+    def _get_info(self):
+        return {}
+
+    @property
+    def info(self):
+        return self._get_info()
+
+    @property
+    def active(self):
+        return self._state == self.RUN

+ 31 - 0
celery/concurrency/evg.py

@@ -0,0 +1,31 @@
+from gevent import Greenlet
+from gevent.pool import Pool
+
+from celery.concurrency.base import apply_target, BasePool
+
+
+class TaskPool(BasePool):
+    Pool = Pool
+
+    signal_safe = False
+
+    def on_start(self):
+        self._pool = self.Pool(self.limit)
+
+    def on_stop(self):
+        if self._pool is not None:
+            self._pool.join()
+
+    def on_apply(self, target, args=None, kwargs=None, callback=None,
+            accept_callback=None, **_):
+        return self._pool.spawn(apply_target, target, args, kwargs,
+                                callback, accept_callback)
+
+    def blocking(self, fun, *args, **kwargs):
+        Greenlet.spawn(fun, *args, **kwargs).get()
+
+    @classmethod
+    def on_import(cls):
+        from gevent import monkey
+        monkey.patch_all()
+TaskPool.on_import()

+ 101 - 0
celery/concurrency/evlet.py

@@ -0,0 +1,101 @@
+import sys
+
+from time import time
+
+import eventlet
+import eventlet.debug
+eventlet.monkey_patch()
+eventlet.debug.hub_prevent_multiple_readers(False)
+
+from eventlet import GreenPool
+from eventlet.greenthread import spawn, spawn_after_local
+from greenlet import GreenletExit
+
+from celery.concurrency.base import apply_target, BasePool
+from celery.utils import timer2
+
+
+class Schedule(timer2.Schedule):
+
+    def __init__(self, *args, **kwargs):
+        super(Schedule, self).__init__(*args, **kwargs)
+        self._queue = set()
+
+    def enter(self, entry, eta=None, priority=0):
+        try:
+            timer2.to_timestamp(eta)
+        except OverflowError:
+            if not self.handle_error(sys.exc_info()):
+                raise
+
+        now = time()
+        if eta is None:
+            eta = now
+        secs = eta - now
+
+        g = spawn_after_local(secs, entry)
+        self._queue.add(g)
+        g.link(self._entry_exit, entry)
+        g.entry = entry
+        g.eta = eta
+        g.priority = priority
+        g.cancelled = False
+
+        return g
+
+    def _entry_exit(self, g, entry):
+        try:
+            try:
+                g.wait()
+            except GreenletExit:
+                entry.cancel()
+                g.cancelled = True
+        finally:
+            self._queue.discard(g)
+
+    def clear(self):
+        queue = self._queue
+        while queue:
+            try:
+                queue.pop().cancel()
+            except KeyError:
+                pass
+
+    @property
+    def queue(self):
+        return [(g.eta, g.priority, g.entry) for g in self._queue]
+
+
+class Timer(timer2.Timer):
+    Schedule = Schedule
+
+    def ensure_started(self):
+        pass
+
+    def stop(self):
+        self.schedule.clear()
+
+    def start(self):
+        pass
+
+
+class TaskPool(BasePool):
+    Pool = GreenPool
+    Timer = Timer
+
+    signal_safe = False
+
+    def on_start(self):
+        self._pool = self.Pool(self.limit)
+
+    def on_stop(self):
+        if self._pool is not None:
+            self._pool.waitall()
+
+    def on_apply(self, target, args=None, kwargs=None, callback=None,
+            accept_callback=None, **_):
+        self._pool.spawn(apply_target, target, args, kwargs,
+                         callback, accept_callback)
+
+    def blocking(self, fun, *args, **kwargs):
+        return spawn(fun, *args, **kwargs).wait()

+ 9 - 75
celery/concurrency/processes/__init__.py

@@ -3,17 +3,11 @@
 Process Pools.
 
 """
-import sys
-import traceback
-
-from celery import log
-from celery.datastructures import ExceptionInfo
-from celery.utils.functional import partial
-
+from celery.concurrency.base import BasePool
 from celery.concurrency.processes.pool import Pool, RUN
 
 
-class TaskPool(object):
+class TaskPool(BasePool):
     """Process Pool for processing tasks in parallel.
 
     :param processes: see :attr:`processes`.
@@ -31,96 +25,36 @@ class TaskPool(object):
     """
     Pool = Pool
 
-    def __init__(self, processes=None, putlocks=True, logger=None, **options):
-        self.processes = processes
-        self.putlocks = putlocks
-        self.logger = logger or log.get_default_logger()
-        self.options = options
-        self._pool = None
-
-    def start(self):
+    def on_start(self):
         """Run the task pool.
 
         Will pre-fork all workers so they're ready to accept tasks.
 
         """
-        self._pool = self.Pool(processes=self.processes, **self.options)
+        self._pool = self.Pool(processes=self.limit, **self.options)
+        self.on_apply = self._pool.apply_async
 
-    def stop(self):
+    def on_stop(self):
         """Gracefully stop the pool."""
         if self._pool is not None and self._pool._state == RUN:
             self._pool.close()
             self._pool.join()
             self._pool = None
 
-    def terminate(self):
+    def on_terminate(self):
         """Force terminate the pool."""
         if self._pool is not None:
             self._pool.terminate()
             self._pool = None
 
-    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
-            errbacks=None, accept_callback=None, timeout_callback=None,
-            **compat):
-        """Equivalent of the :func:`apply` built-in function.
-
-        All `callbacks` and `errbacks` should complete immediately since
-        otherwise the thread which handles the result will get blocked.
-
-        """
-        args = args or []
-        kwargs = kwargs or {}
-        callbacks = callbacks or []
-        errbacks = errbacks or []
-
-        on_ready = partial(self.on_ready, callbacks, errbacks)
-        on_worker_error = partial(self.on_worker_error, errbacks)
-
-        self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
-            target, args, kwargs))
-
-        return self._pool.apply_async(target, args, kwargs,
-                                      callback=on_ready,
-                                      accept_callback=accept_callback,
-                                      timeout_callback=timeout_callback,
-                                      error_callback=on_worker_error,
-                                      waitforslot=self.putlocks)
-
     def grow(self, n=1):
         return self._pool.grow(n)
 
     def shrink(self, n=1):
         return self._pool.shrink(n)
 
-    def on_worker_error(self, errbacks, exc):
-        einfo = ExceptionInfo((exc.__class__, exc, None))
-        [errback(einfo) for errback in errbacks]
-
-    def on_ready(self, callbacks, errbacks, ret_value):
-        """What to do when a worker task is ready and its return value has
-        been collected."""
-
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
-                raise ret_value.exception
-            [self.safe_apply_callback(errback, ret_value)
-                    for errback in errbacks]
-        else:
-            [self.safe_apply_callback(callback, ret_value)
-                    for callback in callbacks]
-
-    def safe_apply_callback(self, fun, *args):
-        try:
-            fun(*args)
-        except:
-            self.logger.error("Pool callback raised exception: %s" % (
-                traceback.format_exc(), ),
-                exc_info=sys.exc_info())
-
-    @property
-    def info(self):
-        return {"max-concurrency": self.processes,
+    def _get_info(self):
+        return {"max-concurrency": self.limit,
                 "processes": [p.pid for p in self._pool._pool],
                 "max-tasks-per-child": self._pool._maxtasksperchild,
                 "put-guarded-by-semaphore": self.putlocks,

+ 10 - 54
celery/concurrency/threads.py

@@ -1,52 +1,20 @@
-
-import threading
 from threadpool import ThreadPool, WorkRequest
 
-from celery import log
-from celery.utils.functional import partial
-from celery.datastructures import ExceptionInfo
-
-
-accept_lock = threading.Lock()
-
-
-def do_work(target, args=(), kwargs={}, callback=None,
-        accept_callback=None):
-    accept_lock.acquire()
-    try:
-        accept_callback()
-    finally:
-        accept_lock.release()
-    callback(target(*args, **kwargs))
-
+from celery.concurrency.base import apply_target, BasePool
 
-class TaskPool(object):
 
-    def __init__(self, processes, logger=None, **kwargs):
-        self.processes = processes
-        self.logger = logger or log.get_default_logger()
-        self._pool = None
+class TaskPool(BasePool):
 
-    def start(self):
-        self._pool = ThreadPool(self.processes)
+    def on_start(self):
+        self._pool = ThreadPool(self.limit)
 
-    def stop(self):
-        self._pool.dismissWorkers(self.processes, do_join=True)
+    def on_stop(self):
+        self._pool.dismissWorkers(self.limit, do_join=True)
 
-    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
-            errbacks=None, accept_callback=None, **compat):
-        args = args or []
-        kwargs = kwargs or {}
-        callbacks = callbacks or []
-        errbacks = errbacks or []
-
-        on_ready = partial(self.on_ready, callbacks, errbacks)
-
-        self.logger.debug("ThreadPool: Apply %s (args:%s kwargs:%s)" % (
-            target, args, kwargs))
-
-        req = WorkRequest(do_work, (target, args, kwargs, on_ready,
-                                    accept_callback))
+    def on_apply(self, target, args=None, kwargs=None, callback=None,
+            accept_callback=None, **_):
+        req = WorkRequest(apply_target, (target, args, kwargs, callback,
+                                         accept_callback))
         self._pool.putRequest(req)
         # threadpool also has callback support,
         # but for some reason the callback is not triggered
@@ -54,15 +22,3 @@ class TaskPool(object):
         # Clear the results (if any), so it doesn't grow too large.
         self._pool._results_queue.queue.clear()
         return req
-
-    def on_ready(self, callbacks, errbacks, ret_value):
-        """What to do when a worker task is ready and its return value has
-        been collected."""
-
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):    # pragma: no cover
-                raise ret_value.exception
-            [errback(ret_value) for errback in errbacks]
-        else:
-            [callback(ret_value) for callback in callbacks]

+ 1 - 3
celery/events/__init__.py

@@ -13,6 +13,7 @@ from celery.utils import gen_unique_id
 
 event_exchange = Exchange("celeryev", type="topic")
 
+
 def create_event(type, fields):
     std = {"type": type,
            "timestamp": fields.get("timestamp") or time.time()}
@@ -156,7 +157,6 @@ class EventReceiver(object):
             by calling `consumer.channel.close()`.
 
         """
-        conf = self.app.conf
         consumer = Consumer(self.connection.channel(),
                             queues=[self.queue],
                             no_ack=True)
@@ -188,7 +188,6 @@ class EventReceiver(object):
                               timeout=timeout,
                               wakeup=wakeup))
 
-
     def wakeup_workers(self, channel=None):
         self.app.control.broadcast("heartbeat",
                                    connection=self.connection,
@@ -211,7 +210,6 @@ class EventReceiver(object):
         self.process(type, create_event(type, message_data))
 
 
-
 class Events(object):
 
     def __init__(self, app):

+ 4 - 0
celery/exceptions.py

@@ -9,6 +9,10 @@ Task of kind %s is not registered, please make sure it's imported.
 """.strip()
 
 
+class SystemTerminate(SystemExit):
+    pass
+
+
 class QueueNotFound(KeyError):
     """Task routed to a queue not in CELERY_QUEUES."""
     pass

+ 7 - 4
celery/result.py

@@ -192,7 +192,6 @@ class AsyncResult(BaseAsyncResult):
                                           task_name=task_name, app=app)
 
 
-
 class TaskSetResult(object):
     """Working with :class:`~celery.task.sets.TaskSet` results.
 
@@ -355,6 +354,11 @@ class TaskSetResult(object):
                         time.time() >= time_start + timeout):
                     raise TimeoutError("join operation timed out.")
 
+    def iter_native(self, timeout=None):
+        backend = self.subtasks[0].backend
+        ids = [subtask.task_id for subtask in self.subtasks]
+        return backend.get_many(ids, timeout=timeout)
+
     def join_native(self, timeout=None, propagate=True):
         """Backend optimized version of :meth:`join`.
 
@@ -368,9 +372,9 @@ class TaskSetResult(object):
         """
         backend = self.subtasks[0].backend
         results = PositionQueue(length=self.total)
-        ids = [subtask.task_id for subtask in self.subtasks]
 
-        states = backend.get_many(ids, timeout=timeout)
+        ids = [subtask.task_id for subtask in self.subtasks]
+        states = dict(backend.get_many(ids, timeout=timeout))
 
         for task_id, meta in states.items():
             index = self.subtasks.index(task_id)
@@ -378,7 +382,6 @@ class TaskSetResult(object):
 
         return list(results)
 
-
     def save(self, backend=None):
         """Save taskset result for later retrieval using :meth:`restore`.
 

+ 1 - 1
celery/schedules.py

@@ -251,7 +251,7 @@ class crontab(schedule):
     def remaining_estimate(self, last_run_at):
         """Returns when the periodic task should run next as a timedelta."""
         weekday = last_run_at.isoweekday()
-        if weekday == 7: # Sunday is day 0, not day 7.
+        if weekday == 7:    # Sunday is day 0, not day 7.
             weekday = 0
 
         execute_this_hour = (weekday in self.day_of_week and

+ 2 - 2
celery/tests/test_bin/test_celeryd.py

@@ -221,10 +221,10 @@ class test_signal_handlers(unittest.TestCase):
         terminated = False
         logger = get_logger()
 
-        def stop(self):
+        def stop(self, in_sighandler=False):
             self.stopped = True
 
-        def terminate(self):
+        def terminate(self, in_sighandler=False):
             self.terminated = True
 
     def psig(self, fun, *args, **kwargs):

+ 2 - 3
celery/tests/test_concurrency_processes.py

@@ -165,13 +165,12 @@ class test_TaskPool(unittest.TestCase):
 
     def test_info(self):
         pool = TaskPool(10)
-        procs = [Object(pid=i) for i in range(pool.processes)]
+        procs = [Object(pid=i) for i in range(pool.limit)]
         pool._pool = Object(_pool=procs,
                             _maxtasksperchild=None,
                             timeout=10,
                             soft_timeout=5)
         info = pool.info
-        self.assertEqual(info["max-concurrency"], pool.processes)
-        self.assertEqual(len(info["processes"]), pool.processes)
+        self.assertEqual(info["max-concurrency"], pool.limit)
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertEqual(info["timeouts"], (5, 10))

+ 1 - 1
celery/tests/test_pool.py

@@ -27,7 +27,7 @@ class TestTaskPool(unittest.TestCase):
 
     def test_attrs(self):
         p = TaskPool(2)
-        self.assertEqual(p.processes, 2)
+        self.assertEqual(p.limit, 2)
         self.assertIsInstance(p.logger, logging.Logger)
         self.assertIsNone(p._pool)
 

+ 0 - 3
celery/tests/test_utils.py

@@ -1,12 +1,9 @@
 import pickle
-import sys
 import unittest2 as unittest
 
 from celery import utils
 from celery.utils import promise, mpromise, maybe_promise
 
-from celery.tests.utils import execute_context, mask_modules
-
 
 def double(x):
     return x * 2

+ 14 - 9
celery/tests/test_worker.py

@@ -9,6 +9,7 @@ from kombu.connection import BrokerConnection
 from celery.utils.timer2 import Timer
 
 from celery.app import app_or_default
+from celery.concurrency.base import BasePool
 from celery.decorators import task as task_dec
 from celery.decorators import periodic_task as periodic_task_dec
 from celery.serialization import pickle
@@ -117,7 +118,7 @@ class MockBackend(object):
         self._acked = True
 
 
-class MockPool(object):
+class MockPool(BasePool):
     _terminated = False
     _stopped = False
 
@@ -436,13 +437,16 @@ class test_Consumer(unittest.TestCase):
         l.broadcast_consumer = MockConsumer()
         l.qos = _QoS()
         l.connection = BrokerConnection()
+        l.iterations = 0
 
         def raises_KeyError(limit=None):
-            yield True
-            l.iterations = 1
-            raise KeyError("foo")
+            l.iterations += 1
+            if l.qos.prev != l.qos.next:
+                l.qos.update()
+            if l.iterations >= 2:
+                raise KeyError("foo")
 
-        l._mainloop = raises_KeyError
+        l.consume_messages = raises_KeyError
         self.assertRaises(KeyError, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
@@ -456,11 +460,10 @@ class test_Consumer(unittest.TestCase):
         l.connection = BrokerConnection()
 
         def raises_socket_error(limit=None):
-            yield True
             l.iterations = 1
             raise socket.error("foo")
 
-        l._mainloop = raises_socket_error
+        l.consume_messages = raises_socket_error
         self.assertRaises(socket.error, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
@@ -509,8 +512,10 @@ class test_WorkController(unittest.TestCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
-        worker.process_task(task)
-        worker.pool.stop()
+        worker.components = []
+        worker._state = worker.RUN
+        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
+        self.assertEqual(worker._state, worker.TERMINATE)
 
     def test_process_task_raise_regular(self):
         worker = self.worker

+ 2 - 1
celery/tests/test_worker_job.py

@@ -10,6 +10,7 @@ from kombu.transport.base import Message
 
 from celery import states
 from celery.app import app_or_default
+from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.decorators import task as task_dec
 from celery.exceptions import RetryTaskError, NotRegistered
@@ -408,7 +409,7 @@ class test_TaskRequest(unittest.TestCase):
         tid = gen_unique_id()
         tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
 
-        class MockPool(object):
+        class MockPool(BasePool):
             target = None
             args = None
             kwargs = None

+ 2 - 1
celery/utils/timer2.py

@@ -131,13 +131,14 @@ class Schedule(object):
 
 class Timer(Thread):
     Entry = Entry
+    Schedule = Schedule
 
     running = False
     on_tick = None
     _timer_count = count(1).next
 
     def __init__(self, schedule=None, on_error=None, on_tick=None, **kwargs):
-        self.schedule = schedule or Schedule(on_error=on_error)
+        self.schedule = schedule or self.Schedule(on_error=on_error)
         self.on_tick = on_tick or self.on_tick
 
         Thread.__init__(self)

+ 47 - 23
celery/worker/__init__.py

@@ -4,10 +4,12 @@ import traceback
 from multiprocessing.util import Finalize
 
 from celery import beat
+from celery import concurrency as _concurrency
 from celery import registry
 from celery import platforms
 from celery import signals
 from celery.app import app_or_default
+from celery.exceptions import SystemTerminate
 from celery.log import SilenceRepeated
 from celery.utils import noop, instantiate
 
@@ -50,6 +52,9 @@ def process_initializer(app, hostname):
 
 class WorkController(object):
     """Unmanaged worker instance."""
+    RUN = RUN
+    CLOSE = CLOSE
+    TERMINATE = TERMINATE
 
     #: The number of simultaneous processes doing work (default:
     #: :setting:`CELERYD_CONCURRENCY`)
@@ -80,9 +85,6 @@ class WorkController(object):
     #: processing.
     ready_queue = None
 
-    #: Instance of :class:`celery.worker.controllers.ScheduleController`.
-    schedule_controller = None
-
     #: Instance of :class:`celery.worker.controllers.Mediator`.
     mediator = None
 
@@ -114,11 +116,13 @@ class WorkController(object):
         if send_events is None:
             send_events = conf.CELERY_SEND_EVENTS
         self.send_events = send_events
-        self.pool_cls = pool_cls or conf.CELERYD_POOL
+        self.pool_cls = _concurrency.get_implementation(
+                            pool_cls or conf.CELERYD_POOL)
         self.consumer_cls = consumer_cls or conf.CELERYD_CONSUMER
         self.mediator_cls = mediator_cls or conf.CELERYD_MEDIATOR
         self.eta_scheduler_cls = eta_scheduler_cls or \
                                     conf.CELERYD_ETA_SCHEDULER
+
         self.autoscaler_cls = autoscaler_cls or \
                                     conf.CELERYD_AUTOSCALER
         self.schedule_filename = schedule_filename or \
@@ -153,7 +157,7 @@ class WorkController(object):
             Finalize(persistence, persistence.save, exitpriority=5)
 
         # Queues
-        if disable_rate_limits:
+        if self.disable_rate_limits:
             self.ready_queue = FastQueue()
             self.ready_queue.put = self.process_task
         else:
@@ -176,6 +180,10 @@ class WorkController(object):
                                 timeout=self.task_time_limit,
                                 soft_timeout=self.task_soft_time_limit,
                                 putlocks=self.pool_putlocks)
+        if not self.eta_scheduler_cls:
+            # Default Timer is set by the pool, as e.g. eventlet
+            # needs a custom implementation.
+            self.eta_scheduler_cls = self.pool.Timer
 
         if autoscale:
             self.autoscaler = instantiate(self.autoscaler_cls, self.pool,
@@ -184,15 +192,16 @@ class WorkController(object):
                                           logger=self.logger)
 
         self.mediator = None
-        if not disable_rate_limits:
+        if not self.disable_rate_limits:
             self.mediator = instantiate(self.mediator_cls, self.ready_queue,
                                         app=self.app,
                                         callback=self.process_task,
                                         logger=self.logger)
+
         self.scheduler = instantiate(self.eta_scheduler_cls,
-                                     precision=eta_scheduler_precision,
-                                     on_error=self.on_timer_error,
-                                     on_tick=self.on_timer_tick)
+                                precision=eta_scheduler_precision,
+                                on_error=self.on_timer_error,
+                                on_tick=self.on_timer_tick)
 
         self.beat = None
         if self.embed_clockservice:
@@ -226,13 +235,20 @@ class WorkController(object):
 
     def start(self):
         """Starts the workers main loop."""
-        self._state = RUN
+        self._state = self.RUN
 
-        for i, component in enumerate(self.components):
-            self.logger.debug("Starting thread %s..." % (
-                                    component.__class__.__name__))
-            self._running = i + 1
-            component.start()
+        try:
+            for i, component in enumerate(self.components):
+                self.logger.debug("Starting thread %s..." % (
+                                        component.__class__.__name__))
+                self._running = i + 1
+                self.pool.blocking(component.start)
+        except SystemTerminate:
+            self.terminate()
+            raise SystemExit()
+        except (SystemExit, KeyboardInterrupt), exc:
+            self.stop()
+            raise exc
 
     def process_task(self, wrapper):
         """Process task by sending it to the pool of workers."""
@@ -243,25 +259,33 @@ class WorkController(object):
             except Exception, exc:
                 self.logger.critical("Internal error %s: %s\n%s" % (
                                 exc.__class__, exc, traceback.format_exc()))
-        except (SystemExit, KeyboardInterrupt):
+        except SystemTerminate:
+            self.terminate()
+            raise SystemExit()
+        except (SystemExit, KeyboardInterrupt), exc:
             self.stop()
+            raise exc
 
-    def stop(self):
+    def stop(self, in_sighandler=False):
         """Graceful shutdown of the worker server."""
-        self._shutdown(warm=True)
+        if in_sighandler and not self.pool.signal_safe:
+            return
+        self.pool.blocking(self._shutdown, warm=True)
 
-    def terminate(self):
+    def terminate(self, in_sighandler=False):
         """Not so graceful shutdown of the worker server."""
-        self._shutdown(warm=False)
+        if in_sighandler and not self.pool.signal_safe:
+            return
+        self.pool.blocking(self._shutdown, warm=False)
 
     def _shutdown(self, warm=True):
         what = (warm and "stopping" or "terminating").capitalize()
 
-        if self._state != RUN or self._running != len(self.components):
+        if self._state != self.RUN or self._running != len(self.components):
             # Not fully started, can safely exit.
             return
 
-        self._state = CLOSE
+        self._state = self.CLOSE
         signals.worker_shutdown.send(sender=self)
 
         for component in reversed(self.components):
@@ -273,7 +297,7 @@ class WorkController(object):
             stop()
 
         self.consumer.close_connection()
-        self._state = TERMINATE
+        self._state = self.TERMINATE
 
     def on_timer_error(self, exc_info):
         _, exc, _ = exc_info

+ 4 - 6
celery/worker/consumer.py

@@ -246,13 +246,14 @@ class Consumer(object):
         self.logger.debug("Consumer: Starting message consumer...")
         self.task_consumer.consume()
         self.broadcast_consumer.consume()
-        wait_for_message = self._mainloop().next
         self.logger.debug("Consumer: Ready to accept tasks!")
 
         while 1:
+            if not self.connection:
+                break
             if self.qos.prev != self.qos.next:
                 self.qos.update()
-            wait_for_message()
+            self.connection.drain_events()
 
     def on_task(self, task):
         """Handle received task.
@@ -313,6 +314,7 @@ class Consumer(object):
                     self.logger.critical(
                             "Couldn't ack %r: message:%r reason:%r" % (
                                 message.delivery_tag, message_data, exc))
+
             try:
                 task = TaskRequest.from_message(message, message_data, ack,
                                                 app=self.app,
@@ -438,10 +440,6 @@ class Consumer(object):
         self.heart = Heart(self.event_dispatcher)
         self.heart.start()
 
-    def _mainloop(self):
-        while 1:
-            yield self.connection.drain_events()
-
     def _open_connection(self):
         """Open connection.  May retry opening the connection if configuration
         allows that."""

+ 0 - 1
celery/worker/heartbeat.py

@@ -63,6 +63,5 @@ class Heart(threading.Thread):
             return
         self._state = "CLOSE"
         self._shutdown.set()
-        self._stopped.wait()            # blocks until this thread is done
         if self.isAlive():
             self.join(1e100)

+ 1 - 1
celery/worker/job.py

@@ -430,7 +430,7 @@ class TaskRequest(object):
         if self.task.acks_late:
             self.acknowledge()
 
-        runtime = time.time() - self.time_start
+        runtime = self.time_start and (time.time() - self.time_start) or 0
         self.send_event("task-succeeded", uuid=self.task_id,
                         result=repr(ret_value), runtime=runtime)
 

+ 15 - 0
examples/eventlet/celeryconfig.py

@@ -0,0 +1,15 @@
+import os
+import sys
+sys.path.insert(0, os.getcwd())
+
+CELERYD_POOL = "eventlet"
+
+BROKER_HOST = "localhost"
+BROKER_USER = "guest"
+BROKER_PASSWORD = "guest"
+BROKER_VHOST = "/"
+CELERY_DISABLE_RATE_LIMITS = True
+CELERY_RESULT_BACKEND = "amqp"
+CELERY_TASK_RESULT_EXPIRES = 30 * 60
+
+CELERY_IMPORTS = ("tasks", "webcrawler")

+ 14 - 0
examples/eventlet/tasks.py

@@ -0,0 +1,14 @@
+from celery.decorators import task
+from eventlet.green import urllib2
+
+
+@task(ignore_result=True)
+def urlopen(url):
+    print("Opening: %r" % (url, ))
+    try:
+        body = urllib2.urlopen(url).read()
+    except Exception, exc:
+        print("Exception for %r: %r" % (url, exc, ))
+        return url, 0
+    print("Done with: %r" % (url, ))
+    return url, 1

+ 41 - 0
examples/eventlet/webcrawler.py

@@ -0,0 +1,41 @@
+"""Recursive webcrawler example.
+
+One problem with this solution is that it does not remember
+urls it has already seen.
+
+To add support for this a bloom filter or redis sets can be used.
+
+"""
+
+from __future__ import with_statement
+
+import re
+import time
+import urlparse
+
+from celery.decorators import task
+from eventlet import Timeout
+from eventlet.green import urllib2
+
+# http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
+url_regex = re.compile(
+    r'\b(([\w-]+://?|www[.])[^\s()<>]+ (?:\([\w\d]+\)|([^[:punct:]\s]|/)))')
+
+
+def domain(url):
+    return urlparse.urlsplit(url)[1].split(":")[0]
+
+
+@task
+def crawl(url):
+    print("crawling: %r" % (url, ))
+    location = domain(url)
+    data = ''
+    with Timeout(5, False):
+        data = urllib2.urlopen(url).read()
+    for url_match in url_regex.finditer(data):
+        new_url = url_match.group(0)
+        # Don't destroy the internet
+        if location in domain(new_url):
+            crawl.delay(new_url)
+            time.sleep(0.3)

+ 15 - 0
examples/gevent/celeryconfig.py

@@ -0,0 +1,15 @@
+import os
+import sys
+sys.path.insert(0, os.getcwd())
+
+CELERYD_POOL = "gevent"
+
+BROKER_HOST = "localhost"
+BROKER_USER = "guest"
+BROKER_PASSWORD = "guest"
+BROKER_VHOST = "/"
+CELERY_DISABLE_RATE_LIMITS = True
+CELERY_RESULT_BACKEND = "amqp"
+CELERY_TASK_RESULT_EXPIRES = 30 * 60
+
+CELERY_IMPORTS = ("tasks", )

+ 15 - 0
examples/gevent/tasks.py

@@ -0,0 +1,15 @@
+import urllib2
+
+from celery.decorators import task
+
+
+@task(ignore_result=True)
+def urlopen(url):
+    print("Opening: %r" % (url, ))
+    try:
+        body = urllib2.urlopen(url).read()
+    except Exception, exc:
+        print("Exception for %r: %r" % (url, exc, ))
+        return url, 0
+    print("Done with: %r" % (url, ))
+    return url, 1