Prechádzať zdrojové kódy

Back at 80% coverage

Ask Solem 15 rokov pred
rodič
commit
5305fad910

+ 3 - 5
Makefile

@@ -26,14 +26,12 @@ readme: clean_readme
 bump:
 bump:
 	contrib/bump -c celery
 	contrib/bump -c celery
 
 
-coverage2:
-	[ -d testproj/temp ] || mkdir -p testproj/temp
-	(cd testproj; python manage.py test --figleaf)
-
 coverage:
 coverage:
-	[ -d testproj/temp ] || mkdir -p testproj/temp
 	(cd testproj; python manage.py test --coverage)
 	(cd testproj; python manage.py test --coverage)
 
 
+quickcoverage:
+	(cd testproj; env QUICKTEST=1 SKIP_RLIMITS=1 python manage.py test --coverage)
+
 test:
 test:
 	(cd testproj; python manage.py test)
 	(cd testproj; python manage.py test)
 
 

+ 1 - 1
celery/backends/__init__.py

@@ -9,7 +9,7 @@ BACKEND_ALIASES = {
     "amqp": "celery.backends.amqp.AMQPBackend",
     "amqp": "celery.backends.amqp.AMQPBackend",
     "database": "celery.backends.database.DatabaseBackend",
     "database": "celery.backends.database.DatabaseBackend",
     "db": "celery.backends.database.DatabaseBackend",
     "db": "celery.backends.database.DatabaseBackend",
-    "redis": "celery.backends.redis.RedisBackend",
+    "redis": "celery.backends.pyredis.RedisBackend",
     "cache": "celery.backends.cache.CacheBackend",
     "cache": "celery.backends.cache.CacheBackend",
     "mongodb": "celery.backends.mongodb.MongoBackend",
     "mongodb": "celery.backends.mongodb.MongoBackend",
     "tyrant": "celery.backends.tyrant.TyrantBackend",
     "tyrant": "celery.backends.tyrant.TyrantBackend",

+ 1 - 1
celery/backends/base.py

@@ -149,7 +149,7 @@ class KeyValueStoreBackend(BaseBackend):
     def get_result(self, task_id):
     def get_result(self, task_id):
         """Get the result of a task."""
         """Get the result of a task."""
         meta = self._get_task_meta_for(task_id)
         meta = self._get_task_meta_for(task_id)
-        if meta["status"] == "FAILURE":
+        if meta["status"] in self.EXCEPTION_STATES:
             return self.exception_to_python(meta["result"])
             return self.exception_to_python(meta["result"])
         else:
         else:
             return meta["result"]
             return meta["result"]

+ 3 - 6
celery/backends/cache.py

@@ -32,12 +32,9 @@ class DjangoMemcacheWrapper(object):
 # Check if django is using memcache as the cache backend. If so, wrap the
 # Check if django is using memcache as the cache backend. If so, wrap the
 # cache object in a DjangoMemcacheWrapper that fixes a bug with retrieving
 # cache object in a DjangoMemcacheWrapper that fixes a bug with retrieving
 # pickled data
 # pickled data
-try:
-    from django.core.cache.backends.memcached import CacheClass
-    if isinstance(cache, CacheClass):
-        cache = DjangoMemcacheWrapper(cache)
-except InvalidCacheBackendError:
-    pass
+from django.core.cache.backends.memcached import CacheClass
+if isinstance(cache, CacheClass):
+    cache = DjangoMemcacheWrapper(cache)
 
 
 
 
 class CacheBackend(KeyValueStoreBackend):
 class CacheBackend(KeyValueStoreBackend):

+ 10 - 9
celery/backends/redis.py → celery/backends/pyredis.py

@@ -1,13 +1,13 @@
-"""celery.backends.tyrant"""
 from django.core.exceptions import ImproperlyConfigured
 from django.core.exceptions import ImproperlyConfigured
+
 from celery.backends.base import KeyValueStoreBackend
 from celery.backends.base import KeyValueStoreBackend
+from celery.loaders import settings
+
 try:
 try:
     import redis
     import redis
 except ImportError:
 except ImportError:
     redis = None
     redis = None
 
 
-from celery.loaders import settings
-
 
 
 class RedisBackend(KeyValueStoreBackend):
 class RedisBackend(KeyValueStoreBackend):
     """Redis based task backend store.
     """Redis based task backend store.
@@ -24,8 +24,8 @@ class RedisBackend(KeyValueStoreBackend):
         :setting:`REDIS_HOST` or :setting:`REDIS_PORT` is not set.
         :setting:`REDIS_HOST` or :setting:`REDIS_PORT` is not set.
 
 
     """
     """
-    redis_host = None
-    redis_port = None
+    redis_host = "localhost"
+    redis_port = 6379
     redis_db = "celery_results"
     redis_db = "celery_results"
     redis_timeout = None
     redis_timeout = None
     redis_connect_retry = None
     redis_connect_retry = None
@@ -34,7 +34,7 @@ class RedisBackend(KeyValueStoreBackend):
             redis_timeout=None,
             redis_timeout=None,
             redis_connect_retry=None,
             redis_connect_retry=None,
             redis_connect_timeout=None):
             redis_connect_timeout=None):
-        if not redis:
+        if redis is None:
             raise ImproperlyConfigured(
             raise ImproperlyConfigured(
                     "You need to install the redis library in order to use "
                     "You need to install the redis library in order to use "
                   + "Redis result store backend.")
                   + "Redis result store backend.")
@@ -73,13 +73,14 @@ class RedisBackend(KeyValueStoreBackend):
                                     port=self.redis_port,
                                     port=self.redis_port,
                                     db=self.redis_db,
                                     db=self.redis_db,
                                     timeout=self.redis_timeout,
                                     timeout=self.redis_timeout,
-                                    connect_retry=self.redis_connect_retry)
+                                    retry_connection=self.redis_connect_retry)
+            self._connection.connect()
         return self._connection
         return self._connection
 
 
     def close(self):
     def close(self):
-        """Close the redis connection and remove the cache."""
+        """Close the connection to redis."""
         if self._connection is not None:
         if self._connection is not None:
-            self._connection.close()
+            self._connection.disconnect()
             self._connection = None
             self._connection = None
 
 
     def process_cleanup(self):
     def process_cleanup(self):

+ 40 - 28
celery/beat.py

@@ -1,14 +1,13 @@
 import time
 import time
 import math
 import math
 import shelve
 import shelve
-import atexit
 import threading
 import threading
 from datetime import datetime
 from datetime import datetime
 from UserDict import UserDict
 from UserDict import UserDict
 
 
 from celery import log
 from celery import log
 from celery import conf
 from celery import conf
-from celery import registry
+from celery import registry as _registry
 from celery.utils.info import humanize_seconds
 from celery.utils.info import humanize_seconds
 
 
 
 
@@ -82,8 +81,8 @@ class Scheduler(UserDict):
 
 
     def __init__(self, registry=None, schedule=None, logger=None,
     def __init__(self, registry=None, schedule=None, logger=None,
             max_interval=None):
             max_interval=None):
-        self.registry = registry or {}
-        self.schedule = schedule or {}
+        self.registry = registry or _registry.TaskRegistry()
+        self.data = schedule or {}
         self.logger = logger or log.get_default_logger()
         self.logger = logger or log.get_default_logger()
         self.max_interval = max_interval or conf.CELERYBEAT_MAX_LOOP_INTERVAL
         self.max_interval = max_interval or conf.CELERYBEAT_MAX_LOOP_INTERVAL
 
 
@@ -153,59 +152,72 @@ class Scheduler(UserDict):
 
 
 class ClockService(object):
 class ClockService(object):
     scheduler_cls = Scheduler
     scheduler_cls = Scheduler
-    registry = registry.tasks
+    registry = _registry.tasks
+    open_schedule = shelve.open
 
 
     def __init__(self, logger=None, is_detached=False,
     def __init__(self, logger=None, is_detached=False,
             max_interval=conf.CELERYBEAT_MAX_LOOP_INTERVAL,
             max_interval=conf.CELERYBEAT_MAX_LOOP_INTERVAL,
             schedule_filename=conf.CELERYBEAT_SCHEDULE_FILENAME):
             schedule_filename=conf.CELERYBEAT_SCHEDULE_FILENAME):
-        self.logger = logger
+        self.logger = logger or log.get_default_logger()
         self.max_interval = max_interval
         self.max_interval = max_interval
         self.schedule_filename = schedule_filename
         self.schedule_filename = schedule_filename
         self._shutdown = threading.Event()
         self._shutdown = threading.Event()
         self._stopped = threading.Event()
         self._stopped = threading.Event()
+        self._schedule = None
+        self._scheduler = None
+        self._in_sync = False
+        silence = self.max_interval < 60 and 10 or 1
+        self.debug = log.SilenceRepeated(self.logger.debug,
+                                         max_iterations=silence)
 
 
     def start(self):
     def start(self):
         self.logger.info("ClockService: Starting...")
         self.logger.info("ClockService: Starting...")
-        schedule = shelve.open(filename=self.schedule_filename)
-        atexit.register(schedule.close)
-        scheduler = self.scheduler_cls(schedule=schedule,
-                                       registry=self.registry,
-                                       logger=self.logger,
-                                       max_interval=self.max_interval)
         self.logger.debug("ClockService: "
         self.logger.debug("ClockService: "
             "Ticking with max interval->%s, schedule->%s" % (
             "Ticking with max interval->%s, schedule->%s" % (
                     humanize_seconds(self.max_interval),
                     humanize_seconds(self.max_interval),
                     self.schedule_filename))
                     self.schedule_filename))
 
 
-        synced = [False]
-        def _stop():
-            if not synced[0]:
-                self.logger.debug("ClockService: Syncing schedule to disk...")
-                schedule.sync()
-                schedule.close()
-                synced[0] = True
-                self._stopped.set()
-
-        silence = self.max_interval < 60 and 10 or 1
-        debug = log.SilenceRepeated(self.logger.debug, max_iterations=silence)
-
         try:
         try:
             while True:
             while True:
                 if self._shutdown.isSet():
                 if self._shutdown.isSet():
                     break
                     break
-                interval = scheduler.tick()
-                debug("ClockService: Waking up %s." % (
+                interval = self.scheduler.tick()
+                self.debug("ClockService: Waking up %s." % (
                         humanize_seconds(interval, prefix="in ")))
                         humanize_seconds(interval, prefix="in ")))
                 time.sleep(interval)
                 time.sleep(interval)
         except (KeyboardInterrupt, SystemExit):
         except (KeyboardInterrupt, SystemExit):
-            _stop()
+            self.sync()
         finally:
         finally:
-            _stop()
+            self.sync()
+
+    def sync(self):
+        if self._schedule is not None and not self._in_sync:
+            self.logger.debug("ClockService: Syncing schedule to disk...")
+            self._schedule.sync()
+            self._schedule.close()
+            self._in_sync = True
+            self._stopped.set()
 
 
     def stop(self, wait=False):
     def stop(self, wait=False):
         self._shutdown.set()
         self._shutdown.set()
         wait and self._stopped.wait() # block until shutdown done.
         wait and self._stopped.wait() # block until shutdown done.
 
 
+    @property
+    def schedule(self):
+        if self._schedule is None:
+            filename = self.schedule_filename
+            self._schedule = self.open_schedule(filename=filename)
+        return self._schedule
+
+    @property
+    def scheduler(self):
+        if self._scheduler is None:
+            self._scheduler = self.scheduler_cls(schedule=self.schedule,
+                                            registry=self.registry,
+                                            logger=self.logger,
+                                            max_interval=self.max_interval)
+        return self._scheduler
+
 
 
 class ClockServiceThread(threading.Thread):
 class ClockServiceThread(threading.Thread):
 
 

+ 1 - 3
celery/result.py

@@ -97,9 +97,7 @@ class BaseAsyncResult(object):
         If the task raised an exception, this will be the exception instance.
         If the task raised an exception, this will be the exception instance.
 
 
         """
         """
-        if self.status in self.backend.READY_STATES:
-            return self.backend.get_result(self.task_id)
-        return None
+        return self.backend.get_result(self.task_id)
 
 
     @property
     @property
     def traceback(self):
     def traceback(self):

+ 27 - 0
celery/tests/test_backends/__init__.py

@@ -0,0 +1,27 @@
+import unittest
+
+
+from celery.backends.database import DatabaseBackend
+from celery.backends.amqp import AMQPBackend
+from celery.backends.pyredis import RedisBackend
+from celery import backends
+
+
+class TestBackends(unittest.TestCase):
+
+    def test_get_backend_aliases(self):
+        self.assertTrue(issubclass(
+            backends.get_backend_cls("amqp"), AMQPBackend))
+        self.assertTrue(issubclass(
+            backends.get_backend_cls("database"), DatabaseBackend))
+        self.assertTrue(issubclass(
+            backends.get_backend_cls("db"), DatabaseBackend))
+        self.assertTrue(issubclass(
+            backends.get_backend_cls("redis"), RedisBackend))
+
+    def test_get_backend_cahe(self):
+        backends._backend_cache = {}
+        backends.get_backend_cls("amqp")
+        self.assertTrue("amqp" in backends._backend_cache)
+        amqp_backend = backends.get_backend_cls("amqp")
+        self.assertTrue(amqp_backend is backends._backend_cache["amqp"])

+ 4 - 0
celery/tests/test_backends/test_base.py

@@ -36,6 +36,10 @@ class TestBaseBackendInterface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
         self.assertRaises(NotImplementedError,
                 b.get_result, "SOMExx-N0nex1stant-IDxx-")
                 b.get_result, "SOMExx-N0nex1stant-IDxx-")
 
 
+    def test_get_traceback(self):
+        self.assertRaises(NotImplementedError,
+                b.get_traceback, "SOMExx-N0nex1stant-IDxx-")
+
 
 
 class TestPickleException(unittest.TestCase):
 class TestPickleException(unittest.TestCase):
 
 

+ 55 - 2
celery/tests/test_backends/test_cache.py

@@ -1,6 +1,11 @@
+import sys
 import unittest
 import unittest
-from celery.backends.cache import CacheBackend
+
+from billiard.serialization import pickle
+
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
+from celery.backends.cache import CacheBackend
+from celery.datastructures import ExceptionInfo
 
 
 
 
 class SomeClass(object):
 class SomeClass(object):
@@ -41,17 +46,65 @@ class TestCacheBackend(unittest.TestCase):
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
         cb = CacheBackend()
         cb = CacheBackend()
 
 
+        einfo = None
         tid3 = gen_unique_id()
         tid3 = gen_unique_id()
         try:
         try:
             raise KeyError("foo")
             raise KeyError("foo")
         except KeyError, exception:
         except KeyError, exception:
+            einfo = ExceptionInfo(sys.exc_info())
             pass
             pass
-        cb.mark_as_failure(tid3, exception)
+        cb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(cb.is_successful(tid3))
         self.assertFalse(cb.is_successful(tid3))
         self.assertEquals(cb.get_status(tid3), "FAILURE")
         self.assertEquals(cb.get_status(tid3), "FAILURE")
         self.assertTrue(isinstance(cb.get_result(tid3), KeyError))
         self.assertTrue(isinstance(cb.get_result(tid3), KeyError))
+        self.assertEquals(cb.get_traceback(tid3), einfo.traceback)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         cb = CacheBackend()
         cb = CacheBackend()
 
 
         cb.process_cleanup()
         cb.process_cleanup()
+
+
+class TestCustomCacheBackend(unittest.TestCase):
+
+    def test_custom_cache_backend(self):
+        from celery import conf
+        prev_backend = conf.CELERY_CACHE_BACKEND
+        prev_module = sys.modules["celery.backends.cache"]
+        conf.CELERY_CACHE_BACKEND = "dummy://"
+        sys.modules.pop("celery.backends.cache")
+        try:
+            from celery.backends.cache import cache
+            from django.core.cache import cache as django_cache
+            self.assertEquals(cache.__class__.__module__,
+                              "django.core.cache.backends.dummy")
+            self.assertTrue(cache is not django_cache)
+        finally:
+            conf.CELERY_CACHE_BACKEND = prev_backend
+            sys.modules["celery.backends.cache"] = prev_module
+
+
+class TestMemcacheWrapper(unittest.TestCase):
+
+    def test_memcache_wrapper(self):
+
+        from django.core.cache.backends import memcached
+        from django.core.cache.backends import locmem
+        prev_cache_cls = memcached.CacheClass
+        memcached.CacheClass = locmem.CacheClass
+        prev_backend_module = sys.modules.pop("celery.backends.cache")
+        try:
+            from celery.backends.cache import cache, DjangoMemcacheWrapper
+            self.assertTrue(isinstance(cache, DjangoMemcacheWrapper))
+
+            key = "cu.test_memcache_wrapper"
+            val = "The quick brown fox."
+            default = "The lazy dog."
+
+            self.assertEquals(cache.get(key, default=default), default)
+            cache.set(key, val)
+            self.assertEquals(pickle.loads(cache.get(key, default=default)),
+                              val)
+        finally:
+            memcached.CacheClass = prev_cache_cls
+            sys.modules["celery.backends.cache"] = prev_backend_module

+ 156 - 0
celery/tests/test_backends/test_redis.py

@@ -0,0 +1,156 @@
+from __future__ import with_statement
+
+import sys
+import unittest
+import errno
+
+from django.core.exceptions import ImproperlyConfigured
+
+from celery.backends import pyredis
+from celery.backends.pyredis import RedisBackend
+from celery.utils import gen_unique_id
+
+_no_redis_msg = "* Redis %s. Will not execute related tests."
+_no_redis_msg_emitted = False
+
+
+class SomeClass(object):
+
+    def __init__(self, data):
+        self.data = data
+
+
+def get_redis_or_None():
+
+    def emit_no_redis_msg(reason):
+        global _no_redis_msg_emitted
+        if not _no_redis_msg_emitted:
+            sys.stderr.write("\n" + _no_redis_msg % reason + "\n")
+            _no_redis_msg_emitted = True
+
+    if pyredis.redis is None:
+        return emit_no_redis_msg("not installed")
+    try:
+        tb = RedisBackend(redis_db="celery_unittest")
+        try:
+            tb.open()
+        except pyredis.redis.ConnectionError, exc:
+            return emit_no_redis_msg("not running")
+        return tb
+    except ImproperlyConfigured, exc:
+        if "need to install" in str(exc):
+            return emit_no_redis_msg("not installed")
+        return emit_no_redis_msg("not configured")
+
+
+class TestRedisBackend(unittest.TestCase):
+
+    def test_cached_connection(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return # Skip test
+
+        self.assertTrue(tb._connection is not None)
+        tb.close()
+        self.assertTrue(tb._connection is None)
+        tb.open()
+        self.assertTrue(tb._connection is not None)
+
+    def test_mark_as_done(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return
+
+        tid = gen_unique_id()
+
+        self.assertFalse(tb.is_successful(tid))
+        self.assertEquals(tb.get_status(tid), "PENDING")
+        self.assertEquals(tb.get_result(tid), None)
+
+        tb.mark_as_done(tid, 42)
+        self.assertTrue(tb.is_successful(tid))
+        self.assertEquals(tb.get_status(tid), "SUCCESS")
+        self.assertEquals(tb.get_result(tid), 42)
+        self.assertTrue(tb._cache.get(tid))
+        self.assertTrue(tb.get_result(tid), 42)
+
+    def test_is_pickled(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return
+
+        tid2 = gen_unique_id()
+        result = {"foo": "baz", "bar": SomeClass(12345)}
+        tb.mark_as_done(tid2, result)
+        # is serialized properly.
+        rindb = tb.get_result(tid2)
+        self.assertEquals(rindb.get("foo"), "baz")
+        self.assertEquals(rindb.get("bar").data, 12345)
+
+    def test_mark_as_failure(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return
+
+        tid3 = gen_unique_id()
+        try:
+            raise KeyError("foo")
+        except KeyError, exception:
+            pass
+        tb.mark_as_failure(tid3, exception)
+        self.assertFalse(tb.is_successful(tid3))
+        self.assertEquals(tb.get_status(tid3), "FAILURE")
+        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+
+    def test_process_cleanup(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return
+
+        tb.process_cleanup()
+
+        self.assertTrue(tb._connection is None)
+
+    def test_connection_close_if_connected(self):
+        tb = get_redis_or_None()
+        if not tb:
+            return
+
+        tb.open()
+        self.assertTrue(tb._connection is not None)
+        tb.close()
+        self.assertTrue(tb._connection is None)
+        tb.close()
+        self.assertTrue(tb._connection is None)
+
+
+class TestTyrantBackendNoTyrant(unittest.TestCase):
+
+    def test_tyrant_None_if_tyrant_not_installed(self):
+        from celery.tests.utils import mask_modules
+        prev = sys.modules.pop("celery.backends.pyredis")
+        with mask_modules("redis"):
+            from celery.backends.pyredis import redis
+            self.assertTrue(redis is None)
+        sys.modules["celery.backends.pyredis"] = prev
+
+    def test_constructor_raises_if_tyrant_not_installed(self):
+        from celery.backends import pyredis
+        prev = pyredis.redis
+        pyredis.redis = None
+        try:
+            self.assertRaises(ImproperlyConfigured, pyredis.RedisBackend)
+        finally:
+            pyredis.redis = prev
+
+    def test_constructor_raises_if_not_host_or_port(self):
+        from celery.backends import pyredis
+        prev_host = pyredis.RedisBackend.redis_host
+        prev_port = pyredis.RedisBackend.redis_port
+        pyredis.RedisBackend.redis_host = None
+        pyredis.RedisBackend.redis_port = None
+        try:
+            self.assertRaises(ImproperlyConfigured, pyredis.RedisBackend)
+        finally:
+            pyredis.RedisBackend.redis_host = prev_host
+            pyredis.RedisBackend.redis_port = prev_port

+ 8 - 8
celery/tests/test_backends/test_tyrant.py

@@ -7,7 +7,7 @@ from celery.backends.tyrant import TyrantBackend
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 from django.core.exceptions import ImproperlyConfigured
 from django.core.exceptions import ImproperlyConfigured
 
 
-_no_tyrant_msg = "* Tokyo Tyrant not running. Will not execute related tests."
+_no_tyrant_msg = "* Tokyo Tyrant %s. Will not execute related tests."
 _no_tyrant_msg_emitted = False
 _no_tyrant_msg_emitted = False
 
 
 
 
@@ -19,28 +19,28 @@ class SomeClass(object):
 
 
 def get_tyrant_or_None():
 def get_tyrant_or_None():
 
 
-    def emit_no_tyrant_msg():
+    def emit_no_tyrant_msg(reason):
         global _no_tyrant_msg_emitted
         global _no_tyrant_msg_emitted
         if not _no_tyrant_msg_emitted:
         if not _no_tyrant_msg_emitted:
-            sys.stderr.write("\n" + _no_tyrant_msg + "\n")
+            sys.stderr.write("\n" + _no_tyrant_msg % reason + "\n")
             _no_tyrant_msg_emitted = True
             _no_tyrant_msg_emitted = True
 
 
     if tyrant.pytyrant is None:
     if tyrant.pytyrant is None:
-        emit_no_tyrant_msg()
-        return None
+        return emit_no_tyrant_msg("not installed")
     try:
     try:
         tb = TyrantBackend()
         tb = TyrantBackend()
         try:
         try:
             tb.open()
             tb.open()
         except socket.error, exc:
         except socket.error, exc:
             if exc.errno == errno.ECONNREFUSED:
             if exc.errno == errno.ECONNREFUSED:
-                emit_no_tyrant_msg()
-                return None
+                return emit_no_tyrant_msg("not running")
             else:
             else:
                 raise
                 raise
         return tb
         return tb
     except ImproperlyConfigured, exc:
     except ImproperlyConfigured, exc:
-        return None
+        if "need to install" in str(exc):
+            return emit_no_tyrant_msg("not installed")
+        return emit_no_tyrant_msg("not configured")
 
 
 
 
 class TestTyrantBackend(unittest.TestCase):
 class TestTyrantBackend(unittest.TestCase):

+ 217 - 0
celery/tests/test_beat.py

@@ -0,0 +1,217 @@
+import unittest
+import logging
+from datetime import datetime, timedelta
+
+from celery import log
+from celery import beat
+from celery import conf
+from celery.utils import gen_unique_id
+from celery.task.base import PeriodicTask
+from celery.registry import TaskRegistry
+from celery.result import AsyncResult
+
+
+class MockShelve(dict):
+    closed = False
+    synced = False
+
+    def close(self):
+        self.closed = True
+
+    def sync(self):
+        self.synced = True
+
+
+class MockClockService(object):
+    started = False
+    stopped = False
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def start(self, **kwargs):
+        self.started = True
+
+    def stop(self, **kwargs):
+        self.stopped = True
+
+
+class DuePeriodicTask(PeriodicTask):
+    run_every = timedelta(seconds=1)
+    applied = False
+
+    def is_due(self, *args, **kwargs):
+        return True, 100
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        self.applied = True
+        return AsyncResult(gen_unique_id())
+
+
+class DuePeriodicTaskRaising(PeriodicTask):
+    run_every = timedelta(seconds=1)
+    applied = False
+
+    def is_due(self, *args, **kwargs):
+        return True, 0
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        raise Exception("FoozBaaz")
+
+
+class PendingPeriodicTask(PeriodicTask):
+    run_every = timedelta(seconds=1)
+    applied = False
+
+    def is_due(self, *args, **kwargs):
+        return False, 100
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        self.applied = True
+        return AsyncResult(gen_unique_id())
+
+
+class AdditionalTask(PeriodicTask):
+    run_every = timedelta(days=7)
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        raise Exception("FoozBaaz")
+
+
+class TestScheduleEntry(unittest.TestCase):
+
+    def test_constructor(self):
+        s = beat.ScheduleEntry(DuePeriodicTask.name)
+        self.assertEquals(s.name, DuePeriodicTask.name)
+        self.assertTrue(isinstance(s.last_run_at, datetime))
+        self.assertEquals(s.total_run_count, 0)
+
+        now = datetime.now()
+        s = beat.ScheduleEntry(DuePeriodicTask.name, now, 300)
+        self.assertEquals(s.name, DuePeriodicTask.name)
+        self.assertEquals(s.last_run_at, now)
+        self.assertEquals(s.total_run_count, 300)
+
+    def test_next(self):
+        s = beat.ScheduleEntry(DuePeriodicTask.name, None, 300)
+        n = s.next()
+        self.assertEquals(n.name, s.name)
+        self.assertEquals(n.total_run_count, 301)
+        self.assertTrue(n.last_run_at > s.last_run_at)
+
+    def test_is_due(self):
+        due = beat.ScheduleEntry(DuePeriodicTask.name)
+        pending = beat.ScheduleEntry(PendingPeriodicTask.name)
+
+        self.assertTrue(due.is_due(DuePeriodicTask())[0])
+        self.assertFalse(pending.is_due(PendingPeriodicTask())[0])
+
+
+class TestScheduler(unittest.TestCase):
+
+    def setUp(self):
+        self.registry = TaskRegistry()
+        self.registry.register(DuePeriodicTask)
+        self.registry.register(PendingPeriodicTask)
+        self.scheduler = beat.Scheduler(self.registry,
+                                        max_interval=0.0001,
+                                        logger=log.get_default_logger())
+
+    def test_constructor(self):
+        s = beat.Scheduler()
+        self.assertTrue(isinstance(s.registry, TaskRegistry))
+        self.assertTrue(isinstance(s.schedule, dict))
+        self.assertTrue(isinstance(s.logger, logging.Logger))
+        self.assertEquals(s.max_interval, conf.CELERYBEAT_MAX_LOOP_INTERVAL)
+
+    def test_cleanup(self):
+        self.scheduler.schedule["fbz"] = beat.ScheduleEntry("fbz")
+        self.scheduler.cleanup()
+        self.assertTrue("fbz" not in self.scheduler.schedule)
+
+    def test_schedule_registry(self):
+        self.registry.register(AdditionalTask)
+        self.scheduler.schedule_registry()
+        self.assertTrue(AdditionalTask.name in self.scheduler.schedule)
+
+    def test_apply_async(self):
+        due_task = self.registry[DuePeriodicTask.name]
+        self.scheduler.apply_async(self.scheduler[due_task.name])
+        self.assertTrue(due_task.applied)
+
+    def test_apply_async_raises_SchedulingError_on_error(self):
+        self.registry.register(AdditionalTask)
+        self.scheduler.schedule_registry()
+        add_task = self.registry[AdditionalTask.name]
+        self.assertRaises(beat.SchedulingError,
+                          self.scheduler.apply_async,
+                          self.scheduler[add_task.name])
+
+    def test_is_due(self):
+        due = self.scheduler[DuePeriodicTask.name]
+        pending = self.scheduler[PendingPeriodicTask.name]
+
+        self.assertTrue(self.scheduler.is_due(due)[0])
+        self.assertFalse(self.scheduler.is_due(pending)[0])
+
+
+    def test_tick(self):
+        self.scheduler.schedule.pop(DuePeriodicTaskRaising.name, None)
+        self.registry.pop(DuePeriodicTaskRaising.name, None)
+        self.assertEquals(self.scheduler.tick(),
+                            self.scheduler.max_interval)
+
+    def test_quick_schedulingerror(self):
+        self.registry.register(DuePeriodicTaskRaising)
+        self.scheduler.schedule_registry()
+        self.assertEquals(self.scheduler.tick(),
+                            self.scheduler.max_interval)
+
+
+class TestClockService(unittest.TestCase):
+
+    def test_start(self):
+        s = beat.ClockService()
+        sh = MockShelve()
+        s.open_schedule = lambda *a, **kw: sh
+
+        self.assertTrue(isinstance(s.schedule, dict))
+        self.assertTrue(isinstance(s.schedule, dict))
+        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
+        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
+
+        self.assertTrue(s.schedule is sh)
+        self.assertTrue(s._schedule is sh)
+
+        s._in_sync = False
+        s.sync()
+        self.assertTrue(sh.closed)
+        self.assertTrue(sh.synced)
+        self.assertTrue(s._stopped.isSet())
+        s.sync()
+
+        s.stop(wait=False)
+        self.assertTrue(s._shutdown.isSet())
+        s.stop(wait=True)
+        self.assertTrue(s._shutdown.isSet())
+
+
+
+
+class TestClockServiceThread(unittest.TestCase):
+
+    def test_start_stop(self):
+        s = beat.ClockServiceThread()
+        self.assertTrue(isinstance(s.clockservice, beat.ClockService))
+        s.clockservice = MockClockService()
+
+        s.run()
+        self.assertTrue(s.clockservice.started)
+
+        s.stop()
+        self.assertTrue(s.clockservice.stopped)
+

+ 27 - 12
celery/tests/test_result.py

@@ -1,9 +1,10 @@
 import unittest
 import unittest
-from celery.backends import default_backend
-from celery.result import AsyncResult
-from celery.result import TaskSetResult
-from celery.result import TimeoutError
+
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
+from celery.tests.utils import skip_if_quick
+from celery.result import AsyncResult, TaskSetResult
+from celery.backends import default_backend
+from celery.exceptions import TimeoutError
 
 
 
 
 def mock_task(name, status, result):
 def mock_task(name, status, result):
@@ -13,6 +14,8 @@ def mock_task(name, status, result):
 def save_result(task):
 def save_result(task):
     if task["status"] == "SUCCESS":
     if task["status"] == "SUCCESS":
         default_backend.mark_as_done(task["id"], task["result"])
         default_backend.mark_as_done(task["id"], task["result"])
+    elif task["status"] == "RETRY":
+        default_backend.mark_as_retry(task["id"], task["result"])
     else:
     else:
         default_backend.mark_as_failure(task["id"], task["result"])
         default_backend.mark_as_failure(task["id"], task["result"])
 
 
@@ -29,23 +32,19 @@ class TestAsyncResult(unittest.TestCase):
         self.task1 = mock_task("task1", "SUCCESS", "the")
         self.task1 = mock_task("task1", "SUCCESS", "the")
         self.task2 = mock_task("task2", "SUCCESS", "quick")
         self.task2 = mock_task("task2", "SUCCESS", "quick")
         self.task3 = mock_task("task3", "FAILURE", KeyError("brown"))
         self.task3 = mock_task("task3", "FAILURE", KeyError("brown"))
+        self.task4 = mock_task("task3", "RETRY", KeyError("red"))
 
 
-        for task in (self.task1, self.task2, self.task3):
+        for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(task)
             save_result(task)
 
 
     def test_successful(self):
     def test_successful(self):
         ok_res = AsyncResult(self.task1["id"])
         ok_res = AsyncResult(self.task1["id"])
         nok_res = AsyncResult(self.task3["id"])
         nok_res = AsyncResult(self.task3["id"])
+        nok_res2 = AsyncResult(self.task4["id"])
 
 
         self.assertTrue(ok_res.successful())
         self.assertTrue(ok_res.successful())
         self.assertFalse(nok_res.successful())
         self.assertFalse(nok_res.successful())
-
-    def test_sucessful(self):
-        ok_res = AsyncResult(self.task1["id"])
-        nok_res = AsyncResult(self.task3["id"])
-
-        self.assertTrue(ok_res.successful())
-        self.assertFalse(nok_res.successful())
+        self.assertFalse(nok_res2.successful())
 
 
     def test_str(self):
     def test_str(self):
         ok_res = AsyncResult(self.task1["id"])
         ok_res = AsyncResult(self.task1["id"])
@@ -70,16 +69,28 @@ class TestAsyncResult(unittest.TestCase):
         ok_res = AsyncResult(self.task1["id"])
         ok_res = AsyncResult(self.task1["id"])
         ok2_res = AsyncResult(self.task2["id"])
         ok2_res = AsyncResult(self.task2["id"])
         nok_res = AsyncResult(self.task3["id"])
         nok_res = AsyncResult(self.task3["id"])
+        nok2_res = AsyncResult(self.task4["id"])
 
 
         self.assertEquals(ok_res.get(), "the")
         self.assertEquals(ok_res.get(), "the")
         self.assertEquals(ok2_res.get(), "quick")
         self.assertEquals(ok2_res.get(), "quick")
         self.assertRaises(KeyError, nok_res.get)
         self.assertRaises(KeyError, nok_res.get)
+        self.assertTrue(isinstance(nok2_res.result, KeyError))
+
+    def test_get_timeout(self):
+        res = AsyncResult(self.task4["id"]) # has RETRY status
+        self.assertRaises(TimeoutError, res.get, timeout=0.1)
+
+    @skip_if_quick
+    def test_get_timeout_longer(self):
+        res = AsyncResult(self.task4["id"]) # has RETRY status
+        self.assertRaises(TimeoutError, res.get, timeout=1)
 
 
     def test_ready(self):
     def test_ready(self):
         oks = (AsyncResult(self.task1["id"]),
         oks = (AsyncResult(self.task1["id"]),
                AsyncResult(self.task2["id"]),
                AsyncResult(self.task2["id"]),
                AsyncResult(self.task3["id"]))
                AsyncResult(self.task3["id"]))
         [self.assertTrue(ok.ready()) for ok in oks]
         [self.assertTrue(ok.ready()) for ok in oks]
+        self.assertFalse(AsyncResult(self.task4["id"]).ready())
 
 
 
 
 class TestTaskSetResult(unittest.TestCase):
 class TestTaskSetResult(unittest.TestCase):
@@ -192,3 +203,7 @@ class TestTaskSetPending(unittest.TestCase):
 
 
     def x_join(self):
     def x_join(self):
         self.assertRaises(TimeoutError, self.ts.join, timeout=0.001)
         self.assertRaises(TimeoutError, self.ts.join, timeout=0.001)
+
+    @skip_if_quick
+    def x_join_longer(self):
+        self.assertRaises(TimeoutError, self.ts.join, timeout=1)

+ 4 - 0
celery/tests/utils.py

@@ -25,6 +25,10 @@ def skip_if_environ(env_var_name):
     return _wrap_test
     return _wrap_test
 
 
 
 
+def skip_if_quick(fun):
+    return skip_if_environ("QUICKTEST")(fun)
+
+
 def _skip_test(reason, sign):
 def _skip_test(reason, sign):
 
 
     def _wrap_test(fun):
     def _wrap_test(fun):

+ 1 - 1
setup.py

@@ -46,7 +46,7 @@ class RunTests(Command):
 
 
 class QuickRunTests(RunTests):
 class QuickRunTests(RunTests):
 
 
-    quicktest_envs = dict(SKIP_RLIMITS=1)
+    quicktest_envs = dict(SKIP_RLIMITS=1, QUICKTEST=1)
 
 
     def run(self):
     def run(self):
         for env_name, env_value in self.quicktest_envs.items():
         for env_name, env_value in self.quicktest_envs.items():

+ 8 - 1
testproj/settings.py

@@ -20,11 +20,18 @@ TEST_RUNNER = "celery.tests.runners.run_tests"
 TEST_APPS = (
 TEST_APPS = (
     "celery",
     "celery",
 )
 )
-COVERAGE_EXCLUDE_MODULES = ("celery.tests.*",
+COVERAGE_EXCLUDE_MODULES = ("celery.__init__",
+                            "celery.conf",
+                            "celery.tests.*",
                             "celery.management.*",
                             "celery.management.*",
                             "celery.contrib.*",
                             "celery.contrib.*",
                             "celery.bin.*",
                             "celery.bin.*",
                             "celery.utils.patch",
                             "celery.utils.patch",
+                            "celery.task.rest",
+                            "celery.platform", # FIXME
+                            "celery.loaders.default", # FIXME
+                            "celery.backends.mongodb", # FIXME
+                            "celery.backends.tyrant", # FIXME
                             "celery.task.strategy")
                             "celery.task.strategy")
 COVERAGE_HTML_REPORT = True
 COVERAGE_HTML_REPORT = True
 COVERAGE_BRANCH_COVERAGE = True
 COVERAGE_BRANCH_COVERAGE = True