Browse Source

93% total coverage.

celery.bin.celeryd and celery.concurrency.processes now fully covered.
Ask Solem 15 years ago
parent
commit
1e458526f3

+ 3 - 12
celery/bin/celeryd.py

@@ -160,6 +160,7 @@ OPTION_LIST = (
 
 
 class Worker(object):
+    WorkController = WorkController
 
     def __init__(self, concurrency=conf.CELERYD_CONCURRENCY,
             loglevel=conf.CELERYD_LOG_LEVEL, logfile=conf.CELERYD_LOG_FILE,
@@ -195,15 +196,6 @@ class Worker(object):
         print("celery@%s v%s is starting." % (self.hostname,
                                               celery.__version__))
 
-
-        if conf.RESULT_BACKEND == "database" \
-                and self.settings.DATABASE_ENGINE == "sqlite3" and \
-                self.concurrency > 1:
-            warnings.warn("The sqlite3 database engine doesn't handle "
-                          "concurrency well. Will use a single process only.",
-                          UserWarning)
-            self.concurrency = 1
-
         if getattr(self.settings, "DEBUG", False):
             warnings.warn("Using settings.DEBUG leads to a memory leak, "
                     "never use this setting in a production environment!")
@@ -232,7 +224,6 @@ class Worker(object):
                 if queue not in conf.QUEUES:
                     if conf.CREATE_MISSING_QUEUES:
                         Router(queues=conf.QUEUES).add_queue(queue)
-                        print("QUEUES: %s" % conf.QUEUES)
                     else:
                         raise ImproperlyConfigured(
                             "Queue '%s' not defined in CELERY_QUEUES" % queue)
@@ -293,7 +284,7 @@ class Worker(object):
         }
 
     def run_worker(self):
-        worker = WorkController(concurrency=self.concurrency,
+        worker = self.WorkController(concurrency=self.concurrency,
                                 loglevel=self.loglevel,
                                 logfile=self.logfile,
                                 hostname=self.hostname,
@@ -381,7 +372,7 @@ def set_process_status(info):
     arg_start = "manage" in sys.argv[0] and 2 or 1
     if sys.argv[arg_start:]:
         info = "%s (%s)" % (info, " ".join(sys.argv[arg_start:]))
-    platform.set_mp_process_title("celeryd", info=info)
+    return platform.set_mp_process_title("celeryd", info=info)
 
 
 def run_worker(**options):

+ 7 - 6
celery/concurrency/processes/__init__.py

@@ -27,6 +27,7 @@ class TaskPool(object):
         The logger used for debugging.
 
     """
+    Pool = Pool
 
     def __init__(self, limit, logger=None, initializer=None,
             maxtasksperchild=None, timeout=None, soft_timeout=None,
@@ -46,11 +47,11 @@ class TaskPool(object):
         Will pre-fork all workers so they're ready to accept tasks.
 
         """
-        self._pool = Pool(processes=self.limit,
-                          initializer=self.initializer,
-                          timeout=self.timeout,
-                          soft_timeout=self.soft_timeout,
-                          maxtasksperchild=self.maxtasksperchild)
+        self._pool = self.Pool(processes=self.limit,
+                               initializer=self.initializer,
+                               timeout=self.timeout,
+                               soft_timeout=self.soft_timeout,
+                               maxtasksperchild=self.maxtasksperchild)
 
     def stop(self):
         """Gracefully stop the pool."""
@@ -96,7 +97,7 @@ class TaskPool(object):
 
         if isinstance(ret_value, ExceptionInfo):
             if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)): # pragma: no cover
+                    SystemExit, KeyboardInterrupt)):
                 raise ret_value.exception
             [errback(ret_value) for errback in errbacks]
         else:

+ 3 - 2
celery/platform.py

@@ -53,10 +53,11 @@ def set_process_title(progname, info=None):
     Only works if :mod`setproctitle` is installed.
 
     """
+    proctitle = "[%s]" % progname
+    proctitle = info and "%s %s" % (proctitle, info) or proctitle
     if _setproctitle:
-        proctitle = "[%s]" % progname
-        proctitle = info and "%s %s" % (proctitle, info) or proctitle
         _setproctitle(proctitle)
+    return proctitle
 
 
 def set_mp_process_title(progname, info=None):

+ 0 - 0
celery/tests/test_bin/__init__.py


+ 305 - 0
celery/tests/test_bin/test_celeryd.py

@@ -0,0 +1,305 @@
+import logging
+import os
+import sys
+import unittest2 as unittest
+
+from multiprocessing import get_logger, current_process
+from StringIO import StringIO
+
+from celery import conf
+from celery import platform
+from celery import signals
+from celery.bin import celeryd as cd
+from celery.exceptions import ImproperlyConfigured
+from celery.utils import noop
+from celery.utils import patch
+from celery.utils.functional import wraps
+
+from celery.tests.compat import catch_warnings
+from celery.tests.utils import execute_context
+
+
+patch.ensure_process_aware_logger()
+
+def disable_stdouts(fun):
+
+    @wraps(fun)
+    def disable(*args, **kwargs):
+        sys.stdout, sys.stderr = StringIO(), StringIO()
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            sys.stdout = sys.__stdout__
+            sys.stderr = sys.__stderr__
+
+    return disable
+
+
+
+class _WorkController(object):
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def start(self):
+        pass
+
+
+class Worker(cd.Worker):
+    WorkController = _WorkController
+
+
+class test_Worker(unittest.TestCase):
+    Worker = Worker
+
+    @disable_stdouts
+    def test_queues_string(self):
+        worker = self.Worker(queues="foo,bar,baz")
+        self.assertEqual(worker.queues, ["foo", "bar", "baz"])
+
+    @disable_stdouts
+    def test_loglevel_string(self):
+        worker = self.Worker(loglevel="INFO")
+        self.assertEqual(worker.loglevel, logging.INFO)
+
+    @disable_stdouts
+    def test_run_worker(self):
+        handlers = {}
+
+        def i(sig, handler):
+            handlers[sig] = handler
+
+        p = platform.install_signal_handler
+        platform.install_signal_handler = i
+        try:
+            self.Worker().run_worker()
+            for sig in "SIGINT", "SIGHUP", "SIGTERM":
+                self.assertIn(sig, handlers)
+        finally:
+            platform.install_signal_handler = p
+
+    @disable_stdouts
+    def test_startup_info(self):
+        worker = self.Worker()
+        worker.run()
+        self.assertTrue(worker.startup_info())
+        worker.loglevel = logging.DEBUG
+        self.assertTrue(worker.startup_info())
+        worker.loglevel = logging.INFO
+        self.assertTrue(worker.startup_info())
+
+    @disable_stdouts
+    def test_run(self):
+        self.Worker().run()
+        self.Worker(discard=True).run()
+
+        worker = self.Worker()
+        worker.init_loader()
+        worker.settings.DEBUG = True
+
+        def with_catch_warnings(log):
+            worker.run()
+            self.assertIn("memory leak", log[0].message.args[0])
+
+        context = catch_warnings(record=True)
+        execute_context(context, with_catch_warnings)
+        worker.settings.DEBUG = False
+
+    @disable_stdouts
+    def test_purge_messages(self):
+        self.Worker().purge_messages()
+
+    @disable_stdouts
+    def test_init_queues(self):
+        p, conf.QUEUES = conf.QUEUES, {
+                "celery": {"exchange": "celery",
+                           "binding_key": "celery"},
+                "video": {"exchange": "video",
+                           "binding_key": "video"}}
+        try:
+            self.Worker(queues=["video"]).init_queues()
+            self.assertIn("video", conf.QUEUES)
+            self.assertNotIn("celery", conf.QUEUES)
+
+            conf.CREATE_MISSING_QUEUES = False
+            self.assertRaises(ImproperlyConfigured,
+                    self.Worker(queues=["image"]).init_queues)
+            conf.CREATE_MISSING_QUEUES = True
+            self.Worker(queues=["image"]).init_queues()
+            self.assertIn("image", conf.QUEUES)
+        finally:
+            conf.QUEUES = p
+
+    @disable_stdouts
+    def test_on_listener_ready(self):
+
+        worker_ready_sent = [False]
+        def on_worker_ready(**kwargs):
+            worker_ready_sent[0] = True
+
+        signals.worker_ready.connect(on_worker_ready)
+
+        self.Worker().on_listener_ready(object())
+        self.assertTrue(worker_ready_sent[0])
+
+
+
+class test_funs(unittest.TestCase):
+
+    @disable_stdouts
+    def test_dump_version(self):
+        self.assertRaises(SystemExit, cd.dump_version)
+
+    @disable_stdouts
+    def test_set_process_status(self):
+        prev1, sys.argv = sys.argv, ["Arg0"]
+        try:
+            st = cd.set_process_status("Running")
+            self.assertIn("celeryd", st)
+            self.assertIn("Running", st)
+            prev2, sys.argv = sys.argv, ["Arg0", "Arg1"]
+            try:
+                st = cd.set_process_status("Running")
+                self.assertIn("celeryd", st)
+                self.assertIn("Running", st)
+                self.assertIn("Arg1", st)
+            finally:
+                sys.argv = prev2
+        finally:
+            sys.argv = prev1
+
+    @disable_stdouts
+    def test_parse_options(self):
+        opts = cd.parse_options(["--concurrency=512"])
+        self.assertEqual(opts.concurrency, 512)
+
+    @disable_stdouts
+    def test_run_worker(self):
+        p, cd.Worker = cd.Worker, Worker
+        try:
+            cd.run_worker(discard=True)
+        finally:
+            cd.Worker = p
+
+    @disable_stdouts
+    def test_main(self):
+        p, cd.Worker = cd.Worker, Worker
+        s, sys.argv = sys.argv, ["celeryd", "--discard"]
+        try:
+            cd.main()
+        finally:
+            cd.Worker = p
+            sys.argv = s
+
+
+
+class test_signal_handlers(unittest.TestCase):
+
+    class _Worker(object):
+        stopped = False
+        terminated = False
+        logger = get_logger()
+
+        def stop(self):
+            self.stopped = True
+
+        def terminate(self):
+            self.terminated = True
+
+    def psig(self, fun, *args, **kwargs):
+        handlers = {}
+
+        def i(sig, handler):
+            handlers[sig] = handler
+
+        p, platform.install_signal_handler = platform.install_signal_handler, i
+        try:
+            fun(*args, **kwargs)
+            return handlers
+        finally:
+            platform.install_signal_handler = p
+
+    @disable_stdouts
+    def test_worker_int_handler(self):
+        worker = self._Worker()
+        handlers = self.psig(cd.install_worker_int_handler, worker)
+
+        next_handlers = {}
+        def i(sig, handler):
+            next_handlers[sig] = handler
+        p = platform.install_signal_handler
+        platform.install_signal_handler = i
+        try:
+            self.assertRaises(SystemExit, handlers["SIGINT"],
+                              "SIGINT", object())
+            self.assertTrue(worker.stopped)
+        finally:
+            platform.install_signal_handler = p
+
+        self.assertRaises(SystemExit, next_handlers["SIGINT"],
+                          "SIGINT", object())
+        self.assertTrue(worker.terminated)
+
+    @disable_stdouts
+    def test_worker_int_handler_only_stop_MainProcess(self):
+        process = current_process()
+        name, process.name = process.name, "OtherProcess"
+        try:
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_int_handler, worker)
+            self.assertRaises(SystemExit, handlers["SIGINT"],
+                            "SIGINT", object())
+            self.assertFalse(worker.stopped)
+        finally:
+            process.name = name
+
+    @disable_stdouts
+    def test_worker_int_again_handler_only_stop_MainProcess(self):
+        process = current_process()
+        name, process.name = process.name, "OtherProcess"
+        try:
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_int_again_handler, worker)
+            self.assertRaises(SystemExit, handlers["SIGINT"],
+                            "SIGINT", object())
+            self.assertFalse(worker.terminated)
+        finally:
+            process.name = name
+
+    @disable_stdouts
+    def test_worker_term_handler(self):
+        worker = self._Worker()
+        handlers = self.psig(cd.install_worker_term_handler, worker)
+        self.assertRaises(SystemExit, handlers["SIGTERM"],
+                          "SIGTERM", object())
+        self.assertTrue(worker.stopped)
+
+    @disable_stdouts
+    def test_worker_term_handler_only_stop_MainProcess(self):
+        process = current_process()
+        name, process.name = process.name, "OtherProcess"
+        try:
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_term_handler, worker)
+            self.assertRaises(SystemExit, handlers["SIGTERM"],
+                          "SIGTERM", object())
+            self.assertFalse(worker.stopped)
+        finally:
+            process.name = name
+
+    @disable_stdouts
+    def test_worker_restart_handler(self):
+        argv = []
+
+        def _execv(*args):
+            argv.extend(args)
+
+        execv, os.execv = os.execv, _execv
+        try:
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_restart_handler, worker)
+            handlers["SIGHUP"]("SIGHUP", object())
+            self.assertTrue(worker.stopped)
+            self.assertTrue(argv)
+        finally:
+            os.execv = execv

+ 95 - 0
celery/tests/test_concurrency_processes.py

@@ -0,0 +1,95 @@
+import sys
+import unittest2 as unittest
+
+from celery.concurrency import processes as mp
+from celery.datastructures import ExceptionInfo
+
+
+def to_excinfo(exc):
+    try:
+        raise exc
+    except:
+        return ExceptionInfo(sys.exc_info())
+
+
+class MockPool(object):
+    started = False
+    closed = False
+    joined = False
+    terminated = False
+    _state = None
+
+    def __init__(self, *args, **kwargs):
+        self.started = True
+        self._state = mp.RUN
+
+    def close(self):
+        self.closed = True
+        self._state = "CLOSE"
+
+    def join(self):
+        self.joined = True
+
+    def terminate(self):
+        self.terminated = True
+
+    def apply_async(self, *args, **kwargs):
+        pass
+
+
+class TaskPool(mp.TaskPool):
+    Pool = MockPool
+
+
+class test_TaskPool(unittest.TestCase):
+
+    def test_start(self):
+        pool = TaskPool(10)
+        pool.start()
+        self.assertTrue(pool._pool.started)
+
+        _pool = pool._pool
+        pool.stop()
+        self.assertTrue(_pool.closed)
+        self.assertTrue(_pool.joined)
+        pool.stop()
+
+        pool.start()
+        _pool = pool._pool
+        pool.terminate()
+        pool.terminate()
+        self.assertTrue(_pool.terminated)
+
+    def test_on_ready_exception(self):
+
+        scratch = [None]
+        def errback(retval):
+            scratch[0] = retval
+
+        pool = TaskPool(10)
+        exc = to_excinfo(KeyError("foo"))
+        pool.on_ready([], [errback], exc)
+        self.assertEqual(exc, scratch[0])
+
+    def test_on_ready_value(self):
+
+        scratch = [None]
+        def callback(retval):
+            scratch[0] = retval
+
+        pool = TaskPool(10)
+        retval = "the quick brown fox"
+        pool.on_ready([callback], [], retval)
+        self.assertEqual(retval, scratch[0])
+
+    def test_on_ready_exit_exception(self):
+        pool = TaskPool(10)
+        exc = to_excinfo(SystemExit("foo"))
+        self.assertRaises(SystemExit, pool.on_ready, [], [], exc)
+
+    def test_apply_async(self):
+        pool = TaskPool(10)
+        pool.start()
+        pool.apply_async(lambda x: x, (2, ), {})
+
+

+ 1 - 2
setup.cfg

@@ -6,9 +6,9 @@ cover3-package = celery
 cover3-exclude = celery
                  celery.conf
                  celery.tests.*
-                 celery.bin.celeryd
                  celery.bin.celerybeat
                  celery.bin.celeryev
+                 celery.platform
                  celery.utils.patch
                  celery.utils.compat
                  celery.utils.mail
@@ -18,7 +18,6 @@ cover3-exclude = celery
                  celery.contrib*
                  celery.concurrency.threads
                  celery.concurrency.processes.pool
-                 celery.platform
                  celery.backends.mongodb
                  celery.backends.tyrant
                  celery.backends.pyredis