Browse Source

89% coverage

Ask Solem 15 năm trước cách đây
mục cha
commit
87960ce141

+ 4 - 2
Makefile

@@ -26,10 +26,12 @@ readme: clean_readme
 bump:
 	contrib/bump -c celery
 
-coverage:
+cover:
 	(cd testproj; python manage.py test --coverage)
 
-quickcoverage:
+coverage: cover
+
+quickcover:
 	(cd testproj; env QUICKTEST=1 SKIP_RLIMITS=1 python manage.py test --coverage)
 
 test:

+ 0 - 4
celery/task/base.py

@@ -168,10 +168,6 @@ class Task(object):
 
     MaxRetriesExceededError = MaxRetriesExceededError
 
-    def __init__(self):
-        if not self.__class__.name:
-            self.__class__.name = get_full_cls_name(self.__class__)
-
     def __call__(self, *args, **kwargs):
         return self.run(*args, **kwargs)
 

+ 109 - 0
celery/tests/test_task.py

@@ -8,6 +8,7 @@ from celery.result import EagerResult
 from celery.backends import default_backend
 from celery.decorators import task as task_dec
 from celery.worker.listener import parse_iso8601
+from celery.exceptions import RetryTaskError
 
 def return_True(*args, **kwargs):
     # Task run functions can't be closures/lambdas, as they're pickled.
@@ -21,6 +22,16 @@ def raise_exception(self, **kwargs):
     raise Exception("%s error" % self.__class__)
 
 
+class MockApplyTask(task.Task):
+
+    def run(self, x, y):
+        return x * y
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        pass
+
+
 class IncrementCounterTask(task.Task):
     name = "c.unittest.increment_counter_task"
     count = 0
@@ -53,6 +64,27 @@ class RetryTask(task.Task):
             return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
 
 
+class RetryTaskMockApply(task.Task):
+    max_retries = 3
+    iterations = 0
+    applied = 0
+
+    def run(self, arg1, arg2, kwarg=1, **kwargs):
+        self.__class__.iterations += 1
+
+        retries = kwargs["task_retries"]
+        if retries >= 3:
+            return arg1
+        else:
+            kwargs.update({"kwarg": kwarg})
+            return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
+
+
+    @classmethod
+    def apply_async(self, *args, **kwargs):
+        self.applied = 1
+
+
 class MyCustomException(Exception):
     """Random custom exception."""
 
@@ -85,6 +117,23 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEquals(result.get(), 0xFF)
         self.assertEquals(RetryTask.iterations, 4)
 
+    def test_retry_not_eager(self):
+        exc = Exception("baz")
+        try:
+            RetryTaskMockApply.retry(args=[4, 4], kwargs={},
+                                     exc=exc, throw=False)
+            self.assertTrue(RetryTaskMockApply.applied)
+        finally:
+            RetryTaskMockApply.applied = 0
+
+        try:
+            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
+                    args=[4, 4], kwargs={}, exc=exc, throw=True)
+            self.assertTrue(RetryTaskMockApply.applied)
+        finally:
+            RetryTaskMockApply.applied = 0
+
+
     def test_retry_with_kwargs(self):
         RetryTaskCustomExc.max_retries = 3
         RetryTaskCustomExc.iterations = 0
@@ -116,6 +165,12 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEquals(RetryTask.iterations, 2)
 
 
+class MockPublisher(object):
+
+    def __init__(self, *args, **kwargs):
+        self.kwargs = kwargs
+
+
 class TestCeleryTasks(unittest.TestCase):
 
     def createTaskCls(self, cls_name, task_name=None):
@@ -232,6 +287,18 @@ class TestCeleryTasks(unittest.TestCase):
         publisher = t1.get_publisher()
         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 
+    def test_get_publisher(self):
+        from celery.task import base
+        old_pub = base.TaskPublisher
+        base.TaskPublisher = MockPublisher
+        try:
+            p = IncrementCounterTask.get_publisher(exchange="foo",
+                                                   connection="bar")
+            self.assertEquals(p.kwargs["exchange"], "foo")
+        finally:
+            base.TaskPublisher = old_pub
+
+
     def test_get_logger(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
         t1 = T1()
@@ -308,3 +375,45 @@ class TestTaskApply(unittest.TestCase):
         self.assertFalse(f.successful())
         self.assertTrue(f.traceback)
         self.assertRaises(KeyError, f.get)
+
+
+class MyPeriodic(task.PeriodicTask):
+    run_every = timedelta(hours=1)
+
+
+
+class TestPeriodicTask(unittest.TestCase):
+
+    def test_must_have_run_every(self):
+        self.assertRaises(NotImplementedError, type, "Foo",
+            (task.PeriodicTask, ), {"__module__": __name__})
+
+    def test_remaining_estimate(self):
+        self.assertTrue(isinstance(
+            MyPeriodic().remaining_estimate(datetime.now()),
+            timedelta))
+
+    def test_timedelta_seconds_returns_0_on_negative_time(self):
+        delta = timedelta(days=-2)
+        self.assertEquals(MyPeriodic().timedelta_seconds(delta), 0)
+
+    def test_timedelta_seconds(self):
+        deltamap = ((timedelta(seconds=1), 1),
+                    (timedelta(seconds=27), 27),
+                    (timedelta(minutes=3), 3 * 60),
+                    (timedelta(hours=4), 4 * 60 * 60),
+                    (timedelta(days=3), 3 * 86400))
+        for delta, seconds in deltamap:
+            self.assertEquals(MyPeriodic().timedelta_seconds(delta), seconds)
+
+    def test_is_due_not_due(self):
+        due, remaining = MyPeriodic().is_due(datetime.now())
+        self.assertFalse(due)
+        self.assertTrue(remaining > 60)
+
+    def test_is_due(self):
+        p = MyPeriodic()
+        due, remaining = p.is_due(datetime.now() - p.run_every)
+        self.assertTrue(due)
+        self.assertEquals(remaining, p.timedelta_seconds(p.run_every))
+

+ 94 - 4
celery/tests/test_utils.py

@@ -1,5 +1,10 @@
+import sys
+import socket
 import unittest
-from celery.utils import chunks
+
+from billiard.utils.functional import wraps
+
+from celery import utils
 
 
 class TestChunks(unittest.TestCase):
@@ -7,16 +12,101 @@ class TestChunks(unittest.TestCase):
     def test_chunks(self):
 
         # n == 2
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
         self.assertEquals(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
 
         # n == 3
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
         self.assertEquals(list(x),
             [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
 
         # n == 2 (exact)
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
+        x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
         self.assertEquals(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
+
+
+class TestGenUniqueId(unittest.TestCase):
+
+    def test_gen_unique_id_without_ctypes(self):
+        from celery.tests.utils import mask_modules
+        old_utils = sys.modules.pop("celery.utils")
+        try:
+            with mask_modules("ctypes"):
+                from celery.utils import ctypes, gen_unique_id
+                self.assertTrue(ctypes is None)
+                uuid = gen_unique_id()
+                self.assertTrue(uuid)
+                self.assertTrue(isinstance(uuid, basestring))
+        finally:
+            sys.modules["celery.utils"] = old_utils
+
+
+class TestDivUtils(unittest.TestCase):
+
+    def test_repeatlast(self):
+        items = range(6)
+        it = utils.repeatlast(items)
+        for i in items:
+            self.assertEquals(it.next(), i)
+        for j in items:
+            self.assertEquals(it.next(), i)
+
+
+def sleepdeprived(fun):
+
+    @wraps(fun)
+    def _sleepdeprived(*args, **kwargs):
+        import time
+        old_sleep = time.sleep
+        time.sleep = utils.noop
+        try:
+            return fun(*args, **kwargs)
+        finally:
+            time.sleep = old_sleep
+
+    return _sleepdeprived
+
+
+class TestRetryOverTime(unittest.TestCase):
+
+    def test_returns_retval_on_success(self):
+
+        def _fun(x, y):
+            return x * y
+
+        ret = utils.retry_over_time(_fun, (socket.error, ), args=[16, 16],
+                                    max_retries=3)
+
+        self.assertEquals(ret, 256)
+
+    @sleepdeprived
+    def test_raises_on_unlisted_exception(self):
+
+        def _fun(x, y):
+            raise KeyError("bar")
+
+        self.assertRaises(KeyError, utils.retry_over_time, _fun,
+                         (socket.error, ), args=[32, 32], max_retries=3)
+
+
+    @sleepdeprived
+    def test_retries_on_failure(self):
+
+        iterations = [0]
+
+        def _fun(x, y):
+            iterations[0] += 1
+            if iterations[0] == 3:
+                return x * y
+            raise socket.error("foozbaz")
+
+        ret = utils.retry_over_time(_fun, (socket.error, ), args=[32, 32],
+                                    max_retries=None)
+
+        self.assertEquals(iterations[0], 3)
+        self.assertEquals(ret, 1024)
+
+        self.assertRaises(socket.error, utils.retry_over_time,
+                        _fun, (socket.error, ), args=[32, 32], max_retries=1)

+ 2 - 6
celery/utils/__init__.py

@@ -9,11 +9,7 @@ try:
     import ctypes
 except ImportError:
     ctypes = None
-from uuid import UUID, uuid4
-try:
-    from uuid import _uuid_generate_random
-except ImportError:
-    _uuid_generate_random = None
+from uuid import UUID, uuid4, _uuid_generate_random
 from inspect import getargspec
 from itertools import repeat
 
@@ -87,7 +83,7 @@ def repeatlast(it):
     yield the last value infinitely."""
     for item in it:
         yield item
-    for item in repeat(item):
+    for item in repeat(item): # pragma: no cover
         yield item
 
 

+ 9 - 7
celery/worker/listener.py

@@ -150,17 +150,19 @@ class CarrotListener(object):
             return
         self._state = CLOSE
 
-        self.logger.debug("Heart: Going into cardiac arrest...")
-        self.heart = self.heart and self.heart.stop()
+        if self.heart:
+            self.logger.debug("Heart: Going into cardiac arrest...")
+            self.heart = self.heart.stop()
 
         self.logger.debug("TaskConsumer: Shutting down...")
         self.task_consumer = self.task_consumer and self.task_consumer.close()
 
-        self.logger.debug("EventDispatcher: Shutting down...")
-        self.event_dispatcher = self.event_dispatcher and \
-                                    self.event_dispatcher.close()
-        self.logger.debug(
-                "CarrotListener: Closing connection to broker...")
+        if self.event_dispatcher:
+            self.logger.debug("EventDispatcher: Shutting down...")
+            self.event_dispatcher = self.event_dispatcher.close()
+
+        self.logger.debug("CarrotListener: "
+                          "Closing connection to broker...")
         self.connection = self.connection and self.connection.close()
 
     def reset_connection(self):

+ 4 - 2
examples/pythonproject/demoapp/celeryconfig.py

@@ -7,6 +7,8 @@ DATABASE_NAME = "celery.db"
 BROKER_HOST = "localhost"
 BROKER_USER = "guest"
 BROKER_PASSWORD = "guest"
-BROKER_VHOST = "/"
-CELERY_BACKEND = "amqp"
+BROKER_VHOST = "celery"
+CELERY_DEFAULT_EXCHANGE = "celery"
+CARROT_BACKEND = "ghettoq.taproot.Redis"
+CELERY_BACKEND = "database"
 CELERY_IMPORTS = ("tasks", )