Browse Source

More unit tests. Coverage up to 83%

Ask Solem 16 years ago
parent
commit
1f467a13ff

+ 2 - 1
celery/execute.py

@@ -1,5 +1,5 @@
 from carrot.connection import DjangoAMQPConnection
 from carrot.connection import DjangoAMQPConnection
-from celery.conf import AMQP_CONNECTION_TIMEOUT, ALWAYS_EAGER
+from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.result import AsyncResult, EagerResult
 from celery.result import AsyncResult, EagerResult
 from celery.messaging import TaskPublisher
 from celery.messaging import TaskPublisher
 from celery.registry import tasks
 from celery.registry import tasks
@@ -62,6 +62,7 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     if countdown:
     if countdown:
         eta = datetime.now() + timedelta(seconds=countdown)
         eta = datetime.now() + timedelta(seconds=countdown)
 
 
+    from celery.conf import ALWAYS_EAGER
     if ALWAYS_EAGER:
     if ALWAYS_EAGER:
         return apply(task, args, kwargs)
         return apply(task, args, kwargs)
 
 

+ 2 - 1
celery/task/__init__.py

@@ -9,7 +9,8 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.backends import default_backend
 from celery.task.base import Task, TaskSet, PeriodicTask
 from celery.task.base import Task, TaskSet, PeriodicTask
-from celery.task.builtins import AsynchronousMapTask, ExecuteRemoteTask
+from celery.task.base import ExecuteRemoteTask
+from celery.task.base import AsynchronousMapTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
 from celery.execute import apply_async, delay_task
 from celery.execute import apply_async, delay_task
 from celery.utils import pickle
 from celery.utils import pickle

+ 46 - 10
celery/task/base.py

@@ -6,10 +6,10 @@ from celery.result import TaskSetResult
 from celery.execute import apply_async, delay_task, apply
 from celery.execute import apply_async, delay_task, apply
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 from datetime import timedelta
 from datetime import timedelta
+from celery.registry import tasks
 from celery.utils import pickle
 from celery.utils import pickle
 
 
 
 
-
 class Task(object):
 class Task(object):
     """A task that can be delayed for execution by the ``celery`` daemon.
     """A task that can be delayed for execution by the ``celery`` daemon.
 
 
@@ -216,6 +216,44 @@ class Task(object):
         return apply(cls, args, kwargs, **options)
         return apply(cls, args, kwargs, **options)
 
 
 
 
+class ExecuteRemoteTask(Task):
+    """Execute an arbitrary function or object.
+
+    *Note* You probably want :func:`execute_remote` instead, which this
+    is an internal component of.
+
+    The object must be pickleable, so you can't use lambdas or functions
+    defined in the REPL (that is the python shell, or ``ipython``).
+
+    """
+    name = "celery.execute_remote"
+
+    def run(self, ser_callable, fargs, fkwargs, **kwargs):
+        """
+        :param ser_callable: A pickled function or callable object.
+
+        :param fargs: Positional arguments to apply to the function.
+
+        :param fkwargs: Keyword arguments to apply to the function.
+
+        """
+        callable_ = pickle.loads(ser_callable)
+        return callable_(*fargs, **fkwargs)
+tasks.register(ExecuteRemoteTask)
+
+
+class AsynchronousMapTask(Task):
+    """Task used internally by :func:`dmap_async` and
+    :meth:`TaskSet.map_async`.  """
+    name = "celery.map_async"
+
+    def run(self, serfunc, args, **kwargs):
+        """The method run by ``celeryd``."""
+        timeout = kwargs.get("timeout")
+        return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
+tasks.register(AsynchronousMapTask)
+
+
 class TaskSet(object):
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks has been completed.
     to track how many, or when all of the tasks has been completed.
@@ -297,6 +335,13 @@ class TaskSet(object):
 
 
         """
         """
         taskset_id = gen_unique_id()
         taskset_id = gen_unique_id()
+
+        from celery.conf import ALWAYS_EAGER
+        if ALWAYS_EAGER:
+            subtasks = [apply(self.task, args, kwargs)
+                            for args, kwargs in self.arguments]
+            return TaskSetResult(taskset_id, subtasks)
+
         conn = DjangoAMQPConnection(connect_timeout=connect_timeout)
         conn = DjangoAMQPConnection(connect_timeout=connect_timeout)
         publisher = TaskPublisher(connection=conn)
         publisher = TaskPublisher(connection=conn)
         subtasks = [apply_async(self.task, args, kwargs,
         subtasks = [apply_async(self.task, args, kwargs,
@@ -306,15 +351,6 @@ class TaskSet(object):
         conn.close()
         conn.close()
         return TaskSetResult(taskset_id, subtasks)
         return TaskSetResult(taskset_id, subtasks)
 
 
-    def iterate(self):
-        """Iterate over the results returned after calling :meth:`run`.
-
-        If any of the tasks raises an exception, the exception will
-        be re-raised.
-
-        """
-        return iter(self.run())
-
     def join(self, timeout=None):
     def join(self, timeout=None):
         """Gather the results for all of the tasks in the taskset,
         """Gather the results for all of the tasks in the taskset,
         and return a list with them ordered by the order of which they
         and return a list with them ordered by the order of which they

+ 0 - 38
celery/task/builtins.py

@@ -5,44 +5,6 @@ from datetime import timedelta
 from celery.utils import pickle
 from celery.utils import pickle
 
 
 
 
-class AsynchronousMapTask(Task):
-    """Task used internally by :func:`dmap_async` and
-    :meth:`TaskSet.map_async`.  """
-    name = "celery.map_async"
-
-    def run(self, serfunc, args, **kwargs):
-        """The method run by ``celeryd``."""
-        timeout = kwargs.get("timeout")
-        return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
-tasks.register(AsynchronousMapTask)
-
-
-class ExecuteRemoteTask(Task):
-    """Execute an arbitrary function or object.
-
-    *Note* You probably want :func:`execute_remote` instead, which this
-    is an internal component of.
-
-    The object must be pickleable, so you can't use lambdas or functions
-    defined in the REPL (that is the python shell, or ``ipython``).
-
-    """
-    name = "celery.execute_remote"
-
-    def run(self, ser_callable, fargs, fkwargs, **kwargs):
-        """
-        :param ser_callable: A pickled function or callable object.
-
-        :param fargs: Positional arguments to apply to the function.
-
-        :param fkwargs: Keyword arguments to apply to the function.
-
-        """
-        callable_ = pickle.loads(ser_callable)
-        return callable_(*fargs, **fkwargs)
-tasks.register(ExecuteRemoteTask)
-
-
 class DeleteExpiredTaskMetaTask(PeriodicTask):
 class DeleteExpiredTaskMetaTask(PeriodicTask):
     """A periodic task that deletes expired task metadata every day.
     """A periodic task that deletes expired task metadata every day.
 
 

+ 26 - 1
celery/tests/test_backends/test_base.py

@@ -1,6 +1,6 @@
 import unittest
 import unittest
 from celery.backends.base import find_nearest_pickleable_exception as fnpe
 from celery.backends.base import find_nearest_pickleable_exception as fnpe
-from celery.backends.base import BaseBackend
+from celery.backends.base import BaseBackend, KeyValueStoreBackend
 from celery.backends.base import UnpickleableExceptionWrapper
 from celery.backends.base import UnpickleableExceptionWrapper
 from django.db.models.base import subclass_exception
 from django.db.models.base import subclass_exception
 
 
@@ -16,6 +16,17 @@ Lookalike = subclass_exception("Lookalike", wrapobject, "foo.module")
 b = BaseBackend()
 b = BaseBackend()
 
 
 
 
+class TestBaseBackendInterface(unittest.TestCase):
+
+    def test_get_status(self):
+        self.assertRaises(NotImplementedError,
+                b.is_done, "SOMExx-N0Nex1stant-IDxx-")
+
+    def test_store_result(self):
+        self.assertRaises(NotImplementedError,
+                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, "DONE")
+
+
 class TestPickleException(unittest.TestCase):
 class TestPickleException(unittest.TestCase):
 
 
     def test_BaseException(self):
     def test_BaseException(self):
@@ -46,3 +57,17 @@ class TestPrepareException(unittest.TestCase):
         self.assertTrue(isinstance(x, KeyError))
         self.assertTrue(isinstance(x, KeyError))
         y = b.exception_to_python(x)
         y = b.exception_to_python(x)
         self.assertTrue(isinstance(y, KeyError))
         self.assertTrue(isinstance(y, KeyError))
+
+
+class TestKeyValueStoreBackendInterface(unittest.TestCase):
+
+    def test_get(self):
+        self.assertRaises(NotImplementedError, KeyValueStoreBackend().get,
+                "a")
+    
+    def test_set(self):
+        self.assertRaises(NotImplementedError, KeyValueStoreBackend().set,
+                "a", 1)
+
+    def test_cleanup(self):
+        self.assertFalse(KeyValueStoreBackend().cleanup())

+ 7 - 4
celery/tests/test_celery.py

@@ -6,11 +6,14 @@ class TestInitFile(unittest.TestCase):
     def test_version(self):
     def test_version(self):
         self.assertTrue(celery.VERSION)
         self.assertTrue(celery.VERSION)
         self.assertEquals(len(celery.VERSION), 3)
         self.assertEquals(len(celery.VERSION), 3)
-        is_stable = not (celery.VERSION[1] % 2)
-        self.assertTrue(celery.is_stable_release() == is_stable)
+        celery.VERSION = (0, 3, 0)
+        self.assertFalse(celery.is_stable_release())
         self.assertEquals(celery.__version__.count("."), 2)
         self.assertEquals(celery.__version__.count("."), 2)
-        self.assertTrue("(%s)" % (is_stable and "stable" or "unstable") in \
-                celery.version_with_meta())
+        self.assertTrue("(unstable)" in celery.version_with_meta())
+        celery.VERSION = (0, 4, 0)
+        self.assertTrue(celery.is_stable_release())
+        self.assertTrue("(stable)" in celery.version_with_meta())
+
 
 
     def test_meta(self):
     def test_meta(self):
         for m in ("__author__", "__contact__", "__homepage__",
         for m in ("__author__", "__contact__", "__homepage__",

+ 4 - 0
celery/tests/test_discovery.py

@@ -16,3 +16,7 @@ class TestDiscovery(unittest.TestCase):
     def test_discovery(self):
     def test_discovery(self):
         if "someapp" in settings.INSTALLED_APPS:
         if "someapp" in settings.INSTALLED_APPS:
             self.assertDiscovery()
             self.assertDiscovery()
+
+    def test_discovery_with_broken(self):
+        settings.INSTALLED_APPS = settings.INSTALLED_APPS + ["xxxnot.aexist"]
+        self.assertDiscovery()

+ 77 - 2
celery/tests/test_task.py

@@ -9,11 +9,13 @@ from celery.log import setup_logger
 from celery import messaging
 from celery import messaging
 from celery.result import EagerResult
 from celery.result import EagerResult
 from celery.backends import default_backend
 from celery.backends import default_backend
+from datetime import datetime, timedelta
 
 
 
 
 def return_True(self, **kwargs):
 def return_True(self, **kwargs):
     # Task run functions can't be closures/lambdas, as they're pickled.
     # Task run functions can't be closures/lambdas, as they're pickled.
     return True
     return True
+registry.tasks.register(return_True, "cu.return-true")
 
 
 
 
 def raise_exception(self, **kwargs):
 def raise_exception(self, **kwargs):
@@ -47,13 +49,46 @@ class TestCeleryTasks(unittest.TestCase):
         cls.run = return_True
         cls.run = return_True
         return cls
         return cls
 
 
+    def test_ping(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True
+        self.assertEquals(task.ping(), 'pong')
+        conf.ALWAYS_EAGER = False
+
+    def test_execute_remote(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True
+        self.assertEquals(task.execute_remote(return_True, ["foo"]).get(),
+                          True)
+        conf.ALWAYS_EAGER = False
+
+    def test_dmap(self):
+        from celery import conf
+        import operator
+        conf.ALWAYS_EAGER = True
+        res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
+        self.assertTrue(res, sum([operator.add(x, x)
+                                    for x in xrange(10)]))
+        conf.ALWAYS_EAGER = False
+    
+    def test_dmap_async(self):
+        from celery import conf
+        import operator
+        conf.ALWAYS_EAGER = True
+        res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
+        self.assertTrue(res.get(), sum([operator.add(x, x)
+                                            for x in xrange(10)]))
+        conf.ALWAYS_EAGER = False
+
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
-            **kwargs):
+            test_eta=False, **kwargs):
         next_task = consumer.fetch()
         next_task = consumer.fetch()
-        task_data = consumer.decoder(next_task.body)
+        task_data = next_task.decode()
         self.assertEquals(task_data["id"], presult.task_id)
         self.assertEquals(task_data["id"], presult.task_id)
         self.assertEquals(task_data["task"], task_name)
         self.assertEquals(task_data["task"], task_name)
         task_kwargs = task_data.get("kwargs", {})
         task_kwargs = task_data.get("kwargs", {})
+        if test_eta:
+            self.assertTrue(isinstance(task_data.get("eta"), datetime))
         for arg_name, arg_value in kwargs.items():
         for arg_name, arg_value in kwargs.items():
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
 
 
@@ -92,6 +127,18 @@ class TestCeleryTasks(unittest.TestCase):
         presult2 = task.delay_task(t1.name, name="George Constanza")
         presult2 = task.delay_task(t1.name, name="George Constanza")
         self.assertNextTaskDataEquals(consumer, presult2, t1.name,
         self.assertNextTaskDataEquals(consumer, presult2, t1.name,
                 name="George Constanza")
                 name="George Constanza")
+        
+        # With eta.
+        presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
+                                    eta=datetime.now() + timedelta(days=1))
+        self.assertNextTaskDataEquals(consumer, presult2, t1.name,
+                name="George Constanza", test_eta=True)
+
+        # With countdown.
+        presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
+                                    countdown=10)
+        self.assertNextTaskDataEquals(consumer, presult2, t1.name,
+                name="George Constanza", test_eta=True)
 
 
         self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
         self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
                 "some.task.that.should.never.exist.X.X.X.X.X")
                 "some.task.that.should.never.exist.X.X.X.X.X")
@@ -112,9 +159,26 @@ class TestCeleryTasks(unittest.TestCase):
         publisher = t1.get_publisher()
         publisher = t1.get_publisher()
         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 
 
+    def test_get_logger(self):
+        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+        t1 = T1()
+        logfh = StringIO()
+        logger = t1.get_logger(logfile=logfh, loglevel=0)
+        self.assertTrue(logger)
+
 
 
 class TestTaskSet(unittest.TestCase):
 class TestTaskSet(unittest.TestCase):
 
 
+    def test_function_taskset(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True 
+        ts = task.TaskSet("cu.return-true", [
+            [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
+        res = ts.run()
+        self.assertEquals(res.join(), [True, True, True, True, True])
+
+        conf.ALWAYS_EAGER = False
+
     def test_counter_taskset(self):
     def test_counter_taskset(self):
         IncrementCounterTask.count = 0
         IncrementCounterTask.count = 0
         ts = task.TaskSet(IncrementCounterTask, [
         ts = task.TaskSet(IncrementCounterTask, [
@@ -169,3 +233,14 @@ class TestTaskApply(unittest.TestCase):
         self.assertTrue(f.is_ready())
         self.assertTrue(f.is_ready())
         self.assertFalse(f.is_done())
         self.assertFalse(f.is_done())
         self.assertRaises(KeyError, f.get)
         self.assertRaises(KeyError, f.get)
+
+
+class TestPeriodicTask(unittest.TestCase):
+
+    def test_interface(self):
+
+        class MyPeriodicTask(task.PeriodicTask):
+            run_every = None
+
+        self.assertRaises(NotImplementedError, MyPeriodicTask)
+

+ 2 - 1
celery/tests/test_task_builtins.py

@@ -1,5 +1,6 @@
 import unittest
 import unittest
-from celery.task.builtins import PingTask, ExecuteRemoteTask
+from celery.task.builtins import PingTask
+from celery.task.base import ExecuteRemoteTask
 from celery.utils import pickle
 from celery.utils import pickle