Ask Solem 15 år sedan
förälder
incheckning
61b4d9f05e

+ 1 - 1
celery/loaders/__init__.py

@@ -27,7 +27,7 @@ def get_loader_cls(loader):
     return _loader_cache[loader]
 
 
-def _detect_loader():
+def _detect_loader(): # pragma: no cover
     loader = os.environ.get("CELERY_LOADER")
     if loader:
         return get_loader_cls(loader)

+ 2 - 2
celery/loaders/base.py

@@ -27,11 +27,11 @@ class BaseLoader(object):
         pass
 
     def import_task_module(self, module):
-        __import__(module, [], [], [''])
+        return __import__(module, [], [], [''])
 
     def import_default_modules(self):
         imports = getattr(self.conf, "CELERY_IMPORTS", [])
-        map(self.import_task_module, imports)
+        return map(self.import_task_module, imports)
 
     @property
     def conf(self):

+ 1 - 4
celery/loaders/djangoapp.py

@@ -85,7 +85,4 @@ def find_related_module(app, related_name):
 
     module = importlib.import_module("%s.%s" % (app, related_name))
 
-    try:
-        return getattr(module, related_name)
-    except AttributeError:
-        return
+    return getattr(module, related_name)

+ 1 - 1
celery/log.py

@@ -107,7 +107,7 @@ class LoggingProxy(object):
         ``sys.__stderr__`` instead of ``sys.stderr`` to circumvent
         infinite loops."""
 
-        def wrap_handler(handler):
+        def wrap_handler(handler): # pragma: no cover
 
             class WithSafeHandleError(logging.Handler):
 

+ 1 - 1
celery/result.py

@@ -252,7 +252,7 @@ class TaskSetResult(object):
         :raises: The exception if any of the tasks raised an exception.
 
         """
-        results = dict((subtask.task_id, AsyncResult(subtask.task_id))
+        results = dict((subtask.task_id, subtask.__class__(subtask.task_id))
                             for subtask in self.subtasks)
         while results:
             for task_id, pending_result in results.items():

+ 87 - 0
celery/tests/test_loaders.py

@@ -0,0 +1,87 @@
+import os
+import unittest
+
+from billiard.utils.functional import wraps
+
+from celery import loaders
+from celery.loaders import base
+from celery.loaders import djangoapp
+from celery.tests.utils import with_environ
+
+
+
+class TestLoaders(unittest.TestCase):
+
+    def test_get_loader_cls(self):
+
+        self.assertEquals(loaders.get_loader_cls("django"),
+                          loaders.DjangoLoader)
+        self.assertEquals(loaders.get_loader_cls("default"),
+                          loaders.DefaultLoader)
+        # Execute cached branch.
+        self.assertEquals(loaders.get_loader_cls("django"),
+                          loaders.DjangoLoader)
+        self.assertEquals(loaders.get_loader_cls("default"),
+                          loaders.DefaultLoader)
+
+    @with_environ("CELERY_LOADER", "default")
+    def test_detect_loader_CELERY_LOADER(self):
+        self.assertEquals(loaders._detect_loader(), loaders.DefaultLoader)
+
+
+class DummyLoader(base.BaseLoader):
+
+    class Config(object):
+
+        def __init__(self, **kwargs):
+            for attr, val in kwargs.items():
+                setattr(self, attr, val)
+
+    def read_configuration(self):
+        return self.Config(foo="bar", CELERY_IMPORTS=("os", "sys"))
+
+
+class TestLoaderBase(unittest.TestCase):
+
+    def setUp(self):
+        self.loader = DummyLoader()
+
+    def test_handlers_pass(self):
+        self.loader.on_task_init("foo.task", "feedface-cafebabe")
+        self.loader.on_worker_init()
+
+    def test_import_task_module(self):
+        import sys
+        self.assertEquals(sys, self.loader.import_task_module("sys"))
+
+    def test_conf_property(self):
+        self.assertEquals(self.loader.conf.foo, "bar")
+        self.assertEquals(self.loader._conf_cache.foo, "bar")
+        self.assertEquals(self.loader.conf.foo, "bar")
+
+    def test_import_default_modules(self):
+        import os
+        import sys
+        self.assertEquals(self.loader.import_default_modules(), [os, sys])
+
+
+class TestDjangoLoader(unittest.TestCase):
+
+    def setUp(self):
+        self.loader = loaders.DjangoLoader()
+
+    def test_on_worker_init(self):
+        self.assertRaises(ImportError, self.loader.on_worker_init)
+
+    def test_race_protection(self):
+        djangoapp._RACE_PROTECTION = True
+        try:
+            self.assertFalse(self.loader.on_worker_init())
+        finally:
+            djangoapp._RACE_PROTECTION = False
+
+    def test_find_related_module_no_path(self):
+        self.assertFalse(djangoapp.find_related_module("sys", "tasks"))
+
+    def test_find_related_module_no_related(self):
+        self.assertFalse(djangoapp.find_related_module("someapp", "frobulators"))

+ 51 - 12
celery/tests/test_log.py

@@ -6,27 +6,35 @@ import logging
 import unittest
 from tempfile import mktemp
 from StringIO import StringIO
+from contextlib import contextmanager
 
 from carrot.utils import rpartition
 
-from celery.log import setup_logger, emergency_error
+from celery.log import (setup_logger, emergency_error,
+                        redirect_stdouts_to_logger, LoggingProxy)
 from celery.tests.utils import override_stdouts
 
 
+@contextmanager
+def wrap_logger(logger, loglevel=logging.ERROR):
+    old_handlers = logger.handlers
+    sio = StringIO()
+    siohandler = logging.StreamHandler(sio)
+    logger.handlers = [siohandler]
+
+    yield sio
+
+    logger.handlers = old_handlers
+
+
+
 class TestLog(unittest.TestCase):
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
-        # Save old handlers
-        old_handler = logger.handlers[0]
-        logger.removeHandler(old_handler)
-        sio = StringIO()
-        siohandler = logging.StreamHandler(sio)
-        logger.addHandler(siohandler)
-        logger.log(loglevel, logmsg)
-        logger.removeHandler(siohandler)
-        # Reset original handlers
-        logger.addHandler(old_handler)
-        return sio.getvalue().strip()
+
+        with wrap_logger(logger, loglevel=loglevel) as sio:
+            logger.log(loglevel, logmsg)
+            return sio.getvalue().strip()
 
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
@@ -88,3 +96,34 @@ class TestLog(unittest.TestCase):
         with open(tempfile, "r") as tempfilefh:
             self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
         os.unlink(tempfile)
+
+    def test_redirect_stdouts(self):
+        logger = setup_logger(loglevel=logging.ERROR, logfile=None)
+        try:
+            with wrap_logger(logger) as sio:
+                redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
+                logger.error("foo")
+                self.assertTrue("foo" in sio.getvalue())
+        finally:
+            sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
+
+    def test_logging_proxy(self):
+        logger = setup_logger(loglevel=logging.ERROR, logfile=None)
+        with wrap_logger(logger) as sio:
+            p = LoggingProxy(logger)
+            p.close()
+            p.write("foo")
+            self.assertTrue("foo" not in sio.getvalue())
+            p.closed = False
+            p.write("foo")
+            self.assertTrue("foo" in sio.getvalue())
+            lines = ["baz", "xuzzy"]
+            p.writelines(lines)
+            for line in lines:
+                self.assertTrue(line in sio.getvalue())
+            p.flush()
+            p.close()
+            self.assertFalse(p.isatty())
+            self.assertTrue(p.fileno() is None)
+
+

+ 75 - 2
celery/tests/test_result.py

@@ -5,6 +5,7 @@ from celery.tests.utils import skip_if_quick
 from celery.result import AsyncResult, TaskSetResult
 from celery.backends import default_backend
 from celery.exceptions import TimeoutError
+from celery.task.base import Task
 
 
 def mock_task(name, status, result):
@@ -12,12 +13,15 @@ def mock_task(name, status, result):
 
 
 def save_result(task):
+    traceback = "Some traceback"
     if task["status"] == "SUCCESS":
         default_backend.mark_as_done(task["id"], task["result"])
     elif task["status"] == "RETRY":
-        default_backend.mark_as_retry(task["id"], task["result"])
+        default_backend.mark_as_retry(task["id"], task["result"],
+                traceback=traceback)
     else:
-        default_backend.mark_as_failure(task["id"], task["result"])
+        default_backend.mark_as_failure(task["id"], task["result"],
+                traceback=traceback)
 
 
 def make_mock_taskset(size=10):
@@ -65,6 +69,15 @@ class TestAsyncResult(unittest.TestCase):
         self.assertEquals(repr(nok_res), "<AsyncResult: %s>" % (
                 self.task3["id"]))
 
+    def test_get_traceback(self):
+        ok_res = AsyncResult(self.task1["id"])
+        nok_res = AsyncResult(self.task3["id"])
+        nok_res2 = AsyncResult(self.task4["id"])
+        self.assertFalse(ok_res.traceback)
+        self.assertTrue(nok_res.traceback)
+        self.assertTrue(nok_res2.traceback)
+
+
     def test_get(self):
         ok_res = AsyncResult(self.task1["id"])
         ok2_res = AsyncResult(self.task2["id"])
@@ -93,6 +106,28 @@ class TestAsyncResult(unittest.TestCase):
         self.assertFalse(AsyncResult(self.task4["id"]).ready())
 
 
+class MockAsyncResultFailure(AsyncResult):
+
+    @property
+    def result(self):
+        return KeyError("baz")
+
+    @property
+    def status(self):
+        return "FAILURE"
+
+
+class MockAsyncResultSuccess(AsyncResult):
+
+    @property
+    def result(self):
+        return 42
+
+    @property
+    def status(self):
+        return "SUCCESS"
+
+
 class TestTaskSetResult(unittest.TestCase):
 
     def setUp(self):
@@ -102,6 +137,27 @@ class TestTaskSetResult(unittest.TestCase):
     def test_total(self):
         self.assertEquals(self.ts.total, self.size)
 
+    def test_iterate_raises(self):
+        ar = MockAsyncResultFailure(gen_unique_id())
+        ts = TaskSetResult(gen_unique_id(), [ar])
+        it = iter(ts)
+        self.assertRaises(KeyError, it.next)
+
+    def test_iterate_yields(self):
+        ar = MockAsyncResultSuccess(gen_unique_id())
+        ar2 = MockAsyncResultSuccess(gen_unique_id())
+        ts = TaskSetResult(gen_unique_id(), [ar, ar2])
+        it = iter(ts)
+        self.assertEquals(it.next(), 42)
+        self.assertEquals(it.next(), 42)
+
+    def test_join_timeout(self):
+        ar = MockAsyncResultSuccess(gen_unique_id())
+        ar2 = MockAsyncResultSuccess(gen_unique_id())
+        ar3 = AsyncResult(gen_unique_id())
+        ts = TaskSetResult(gen_unique_id(), [ar, ar2, ar3])
+        self.assertRaises(TimeoutError, ts.join, timeout=0.0000001)
+
     def test_itersubtasks(self):
 
         it = self.ts.itersubtasks()
@@ -207,3 +263,20 @@ class TestTaskSetPending(unittest.TestCase):
     @skip_if_quick
     def x_join_longer(self):
         self.assertRaises(TimeoutError, self.ts.join, timeout=1)
+
+
+class RaisingTask(Task):
+
+    def run(self, x, y):
+        raise KeyError("xy")
+
+
+class TestEagerResult(unittest.TestCase):
+
+    def test_wait_raises(self):
+        res = RaisingTask.apply(args=[3, 3])
+        self.assertRaises(KeyError, res.wait)
+
+    def test_revoke(self):
+        res = RaisingTask.apply(args=[3, 3])
+        self.assertFalse(res.revoke())

+ 15 - 0
celery/tests/test_task_control.py

@@ -2,6 +2,7 @@ import unittest
 
 from celery.task import control
 from celery.task.builtins import PingTask
+from celery.utils import gen_unique_id
 
 
 class MockBroadcastPublisher(object):
@@ -46,3 +47,17 @@ class TestBroadcast(unittest.TestCase):
     def test_revoke(self):
         control.revoke("foozbaaz")
         self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+
+    @with_mock_broadcast
+    def test_revoke_from_result(self):
+        from celery.result import AsyncResult
+        AsyncResult("foozbazzbar").revoke()
+        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+
+    @with_mock_broadcast
+    def test_revoke_from_resultset(self):
+        from celery.result import TaskSetResult, AsyncResult
+        r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
+                                                            for i in range(10)]))
+        r.revoke()
+        self.assertTrue("revoke" in MockBroadcastPublisher.sent)

+ 1 - 15
celery/tests/test_utils.py

@@ -5,6 +5,7 @@ import unittest
 from billiard.utils.functional import wraps
 
 from celery import utils
+from celery.tests.utils import sleepdeprived
 
 
 class TestChunks(unittest.TestCase):
@@ -54,21 +55,6 @@ class TestDivUtils(unittest.TestCase):
             self.assertEquals(it.next(), i)
 
 
-def sleepdeprived(fun):
-
-    @wraps(fun)
-    def _sleepdeprived(*args, **kwargs):
-        import time
-        old_sleep = time.sleep
-        time.sleep = utils.noop
-        try:
-            return fun(*args, **kwargs)
-        finally:
-            time.sleep = old_sleep
-
-    return _sleepdeprived
-
-
 class TestRetryOverTime(unittest.TestCase):
 
     def test_returns_retval_on_success(self):

+ 36 - 1
celery/tests/utils.py

@@ -1,12 +1,47 @@
 from __future__ import with_statement
 
-import sys
 import os
+import sys
 import __builtin__
 from StringIO import StringIO
 from functools import wraps
 from contextlib import contextmanager
 
+from celery.utils import noop
+
+
+def with_environ(env_name, env_value):
+
+    def _envpatched(fun):
+
+        @wraps(fun)
+        def _patch_environ(*args, **kwargs):
+            prev_val = os.environ.get(env_name)
+            os.environ[env_name] = env_value
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                if prev_val is not None:
+                    os.environ[env_name] = prev_val
+
+        return _patch_environ
+    return _envpatched
+
+
+def sleepdeprived(fun):
+
+    @wraps(fun)
+    def _sleepdeprived(*args, **kwargs):
+        import time
+        old_sleep = time.sleep
+        time.sleep = noop
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            time.sleep = old_sleep
+
+    return _sleepdeprived
+
 
 def skip_if_environ(env_var_name):