فهرست منبع

More coverage

Ask Solem 14 سال پیش
والد
کامیت
32809ebc0c

+ 8 - 9
celery/apps/worker.py

@@ -49,6 +49,11 @@ def cpu_count():
     return 2
 
 
+def get_process_name():
+    if multiprocessing is not None:
+        return multiprocessing.current_process().name
+
+
 class Worker(object):
     WorkController = WorkController
 
@@ -274,9 +279,7 @@ class Worker(object):
 def install_worker_int_handler(worker):
 
     def _stop(signum, frame):
-        process_name = None
-        if multiprocessing:
-            process_name = multiprocessing.current_process().name
+        process_name = get_process_name()
         if not process_name or process_name == "MainProcess":
             worker.logger.warn(
                 "celeryd: Hitting Ctrl+C again will terminate "
@@ -293,9 +296,7 @@ def install_worker_int_handler(worker):
 def install_worker_int_again_handler(worker):
 
     def _stop(signum, frame):
-        process_name = None
-        if multiprocessing:
-            process_name = multiprocessing.current_process().name
+        process_name = get_process_name()
         if not process_name or process_name == "MainProcess":
             worker.logger.warn("celeryd: Cold shutdown (%s)" % (
                 process_name))
@@ -308,9 +309,7 @@ def install_worker_int_again_handler(worker):
 def install_worker_term_handler(worker):
 
     def _stop(signum, frame):
-        process_name = None
-        if multiprocessing:
-            process_name = multiprocessing.current_process().name
+        process_name = get_process_name()
         if not process_name or process_name == "MainProcess":
             worker.logger.warn("celeryd: Warm shutdown (%s)" % (
                 process_name))

+ 6 - 0
celery/tests/test_backends/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 from celery.tests.utils import unittest
 
 from celery import backends
@@ -20,3 +22,7 @@ class TestBackends(unittest.TestCase):
         self.assertIn("amqp", backends._backend_cache)
         amqp_backend = backends.get_backend_cls("amqp")
         self.assertIs(amqp_backend, backends._backend_cache["amqp"])
+
+    def test_unknown_backend(self):
+        with self.assertRaises(ValueError):
+            backends.get_backend_cls("fasodaopjeqijwqe")

+ 61 - 3
celery/tests/test_bin/test_celeryd.py

@@ -11,6 +11,7 @@ try:
 except ImportError:
     current_process = None  # noqa
 
+from mock import patch
 from nose import SkipTest
 from kombu.tests.utils import redirect_stdouts
 
@@ -22,13 +23,14 @@ from celery.apps import worker as cd
 from celery.bin.celeryd import WorkerCommand, windows_main, \
                                main as celeryd_main
 from celery.exceptions import ImproperlyConfigured
-from celery.utils import patch
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import AppCase, StringIO
+from celery.tests.utils import (AppCase, StringIO, mask_modules,
+                                reset_modules, patch_modules)
 
 
-patch.ensure_process_aware_logger()
+from celery.utils.patch import ensure_process_aware_logger
+ensure_process_aware_logger()
 
 
 def disable_stdouts(fun):
@@ -58,6 +60,39 @@ class Worker(cd.Worker):
     WorkController = _WorkController
 
 
+class test_compilation(AppCase):
+
+    def test_no_multiprocessing(self):
+        with mask_modules("multiprocessing"):
+            with reset_modules("celery.apps.worker"):
+                from celery.apps.worker import multiprocessing
+                self.assertIsNone(multiprocessing)
+
+    def test_cpu_count_no_mp(self):
+        with mask_modules("multiprocessing"):
+            with reset_modules("celery.apps.worker"):
+                from celery.apps.worker import cpu_count
+                self.assertEqual(cpu_count(), 2)
+
+    @patch("multiprocessing.cpu_count")
+    def test_no_cpu_count(self, pcount):
+        pcount.side_effect = NotImplementedError("cpu_count")
+        from celery.apps.worker import cpu_count
+        self.assertEqual(cpu_count(), 2)
+        pcount.assert_called_with()
+
+    def test_process_name(self):
+        with mask_modules("multiprocessing"):
+            with reset_modules("celery.apps.worker"):
+                from celery.apps.worker import get_process_name
+                self.assertIsNone(get_process_name())
+
+    @patch("multiprocessing.current_process")
+    def test_process_name(self, current_process):
+            from celery.apps.worker import get_process_name
+            self.assertTrue(get_process_name())
+
+
 class test_Worker(AppCase):
     Worker = Worker
 
@@ -69,6 +104,27 @@ class test_Worker(AppCase):
         self.assertEqual(worker.use_queues, ["foo", "bar", "baz"])
         self.assertTrue("foo" in celery.amqp.queues)
 
+    @disable_stdouts
+    def test_windows_B_option(self):
+        celery = Celery(set_as_current=False)
+        celery.IS_WINDOWS = True
+        with self.assertRaises(SystemExit):
+            celery.Worker(run_clockservice=True)
+
+    def test_tasklist(self):
+        celery = Celery(set_as_current=False)
+        worker = celery.Worker()
+        self.assertTrue(worker.tasklist(include_builtins=True))
+        worker.tasklist(include_builtins=False)
+
+    def test_extra_info(self):
+        celery = Celery(set_as_current=False)
+        worker = celery.Worker()
+        worker.loglevel = logging.WARNING
+        self.assertFalse(worker.extra_info())
+        worker.loglevel = logging.INFO
+        self.assertTrue(worker.extra_info())
+
     @disable_stdouts
     def test_loglevel_string(self):
         worker = self.Worker(loglevel="INFO")
@@ -110,6 +166,8 @@ class test_Worker(AppCase):
         self.assertTrue(worker.startup_info())
         worker.loglevel = logging.INFO
         self.assertTrue(worker.startup_info())
+        worker.autoscale = 13, 10
+        self.assertTrue(worker.startup_info())
 
     @disable_stdouts
     def test_run(self):

+ 60 - 6
celery/tests/test_worker/test_worker_control.py

@@ -8,7 +8,7 @@ from mock import Mock
 
 from celery.utils.timer2 import Timer
 
-from celery.app import app_or_default
+from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.registry import tasks
@@ -16,7 +16,9 @@ from celery.task import PingTask
 from celery.utils import gen_unique_id
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
+from celery.worker import state
 from celery.worker.state import revoked
+from celery.worker.control import builtins
 from celery.worker.control.registry import Panel
 
 hostname = socket.gethostname()
@@ -36,7 +38,7 @@ class Consumer(object):
                                          args=(2, 2),
                                          kwargs={}))
         self.eta_schedule = Timer()
-        self.app = app_or_default()
+        self.app = current_app
         self.event_dispatcher = Mock()
 
         from celery.concurrency.base import BasePool
@@ -50,7 +52,7 @@ class Consumer(object):
 class test_ControlPanel(unittest.TestCase):
 
     def setUp(self):
-        self.app = app_or_default()
+        self.app = current_app
         self.panel = self.create_panel(consumer=Consumer())
 
     def create_state(self, **kwargs):
@@ -71,7 +73,8 @@ class test_ControlPanel(unittest.TestCase):
         self.assertTrue(consumer.event_dispatcher.enable.call_count)
         self.assertIn(("worker-online", ),
                 consumer.event_dispatcher.send.call_args)
-        self.assertTrue(panel.handle("enable_events")["ok"])
+        consumer.event_dispatcher.enabled = True
+        self.assertIn("already enabled", panel.handle("enable_events")["ok"])
 
     def test_disable_events(self):
         consumer = Consumer()
@@ -81,7 +84,8 @@ class test_ControlPanel(unittest.TestCase):
         self.assertTrue(consumer.event_dispatcher.disable.call_count)
         self.assertIn(("worker-offline", ),
                       consumer.event_dispatcher.send.call_args)
-        self.assertTrue(panel.handle("disable_events")["ok"])
+        consumer.event_dispatcher.enabled = False
+        self.assertIn("already disabled", panel.handle("disable_events")["ok"])
 
     def test_heartbeat(self):
         consumer = Consumer()
@@ -91,6 +95,41 @@ class test_ControlPanel(unittest.TestCase):
         self.assertIn(("worker-heartbeat", ),
                       consumer.event_dispatcher.send.call_args)
 
+    def test_time_limit(self):
+        panel = self.create_panel(consumer=Mock())
+        th, ts = mytask.time_limit, mytask.soft_time_limit
+        try:
+            r = panel.handle("time_limit", arguments=dict(
+                task_name=mytask.name, hard=30, soft=10))
+            self.assertEqual((mytask.time_limit, mytask.soft_time_limit),
+                             (30, 10))
+            self.assertIn("ok", r)
+            r = panel.handle("time_limit", arguments=dict(
+                task_name=mytask.name, hard=None, soft=None))
+            self.assertEqual((mytask.time_limit, mytask.soft_time_limit),
+                             (None, None))
+            self.assertIn("ok", r)
+
+            r = panel.handle("time_limit", arguments=dict(
+                task_name="248e8afya9s8dh921eh928", hard=30))
+            self.assertIn("error", r)
+        finally:
+            mytask.time_limit, mytask.soft_time_limit = th, ts
+
+    def test_active_queues(self):
+        import kombu
+
+        x = kombu.Consumer(current_app.broker_connection(),
+                           [kombu.Queue("foo", kombu.Exchange("foo"), "foo"),
+                            kombu.Queue("bar", kombu.Exchange("bar"), "bar")],
+                           auto_declare=False)
+        consumer = Mock()
+        consumer.task_consumer = x
+        panel = self.create_panel(consumer=consumer)
+        r = panel.handle("active_queues")
+        self.assertListEqual(list(sorted(q["name"] for q in r)),
+                             ["bar", "foo"])
+
     def test_dump_tasks(self):
         info = "\n".join(self.panel.handle("dump_tasks"))
         self.assertIn("mytask", info)
@@ -200,7 +239,7 @@ class test_ControlPanel(unittest.TestCase):
         self.assertFalse(panel.handle("dump_reserved"))
 
     def test_rate_limit_when_disabled(self):
-        app = app_or_default()
+        app = current_app
         app.conf.CELERY_DISABLE_RATE_LIMITS = True
         try:
             e = self.panel.handle("rate_limit", arguments=dict(
@@ -285,6 +324,21 @@ class test_ControlPanel(unittest.TestCase):
         self.panel.dispatch_from_message(m)
         self.assertNotIn(uuid + "xxx", revoked)
 
+    def test_revoke_terminate(self):
+        request = Mock()
+        request.task_id = uuid = gen_unique_id()
+        state.active_requests.add(request)
+        try:
+            r = builtins.revoke(Mock(), uuid, terminate=True)
+            self.assertIn(uuid, revoked)
+            self.assertTrue(request.terminate.call_count)
+            self.assertIn("terminated", r["ok"])
+            # unknown task id only revokes
+            r = builtins.revoke(Mock(), gen_unique_id(), terminate=True)
+            self.assertIn("revoked", r["ok"])
+        finally:
+            state.active_requests.discard(request)
+
     def test_ping(self):
         m = {"method": "ping",
              "destination": hostname}

+ 24 - 0
celery/tests/utils.py

@@ -296,3 +296,27 @@ def pypy_version(value=None):
     yield
     if prev is not None:
         sys.pypy_version_info = prev
+
+
+@contextmanager
+def reset_modules(*modules):
+    prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
+    yield
+    sys.modules.update(prev)
+
+
+@contextmanager
+def patch_modules(*modules):
+    from types import ModuleType
+
+    prev = {}
+    for mod in modules:
+        prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
+    yield
+    for name, mod in prev.iteritems():
+        if mod is None:
+            sys.modules.pop(name, None)
+        else:
+            sys.modules[name] = mod
+
+

+ 1 - 1
celery/worker/control/builtins.py

@@ -19,7 +19,7 @@ def revoke(panel, task_id, terminate=False, signal=None, **kwargs):
     revoked.add(task_id)
     action = "revoked"
     if terminate:
-        signum = _signals.signum(signal)
+        signum = _signals.signum(signal or "TERM")
         for request in state.active_requests:
             if request.task_id == task_id:
                 action = "terminated (%s)" % (signum, )