Browse Source

Created unittests for task.py (not all of it is covered yet)

Ask Solem 16 years ago
parent
commit
a7b7a3d21f
1 changed files with 69 additions and 0 deletions
  1. 69 0
      celery/tests/test_task.py

+ 69 - 0
celery/tests/test_task.py

@@ -0,0 +1,69 @@
+import unittest
+
+from celery import task
+from celery import registry
+
+# Task run functions can't be closures/lambdas, as they're pickled.
+def return_True(self, **kwargs):
+    return True
+
+class TestCeleryTasks(unittest.TestCase):
+
+    def createTaskCls(self, cls_name, task_name=None):
+        attrs = {}
+        if task_name:
+            attrs["name"] = task_name
+        cls = type(cls_name, (task.Task, ), attrs)
+        cls.run = return_True
+        return cls
+
+    def assertNextTaskDataEquals(self, consumer, task_id, task_name,
+            **kwargs):
+        next_task = consumer.fetch()
+        task_data = consumer.decoder(next_task.body)
+        self.assertEquals(task_data["celeryID"], task_id)
+        self.assertEquals(task_data["celeryTASK"], task_name)
+        for arg_name, arg_value in kwargs.items():
+            self.assertEquals(task_data.get(arg_name), arg_value)
+
+    def test_regular_task(self):
+    
+            T1 = self.createTaskCls("T1", "c.unittest.t.t1")
+            self.assertTrue(isinstance(T1(), T1))
+            self.assertTrue(T1().run())
+            self.assertTrue(callable(T1()),
+                    "Task class is callable()")
+            self.assertTrue(T1()(),
+                    "Task class runs run() when called")
+            
+            # task without name raises NotImplementedError
+            T2 = self.createTaskCls("T2")
+            self.assertRaises(NotImplementedError, T2)
+
+            registry.tasks.register(T1)
+            t1 = T1()
+            consumer = t1.get_consumer()
+            consumer.discard_all()
+            self.assertTrue(consumer.fetch() is None)
+
+            # Without arguments.
+            tid = t1.delay()
+            self.assertNextTaskDataEquals(consumer, tid, t1.name)
+
+            # With arguments.
+            tid2 = task.delay_task(t1.name, name="George Constanza")
+            self.assertNextTaskDataEquals(consumer, tid2, t1.name,
+                    name="George Constanza")
+
+            self.assertRaises(registry.tasks.NotRegistered, task.delay_task,
+                    "some.task.that.should.never.exist.X.X.X.X.X")
+
+            # Discarding all tasks.
+            task.discard_all()
+            tid3 = task.delay_task(t1.name)
+            self.assertEquals(task.discard_all(), 1)
+            self.assertTrue(consumer.fetch() is None)
+
+            self.assertFalse(task.is_done(tid))
+            task.mark_as_done(tid, result=None)
+            self.assertTrue(task.is_done(tid))