فهرست منبع

More unit tests. Coverage up to 83%

Ask Solem 16 سال پیش
والد
کامیت
1f467a13ff

+ 2 - 1
celery/execute.py

@@ -1,5 +1,5 @@
 from carrot.connection import DjangoAMQPConnection
-from celery.conf import AMQP_CONNECTION_TIMEOUT, ALWAYS_EAGER
+from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.result import AsyncResult, EagerResult
 from celery.messaging import TaskPublisher
 from celery.registry import tasks
@@ -62,6 +62,7 @@ def apply_async(task, args=None, kwargs=None, routing_key=None,
     if countdown:
         eta = datetime.now() + timedelta(seconds=countdown)
 
+    from celery.conf import ALWAYS_EAGER
     if ALWAYS_EAGER:
         return apply(task, args, kwargs)
 

+ 2 - 1
celery/task/__init__.py

@@ -9,7 +9,8 @@ from celery.conf import AMQP_CONNECTION_TIMEOUT
 from celery.registry import tasks
 from celery.backends import default_backend
 from celery.task.base import Task, TaskSet, PeriodicTask
-from celery.task.builtins import AsynchronousMapTask, ExecuteRemoteTask
+from celery.task.base import ExecuteRemoteTask
+from celery.task.base import AsynchronousMapTask
 from celery.task.builtins import DeleteExpiredTaskMetaTask, PingTask
 from celery.execute import apply_async, delay_task
 from celery.utils import pickle

+ 46 - 10
celery/task/base.py

@@ -6,10 +6,10 @@ from celery.result import TaskSetResult
 from celery.execute import apply_async, delay_task, apply
 from celery.utils import gen_unique_id
 from datetime import timedelta
+from celery.registry import tasks
 from celery.utils import pickle
 
 
-
 class Task(object):
     """A task that can be delayed for execution by the ``celery`` daemon.
 
@@ -216,6 +216,44 @@ class Task(object):
         return apply(cls, args, kwargs, **options)
 
 
+class ExecuteRemoteTask(Task):
+    """Execute an arbitrary function or object.
+
+    *Note* You probably want :func:`execute_remote` instead, which this
+    is an internal component of.
+
+    The object must be pickleable, so you can't use lambdas or functions
+    defined in the REPL (that is the python shell, or ``ipython``).
+
+    """
+    name = "celery.execute_remote"
+
+    def run(self, ser_callable, fargs, fkwargs, **kwargs):
+        """
+        :param ser_callable: A pickled function or callable object.
+
+        :param fargs: Positional arguments to apply to the function.
+
+        :param fkwargs: Keyword arguments to apply to the function.
+
+        """
+        callable_ = pickle.loads(ser_callable)
+        return callable_(*fargs, **fkwargs)
+tasks.register(ExecuteRemoteTask)
+
+
+class AsynchronousMapTask(Task):
+    """Task used internally by :func:`dmap_async` and
+    :meth:`TaskSet.map_async`.  """
+    name = "celery.map_async"
+
+    def run(self, serfunc, args, **kwargs):
+        """The method run by ``celeryd``."""
+        timeout = kwargs.get("timeout")
+        return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
+tasks.register(AsynchronousMapTask)
+
+
 class TaskSet(object):
     """A task containing several subtasks, making it possible
     to track how many, or when all of the tasks has been completed.
@@ -297,6 +335,13 @@ class TaskSet(object):
 
         """
         taskset_id = gen_unique_id()
+
+        from celery.conf import ALWAYS_EAGER
+        if ALWAYS_EAGER:
+            subtasks = [apply(self.task, args, kwargs)
+                            for args, kwargs in self.arguments]
+            return TaskSetResult(taskset_id, subtasks)
+
         conn = DjangoAMQPConnection(connect_timeout=connect_timeout)
         publisher = TaskPublisher(connection=conn)
         subtasks = [apply_async(self.task, args, kwargs,
@@ -306,15 +351,6 @@ class TaskSet(object):
         conn.close()
         return TaskSetResult(taskset_id, subtasks)
 
-    def iterate(self):
-        """Iterate over the results returned after calling :meth:`run`.
-
-        If any of the tasks raises an exception, the exception will
-        be re-raised.
-
-        """
-        return iter(self.run())
-
     def join(self, timeout=None):
         """Gather the results for all of the tasks in the taskset,
         and return a list with them ordered by the order of which they

+ 0 - 38
celery/task/builtins.py

@@ -5,44 +5,6 @@ from datetime import timedelta
 from celery.utils import pickle
 
 
-class AsynchronousMapTask(Task):
-    """Task used internally by :func:`dmap_async` and
-    :meth:`TaskSet.map_async`.  """
-    name = "celery.map_async"
-
-    def run(self, serfunc, args, **kwargs):
-        """The method run by ``celeryd``."""
-        timeout = kwargs.get("timeout")
-        return TaskSet.map(pickle.loads(serfunc), args, timeout=timeout)
-tasks.register(AsynchronousMapTask)
-
-
-class ExecuteRemoteTask(Task):
-    """Execute an arbitrary function or object.
-
-    *Note* You probably want :func:`execute_remote` instead, which this
-    is an internal component of.
-
-    The object must be pickleable, so you can't use lambdas or functions
-    defined in the REPL (that is the python shell, or ``ipython``).
-
-    """
-    name = "celery.execute_remote"
-
-    def run(self, ser_callable, fargs, fkwargs, **kwargs):
-        """
-        :param ser_callable: A pickled function or callable object.
-
-        :param fargs: Positional arguments to apply to the function.
-
-        :param fkwargs: Keyword arguments to apply to the function.
-
-        """
-        callable_ = pickle.loads(ser_callable)
-        return callable_(*fargs, **fkwargs)
-tasks.register(ExecuteRemoteTask)
-
-
 class DeleteExpiredTaskMetaTask(PeriodicTask):
     """A periodic task that deletes expired task metadata every day.
 

+ 26 - 1
celery/tests/test_backends/test_base.py

@@ -1,6 +1,6 @@
 import unittest
 from celery.backends.base import find_nearest_pickleable_exception as fnpe
-from celery.backends.base import BaseBackend
+from celery.backends.base import BaseBackend, KeyValueStoreBackend
 from celery.backends.base import UnpickleableExceptionWrapper
 from django.db.models.base import subclass_exception
 
@@ -16,6 +16,17 @@ Lookalike = subclass_exception("Lookalike", wrapobject, "foo.module")
 b = BaseBackend()
 
 
+class TestBaseBackendInterface(unittest.TestCase):
+
+    def test_get_status(self):
+        self.assertRaises(NotImplementedError,
+                b.is_done, "SOMExx-N0Nex1stant-IDxx-")
+
+    def test_store_result(self):
+        self.assertRaises(NotImplementedError,
+                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, "DONE")
+
+
 class TestPickleException(unittest.TestCase):
 
     def test_BaseException(self):
@@ -46,3 +57,17 @@ class TestPrepareException(unittest.TestCase):
         self.assertTrue(isinstance(x, KeyError))
         y = b.exception_to_python(x)
         self.assertTrue(isinstance(y, KeyError))
+
+
+class TestKeyValueStoreBackendInterface(unittest.TestCase):
+
+    def test_get(self):
+        self.assertRaises(NotImplementedError, KeyValueStoreBackend().get,
+                "a")
+    
+    def test_set(self):
+        self.assertRaises(NotImplementedError, KeyValueStoreBackend().set,
+                "a", 1)
+
+    def test_cleanup(self):
+        self.assertFalse(KeyValueStoreBackend().cleanup())

+ 7 - 4
celery/tests/test_celery.py

@@ -6,11 +6,14 @@ class TestInitFile(unittest.TestCase):
     def test_version(self):
         self.assertTrue(celery.VERSION)
         self.assertEquals(len(celery.VERSION), 3)
-        is_stable = not (celery.VERSION[1] % 2)
-        self.assertTrue(celery.is_stable_release() == is_stable)
+        celery.VERSION = (0, 3, 0)
+        self.assertFalse(celery.is_stable_release())
         self.assertEquals(celery.__version__.count("."), 2)
-        self.assertTrue("(%s)" % (is_stable and "stable" or "unstable") in \
-                celery.version_with_meta())
+        self.assertTrue("(unstable)" in celery.version_with_meta())
+        celery.VERSION = (0, 4, 0)
+        self.assertTrue(celery.is_stable_release())
+        self.assertTrue("(stable)" in celery.version_with_meta())
+
 
     def test_meta(self):
         for m in ("__author__", "__contact__", "__homepage__",

+ 4 - 0
celery/tests/test_discovery.py

@@ -16,3 +16,7 @@ class TestDiscovery(unittest.TestCase):
     def test_discovery(self):
         if "someapp" in settings.INSTALLED_APPS:
             self.assertDiscovery()
+
+    def test_discovery_with_broken(self):
+        settings.INSTALLED_APPS = settings.INSTALLED_APPS + ["xxxnot.aexist"]
+        self.assertDiscovery()

+ 77 - 2
celery/tests/test_task.py

@@ -9,11 +9,13 @@ from celery.log import setup_logger
 from celery import messaging
 from celery.result import EagerResult
 from celery.backends import default_backend
+from datetime import datetime, timedelta
 
 
 def return_True(self, **kwargs):
     # Task run functions can't be closures/lambdas, as they're pickled.
     return True
+registry.tasks.register(return_True, "cu.return-true")
 
 
 def raise_exception(self, **kwargs):
@@ -47,13 +49,46 @@ class TestCeleryTasks(unittest.TestCase):
         cls.run = return_True
         return cls
 
+    def test_ping(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True
+        self.assertEquals(task.ping(), 'pong')
+        conf.ALWAYS_EAGER = False
+
+    def test_execute_remote(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True
+        self.assertEquals(task.execute_remote(return_True, ["foo"]).get(),
+                          True)
+        conf.ALWAYS_EAGER = False
+
+    def test_dmap(self):
+        from celery import conf
+        import operator
+        conf.ALWAYS_EAGER = True
+        res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
+        self.assertTrue(res, sum([operator.add(x, x)
+                                    for x in xrange(10)]))
+        conf.ALWAYS_EAGER = False
+    
+    def test_dmap_async(self):
+        from celery import conf
+        import operator
+        conf.ALWAYS_EAGER = True
+        res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
+        self.assertTrue(res.get(), sum([operator.add(x, x)
+                                            for x in xrange(10)]))
+        conf.ALWAYS_EAGER = False
+
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
-            **kwargs):
+            test_eta=False, **kwargs):
         next_task = consumer.fetch()
-        task_data = consumer.decoder(next_task.body)
+        task_data = next_task.decode()
         self.assertEquals(task_data["id"], presult.task_id)
         self.assertEquals(task_data["task"], task_name)
         task_kwargs = task_data.get("kwargs", {})
+        if test_eta:
+            self.assertTrue(isinstance(task_data.get("eta"), datetime))
         for arg_name, arg_value in kwargs.items():
             self.assertEquals(task_kwargs.get(arg_name), arg_value)
 
@@ -92,6 +127,18 @@ class TestCeleryTasks(unittest.TestCase):
         presult2 = task.delay_task(t1.name, name="George Constanza")
         self.assertNextTaskDataEquals(consumer, presult2, t1.name,
                 name="George Constanza")
+        
+        # With eta.
+        presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
+                                    eta=datetime.now() + timedelta(days=1))
+        self.assertNextTaskDataEquals(consumer, presult2, t1.name,
+                name="George Constanza", test_eta=True)
+
+        # With countdown.
+        presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
+                                    countdown=10)
+        self.assertNextTaskDataEquals(consumer, presult2, t1.name,
+                name="George Constanza", test_eta=True)
 
         self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
                 "some.task.that.should.never.exist.X.X.X.X.X")
@@ -112,9 +159,26 @@ class TestCeleryTasks(unittest.TestCase):
         publisher = t1.get_publisher()
         self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
 
+    def test_get_logger(self):
+        T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+        t1 = T1()
+        logfh = StringIO()
+        logger = t1.get_logger(logfile=logfh, loglevel=0)
+        self.assertTrue(logger)
+
 
 class TestTaskSet(unittest.TestCase):
 
+    def test_function_taskset(self):
+        from celery import conf
+        conf.ALWAYS_EAGER = True 
+        ts = task.TaskSet("cu.return-true", [
+            [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
+        res = ts.run()
+        self.assertEquals(res.join(), [True, True, True, True, True])
+
+        conf.ALWAYS_EAGER = False
+
     def test_counter_taskset(self):
         IncrementCounterTask.count = 0
         ts = task.TaskSet(IncrementCounterTask, [
@@ -169,3 +233,14 @@ class TestTaskApply(unittest.TestCase):
         self.assertTrue(f.is_ready())
         self.assertFalse(f.is_done())
         self.assertRaises(KeyError, f.get)
+
+
+class TestPeriodicTask(unittest.TestCase):
+
+    def test_interface(self):
+
+        class MyPeriodicTask(task.PeriodicTask):
+            run_every = None
+
+        self.assertRaises(NotImplementedError, MyPeriodicTask)
+

+ 2 - 1
celery/tests/test_task_builtins.py

@@ -1,5 +1,6 @@
 import unittest
-from celery.task.builtins import PingTask, ExecuteRemoteTask
+from celery.task.builtins import PingTask
+from celery.task.base import ExecuteRemoteTask
 from celery.utils import pickle