Browse Source

94% total coverage

Ask Solem 15 years ago
parent
commit
0dbdc5de4d

+ 36 - 8
celery/backends/cache.py

@@ -6,20 +6,40 @@ from celery import conf
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.utils import timeutils
+from celery.datastructures import LocalCache
 
-try:
-    import pylibmc as memcache
-except ImportError:
+
+
+def get_best_memcache(*args, **kwargs):
     try:
-        import memcache
+        import pylibmc as memcache
     except ImportError:
-        raise ImproperlyConfigured("Memcached backend requires either "
-                                   "the 'memcache' or 'pylibmc' library")
+        try:
+            import memcache
+        except ImportError:
+            raise ImproperlyConfigured("Memcached backend requires either "
+                                       "the 'memcache' or 'pylibmc' library")
+    return memcache.Client(*args, **kwargs)
 
 
-class CacheBackend(KeyValueStoreBackend):
-    Client = memcache.Client
+class DummyClient(object):
+
+    def __init__(self, *args, **kwargs):
+        self.cache = LocalCache(5000)
+
+    def get(self, key, *args, **kwargs):
+        return self.cache.get(key)
 
+    def set(self, key, value, *args, **kwargs):
+        self.cache[key] = value
+
+
+backends = {"memcache": get_best_memcache,
+            "memcached": get_best_memcache,
+            "pylibmc": get_best_memcache,
+            "memory": DummyClient}
+
+class CacheBackend(KeyValueStoreBackend):
     _client = None
 
     def __init__(self, expires=conf.TASK_RESULT_EXPIRES,
@@ -31,6 +51,14 @@ class CacheBackend(KeyValueStoreBackend):
         self.options = dict(conf.CACHE_BACKEND_OPTIONS, **options)
         self.backend, _, servers = partition(backend, "://")
         self.servers = servers.split(";")
+        try:
+            self.Client = backends[self.backend]
+        except KeyError:
+            raise ImproperlyConfigured(
+                    "Unknown cache backend: %s. Please use one of the "
+                    "following backends: %s" % (self.backend,
+                                                ", ".join(backends.keys())))
+
 
     def get(self, key):
         return self.client.get(key)

+ 15 - 16
celery/backends/database.py

@@ -44,18 +44,6 @@ class DatabaseBackend(BaseDictBackend):
             session.close()
         return result
 
-    def _save_taskset(self, taskset_id, result):
-        """Store the result of an executed taskset."""
-        taskset = TaskSet(taskset_id, result)
-        session = self.ResultSession()
-        try:
-            session.add(taskset)
-            session.flush()
-            session.commit()
-        finally:
-            session.close()
-        return result
-
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         session = self.ResultSession()
@@ -66,8 +54,19 @@ class DatabaseBackend(BaseDictBackend):
                 session.add(task)
                 session.flush()
                 session.commit()
-            if task:
-                return task.to_dict()
+            return task.to_dict()
+        finally:
+            session.close()
+
+    def _save_taskset(self, taskset_id, result):
+        """Store the result of an executed taskset."""
+        session = self.ResultSession()
+        try:
+            taskset = TaskSet(taskset_id, result)
+            session.add(taskset)
+            session.flush()
+            session.commit()
+            return result
         finally:
             session.close()
 
@@ -75,8 +74,8 @@ class DatabaseBackend(BaseDictBackend):
         """Get taskset metadata for a taskset by id."""
         session = self.ResultSession()
         try:
-            qs = session.query(TaskSet)
-            taskset = qs.filter(TaskSet.taskset_id == taskset_id).first()
+            taskset = session.query(TaskSet).filter(
+                    TaskSet.taskset_id == taskset_id).first()
             if taskset:
                 return taskset.to_dict()
         finally:

+ 7 - 18
celery/db/models.py

@@ -26,22 +26,14 @@ class Task(ResultModelBase):
     def __init__(self, task_id):
         self.task_id = task_id
 
-    def __str__(self):
-        return "<Task(%s, %s, %s, %s)>" % (self.task_id,
-                                           self.result,
-                                           self.status,
-                                           self.traceback)
-
     def to_dict(self):
         return {"task_id": self.task_id,
                 "status": self.status,
                 "result": self.result,
-                "date_done": self.date_done,
                 "traceback": self.traceback}
 
-    def __unicode__(self):
-        return u"<Task: %s successful: %s>" % (self.task_id, self.status)
-
+    def __repr__(self):
+        return "<Task %s state: %s>" % (self.task_id, self.status)
 
 class TaskSet(ResultModelBase):
     """TaskSet result"""
@@ -55,16 +47,13 @@ class TaskSet(ResultModelBase):
     date_done = sa.Column(sa.DateTime, default=datetime.now,
                        nullable=True)
 
-    def __init__(self, task_id):
-        self.task_id = task_id
-
-    def __str__(self):
-        return "<TaskSet(%s, %s)>" % (self.task_id, self.result)
+    def __init__(self, taskset_id, result):
+        self.taskset_id = taskset_id
+        self.result = result
 
     def to_dict(self):
         return {"taskset_id": self.taskset_id,
-                "result": self.result,
-                "date_done": self.date_done}
+                "result": self.result}
 
-    def __unicode__(self):
+    def __repr__(self):
         return u"<TaskSet: %s>" % (self.taskset_id, )

+ 2 - 12
celery/events/state.py

@@ -4,29 +4,19 @@ import heapq
 from carrot.utils import partition
 
 from celery import states
-from celery.datastructures import LocalCache
+from celery.datastructures import AttributeDict, LocalCache
 from celery.utils import kwdict
 
 HEARTBEAT_EXPIRE = 150 # 2 minutes, 30 seconds
 
 
-class Element(dict):
+class Element(AttributeDict):
     """Base class for types."""
     visited = False
 
     def __init__(self, **fields):
         dict.__init__(self, fields)
 
-    def __getattr__(self, key):
-        try:
-            return self[key]
-        except KeyError:
-            raise AttributeError("'%s' object has no attribute '%s'" % (
-                    self.__class__.__name__, key))
-
-    def __setattr__(self, key, value):
-        self[key] = value
-
 
 class Worker(Element):
     """Worker State."""

+ 131 - 0
celery/tests/test_backends/test_cache.py

@@ -0,0 +1,131 @@
+import sys
+import types
+import unittest2 as unittest
+
+from celery import states
+from celery.backends.cache import CacheBackend, get_best_memcache, DummyClient
+from celery.exceptions import ImproperlyConfigured
+from celery.utils import gen_unique_id
+
+from celery.tests.utils import execute_context, mask_modules
+
+
+class SomeClass(object):
+
+    def __init__(self, data):
+        self.data = data
+
+
+class test_CacheBackend(unittest.TestCase):
+
+    def test_mark_as_done(self):
+        tb = CacheBackend(backend="memory://")
+
+        tid = gen_unique_id()
+
+        self.assertEqual(tb.get_status(tid), states.PENDING)
+        self.assertIsNone(tb.get_result(tid))
+
+        tb.mark_as_done(tid, 42)
+        self.assertEqual(tb.get_status(tid), states.SUCCESS)
+        self.assertEqual(tb.get_result(tid), 42)
+
+    def test_is_pickled(self):
+        tb = CacheBackend(backend="memory://")
+
+        tid2 = gen_unique_id()
+        result = {"foo": "baz", "bar": SomeClass(12345)}
+        tb.mark_as_done(tid2, result)
+        # is serialized properly.
+        rindb = tb.get_result(tid2)
+        self.assertEqual(rindb.get("foo"), "baz")
+        self.assertEqual(rindb.get("bar").data, 12345)
+
+    def test_mark_as_failure(self):
+        tb = CacheBackend(backend="memory://")
+
+        tid3 = gen_unique_id()
+        try:
+            raise KeyError("foo")
+        except KeyError, exception:
+            pass
+        tb.mark_as_failure(tid3, exception)
+        self.assertEqual(tb.get_status(tid3), states.FAILURE)
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
+
+    def test_process_cleanup(self):
+        tb = CacheBackend(backend="memory://")
+        tb.process_cleanup()
+
+    def test_expires_as_int(self):
+        tb = CacheBackend(backend="memory://", expires=10)
+        self.assertEqual(tb.expires, 10)
+
+    def test_unknown_backend_raises_ImproperlyConfigured(self):
+        self.assertRaises(ImproperlyConfigured,
+                          CacheBackend, backend="unknown://")
+
+
+
+class test_get_best_memcache(unittest.TestCase):
+
+    def mock_memcache(self):
+        memcache = types.ModuleType("memcache")
+        memcache.Client = DummyClient
+        memcache.Client.__module__ = memcache.__name__
+        prev, sys.modules["memcache"] = sys.modules.get("memcache"), memcache
+        yield
+        if prev is not None:
+            sys.modules["memcache"] = prev
+        yield
+
+    def mock_pylibmc(self):
+        pylibmc = types.ModuleType("pylibmc")
+        pylibmc.Client = DummyClient
+        pylibmc.Client.__module__ = pylibmc.__name__
+        prev = sys.modules.get("pylibmc")
+        sys.modules["pylibmc"] = pylibmc
+        yield
+        if prev is not None:
+            sys.modules["pylibmc"] = prev
+        yield
+
+    def test_pylibmc(self):
+        pylibmc = self.mock_pylibmc()
+        pylibmc.next()
+        import __builtin__
+        sys.modules.pop("celery.backends.cache", None)
+        from celery.backends import cache
+        self.assertEqual(cache.get_best_memcache().__module__, "pylibmc")
+        pylibmc.next()
+
+    def test_memcache(self):
+
+        def with_no_pylibmc():
+            sys.modules.pop("celery.backends.cache", None)
+            from celery.backends import cache
+            self.assertEqual(cache.get_best_memcache().__module__, "memcache")
+
+        context = mask_modules("pylibmc")
+        context.__enter__()
+        try:
+            memcache = self.mock_memcache()
+            memcache.next()
+            with_no_pylibmc()
+            memcache.next()
+        finally:
+            context.__exit__(None, None, None)
+
+    def test_no_implementations(self):
+
+        def with_no_memcache_libs():
+            sys.modules.pop("celery.backends.cache", None)
+            from celery.backends import cache
+            self.assertRaises(ImproperlyConfigured, cache.get_best_memcache)
+
+        context = mask_modules("pylibmc", "memcache")
+        context.__enter__()
+        try:
+            with_no_memcache_libs()
+        finally:
+            context.__exit__(None, None, None)

+ 52 - 0
celery/tests/test_backends/test_database.py

@@ -2,9 +2,13 @@ import sys
 import socket
 import unittest2 as unittest
 
+from datetime import datetime
+
 from celery.exceptions import ImproperlyConfigured
 
+from celery import conf
 from celery import states
+from celery.db.models import Task, TaskSet
 from celery.utils import gen_unique_id
 from celery.backends.database import DatabaseBackend
 
@@ -19,6 +23,17 @@ class SomeClass(object):
 
 class test_DatabaseBackend(unittest.TestCase):
 
+    def test_missing_dburi_raises_ImproperlyConfigured(self):
+        prev, conf.RESULT_DBURI = conf.RESULT_DBURI, None
+        try:
+            self.assertRaises(ImproperlyConfigured, DatabaseBackend)
+        finally:
+            conf.RESULT_DBURI = prev
+
+    def test_missing_task_id_is_PENDING(self):
+        tb = DatabaseBackend()
+        self.assertEqual(tb.get_status("xxx-does-not-exist"), states.PENDING)
+
     def test_mark_as_done(self):
         tb = DatabaseBackend()
 
@@ -84,3 +99,40 @@ class test_DatabaseBackend(unittest.TestCase):
     def test_process_cleanup(self):
         tb = DatabaseBackend()
         tb.process_cleanup()
+
+    def test_save___restore_taskset(self):
+        tb = DatabaseBackend()
+
+        tid = gen_unique_id()
+        res = {u"something": "special"}
+        self.assertEqual(tb.save_taskset(tid, res), res)
+
+        res2 = tb.restore_taskset(tid)
+        self.assertEqual(res2, res)
+
+        self.assertIsNone(tb.restore_taskset("xxx-nonexisting-id"))
+
+    def test_cleanup(self):
+        tb = DatabaseBackend()
+        for i in range(10):
+            tb.mark_as_done(gen_unique_id(), 42)
+            tb.save_taskset(gen_unique_id(), {"foo": "bar"})
+        s = tb.ResultSession()
+        for t in s.query(Task).all():
+            t.date_done = datetime.now() - tb.result_expires * 2
+        for t in s.query(TaskSet).all():
+            t.date_done = datetime.now() - tb.result_expires * 2
+        s.commit()
+        s.close()
+
+        tb.cleanup()
+        s2 = tb.ResultSession()
+        self.assertEqual(s2.query(Task).count(), 0)
+        self.assertEqual(s2.query(TaskSet).count(), 0)
+
+    def test_Task__repr__(self):
+        self.assertIn("foo", repr(Task("foo")))
+
+    def test_TaskSet__repr__(self):
+        self.assertIn("foo", repr(TaskSet("foo", None)))
+

+ 86 - 8
celery/tests/test_buckets.py

@@ -1,18 +1,18 @@
 from __future__ import generators
+
 import os
 import sys
-sys.path.insert(0, os.getcwd())
 import time
 import unittest2 as unittest
-from itertools import chain, izip
 
+from itertools import chain, izip
 
+from celery.registry import TaskRegistry
 from celery.task.base import Task
 from celery.utils import timeutils
 from celery.utils import gen_unique_id
 from celery.utils.functional import curry
 from celery.worker import buckets
-from celery.registry import TaskRegistry
 
 from celery.tests.utils import skip_if_environ
 
@@ -41,7 +41,7 @@ class MockJob(object):
                 self.task_name, self.task_id, self.args, self.kwargs)
 
 
-class TestTokenBucketQueue(unittest.TestCase):
+class test_TokenBucketQueue(unittest.TestCase):
 
     @skip_if_disabled
     def empty_queue_yields_QueueEmpty(self):
@@ -94,7 +94,8 @@ class TestTokenBucketQueue(unittest.TestCase):
         self.assertEqual(x.get_nowait(), "The quick brown fox")
 
 
-class TestRateLimitString(unittest.TestCase):
+
+class test_rate_limit_string(unittest.TestCase):
 
     @skip_if_disabled
     def test_conversion(self):
@@ -125,7 +126,7 @@ class TaskD(Task):
     rate_limit = "1000/m"
 
 
-class TestTaskBuckets(unittest.TestCase):
+class test_TaskBucket(unittest.TestCase):
 
     def setUp(self):
         self.registry = TaskRegistry()
@@ -133,6 +134,44 @@ class TestTaskBuckets(unittest.TestCase):
         for task_cls in self.task_classes:
             self.registry.register(task_cls)
 
+    @skip_if_disabled
+    def test_get_nowait(self):
+        x = buckets.TaskBucket(task_registry=self.registry)
+        self.assertRaises(buckets.QueueEmpty, x.get_nowait)
+
+    @skip_if_disabled
+    def test_refresh(self):
+        reg = {}
+        x = buckets.TaskBucket(task_registry=reg)
+        reg["foo"] = "something"
+        x.refresh()
+        self.assertIn("foo", x.buckets)
+        self.assertTrue(x.get_bucket_for_type("foo"))
+
+    @skip_if_disabled
+    def test__get_queue_for_type(self):
+        x = buckets.TaskBucket(task_registry={})
+        x.buckets["foo"] = buckets.TokenBucketQueue(fill_rate=1)
+        self.assertIs(x._get_queue_for_type("foo"), x.buckets["foo"].queue)
+        x.buckets["bar"] = buckets.FastQueue()
+        self.assertIs(x._get_queue_for_type("bar"), x.buckets["bar"])
+
+    @skip_if_disabled
+    def test_update_bucket_for_type(self):
+        bucket = buckets.TaskBucket(task_registry=self.registry)
+        b = bucket._get_queue_for_type(TaskC.name)
+        self.assertIs(bucket.update_bucket_for_type(TaskC.name).queue, b)
+        self.assertIs(bucket.buckets[TaskC.name].queue, b)
+
+    @skip_if_disabled
+    def test_auto_add_on_missing_put(self):
+        reg = {}
+        b = buckets.TaskBucket(task_registry=reg)
+        reg["nonexisting.task"] = "foo"
+
+        b.put(MockJob(gen_unique_id(), "nonexisting.task", (), {}))
+        self.assertIn("nonexisting.task", b.buckets)
+
     @skip_if_disabled
     def test_auto_add_on_missing(self):
         b = buckets.TaskBucket(task_registry=self.registry)
@@ -227,5 +266,44 @@ class TestTaskBuckets(unittest.TestCase):
         finally:
             self.registry.unregister(TaskD)
 
-if __name__ == "__main__":
-    unittest.main()
+    @skip_if_disabled
+    def test_empty(self):
+        x = buckets.TaskBucket(task_registry=self.registry)
+        self.assertTrue(x.empty())
+        x.put(MockJob(gen_unique_id(), TaskC.name, [], {}))
+        self.assertFalse(x.empty())
+        x.clear()
+        self.assertTrue(x.empty())
+
+    @skip_if_disabled
+    def test_items(self):
+        x = buckets.TaskBucket(task_registry=self.registry)
+        x.buckets[TaskA.name].put(1)
+        x.buckets[TaskB.name].put(2)
+        x.buckets[TaskC.name].put(3)
+        self.assertItemsEqual(x.items, [1, 2, 3])
+
+class test_FastQueue(unittest.TestCase):
+
+    def test_can_consume(self):
+        x = buckets.FastQueue()
+        self.assertTrue(x.can_consume())
+
+    def test_items(self):
+        x = buckets.FastQueue()
+        x.put(10)
+        x.put(20)
+        self.assertListEqual([10, 20], list(x.items))
+
+    def test_wait(self):
+        x = buckets.FastQueue()
+        x.put(10)
+        self.assertEqual(x.wait(), 10)
+
+    def test_clear(self):
+        x = buckets.FastQueue()
+        x.put(10)
+        x.put(20)
+        self.assertFalse(x.empty())
+        x.clear()
+        self.assertTrue(x.empty())

+ 19 - 7
celery/tests/test_datastructures.py

@@ -4,9 +4,10 @@ from Queue import Queue
 
 from celery.datastructures import PositionQueue, ExceptionInfo, LocalCache
 from celery.datastructures import LimitedSet, SharedCounter, consume_queue
+from celery.datastructures import AttributeDict
 
 
-class TestPositionQueue(unittest.TestCase):
+class test_PositionQueue(unittest.TestCase):
 
     def test_position_queue_unfilled(self):
         q = PositionQueue(length=10)
@@ -36,7 +37,7 @@ class TestPositionQueue(unittest.TestCase):
         self.assertTrue(q.full())
 
 
-class TestExceptionInfo(unittest.TestCase):
+class test_ExceptionInfo(unittest.TestCase):
 
     def test_exception_info(self):
 
@@ -56,7 +57,7 @@ class TestExceptionInfo(unittest.TestCase):
         self.assertTrue(r)
 
 
-class TestUtilities(unittest.TestCase):
+class test_utilities(unittest.TestCase):
 
     def test_consume_queue(self):
         x = Queue()
@@ -68,7 +69,7 @@ class TestUtilities(unittest.TestCase):
         self.assertRaises(StopIteration, it.next)
 
 
-class TestSharedCounter(unittest.TestCase):
+class test_SharedCounter(unittest.TestCase):
 
     def test_initial_value(self):
         self.assertEqual(int(SharedCounter(10)), 10)
@@ -101,7 +102,7 @@ class TestSharedCounter(unittest.TestCase):
         self.assertIn("<SharedCounter:", repr(SharedCounter(10)))
 
 
-class TestLimitedSet(unittest.TestCase):
+class test_LimitedSet(unittest.TestCase):
 
     def test_add(self):
         s = LimitedSet(maxlen=2)
@@ -118,7 +119,7 @@ class TestLimitedSet(unittest.TestCase):
         s = LimitedSet(maxlen=2)
         items = "foo", "bar"
         map(s.add, items)
-        l = list(iter(items))
+        l = list(iter(s))
         for item in items:
             self.assertIn(item, l)
 
@@ -129,7 +130,7 @@ class TestLimitedSet(unittest.TestCase):
         self.assertIn("LimitedSet(", repr(s))
 
 
-class TestLocalCache(unittest.TestCase):
+class test_LocalCache(unittest.TestCase):
 
     def test_expires(self):
         limit = 100
@@ -138,3 +139,14 @@ class TestLocalCache(unittest.TestCase):
         for i in slots:
             x[i] = i
         self.assertListEqual(x.keys(), slots[limit:])
+
+
+class test_AttributeDict(unittest.TestCase):
+
+    def test_getattr__setattr(self):
+        x = AttributeDict({"foo": "bar"})
+        self.assertEqual(x["foo"], "bar")
+        self.assertRaises(AttributeError, getattr, x, "bar")
+        x.bar = "foo"
+        self.assertEqual(x["bar"], "foo")
+

+ 21 - 0
celery/tests/test_routes.py

@@ -62,6 +62,10 @@ class test_MapRoute(unittest.TestCase):
 
 class test_lookup_route(unittest.TestCase):
 
+    def test_init_queues(self):
+        router = routes.Router(queues=None)
+        self.assertDictEqual(router.queues, {})
+
     @with_queues(foo=a_queue, bar=b_queue)
     def test_lookup_takes_first(self):
         R = routes.prepare(({"celery.ping": "bar"},
@@ -80,3 +84,20 @@ class test_lookup_route(unittest.TestCase):
                 router.route({}, "celery.ping",
                     args=[1, 2], kwargs={}))
         self.assertEqual(router.route({}, "celery.poza"), {})
+
+
+class test_prepare(unittest.TestCase):
+
+    def test_prepare(self):
+        from celery.datastructures import LocalCache
+        o = object()
+        R = [{"foo": "bar"},
+                  "celery.datastructures.LocalCache",
+                  o]
+        p = routes.prepare(R)
+        self.assertIsInstance(p[0], routes.MapRoute)
+        self.assertIsInstance(p[1], LocalCache)
+        self.assertIs(p[2], o)
+
+        self.assertEqual(routes.prepare(o), [o])
+

+ 1 - 1
celery/tests/test_worker_job.py

@@ -43,7 +43,7 @@ def mytask(i, **kwargs):
     return i ** i
 
 
-@task_dec()
+@task_dec # traverses coverage for decorator without parens
 def mytask_no_kwargs(i):
     return i ** i
 

+ 1 - 0
celery/worker/buckets.py

@@ -142,6 +142,7 @@ class TaskBucket(object):
         task_type = self.task_registry[task_name]
         rate_limit = getattr(task_type, "rate_limit", None)
         rate_limit = timeutils.rate(rate_limit)
+        task_queue = FastQueue()
         if task_name in self.buckets:
             task_queue = self._get_queue_for_type(task_name)
         else:

+ 1 - 0
setup.cfg

@@ -8,6 +8,7 @@ cover3-exclude = celery
                  celery.tests.*
                  celery.bin.celerybeat
                  celery.bin.celeryev
+                 celery.task
                  celery.platform
                  celery.utils.patch
                  celery.utils.compat