Browse Source

100% coverage for celery.task.sets

Ask Solem 14 years ago
parent
commit
ae5795b137

+ 1 - 1
celery/task/sets.py

@@ -83,7 +83,7 @@ class subtask(AttributeDict):
         args = tuple(args) + tuple(self.args)
         kwargs = dict(self.kwargs, **kwargs)
         options = dict(self.options, **options)
-        return self.get_type().apply(args, kwargs, options)
+        return self.get_type().apply(args, kwargs, **options)
 
     def apply_async(self, args=(), kwargs={}, **options):
         """Apply this task asynchronously."""

+ 8 - 0
celery/tests/test_backends/test_base.py

@@ -34,6 +34,14 @@ class TestBaseBackendInterface(unittest.TestCase):
         self.assertRaises(NotImplementedError,
                 b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
 
+    def test_reload_task_result(self):
+        self.assertRaises(NotImplementedError,
+                b.reload_task_result, "SOMExx-N0nex1stant-IDxx-")
+
+    def test_reload_taskset_result(self):
+        self.assertRaises(NotImplementedError,
+                b.reload_taskset_result, "SOMExx-N0nex1stant-IDxx-")
+
     def test_get_result(self):
         self.assertRaises(NotImplementedError,
                 b.get_result, "SOMExx-N0nex1stant-IDxx-")

+ 86 - 0
celery/tests/test_backends/test_database.py

@@ -0,0 +1,86 @@
+import sys
+import socket
+import unittest2 as unittest
+
+from celery.exceptions import ImproperlyConfigured
+
+from celery import states
+from celery.utils import gen_unique_id
+from celery.backends.database import DatabaseBackend
+
+from celery.tests.utils import execute_context, mask_modules
+
+
+class SomeClass(object):
+
+    def __init__(self, data):
+        self.data = data
+
+
+class test_DatabaseBackend(unittest.TestCase):
+
+    def test_mark_as_done(self):
+        tb = DatabaseBackend()
+
+        tid = gen_unique_id()
+
+        self.assertEqual(tb.get_status(tid), states.PENDING)
+        self.assertIsNone(tb.get_result(tid))
+
+        tb.mark_as_done(tid, 42)
+        self.assertEqual(tb.get_status(tid), states.SUCCESS)
+        self.assertEqual(tb.get_result(tid), 42)
+
+    def test_is_pickled(self):
+        tb = DatabaseBackend()
+
+        tid2 = gen_unique_id()
+        result = {"foo": "baz", "bar": SomeClass(12345)}
+        tb.mark_as_done(tid2, result)
+        # is serialized properly.
+        rindb = tb.get_result(tid2)
+        self.assertEqual(rindb.get("foo"), "baz")
+        self.assertEqual(rindb.get("bar").data, 12345)
+
+    def test_mark_as_started(self):
+        tb = DatabaseBackend()
+        tid = gen_unique_id()
+        tb.mark_as_started(tid)
+        self.assertEqual(tb.get_status(tid), states.STARTED)
+
+    def test_mark_as_revoked(self):
+        tb = DatabaseBackend()
+        tid = gen_unique_id()
+        tb.mark_as_revoked(tid)
+        self.assertEqual(tb.get_status(tid), states.REVOKED)
+
+    def test_mark_as_retry(self):
+        tb = DatabaseBackend()
+        tid = gen_unique_id()
+        try:
+            raise KeyError("foo")
+        except KeyError, exception:
+            import traceback
+            trace = "\n".join(traceback.format_stack())
+        tb.mark_as_retry(tid, exception, traceback=trace)
+        self.assertEqual(tb.get_status(tid), states.RETRY)
+        self.assertIsInstance(tb.get_result(tid), KeyError)
+        self.assertEqual(tb.get_traceback(tid), trace)
+
+    def test_mark_as_failure(self):
+        tb = DatabaseBackend()
+
+        tid3 = gen_unique_id()
+        try:
+            raise KeyError("foo")
+        except KeyError, exception:
+            import traceback
+            trace = "\n".join(traceback.format_stack())
+        tb.mark_as_failure(tid3, exception, traceback=trace)
+        self.assertEqual(tb.get_status(tid3), states.FAILURE)
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
+        self.assertEqual(tb.get_traceback(tid3), trace)
+
+    def test_process_cleanup(self):
+        tb = DatabaseBackend()
+        tb.process_cleanup()

+ 174 - 0
celery/tests/test_task_sets.py

@@ -0,0 +1,174 @@
+import unittest2 as unittest
+
+import simplejson
+
+from celery import conf
+from celery.task import Task
+from celery.task.sets import subtask, TaskSet
+
+from celery.tests.utils import execute_context, with_eager_tasks
+from celery.tests.compat import catch_warnings
+
+
+class MockTask(Task):
+    name = "tasks.add"
+
+    def run(self, x, y, **kwargs):
+        return x + y
+
+    @classmethod
+    def apply_async(cls, args, kwargs, **options):
+        return (args, kwargs, options)
+
+    @classmethod
+    def apply(cls, args, kwargs, **options):
+        return (args, kwargs, options)
+
+
+class test_subtask(unittest.TestCase):
+
+    def test_behaves_like_type(self):
+        s = subtask("tasks.add", (2, 2), {"cache": True},
+                    {"routing_key": "CPU-bound"})
+        self.assertDictEqual(subtask(s), s)
+
+    def test_task_argument_can_be_task_cls(self):
+        s = subtask(MockTask, (2, 2))
+        self.assertEqual(s.task, MockTask.name)
+
+    def test_apply_async(self):
+        s = MockTask.subtask((2, 2), {"cache": True},
+                {"routing_key": "CPU-bound"})
+        args, kwargs, options = s.apply_async()
+        self.assertTupleEqual(args, (2, 2))
+        self.assertDictEqual(kwargs, {"cache": True})
+        self.assertDictEqual(options, {"routing_key": "CPU-bound"})
+
+    def test_delay_argmerge(self):
+        s = MockTask.subtask((2, ), {"cache": True},
+                {"routing_key": "CPU-bound"})
+        args, kwargs, options = s.delay(10, cache=False, other="foo")
+        self.assertTupleEqual(args, (10, 2))
+        self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
+        self.assertDictEqual(options, {"routing_key": "CPU-bound"})
+
+    def test_apply_async_argmerge(self):
+        s = MockTask.subtask((2, ), {"cache": True},
+                {"routing_key": "CPU-bound"})
+        args, kwargs, options = s.apply_async((10, ),
+                                              {"cache": False, "other": "foo"},
+                                              routing_key="IO-bound",
+                                              exchange="fast")
+
+        self.assertTupleEqual(args, (10, 2))
+        self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
+        self.assertDictEqual(options, {"routing_key": "IO-bound",
+                                        "exchange": "fast"})
+
+    def test_apply_argmerge(self):
+        s = MockTask.subtask((2, ), {"cache": True},
+                {"routing_key": "CPU-bound"})
+        args, kwargs, options = s.apply((10, ),
+                                        {"cache": False, "other": "foo"},
+                                        routing_key="IO-bound",
+                                        exchange="fast")
+
+        self.assertTupleEqual(args, (10, 2))
+        self.assertDictEqual(kwargs, {"cache": False, "other": "foo"})
+        self.assertDictEqual(options, {"routing_key": "IO-bound",
+                                        "exchange": "fast"})
+
+    def test_is_JSON_serializable(self):
+        s = MockTask.subtask((2, ), {"cache": True},
+                {"routing_key": "CPU-bound"})
+        s.args = list(s.args) # tuples are not preserved
+                              # but this doesn't matter.
+        self.assertEqual(s,
+                         subtask(simplejson.loads(simplejson.dumps(s))))
+
+
+class test_TaskSet(unittest.TestCase):
+
+    def test_interface__compat(self):
+
+        def with_catch_warnings(log):
+            ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
+            self.assertTrue(log)
+            self.assertIn("Using this invocation of TaskSet is deprecated",
+                          log[0].message.args[0])
+            return ts
+
+        context = catch_warnings(record=True)
+        ts = execute_context(context, with_catch_warnings)
+        self.assertListEqual(ts.tasks,
+                             [MockTask.subtask((i, i))
+                                 for i in (2, 4, 8)])
+
+        # TaskSet.task (deprecated)
+        def with_catch_warnings2(log):
+            self.assertEqual(ts.task, MockTask)
+            self.assertTrue(log)
+            self.assertIn("TaskSet.task is deprecated",
+                          log[0].message.args[0])
+
+        execute_context(catch_warnings(record=True), with_catch_warnings2)
+
+        # TaskSet.task_name (deprecated)
+        def with_catch_warnings3(log):
+            self.assertEqual(ts.task_name, MockTask.name)
+            self.assertTrue(log)
+            self.assertIn("TaskSet.task_name is deprecated",
+                          log[0].message.args[0])
+
+        execute_context(catch_warnings(record=True), with_catch_warnings3)
+
+    def test_task_arg_can_be_iterable__compat(self):
+        ts = TaskSet([MockTask.subtask((i, i))
+                        for i in (2, 4, 8)])
+        self.assertEqual(len(ts), 3)
+
+
+    def test_respects_ALWAYS_EAGER(self):
+
+        class MockTaskSet(TaskSet):
+            applied = 0
+
+            def apply(self, *args, **kwargs):
+                self.applied += 1
+
+        ts = MockTaskSet([MockTask.subtask((i, i))
+                        for i in (2, 4, 8)])
+        conf.ALWAYS_EAGER = True
+        try:
+            ts.apply_async()
+        finally:
+            conf.ALWAYS_EAGER = False
+        self.assertEqual(ts.applied, 1)
+
+    def test_apply_async(self):
+
+        applied = [0]
+
+        class mocksubtask(subtask):
+
+            def apply_async(self, *args, **kwargs):
+                applied[0] += 1
+
+        ts = TaskSet([mocksubtask(MockTask, (i, i))
+                        for i in (2, 4, 8)])
+        ts.apply_async()
+        self.assertEqual(applied[0], 3)
+
+    def test_apply(self):
+
+        applied = [0]
+
+        class mocksubtask(subtask):
+
+            def apply(self, *args, **kwargs):
+                applied[0] += 1
+
+        ts = TaskSet([mocksubtask(MockTask, (i, i))
+                        for i in (2, 4, 8)])
+        ts.apply()
+        self.assertEqual(applied[0], 3)

+ 12 - 1
docs/internals/moduleindex.rst

@@ -11,14 +11,25 @@ Worker
 celery.worker
 -------------
 
-* :class:`celery.worker.WorkController`
+* :class:`~celery.worker.WorkController`
 
 This is the worker's main process. It starts and stops all the components
 required by the worker: Pool, Mediator, Scheduler, ClockService, and Listener.
 
+* :func:`~celery.worker.process_initializer`
+
+This is the function used to initialize pool processes. It sets up loggers and
+imports required task modules, etc.
+
 celery.worker.job
 -----------------
 
+* :class:`~celery.worker.job.TaskRequest`
+
+A request to execute a task. Contains the task name, id, args and kwargs.
+Handles acknowledgement, execution, writing results to backends and error handling
+(including error e-mails)
+
 celery.worker.pool
 ------------------
 

+ 20 - 1
setup.cfg

@@ -3,7 +3,26 @@ where = celery/tests
 cover3-branch = 1
 cover3-html = 1
 cover3-package = celery
-cover3-exclude = celery.__init__,celery.conf,celery.tests.*,celery.bin.celerybeat,celery.utils.patch,celery.utils.compat,celery.platform,celery.backends.mongodb,celery.backends.tyrant
+cover3-exclude = celery
+                 celery.conf
+                 celery.tests.*
+                 celery.bin.celeryd
+                 celery.bin.celerybeat
+                 celery.bin.celeryev
+                 celery.utils.patch
+                 celery.utils.compat
+                 celery.utils.mail
+                 celery.utils.functional
+                 celery.utils.dispatch*
+                 celery.db.a805d4bd
+                 celery.contrib*
+                 celery.concurrency.threads
+                 celery.concurrency.processes.pool
+                 celery.platform
+                 celery.backends.mongodb
+                 celery.backends.tyrant
+                 celery.backends.pyredis
+                 celery.backends.amqp
 
 [build_sphinx]
 source-dir = docs/

+ 0 - 0
tests/__init__.py