Browse Source

97% Coverage

Ask Solem 13 years ago
parent
commit
80d5249517

+ 1 - 1
celery/bin/celeryd.py

@@ -75,7 +75,7 @@ import sys
 
 try:
     from multiprocessing import freeze_support
-except ImportError:
+except ImportError:  # pragma: no cover
     freeze_support = lambda: True  # noqa
 
 from celery.bin.base import Command, Option

+ 1 - 1
celery/concurrency/processes/__init__.py

@@ -11,7 +11,7 @@ from os import kill as _kill
 from celery.concurrency.base import BasePool
 from celery.concurrency.processes.pool import Pool, RUN
 
-if platform.system() == "Windows":
+if platform.system() == "Windows":  # pragma: no cover
     # On Windows os.kill calls TerminateProcess which cannot be
     # handled by # any process, so this is needed to terminate the task
     # *and its children* (if any).

+ 59 - 1
celery/tests/test_backends/test_base.py

@@ -1,3 +1,5 @@
+from __future__ import with_statement
+
 import sys
 import types
 
@@ -12,7 +14,7 @@ from celery.utils.serialization import get_pickleable_exception as gpe
 
 from celery import states
 from celery.backends.base import BaseBackend, KeyValueStoreBackend
-from celery.backends.base import BaseDictBackend
+from celery.backends.base import BaseDictBackend, DisabledBackend
 from celery.utils import gen_unique_id
 
 from celery.tests.utils import unittest
@@ -54,6 +56,10 @@ class test_BaseBackend_interface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
                 b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
 
+    def test_mark_as_started(self):
+        self.assertRaises(NotImplementedError,
+                b.mark_as_started, "SOMExx-N0nex1stant-IDxx-")
+
     def test_reload_task_result(self):
         self.assertRaises(NotImplementedError,
                 b.reload_task_result, "SOMExx-N0nex1stant-IDxx-")
@@ -86,6 +92,15 @@ class test_BaseBackend_interface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
                 b.forget, "SOMExx-N0nex1stant-IDxx-")
 
+    def test_on_chord_apply(self, unlock="celery.chord_unlock"):
+        from celery.registry import tasks
+        p, tasks[unlock] = tasks.get(unlock), Mock()
+        try:
+            b.on_chord_apply("dakj221", "sdokqweok")
+            self.assertTrue(tasks[unlock].apply_async.call_count)
+        finally:
+            tasks[unlock] = p
+
 
 class test_exception_pickle(unittest.TestCase):
 
@@ -130,6 +145,7 @@ class test_prepare_exception(unittest.TestCase):
 
 
 class KVBackend(KeyValueStoreBackend):
+    mget_returns_dict = False
 
     def __init__(self, *args, **kwargs):
         self.db = {}
@@ -141,6 +157,12 @@ class KVBackend(KeyValueStoreBackend):
     def set(self, key, value):
         self.db[key] = value
 
+    def mget(self, keys):
+        if self.mget_returns_dict:
+            return dict((key, self.get(key)) for key in keys)
+        else:
+            return [self.get(key) for key in keys]
+
     def delete(self, key):
         self.db.pop(key, None)
 
@@ -178,6 +200,11 @@ class test_BaseDictBackend(unittest.TestCase):
         b.save_taskset("foofoo", "xxx")
         b._save_taskset.assert_called_with("foofoo", "xxx")
 
+    def test_forget_interface(self):
+        b = BaseDictBackend()
+        with self.assertRaises(NotImplementedError):
+            b.forget("foo")
+
     def test_restore_taskset(self):
         self.assertIsNone(self.b.restore_taskset("missing"))
         self.assertIsNone(self.b.restore_taskset("missing"))
@@ -210,6 +237,23 @@ class test_KeyValueStoreBackend(unittest.TestCase):
         self.b.forget(tid)
         self.assertEqual(self.b.get_status(tid), states.PENDING)
 
+    def test_strip_prefix(self):
+        x = self.b.get_key_for_task("x1b34")
+        self.assertEqual(self.b._strip_prefix(x), "x1b34")
+        self.assertEqual(self.b._strip_prefix("x1b34"), "x1b34")
+
+    def test_get_many(self):
+        for is_dict in True, False:
+            self.b.mget_returns_dict = is_dict
+            ids = dict((gen_unique_id(), i) for i in xrange(10))
+            for id, i in ids.items():
+                self.b.mark_as_done(id, i)
+            it = self.b.get_many(ids.keys())
+            for i, (got_id, got_state) in enumerate(it):
+                self.assertEqual(got_state["result"], ids[got_id])
+            self.assertEqual(i, 9)
+            self.assertTrue(list(self.b.get_many(ids.keys())))
+
     def test_get_missing_meta(self):
         self.assertIsNone(self.b.get_result("xxx-missing"))
         self.assertEqual(self.b.get_status("xxx-missing"), states.PENDING)
@@ -242,6 +286,20 @@ class test_KeyValueStoreBackend_interface(unittest.TestCase):
         self.assertRaises(NotImplementedError, KeyValueStoreBackend().delete,
                 "a")
 
+    def test_mget(self):
+        self.assertRaises(NotImplementedError, KeyValueStoreBackend().mget,
+                ["a"])
+
     def test_forget(self):
         self.assertRaises(NotImplementedError, KeyValueStoreBackend().forget,
                 "a")
+
+
+class test_DisabledBackend(unittest.TestCase):
+
+    def test_store_result(self):
+        DisabledBackend().store_result()
+
+    def test_is_disabled(self):
+        with self.assertRaises(NotImplementedError):
+            DisabledBackend().get_status("foo")

+ 51 - 41
celery/tests/test_backends/test_cache.py

@@ -1,6 +1,9 @@
+from __future__ import with_statement
+
 import sys
 import types
-from celery.tests.utils import unittest
+
+from contextlib import contextmanager
 
 from celery import states
 from celery.backends.cache import CacheBackend, DummyClient
@@ -8,7 +11,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.result import AsyncResult
 from celery.utils import gen_unique_id
 
-from celery.tests.utils import mask_modules
+from celery.tests.utils import unittest, mask_modules, reset_modules
 
 
 class SomeClass(object):
@@ -54,6 +57,14 @@ class test_CacheBackend(unittest.TestCase):
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertIsInstance(tb.get_result(tid3), KeyError)
 
+    def test_mget(self):
+        tb = CacheBackend(backend="memory://")
+        tb.set("foo", 1)
+        tb.set("bar", 2)
+
+        self.assertDictEqual(tb.mget(["foo", "bar"]),
+                             {"foo": 1, "bar": 2})
+
     def test_forget(self):
         tb = CacheBackend(backend="memory://")
         tid = gen_unique_id()
@@ -81,6 +92,7 @@ class MyClient(DummyClient):
 
 class test_get_best_memcache(unittest.TestCase):
 
+    @contextmanager
     def mock_memcache(self):
         memcache = types.ModuleType("memcache")
         memcache.Client = MyClient
@@ -89,8 +101,8 @@ class test_get_best_memcache(unittest.TestCase):
         yield True
         if prev is not None:
             sys.modules["memcache"] = prev
-        yield True
 
+    @contextmanager
     def mock_pylibmc(self):
         pylibmc = types.ModuleType("pylibmc")
         pylibmc.Client = MyClient
@@ -100,43 +112,41 @@ class test_get_best_memcache(unittest.TestCase):
         yield True
         if prev is not None:
             sys.modules["pylibmc"] = prev
-        yield True
 
     def test_pylibmc(self):
-        pylibmc = self.mock_pylibmc()
-        pylibmc.next()
-        from celery.backends import cache
-        cache._imp = [None]
-        self.assertEqual(cache.get_best_memcache().__module__, "pylibmc")
-        pylibmc.next()
-
-    def xxx_memcache(self):
-
-        def with_no_pylibmc():
-            from celery.backends import cache
-            cache._imp = [None]
-            self.assertEqual(cache.get_best_memcache().__module__, "memcache")
-
-        context = mask_modules("pylibmc")
-        context.__enter__()
-        try:
-            memcache = self.mock_memcache()
-            memcache.next()
-            with_no_pylibmc()
-            memcache.next()
-        finally:
-            context.__exit__(None, None, None)
-
-    def xxx_no_implementations(self):
-
-        def with_no_memcache_libs():
-            from celery.backends import cache
-            cache._imp = [None]
-            self.assertRaises(ImproperlyConfigured, cache.get_best_memcache)
-
-        context = mask_modules("pylibmc", "memcache")
-        context.__enter__()
-        try:
-            with_no_memcache_libs()
-        finally:
-            context.__exit__(None, None, None)
+        with reset_modules("celery.backends.cache"):
+            with self.mock_pylibmc():
+                from celery.backends import cache
+                cache._imp = [None]
+                self.assertEqual(cache.get_best_memcache().__module__,
+                                 "pylibmc")
+
+    def test_memcache(self):
+        with self.mock_memcache():
+            with reset_modules("celery.backends.cache"):
+                with mask_modules("pylibmc"):
+                    from celery.backends import cache
+                    cache._imp = [None]
+                    self.assertEqual(cache.get_best_memcache().__module__,
+                                     "memcache")
+
+    def test_no_implementations(self):
+        with mask_modules("pylibmc", "memcache"):
+            with reset_modules("celery.backends.cache"):
+                from celery.backends import cache
+                cache._imp = [None]
+                with self.assertRaises(ImproperlyConfigured):
+                    cache.get_best_memcache()
+
+    def test_cached(self):
+        with self.mock_pylibmc():
+            with reset_modules("celery.backends.cache"):
+                from celery.backends import cache
+                cache.get_best_memcache(behaviors={"foo": "bar"})
+                self.assertTrue(cache._imp[0])
+                cache.get_best_memcache()
+
+    def test_backends(self):
+        from celery.backends.cache import backends
+        for name, fun in backends.items():
+            self.assertTrue(fun())

+ 13 - 0
celery/tests/test_backends/test_pyredis_compat.py

@@ -0,0 +1,13 @@
+from celery.backends import pyredis
+from celery.tests.utils import unittest
+
+
+class test_RedisBackend(unittest.TestCase):
+
+    def test_constructor(self):
+        x = pyredis.RedisBackend(redis_host="foobar", redis_port=312,
+                                 redis_db=1, redis_password="foo")
+        self.assertEqual(x.redis_host, "foobar")
+        self.assertEqual(x.redis_port, 312)
+        self.assertEqual(x.redis_db, 1)
+        self.assertEqual(x.redis_password, "foo")

+ 54 - 0
celery/tests/test_backends/test_redis_unit.py

@@ -1,5 +1,8 @@
 from datetime import timedelta
 
+from mock import Mock, patch
+from kombu.utils import cached_property
+
 from celery import current_app
 from celery import states
 from celery.utils import gen_unique_id
@@ -55,6 +58,14 @@ class test_RedisBackend(unittest.TestCase):
     def setUp(self):
         self.Backend = self.get_backend()
 
+        class MockBackend(self.Backend):
+
+            @cached_property
+            def client(self):
+                return Mock()
+
+        self.MockBackend = MockBackend
+
     def test_expires_defaults_to_config(self):
         conf = current_app.conf
         prev = conf.CELERY_TASK_RESULT_EXPIRES
@@ -78,6 +89,49 @@ class test_RedisBackend(unittest.TestCase):
         b = self.Backend(expires=timedelta(minutes=1))
         self.assertEqual(b.expires, 60)
 
+    def test_on_chord_apply(self):
+        self.Backend().on_chord_apply()
+
+    def test_mget(self):
+        b = self.MockBackend()
+        self.assertTrue(b.mget(["a", "b", "c"]))
+        b.client.mget.assert_called_with(["a", "b", "c"])
+
+    def test_set_no_expire(self):
+        b = self.MockBackend()
+        b.expires = None
+        b.set("foo", "bar")
+
+    @patch("celery.result.TaskSetResult")
+    def test_on_chord_part_return(self, setresult):
+        from celery.registry import tasks
+        from celery.task import subtask
+        b = self.MockBackend()
+        deps = Mock()
+        deps.total = 10
+        setresult.restore.return_value = deps
+        b.client.incr.return_value = 1
+        task = Mock()
+        task.name = "foobarbaz"
+        try:
+            tasks["foobarbaz"] = task
+            task.request.chord = subtask(task)
+
+            b.on_chord_part_return(task)
+            self.assertTrue(b.client.incr.call_count)
+
+            b.client.incr.return_value = deps.total
+            b.on_chord_part_return(task)
+            deps.join.assert_called_with()
+            deps.delete.assert_called_with()
+
+            self.assertTrue(b.client.expire.call_count)
+        finally:
+            tasks.pop("foobarbaz")
+
+    def test_process_cleanup(self):
+        self.Backend().process_cleanup()
+
     def test_get_set_forget(self):
         b = self.Backend()
         uuid = gen_unique_id()

+ 16 - 3
celery/tests/test_compat/test_log.py

@@ -8,7 +8,8 @@ from tempfile import mktemp
 from celery import log
 from celery.log import (setup_logger, setup_task_logger,
                         get_default_logger, get_task_logger,
-                        redirect_stdouts_to_logger, LoggingProxy)
+                        redirect_stdouts_to_logger, LoggingProxy,
+                        setup_logging_subsystem)
 from celery.utils import gen_unique_id
 from celery.utils.compat import _CompatLoggerAdapter
 from celery.tests.utils import (override_stdouts, wrap_logger,
@@ -22,6 +23,18 @@ class test_default_logger(unittest.TestCase):
         self.get_logger = get_default_logger
         log._setup = False
 
+    def test_setup_logging_subsystem_colorize(self):
+        setup_logging_subsystem(colorize=None)
+        setup_logging_subsystem(colorize=True)
+
+    def test_setup_logging_subsystem_no_mputil(self):
+        mputil, log.mputil = log.mputil, None
+        log.mputil
+        try:
+            log.setup_logging_subsystem()
+        finally:
+            log.mputil = mputil
+
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
         with wrap_logger(logger, loglevel=loglevel) as sio:
@@ -38,10 +51,10 @@ class test_default_logger(unittest.TestCase):
 
     def test_setup_logger(self):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                   root=False)
+                                   root=False, colorize=True)
         set_handlers(logger, [])
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
-                                   root=False)
+                                   root=False, colorize=None)
         self.assertIs(get_handlers(logger)[0].stream, sys.__stderr__,
                 "setup_logger logs to stderr without logfile argument.")
         self.assertDidLogFalse(logger, "Logging something",

+ 8 - 0
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -1,7 +1,9 @@
+import signal
 import sys
 
 from itertools import cycle
 
+from mock import patch
 from nose import SkipTest
 
 try:
@@ -187,6 +189,12 @@ class test_TaskPool(unittest.TestCase):
         pool.start()
         pool.apply_async(lambda x: x, (2, ), {})
 
+    @patch("celery.concurrency.processes._kill")
+    def test_terminate_job(self, _kill):
+        pool = TaskPool(10)
+        pool.terminate_job(1341)
+        _kill.assert_called_with(1341, signal.SIGTERM)
+
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool.start()

+ 23 - 0
celery/tests/test_concurrency/test_concurrency_solo.py

@@ -0,0 +1,23 @@
+import operator
+
+from celery.concurrency import solo
+from celery.utils import noop
+from celery.tests.utils import unittest
+
+
+class test_solo_TaskPool(unittest.TestCase):
+
+    def test_on_start(self):
+        x = solo.TaskPool()
+        x.on_start()
+        self.assertTrue(x.pid)
+
+    def test_on_apply(self):
+        x = solo.TaskPool()
+        x.on_start()
+        x.on_apply(operator.add, (2, 2), {}, noop, noop)
+
+    def test_info(self):
+        x = solo.TaskPool()
+        x.on_start()
+        self.assertTrue(x.info)

+ 3 - 4
celery/tests/test_task/test_chord.py

@@ -7,6 +7,7 @@ from celery.tests.utils import AppCase, Mock
 
 passthru = lambda x: x
 
+
 @current_app.task
 def add(x, y):
     return x + y
@@ -62,7 +63,6 @@ class test_chord(AppCase):
         self.assertTrue(chord.Chord.apply_async.call_count)
 
 
-
 class test_Chord_task(AppCase):
 
     def test_run(self):
@@ -71,7 +71,6 @@ class test_Chord_task(AppCase):
             backend = Mock()
 
         body = dict()
-        r = Chord()(TaskSet(add.subtask((i, i)) for i in xrange(5)), body)
-        r = Chord()([add.subtask((i, i)) for i in xrange(5)], body)
+        Chord()(TaskSet(add.subtask((i, i)) for i in xrange(5)), body)
+        Chord()([add.subtask((i, i)) for i in xrange(5)], body)
         self.assertEqual(Chord.backend.on_chord_apply.call_count, 2)
-

+ 27 - 1
celery/tests/test_task/test_result.py

@@ -2,7 +2,7 @@ from celery import states
 from celery.app import app_or_default
 from celery.utils import gen_unique_id
 from celery.utils.serialization import pickle
-from celery.result import AsyncResult, EagerResult, TaskSetResult
+from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 
@@ -93,6 +93,12 @@ class TestAsyncResult(unittest.TestCase):
         self.assertEqual(repr(pending_res), "<AsyncResult: %s>" % (
                 pending_id))
 
+    def test_hash(self):
+        self.assertEqual(hash(AsyncResult("x0w991")),
+                         hash(AsyncResult("x0w991")))
+        self.assertNotEqual(hash(AsyncResult("x0w991")),
+                            hash(AsyncResult("x1w991")))
+
     def test_get_traceback(self):
         ok_res = AsyncResult(self.task1["id"])
         nok_res = AsyncResult(self.task3["id"])
@@ -138,6 +144,26 @@ class TestAsyncResult(unittest.TestCase):
         self.assertFalse(AsyncResult(gen_unique_id()).ready())
 
 
+class test_ResultSet(unittest.TestCase):
+
+    def test_add_discard(self):
+        x = ResultSet([])
+        x.add(AsyncResult("1"))
+        self.assertIn(AsyncResult("1"), x.results)
+        x.discard(AsyncResult("1"))
+        x.discard(AsyncResult("1"))
+        x.discard("1")
+        self.assertNotIn(AsyncResult("1"), x.results)
+
+        x.update([AsyncResult("2")])
+
+    def test_clear(self):
+        x = ResultSet([])
+        r = x.results
+        x.clear()
+        self.assertIs(x.results, r)
+
+
 class MockAsyncResultFailure(AsyncResult):
 
     @property

+ 91 - 2
celery/tests/test_utils/test_timer2.py

@@ -1,8 +1,16 @@
+from __future__ import with_statement
+
+import sys
 import time
-from celery.tests.utils import unittest
+import warnings
+
+from kombu.tests.utils import redirect_stdouts
+from mock import Mock, patch
+
 import celery.utils.timer2 as timer2
 
-from celery.tests.utils import skip_if_quick
+from celery.tests.utils import unittest, skip_if_quick
+from celery.tests.compat import catch_warnings
 
 
 class test_Entry(unittest.TestCase):
@@ -43,6 +51,12 @@ class test_Schedule(unittest.TestCase):
         try:
             s.enter(timer2.Entry(lambda: None, (), {}),
                     eta=datetime.now())
+            s.enter(timer2.Entry(lambda: None, (), {}),
+                    eta=None)
+            s.on_error = None
+            with self.assertRaises(OverflowError):
+                s.enter(timer2.Entry(lambda: None, (), {}),
+                        eta=datetime.now())
         finally:
             timer2.mktime = mktime
 
@@ -66,3 +80,78 @@ class test_Timer(unittest.TestCase):
                 time.sleep(0.1)
         finally:
             t.stop()
+
+    def test_exit_after(self):
+        t = timer2.Timer()
+        t.apply_after = Mock()
+        t.exit_after(300, priority=10)
+        t.apply_after.assert_called_with(300, sys.exit, 10)
+
+    def test_apply_interval(self):
+        t = timer2.Timer()
+        t.enter_after = Mock()
+
+        myfun = Mock()
+        t.apply_interval(30, myfun)
+
+        self.assertEqual(t.enter_after.call_count, 1)
+        args1, _ = t.enter_after.call_args_list[0]
+        msec1, tref1, _ = args1
+        self.assertEqual(msec1, 30)
+        tref1()
+
+        self.assertEqual(t.enter_after.call_count, 2)
+        args2, _ = t.enter_after.call_args_list[1]
+        msec2, tref2, _ = args2
+        self.assertEqual(msec2, 30)
+        tref2.cancelled = True
+        tref2()
+
+        self.assertEqual(t.enter_after.call_count, 2)
+
+    @redirect_stdouts
+    def test_apply_entry_error_handled(self, stdout, stderr):
+        t = timer2.Timer()
+        t.schedule.on_error = None
+
+        fun = Mock()
+        fun.side_effect = ValueError()
+        warnings.resetwarnings()
+
+        with catch_warnings(record=True) as log:
+            t.apply_entry(fun)
+            fun.assert_called_with()
+            self.assertTrue(log)
+            self.assertTrue(stderr.getvalue())
+
+    @redirect_stdouts
+    def test_apply_entry_error_not_handled(self, stdout, stderr):
+        t = timer2.Timer()
+        t.schedule.on_error = Mock()
+
+        fun = Mock()
+        fun.side_effect = ValueError()
+        warnings.resetwarnings()
+
+        with catch_warnings(record=True) as log:
+            t.apply_entry(fun)
+            fun.assert_called_with()
+            self.assertFalse(log)
+            self.assertFalse(stderr.getvalue())
+
+    @patch("os._exit")
+    def test_thread_crash(self, _exit):
+        t = timer2.Timer()
+        t.next = Mock()
+        t.next.side_effect = OSError(131)
+        t.run()
+        _exit.assert_called_with(1)
+
+    def test_gc_race_lost(self):
+        t = timer2.Timer()
+        t._stopped.set = Mock()
+        t._stopped.set.side_effect = TypeError()
+
+        t._shutdown.set()
+        t.run()
+        t._stopped.set.assert_called_with()

+ 1 - 1
celery/utils/__init__.py

@@ -406,7 +406,7 @@ def import_from_cwd(module, imp=None):
         return imp(module)
 
 
-def cry():
+def cry():  # pragma: no cover
     """Return stacktrace of all active threads.
 
     From https://gist.github.com/737056

+ 5 - 6
celery/utils/timer2.py

@@ -77,16 +77,15 @@ class Schedule(object):
         :keyword priority: Unused.
 
         """
+        if eta is None:  # schedule now
+            eta = datetime.now()
+
         try:
             eta = to_timestamp(eta)
         except OverflowError:
             if not self.handle_error(sys.exc_info()):
                 raise
 
-        if eta is None:
-            # schedule now.
-            eta = time()
-
         heapq.heappush(self._queue, (eta, priority, entry))
         return entry
 
@@ -184,12 +183,12 @@ class Timer(Thread):
                 if delay:
                     if self.on_tick:
                         self.on_tick(delay)
-                    if sleep is None:
+                    if sleep is None:  # pragma: no cover
                         break
                     sleep(delay)
             try:
                 self._stopped.set()
-            except TypeError:           # pragma: no cover
+            except TypeError:  # pragma: no cover
                 # we lost the race at interpreter shutdown,
                 # so gc collected built-in modules.
                 pass