Ask Solem преди 15 години
родител
ревизия
cfa225cb0d

+ 43 - 0
celery/datastructures.py

@@ -3,6 +3,7 @@
 Custom Datastructures
 
 """
+import time
 import traceback
 from UserList import UserList
 from Queue import Queue, Empty as QueueEmpty
@@ -145,3 +146,45 @@ class SharedCounter(object):
 
     def __repr__(self):
         return "<SharedCounter: int(%s)>" % str(int(self))
+
+
+class LimitedSet(object):
+
+    def __init__(self, maxlen=None, expires=None):
+        self.maxlen = maxlen
+        self.expires = expires
+        self._data = {}
+
+    def add(self, value):
+        self._expire_item()
+        self._data[value] = time.time()
+
+    def pop_value(self, value):
+        self._data.pop(value, None)
+
+    def _expire_item(self):
+        while 1:
+            if self.maxlen and len(self) >= self.maxlen:
+                value, when = self.oldest
+                if not self.expires or time.time() > when + self.expires:
+                    try:
+                        self.pop_value(value)
+                    except TypeError: # pragma: no cover
+                        continue
+            break
+
+    def __contains__(self, value):
+        return value in self._data
+
+    def __iter__(self):
+        return iter(self._data.keys())
+
+    def __repr__(self):
+        return "LimitedSet([%s])" % (repr(self._data.keys()))
+
+    def __len__(self):
+        return len(self._data.keys())
+
+    @property
+    def oldest(self):
+        return sorted(self._data.items(), key=lambda (value, when): when)[0]

+ 81 - 0
celery/tests/test_datastructures.py

@@ -1,7 +1,10 @@
 import unittest
 import sys
+from Queue import Queue
 
 from celery.datastructures import PositionQueue, ExceptionInfo
+from celery.datastructures import LimitedSet, consume_queue
+from celery.datastructures import SharedCounter
 
 
 class TestPositionQueue(unittest.TestCase):
@@ -49,3 +52,81 @@ class TestExceptionInfo(unittest.TestCase):
         self.assertEquals(einfo.exception.args,
                 ("The quick brown fox jumps...", ))
         self.assertTrue(einfo.traceback)
+
+        r = repr(einfo)
+        self.assertTrue(r)
+
+
+class TestUtilities(unittest.TestCase):
+
+    def test_consume_queue(self):
+        x = Queue()
+        it = consume_queue(x)
+        self.assertRaises(StopIteration, it.next)
+        x.put("foo")
+        it = consume_queue(x)
+        self.assertEquals(it.next(), "foo")
+        self.assertRaises(StopIteration, it.next)
+
+
+class TestSharedCounter(unittest.TestCase):
+
+    def test_initial_value(self):
+        self.assertEquals(int(SharedCounter(10)), 10)
+
+    def test_increment(self):
+        c = SharedCounter(10)
+        c.increment()
+        self.assertEquals(int(c), 11)
+        c.increment(2)
+        self.assertEquals(int(c), 13)
+
+    def test_decrement(self):
+        c = SharedCounter(10)
+        c.decrement()
+        self.assertEquals(int(c), 9)
+        c.decrement(2)
+        self.assertEquals(int(c), 7)
+
+    def test_iadd(self):
+        c = SharedCounter(10)
+        c += 10
+        self.assertEquals(int(c), 20)
+
+    def test_isub(self):
+        c = SharedCounter(10)
+        c -= 20
+        self.assertEquals(int(c), -10)
+
+    def test_repr(self):
+        self.assertTrue(repr(SharedCounter(10)).startswith("<SharedCounter:"))
+
+
+class TestLimitedSet(unittest.TestCase):
+
+    def test_add(self):
+        s = LimitedSet(maxlen=2)
+        s.add("foo")
+        s.add("bar")
+        for n in "foo", "bar":
+            self.assertTrue(n in s)
+        s.add("baz")
+        for n in "bar", "baz":
+            self.assertTrue(n in s)
+        self.assertTrue("foo" not in s)
+
+    def test_iter(self):
+        s = LimitedSet(maxlen=2)
+        items = "foo", "bar"
+        map(s.add, items)
+        l = list(iter(items))
+        for item in items:
+            self.assertTrue(item in l)
+
+    def test_repr(self):
+        s = LimitedSet(maxlen=2)
+        items = "foo", "bar"
+        map(s.add, items)
+        self.assertTrue(repr(s).startswith("LimitedSet("))
+
+

+ 25 - 1
celery/tests/test_worker.py

@@ -13,6 +13,7 @@ from celery.worker import CarrotListener, WorkController
 from celery.worker.job import TaskWrapper
 from celery.worker.scheduler import Scheduler
 from celery.decorators import task as task_dec
+from celery.decorators import periodic_task as periodic_task_dec
 
 
 class MockEventDispatcher(object):
@@ -29,6 +30,11 @@ def foo_task(x, y, z, **kwargs):
     return x * y * z
 
 
+@periodic_task_dec()
+def foo_periodic_task():
+    return "foo"
+
+
 class MockLogger(object):
 
     def critical(self, *args, **kwargs):
@@ -86,7 +92,7 @@ class MockController(object):
 
 
 def create_message(backend, **data):
-    data["id"] = gen_unique_id()
+    data.setdefault("id", gen_unique_id())
     return BaseMessage(backend, body=pickle.dumps(dict(**data)),
                        content_type="application/x-python-serialize",
                        content_encoding="binary")
@@ -134,6 +140,24 @@ class TestCarrotListener(unittest.TestCase):
         self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_scheduler.empty())
 
+    def test_revoke(self):
+        ready_queue = Queue()
+        l = CarrotListener(ready_queue, self.eta_scheduler, self.logger,
+                           send_events=False)
+        backend = MockBackend()
+        id = gen_unique_id()
+        c = create_message(backend, control={"command": "revoke", "task_id": id})
+        t = create_message(backend, task=foo_task.name, args=[2, 4, 8],
+                           kwargs={}, id=id)
+        l.event_dispatcher = MockEventDispatcher()
+        l.receive_message(c.decode(), c)
+        from celery.worker.revoke import revoked
+        self.assertTrue(id in revoked)
+
+        l.receive_message(t.decode(), t)
+        self.assertTrue(ready_queue.empty())
+
+
     def test_receieve_message_not_registered(self):
         l = CarrotListener(self.ready_queue, self.eta_scheduler, self.logger,
                           send_events=False)

+ 19 - 5
celery/tests/test_worker_job.py

@@ -1,19 +1,22 @@
 # -*- coding: utf-8 -*-
 import sys
 import unittest
+import simplejson
+import logging
+from StringIO import StringIO
+
+from carrot.backends.base import BaseMessage
+from django.core import cache
+
 from celery.worker.job import WorkerTaskTrace, TaskWrapper
 from celery.datastructures import ExceptionInfo
 from celery.models import TaskMeta
 from celery.registry import tasks, NotRegistered
 from celery.worker.pool import TaskPool
 from celery.utils import gen_unique_id
-from carrot.backends.base import BaseMessage
-from StringIO import StringIO
 from celery.log import setup_logger
-from django.core import cache
 from celery.decorators import task as task_dec
-import simplejson
-import logging
+from celery.exceptions import RetryTaskError
 
 scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
@@ -55,6 +58,17 @@ def get_db_connection(i, **kwargs):
 get_db_connection.ignore_result = True
 
 
+class TestRetryTaskError(unittest.TestCase):
+
+    def test_retry_task_error(self):
+        try:
+            raise Exception("foo")
+        except Exception, exc:
+            ret = RetryTaskError("Retrying task", exc)
+
+        self.assertEquals(ret.exc, exc)
+
+
 class TestJail(unittest.TestCase):
 
     def test_execute_jail_success(self):

+ 14 - 0
celery/tests/test_worker_revoke.py

@@ -0,0 +1,14 @@
+import unittest
+from Queue import Queue, Empty
+from datetime import datetime, timedelta
+
+from celery.worker import revoke
+
+
+class TestRevokeRegistry(unittest.TestCase):
+
+    def test_is_working(self):
+        revoke.revoked.add("foo")
+        self.assertTrue("foo" in revoke.revoked)
+        revoke.revoked.pop_value("foo")
+        self.assertTrue("foo" not in revoke.revoked)

+ 48 - 0
celery/tests/test_worker_scheduler.py

@@ -0,0 +1,48 @@
+import unittest
+from Queue import Queue, Empty
+from datetime import datetime, timedelta
+
+from celery.worker.scheduler import Scheduler
+
+
+class TestScheduler(unittest.TestCase):
+
+    def test_sched_and_run_now(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+        now = datetime.now()
+
+        callback_called = [False]
+        def callback():
+            callback_called[0] = True
+
+        sched.enter("foo", eta=now, callback=callback)
+
+        remaining = iter(sched).next()
+        self.assertEquals(remaining, 0)
+        self.assertTrue(callback_called[0])
+        self.assertEquals(ready_queue.get_nowait(), "foo")
+
+    def test_sched_run_later(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+        now = datetime.now()
+
+        callback_called = [False]
+        def callback():
+            callback_called[0] = True
+
+        eta = now + timedelta(seconds=10)
+        sched.enter("foo", eta=eta, callback=callback)
+
+        remaining = iter(sched).next()
+        self.assertTrue(remaining > 7)
+        self.assertFalse(callback_called[0])
+        self.assertRaises(Empty, ready_queue.get_nowait)
+
+    def test_empty_queue_yields_None(self):
+        ready_queue = Queue()
+        sched = Scheduler(ready_queue)
+
+        self.assertTrue(iter(sched).next() is None)
+

+ 2 - 2
celery/worker/pool.py

@@ -52,7 +52,7 @@ class TaskPool(object):
     def replace_dead_workers(self):
         self.logger.debug("TaskPool: Finding dead pool processes...")
         dead_count = self._pool.replace_dead_workers()
-        if dead_count:
+        if dead_count: # pragma: no cover
             self.logger.info(
                 "TaskPool: Replaced %d dead pool workers..." % (
                     dead_count))
@@ -88,7 +88,7 @@ class TaskPool(object):
 
         if isinstance(ret_value, ExceptionInfo):
             if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
+                    SystemExit, KeyboardInterrupt)): # pragma: no cover
                 raise ret_value.exception
             [errback(ret_value) for errback in errbacks]
         else:

+ 2 - 36
celery/worker/revoke.py

@@ -1,40 +1,6 @@
-import time
-from UserDict import UserDict
-
-from carrot.connection import DjangoBrokerConnection
-
-from celery.messaging import BroadcastPublisher
-from celery.utils import noop
+from celery.datastructures import LimitedSet
 
 REVOKES_MAX = 1000
 REVOKE_EXPIRES = 60 * 60 # one hour.
 
-
-class RevokeRegistry(UserDict):
-
-    def __init__(self, maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES):
-        self.maxlen = maxlen
-        self.expires = expires
-        self.data = {}
-
-    def add(self, uuid):
-        self._expire_item()
-        self[uuid] = time.time()
-
-    def _expire_item(self):
-        while 1:
-            if len(self) > self.maxlen:
-                uuid, when = self.oldest
-                if time.time() > when + self.expires:
-                    try:
-                        self.pop(uuid, None)
-                    except TypeError:
-                        continue
-            break
-
-    @property
-    def oldest(self):
-        return sorted(self.items(), key=lambda (uuid, when): when)[0]
-
-
-revoked = RevokeRegistry()
+revoked = LimitedSet(maxlen=REVOKES_MAX, expires=REVOKE_EXPIRES)

+ 1 - 3
celery/worker/scheduler.py

@@ -30,10 +30,8 @@ class Scheduler(object):
                     yield eta - now
                 else:
                     event = pop(q)
-                    print("eta->%s priority->%s item->%s" % (
-                        eta, priority, item))
 
-                    if event is verify:
+                    if event is verify: # pragma: no cover
                         ready_queue.put(item)
                         callback and callback()
                         yield 0

+ 11 - 1
testproj/settings.py

@@ -20,6 +20,16 @@ TEST_RUNNER = "celery.tests.runners.run_tests"
 TEST_APPS = (
     "celery",
 )
+COVERAGE_EXCLUDE_MODULES = ("celery.tests.*",
+                            "celery.management.*",
+                            "celery.contrib.*",
+                            "celery.bin.*",
+                            "celery.patch",
+                            "celery.urls",
+                            "celery.views",
+                            "celery.task.strategy")
+COVERAGE_HTML_REPORT = True
+COVERAGE_BRANCH_COVERAGE = True
 
 BROKER_HOST = "localhost"
 BROKER_PORT = 5672
@@ -59,6 +69,6 @@ except ImportError:
     pass
 else:
     pass
-    #INSTALLED_APPS += ("test_extensions", )
+    INSTALLED_APPS += ("test_extensions", )
 
 SEND_CELERY_TASK_ERROR_EMAILS = False