Browse Source

Some more unittests.

Ask Solem 15 years ago
parent
commit
cfa225cb0d

+ 43 - 0
celery/datastructures.py

@@ -3,6 +3,7 @@
 Custom Datastructures
 Custom Datastructures
 
 
 """
 """
+import time
 import traceback
 import traceback
 from UserList import UserList
 from UserList import UserList
 from Queue import Queue, Empty as QueueEmpty
 from Queue import Queue, Empty as QueueEmpty
@@ -145,3 +146,45 @@ class SharedCounter(object):
 
 
     def __repr__(self):
     def __repr__(self):
         return "<SharedCounter: int(%s)>" % str(int(self))
         return "<SharedCounter: int(%s)>" % str(int(self))
+
+
+class LimitedSet(object):
+
+    def __init__(self, maxlen=None, expires=None):
+        self.maxlen = maxlen
+        self.expires = expires
+        self._data = {}
+
+    def add(self, value):
+        self._expire_item()
+        self._data[value] = time.time()
+
+    def pop_value(self, value):
+        self._data.pop(value, None)
+
+    def _expire_item(self):
+        while 1:
+            if self.maxlen and len(self) >= self.maxlen:
+                value, when = self.oldest
+                if not self.expires or time.time() > when + self.expires:
+                    try:
+                        self.pop_value(value)
+                    except TypeError: # pragma: no cover
+                        continue
+            break
+
+    def __contains__(self, value):
+        return value in self._data
+
+    def __iter__(self):
+        return iter(self._data.keys())
+
+    def __repr__(self):
+        return "LimitedSet([%s])" % (repr(self._data.keys()))
+
+    def __len__(self):
+        return len(self._data.keys())
+
+    @property
+    def oldest(self):
+        return sorted(self._data.items(), key=lambda (value, when): when)[0]

+ 81 - 0
celery/tests/test_datastructures.py

@@ -1,7 +1,10 @@
 import unittest
 import unittest
 import sys
 import sys
+from Queue import Queue
 
 
 from celery.datastructures import PositionQueue, ExceptionInfo
 from celery.datastructures import PositionQueue, ExceptionInfo
+from celery.datastructures import LimitedSet, consume_queue
+from celery.datastructures import SharedCounter
 
 
 
 
 class TestPositionQueue(unittest.TestCase):
 class TestPositionQueue(unittest.TestCase):
@@ -49,3 +52,81 @@ class TestExceptionInfo(unittest.TestCase):
         self.assertEquals(einfo.exception.args,
         self.assertEquals(einfo.exception.args,
                 ("The quick brown fox jumps...", ))
                 ("The quick brown fox jumps...", ))
         self.assertTrue(einfo.traceback)
         self.assertTrue(einfo.traceback)
+
+        r = repr(einfo)
+        self.assertTrue(r)
+
+
+class TestUtilities(unittest.TestCase):
+
+    def test_consume_queue(self):
+        x = Queue()
+        it = consume_queue(x)
+        self.assertRaises(StopIteration, it.next)
+        x.put("foo")
+        it = consume_queue(x)
+        self.assertEquals(it.next(), "foo")
+        self.assertRaises(StopIteration, it.next)
+
+
+class TestSharedCounter(unittest.TestCase):
+
+    def test_initial_value(self):
+        self.assertEquals(int(SharedCounter(10)), 10)
+
+    def test_increment(self):
+        c = SharedCounter(10)
+        c.increment()
+        self.assertEquals(int(c), 11)
+        c.increment(2)
+        self.assertEquals(int(c), 13)
+
+    def test_decrement(self):
+        c = SharedCounter(10)
+        c.decrement()
+        self.assertEquals(int(c), 9)
+        c.decrement(2)
+        self.assertEquals(int(c), 7)
+
+    def test_iadd(self):
+        c = SharedCounter(10)
+        c += 10
+        self.assertEquals(int(c), 20)
+
+    def test_isub(self):
+        c = SharedCounter(10)
+        c -= 20
+        self.assertEquals(int(c), -10)
+
+    def test_repr(self):
+        self.assertTrue(repr(SharedCounter(10)).startswith("<SharedCounter:"))
+
+
+class TestLimitedSet(unittest.TestCase):
+
+    def test_add(self):
+        s = LimitedSet(maxlen=2)
+        s.add("foo")
+        s.add("bar")
+        for n in "foo", "bar":
+            self.assertTrue(n in s)
+        s.add("baz")
+        for n in "bar", "baz":
+            self.assertTrue(n in s)
+        self.assertTrue("foo" not in s)
+
+    def test_iter(self):
+        s = LimitedSet(maxlen=2)
+        items = "foo", "bar"
+        map(s.add, items)
+        l = list(iter(items))
+        for item in items:
+            self.assertTrue(item in l)
+
+    def test_repr(self):
+        s = LimitedSet(maxlen=2)
+        items = "foo", "bar"
+        map(s.add, items)
+        self.assertTrue(repr(s).startswith("LimitedSet("))
+
+

+ 25 - 1
celery/tests/test_worker.py

@@ -13,6 +13,7 @@ from celery.worker import CarrotListener, WorkController
 from celery.worker.job import TaskWrapper
 from celery.worker.job import TaskWrapper
 from celery.worker.scheduler import Scheduler
 from celery.worker.scheduler import Scheduler
 from celery.decorators import task as task_dec
 from celery.decorators import task as task_dec
+from celery.decorators import periodic_task as periodic_task_dec
 
 
 
 
 class MockEventDispatcher(object):
 class MockEventDispatcher(object):
@@ -29,6 +30,11 @@ def foo_task(x, y, z, **kwargs):
     return x * y * z
     return x * y * z
 
 
 
 
+@periodic_task_dec()
+def foo_periodic_task():
+    return "foo"
+
+
 class MockLogger(object):
 class MockLogger(object):
 
 
     def critical(self, *args, **kwargs):
     def critical(self, *args, **kwargs):
@@ -86,7 +92,7 @@ class MockController(object):
 
 
 
 
 def create_message(backend, **data):
 def create_message(backend, **data):
-    data["id"] = gen_unique_id()
+    data.setdefault("id", gen_unique_id())
     return BaseMessage(backend, body=pickle.dumps(dict(**data)),
     return BaseMessage(backend, body=pickle.dumps(dict(**data)),
                        content_type="application/x-python-serialize",
                        content_type="application/x-python-serialize",
                        content_encoding="binary")
                        content_encoding="binary")
@@ -134,6 +140,24 @@ class TestCarrotListener(unittest.TestCase):
         self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
         self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_scheduler.empty())
         self.assertTrue(self.eta_scheduler.empty())
 
 
+    def test_revoke(self):
+        ready_queue = Queue()
+        l = CarrotListener(ready_queue, self.eta_scheduler, self.logger,
+                           send_events=False)
+        backend = MockBackend()
+        id = gen_unique_id()
+        c = create_message(backend, control={"command": "revoke", "task_id": id})
+        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+                           kwargs={}, id=id)
+        l.event_dispatcher = MockEventDispatcher()
+        l.receive_message(c.decode(), c)
+        from celery.worker.revoke import revoked
+        self.assertTrue(id in revoked)
+
+        l.receive_message(t.decode(), t)
+        self.assertTrue(ready_queue.empty())
+
+
     def test_receieve_message_not_registered(self):
     def test_receieve_message_not_registered(self):
         l = CarrotListener(self.ready_queue, self.eta_scheduler, self.logger,
         l = CarrotListener(self.ready_queue, self.eta_scheduler, self.logger,
                           send_events=False)
                           send_events=False)

+ 19 - 5
celery/tests/test_worker_job.py

@@ -1,19 +1,22 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 import sys
 import sys
 import unittest
 import unittest
+import simplejson
+import logging
+from StringIO import StringIO
+
+from carrot.backends.base import BaseMessage
+from django.core import cache
+
 from celery.worker.job import WorkerTaskTrace, TaskWrapper
 from celery.worker.job import WorkerTaskTrace, TaskWrapper
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
 from celery.models import TaskMeta
 from celery.models import TaskMeta
 from celery.registry import tasks, NotRegistered
 from celery.registry import tasks, NotRegistered
 from celery.worker.pool import TaskPool
 from celery.worker.pool import TaskPool
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
-from carrot.backends.base import BaseMessage
-from StringIO import StringIO
 from celery.log import setup_logger
 from celery.log import setup_logger
-from django.core import cache
 from celery.decorators import task as task_dec
 from celery.decorators import task as task_dec
-import simplejson
-import logging
+from celery.exceptions import RetryTaskError
 
 
 scratch = {"ACK": False}
 scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
 some_kwargs_scratchpad = {}
@@ -55,6 +58,17 @@ def get_db_connection(i, **kwargs):
 get_db_connection.ignore_result = True
 get_db_connection.ignore_result = True
 
 
 
 
+class TestRetryTaskError(unittest.TestCase):
+
+    def test_retry_task_error(self):
+        try:
+            raise Exception("foo")
+        except Exception, exc:
+            ret = RetryTaskError("Retrying task", exc)
+
+        self.assertEquals(ret.exc, exc)
+
+
 class TestJail(unittest.TestCase):
 class TestJail(unittest.TestCase):
 
 
     def test_execute_jail_success(self):
     def test_execute_jail_success(self):

+ 14 - 0
celery/tests/test_worker_revoke.py

@@ -0,0 +1,14 @@
+import unittest
+from Queue import Queue, Empty
+from datetime import datetime, timedelta
+
+from celery.worker import revoke
+
+
+class TestRevokeRegistry(unittest.TestCase):
+
+    def test_is_working(self):
+        revoke.revoked.add("foo")
+        self.assertTrue("foo" in revoke.revoked)
+        revoke.revoked.pop_value("foo")
+        self.assertTrue("foo" not in revoke.revoked)

+ 48 - 0
celery/tests/test_worker_scheduler.py

@@ -0,0 +1,48 @@
+import unittest
+from Queue import Queue, Empty
+from datetime import datetime, timedelta
+
+from celery.worker.scheduler import Scheduler
+
+
+class TestScheduler(unittest.TestCase):
+
+    def test_sched_and_run_now(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+        now = datetime.now()
+
+        callback_called = [False]
+        def callback():
+            callback_called[0] = True
+
+        sched.enter("foo", eta=now, callback=callback)
+
+        remaining = iter(sched).next()
+        self.assertEquals(remaining, 0)
+        self.assertTrue(callback_called[0])
+        self.assertEquals(ready_queue.get_nowait(), "foo")
+
+    def test_sched_run_later(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+        now = datetime.now()
+
+        callback_called = [False]
+        def callback():
+            callback_called[0] = True
+
+        eta = now + timedelta(seconds=10)
+        sched.enter("foo", eta=eta, callback=callback)
+
+        remaining = iter(sched).next()
+        self.assertTrue(remaining > 7)
+        self.assertFalse(callback_called[0])
+        self.assertRaises(Empty, ready_queue.get_nowait)
+
+    def test_empty_queue_yields_None(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+
+        self.assertTrue(iter(sched).next() is None)
+

+ 2 - 2
celery/worker/pool.py

@@ -52,7 +52,7 @@ class TaskPool(object):
     def replace_dead_workers(self):
     def replace_dead_workers(self):
         self.logger.debug("TaskPool: Finding dead pool processes...")
         self.logger.debug("TaskPool: Finding dead pool processes...")
         dead_count = self._pool.replace_dead_workers()
         dead_count = self._pool.replace_dead_workers()
-        if dead_count:
+        if dead_count: # pragma: no cover
             self.logger.info(
             self.logger.info(
                 "TaskPool: Replaced %d dead pool workers..." % (
                 "TaskPool: Replaced %d dead pool workers..." % (
                     dead_count))
                     dead_count))
@@ -88,7 +88,7 @@ class TaskPool(object):
 
 
         if isinstance(ret_value, ExceptionInfo):
         if isinstance(ret_value, ExceptionInfo):
             if isinstance(ret_value.exception, (
             if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
+                    SystemExit, KeyboardInterrupt)): # pragma: no cover
                 raise ret_value.exception
                 raise ret_value.exception
             [errback(ret_value) for errback in errbacks]
             [errback(ret_value) for errback in errbacks]
         else:
         else:

+ 2 - 36
celery/worker/revoke.py

@@ -1,40 +1,6 @@
-import time
-from UserDict import UserDict
-
-from carrot.connection import DjangoBrokerConnection
-
-from celery.messaging import BroadcastPublisher
-from celery.utils import noop
+from celery.datastructures import LimitedSet
 
 
 REVOKES_MAX = 1000
 REVOKES_MAX = 1000
 REVOKE_EXPIRES = 60 * 60 # one hour.
 REVOKE_EXPIRES = 60 * 60 # one hour.
 
 
-
-class RevokeRegistry(UserDict):
-
-    def __init__(self, maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES):
-        self.maxlen = maxlen
-        self.expires = expires
-        self.data = {}
-
-    def add(self, uuid):
-        self._expire_item()
-        self[uuid] = time.time()
-
-    def _expire_item(self):
-        while 1:
-            if len(self) > self.maxlen:
-                uuid, when = self.oldest
-                if time.time() > when + self.expires:
-                    try:
-                        self.pop(uuid, None)
-                    except TypeError:
-                        continue
-            break
-
-    @property
-    def oldest(self):
-        return sorted(self.items(), key=lambda (uuid, when): when)[0]
-
-
-revoked = RevokeRegistry()
+revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)

+ 1 - 3
celery/worker/scheduler.py

@@ -30,10 +30,8 @@ class Scheduler(object):
                     yield eta - now
                     yield eta - now
                 else:
                 else:
                     event = pop(q)
                     event = pop(q)
-                    print("eta->%s priority->%s item->%s" % (
-                        eta, priority, item))
 
 
-                    if event is verify:
+                    if event is verify: # pragma: no cover
                         ready_queue.put(item)
                         ready_queue.put(item)
                         callback and callback()
                         callback and callback()
                         yield 0
                         yield 0

+ 11 - 1
testproj/settings.py

@@ -20,6 +20,16 @@ TEST_RUNNER = "celery.tests.runners.run_tests"
 TEST_APPS = (
 TEST_APPS = (
     "celery",
     "celery",
 )
 )
+COVERAGE_EXCLUDE_MODULES = ("celery.tests.*",
+                            "celery.management.*",
+                            "celery.contrib.*",
+                            "celery.bin.*",
+                            "celery.patch",
+                            "celery.urls",
+                            "celery.views",
+                            "celery.task.strategy")
+COVERAGE_HTML_REPORT = True
+COVERAGE_BRANCH_COVERAGE = True
 
 
 BROKER_HOST = "localhost"
 BROKER_HOST = "localhost"
 BROKER_PORT = 5672
 BROKER_PORT = 5672
@@ -59,6 +69,6 @@ except ImportError:
     pass
     pass
 else:
 else:
     pass
     pass
-    #INSTALLED_APPS += ("test_extensions", )
+    INSTALLED_APPS += ("test_extensions", )
 
 
 SEND_CELERY_TASK_ERROR_EMAILS = False
 SEND_CELERY_TASK_ERROR_EMAILS = False