Browse Source

87% coverage.

Ask Solem 15 years ago
parent
commit
7cab9e4c21

+ 48 - 0
celery/tests/test_task_control.py

@@ -0,0 +1,48 @@
+import unittest
+
+from celery.task import control
+from celery.task.builtins import PingTask
+
+
+class MockBroadcastPublisher(object):
+    sent = []
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def send(self, command, *args, **kwargs):
+        self.__class__.sent.append(command)
+
+    def close(self):
+        pass
+
+
+def with_mock_broadcast(fun):
+
+    def _mocked(*args, **kwargs):
+        old_pub = control.BroadcastPublisher
+        control.BroadcastPublisher = MockBroadcastPublisher
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            MockBroadcastPublisher.sent = []
+            control.BroadcastPublisher = old_pub
+    return _mocked
+
+
+class TestBroadcast(unittest.TestCase):
+
+    @with_mock_broadcast
+    def test_broadcast(self):
+        control.broadcast("foobarbaz", arguments=[])
+        self.assertTrue("foobarbaz" in MockBroadcastPublisher.sent)
+
+    @with_mock_broadcast
+    def test_rate_limit(self):
+        control.rate_limit(PingTask.name, "100/m")
+        self.assertTrue("rate_limit" in MockBroadcastPublisher.sent)
+
+    @with_mock_broadcast
+    def test_revoke(self):
+        control.revoke("foozbaaz")
+        self.assertTrue("revoke" in MockBroadcastPublisher.sent)

+ 58 - 0
celery/tests/test_worker_control.py

@@ -0,0 +1,58 @@
+import socket
+import unittest
+
+from celery.task.builtins import PingTask
+from celery.utils import gen_unique_id
+from celery.worker import control
+from celery.worker.revoke import revoked
+from celery.registry import tasks
+
+hostname = socket.gethostname()
+
+class TestControlPanel(unittest.TestCase):
+
+    def setUp(self):
+        self.panel = control.ControlDispatch(hostname=hostname)
+
+    def test_shutdown(self):
+        self.assertRaises(SystemExit, self.panel.execute, "shutdown")
+
+    def test_dump_tasks(self):
+        self.panel.execute("dump_tasks")
+
+    def test_rate_limit(self):
+        task = tasks[PingTask.name]
+        old_rate_limit = task.rate_limit
+        try:
+            self.panel.execute("rate_limit", kwargs=dict(
+                                                task_name=task.name,
+                                                rate_limit="100/m"))
+            self.assertEquals(task.rate_limit, "100/m")
+            self.panel.execute("rate_limit", kwargs=dict(
+                                                task_name=task.name,
+                                                rate_limit=0))
+            self.assertEquals(task.rate_limit, 0)
+        finally:
+            task.rate_limit = old_rate_limit
+
+    def test_rate_limit_nonexistant_task(self):
+        self.panel.execute("rate_limit", kwargs={
+                                "task_name": "xxxx.does.not.exist",
+                                "rate_limit": "1000/s"})
+
+    def test_unexposed_command(self):
+        self.panel.execute("foo", kwargs={})
+
+    def test_revoke(self):
+        uuid = gen_unique_id()
+        m = {"command": "revoke",
+             "destination": hostname,
+             "task_id": uuid}
+        self.panel.dispatch_from_message(m)
+        self.assertTrue(uuid in revoked)
+
+        m = {"command": "revoke",
+             "destination": "does.not.exist",
+             "task_id": uuid + "xxx"}
+        self.panel.dispatch_from_message(m)
+        self.assertTrue(uuid + "xxx" not in revoked)

+ 115 - 0
celery/tests/test_worker_job.py

@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+from __future__ import with_statement
+
 import sys
 import logging
 import unittest
@@ -156,12 +158,125 @@ class TestJail(unittest.TestCase):
             del(cache.parse_backend_uri)
 
 
+class MockEventDispatcher(object):
+
+    def __init__(self):
+        self.sent = []
+
+    def send(self, event):
+        self.sent.append(event)
+
+
 class TestTaskWrapper(unittest.TestCase):
 
     def test_task_wrapper_repr(self):
         tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
         self.assertTrue(repr(tw))
 
+    def test_send_event(self):
+        tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.eventer = MockEventDispatcher()
+        tw.send_event("task-frobulated")
+        self.assertTrue("task-frobulated" in tw.eventer.sent)
+
+    def test_send_email(self):
+        from celery import conf
+        from celery.worker import job
+        old_mail_admins = job.mail_admins
+        old_enable_mails = conf.CELERY_SEND_TASK_ERROR_EMAILS
+        mail_sent = [False]
+
+        def mock_mail_admins(*args, **kwargs):
+            mail_sent[0] = True
+
+        job.mail_admins = mock_mail_admins
+        conf.CELERY_SEND_TASK_ERROR_EMAILS = True
+        try:
+            tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
+            try:
+                raise KeyError("foo")
+            except KeyError, exc:
+                einfo = ExceptionInfo(sys.exc_info())
+
+            tw.on_failure(einfo)
+            self.assertTrue(mail_sent[0])
+
+            mail_sent[0] = False
+            conf.CELERY_SEND_TASK_ERROR_EMAILS = False
+            tw.on_failure(einfo)
+            self.assertFalse(mail_sent[0])
+
+        finally:
+            job.mail_admins = old_mail_admins
+            conf.CELERY_SEND_TASK_ERROR_EMAILS = old_enable_mails
+
+    def test_execute_and_trace(self):
+        from celery.worker.job import execute_and_trace
+        res = execute_and_trace(mytask.name, gen_unique_id(), [4], {})
+        self.assertEquals(res, 4 ** 4)
+
+    def test_execute_safe_catches_exception(self):
+        from celery.worker.job import execute_and_trace, WorkerTaskTrace
+        old_exec = WorkerTaskTrace.execute
+
+        def _error_exec(self, *args, **kwargs):
+            raise KeyError("baz")
+
+        WorkerTaskTrace.execute = _error_exec
+        try:
+            import warnings
+            with warnings.catch_warnings(record=True) as log:
+                res = execute_and_trace(mytask.name, gen_unique_id(),
+                                        [4], {})
+                self.assertTrue(isinstance(res, ExceptionInfo))
+                self.assertTrue(log)
+                self.assertTrue("Exception outside" in log[0].message.args[0])
+                self.assertTrue("KeyError" in log[0].message.args[0])
+        finally:
+            WorkerTaskTrace.execute = old_exec
+
+    def create_exception(self, exc):
+        try:
+            raise exc
+        except exc.__class__, thrown:
+            return sys.exc_info()
+
+    def test_worker_task_trace_handle_retry(self):
+        from celery.exceptions import RetryTaskError
+        uuid = gen_unique_id()
+        w = WorkerTaskTrace(mytask.name, uuid, [4], {})
+        type_, value_, tb_ = self.create_exception(ValueError("foo"))
+        type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
+                                                                  exc=value_))
+        w._store_errors = False
+        w.handle_retry(value_, type_, tb_, "")
+        self.assertEquals(mytask.backend.get_status(uuid), "PENDING")
+        w._store_errors = True
+        w.handle_retry(value_, type_, tb_, "")
+        self.assertEquals(mytask.backend.get_status(uuid), "RETRY")
+
+
+    def test_worker_task_trace_handle_failure(self):
+        from celery.worker.job import WorkerTaskTrace
+        uuid = gen_unique_id()
+        w = WorkerTaskTrace(mytask.name, uuid, [4], {})
+        type_, value_, tb_ = self.create_exception(ValueError("foo"))
+        w._store_errors = False
+        w.handle_failure(value_, type_, tb_, "")
+        self.assertEquals(mytask.backend.get_status(uuid), "PENDING")
+        w._store_errors = True
+        w.handle_failure(value_, type_, tb_, "")
+        self.assertEquals(mytask.backend.get_status(uuid), "FAILURE")
+
+
+    def test_executed_bit(self):
+        from celery.worker.job import AlreadyExecutedError
+        tw = TaskWrapper(mytask.name, gen_unique_id(), [], {})
+        self.assertFalse(tw.executed)
+        tw._set_executed_bit()
+        self.assertTrue(tw.executed)
+        self.assertRaises(AlreadyExecutedError, tw._set_executed_bit)
+
     def test_task_wrapper_mail_attrs(self):
         tw = TaskWrapper(mytask.name, gen_unique_id(), [], {})
         x = tw.success_msg % {"name": tw.task_name,

+ 2 - 70
celery/utils/__init__.py

@@ -19,6 +19,8 @@ from itertools import repeat
 
 from billiard.utils.functional import curry
 
+from celery.utils.compat import all, any, defaultdict
+
 noop = lambda *args, **kwargs: None
 
 
@@ -159,73 +161,3 @@ def fun_takes_kwargs(fun, kwlist=[]):
     if keywords != None:
         return kwlist
     return filter(curry(operator.contains, args), kwlist)
-
-
-try:
-    from collections import defaultdict
-except ImportError:
-    # Written by Jason Kirtland, taken from Python Cookbook:
-    # <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/523034>
-    class defaultdict(dict):
-
-        def __init__(self, default_factory=None, *args, **kwargs):
-            dict.__init__(self, *args, **kwargs)
-            self.default_factory = default_factory
-
-        def __getitem__(self, key):
-            try:
-                return dict.__getitem__(self, key)
-            except KeyError:
-                return self.__missing__(key)
-
-        def __missing__(self, key):
-            if self.default_factory is None:
-                raise KeyError(key)
-            self[key] = value = self.default_factory()
-            return value
-
-        def __reduce__(self):
-            f = self.default_factory
-            args = f is None and tuple() or f
-            return type(self), args, None, None, self.iteritems()
-
-        def copy(self):
-            return self.__copy__()
-
-        def __copy__(self):
-            return type(self)(self.default_factory, self)
-
-        def __deepcopy__(self):
-            import copy
-            return type(self)(self.default_factory,
-                        copy.deepcopy(self.items()))
-
-        def __repr__(self):
-            return "defaultdict(%s, %s)" % (self.default_factory,
-                                            dict.__repr__(self))
-    import collections
-    collections.defaultdict = defaultdict # Pickle needs this.
-
-
-try:
-    all([True])
-    all = all
-except NameError:
-    def all(iterable):
-        for item in iterable:
-            if not item:
-                return False
-        return True
-
-
-try:
-    any([True])
-    any = any
-except NameError:
-    def any(iterable):
-        for item in iterable:
-            if item:
-                return True
-        return False
-
-

+ 66 - 0
celery/utils/compat.py

@@ -0,0 +1,66 @@
+try:
+    from collections import defaultdict
+except ImportError:
+    # Written by Jason Kirtland, taken from Python Cookbook:
+    # <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/523034>
+    class defaultdict(dict):
+
+        def __init__(self, default_factory=None, *args, **kwargs):
+            dict.__init__(self, *args, **kwargs)
+            self.default_factory = default_factory
+
+        def __getitem__(self, key):
+            try:
+                return dict.__getitem__(self, key)
+            except KeyError:
+                return self.__missing__(key)
+
+        def __missing__(self, key):
+            if self.default_factory is None:
+                raise KeyError(key)
+            self[key] = value = self.default_factory()
+            return value
+
+        def __reduce__(self):
+            f = self.default_factory
+            args = f is None and tuple() or f
+            return type(self), args, None, None, self.iteritems()
+
+        def copy(self):
+            return self.__copy__()
+
+        def __copy__(self):
+            return type(self)(self.default_factory, self)
+
+        def __deepcopy__(self):
+            import copy
+            return type(self)(self.default_factory,
+                        copy.deepcopy(self.items()))
+
+        def __repr__(self):
+            return "defaultdict(%s, %s)" % (self.default_factory,
+                                            dict.__repr__(self))
+    import collections
+    collections.defaultdict = defaultdict # Pickle needs this.
+
+
+try:
+    all([True])
+    all = all
+except NameError:
+    def all(iterable):
+        for item in iterable:
+            if not item:
+                return False
+        return True
+
+
+try:
+    any([True])
+    any = any
+except NameError:
+    def any(iterable):
+        for item in iterable:
+            if item:
+                return True
+        return False

+ 3 - 2
celery/worker/control.py

@@ -82,7 +82,7 @@ class ControlDispatch(object):
 
     panel_cls = Control
 
-    def __init__(self, logger, hostname=None):
+    def __init__(self, logger=None, hostname=None):
         self.logger = logger or log.get_default_logger()
         self.hostname = hostname
         self.panel = self.panel_cls(self.logger, hostname=self.hostname)
@@ -104,13 +104,14 @@ class ControlDispatch(object):
         if not destination or self.hostname in destination:
             return self.execute(command, message)
 
-    def execute(self, command, kwargs):
+    def execute(self, command, kwargs=None):
         """Execute control command by name and keyword arguments.
 
         :param command: Name of the command to execute.
         :param kwargs: Keyword arguments.
 
         """
+        kwargs = kwargs or {}
         control = None
         try:
             control = getattr(self.panel, command)

+ 3 - 3
celery/worker/job.py

@@ -82,9 +82,9 @@ class WorkerTaskTrace(TaskTrace):
         try:
             return self.execute(*args, **kwargs)
         except Exception, exc:
-            exc_info = sys.exc_info()
-            exc_info[1] = self.task_backend.prepare_exception(exc)
-            exc_info = ExceptionInfo(exc_info)
+            _type, _value, _tb = sys.exc_info()
+            _value = self.task.backend.prepare_exception(exc)
+            exc_info = ExceptionInfo((_type, _value, _tb))
             warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
                 map(str, (exc.__class__, exc, exc_info.traceback))))
             return exc_info

+ 1 - 0
testproj/settings.py

@@ -27,6 +27,7 @@ COVERAGE_EXCLUDE_MODULES = ("celery.__init__",
                             "celery.contrib.*",
                             "celery.bin.*",
                             "celery.utils.patch",
+                            "celery.utils.compat",
                             "celery.task.rest",
                             "celery.platform", # FIXME
                             "celery.loaders.default", # FIXME