Browse Source

100% coverage for celery.task.*

Ask Solem 11 years ago
parent
commit
5f724e0d39

+ 7 - 11
celery/task/trace.py

@@ -260,6 +260,8 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     except Exception as exc:
                         _logger.error('Process cleanup failed: %r', exc,
                                       exc_info=True)
+        except MemoryError:
+            raise
         except Exception as exc:
             if eager:
                 raise
@@ -333,13 +335,9 @@ def setup_worker_optimizations(app):
     _tasks = app._tasks
 
     trace_task_ret = _fast_trace_task
-    try:
-        job = sys.modules['celery.worker.job']
-    except KeyError:
-        pass
-    else:
-        job.trace_task_ret = _fast_trace_task
-        job.__optimize__()
+    from celery.worker import job as job_module
+    job_module.trace_task_ret = _fast_trace_task
+    job_module.__optimize__()
 
 
 def reset_worker_optimizations():
@@ -353,10 +351,8 @@ def reset_worker_optimizations():
         BaseTask.__call__ = _patched.pop('BaseTask.__call__')
     except KeyError:
         pass
-    try:
-        sys.modules['celery.worker.job'].trace_task_ret = _trace_task_ret
-    except KeyError:
-        pass
+    from celery.worker import job as job_module
+    job_module.trace_task_ret = _trace_task_ret
 
 
 def _install_stack_protection():

+ 9 - 1
celery/tests/tasks/test_sets.py → celery/tests/compat_modules/test_sets.py

@@ -2,6 +2,8 @@ from __future__ import absolute_import
 
 import anyjson
 
+from mock import Mock, patch
+
 from celery import current_app
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
@@ -44,7 +46,6 @@ class test_subtask(Case):
         self.assertTupleEqual(args, (2, 2))
         self.assertDictEqual(kwargs, {'cache': True})
         self.assertDictEqual(options, {'routing_key': 'CPU-bound'})
-
     def test_delay_argmerge(self):
         s = MockTask.subtask(
             (2, ), {'cache': True}, {'routing_key': 'CPU-bound'},
@@ -128,6 +129,13 @@ class test_TaskSet(Case):
             app.conf.CELERY_ALWAYS_EAGER = False
         self.assertEqual(ts.applied, 1)
 
+        with patch('celery.task.sets.get_current_worker_task') as gwt:
+            parent = gwt.return_value = Mock()
+            parent.request.children = []
+            ts.apply_async()
+            self.assertTrue(parent.request.children)
+
+
     def test_apply_async(self):
 
         applied = [0]

+ 3 - 0
celery/tests/tasks/test_tasks.py

@@ -6,6 +6,8 @@ from functools import wraps
 from mock import patch
 from pickle import loads, dumps
 
+from kombu import Queue
+
 from celery.task import (
     current,
     task,
@@ -289,6 +291,7 @@ class test_tasks(Case):
             consumer.receive('foo', 'foo')
         consumer.purge()
         self.assertIsNone(consumer.queues[0].get())
+        consumer2 = T1.get_consumer(queues=[Queue('foo')])
 
         # Without arguments.
         presult = T1.delay()

+ 136 - 31
celery/tests/tasks/test_trace.py

@@ -1,60 +1,149 @@
 from __future__ import absolute_import
 
-from mock import patch
+from mock import Mock, patch
 
-from celery import current_app
+from celery import uuid
+from celery import signals
 from celery import states
-from celery.exceptions import RetryTaskError
-from celery.task.trace import TraceInfo, eager_trace_task, trace_task
-from celery.tests.utils import Case, Mock
-
-
-@current_app.task
-def add(x, y):
-    return x + y
-
+from celery.app.task import Task as BaseTask
+from celery.exceptions import RetryTaskError, Ignore
+from celery.task.trace import (
+    TraceInfo,
+    eager_trace_task,
+    trace_task,
+    setup_worker_optimizations,
+    reset_worker_optimizations,
+)
+from celery.tests.utils import AppCase, Mock
+
+
+def trace(task, args=(), kwargs={}, propagate=False, **opts):
+    return eager_trace_task(task, 'id-1', args, kwargs,
+                            propagate=propagate, **opts)
 
-@current_app.task(ignore_result=True)
-def add_cast(x, y):
-    return x + y
 
+class TraceCase(AppCase):
 
-@current_app.task
-def raises(exc):
-    raise exc
+    def setup(self):
+        @self.app.task
+        def add(x, y):
+            return x + y
+        self.add = add
 
+        @self.app.task(ignore_result=True)
+        def add_cast(x, y):
+            return x + y
+        self.add_cast = add_cast
 
-def trace(task, args=(), kwargs={}, propagate=False):
-    return eager_trace_task(task, 'id-1', args, kwargs,
-                            propagate=propagate)
+        @self.app.task
+        def raises(exc):
+            raise exc
+        self.raises = raises
 
 
-class test_trace(Case):
+class test_trace(TraceCase):
 
     def test_trace_successful(self):
-        retval, info = trace(add, (2, 2), {})
+        retval, info = trace(self.add, (2, 2), {})
         self.assertIsNone(info)
         self.assertEqual(retval, 4)
 
+    def test_trace_on_success(self):
+
+        @self.app.task(on_success=Mock())
+        def add_with_success(x, y):
+            return x + y
+
+        trace(add_with_success, (2, 2), {})
+        self.assertTrue(add_with_success.on_success.called)
+
+    def test_trace_after_return(self):
+
+        @self.app.task(after_return=Mock())
+        def add_with_after_return(x, y):
+            return x + y
+
+        trace(add_with_after_return, (2, 2), {})
+        self.assertTrue(add_with_after_return.after_return.called)
+
+    def test_with_prerun_receivers(self):
+        on_prerun = Mock()
+        signals.task_prerun.connect(on_prerun)
+        try:
+            trace(self.add, (2, 2), {})
+            self.assertTrue(on_prerun.called)
+        finally:
+            signals.task_prerun.receivers[:] = []
+
+    def test_with_postrun_receivers(self):
+        on_postrun = Mock()
+        signals.task_postrun.connect(on_postrun)
+        try:
+            trace(self.add, (2, 2), {})
+            self.assertTrue(on_postrun.called)
+        finally:
+            signals.task_postrun.receivers[:] = []
+
+    def test_with_success_receivers(self):
+        on_success = Mock()
+        signals.task_success.connect(on_success)
+        try:
+            trace(self.add, (2, 2), {})
+            self.assertTrue(on_success.called)
+        finally:
+            signals.task_success.receivers[:] = []
+
+    def test_when_chord_part(self):
+
+        @self.app.task
+        def add(x, y):
+            return x + y
+        add.backend = Mock()
+
+        trace(add, (2, 2), {}, request={'chord': uuid()})
+        add.backend.on_chord_part_return.assert_called_with(add)
+
+    def test_when_backend_cleanup_raises(self):
+
+        @self.app.task
+        def add(x, y):
+            return x + y
+        add.backend = Mock(name='backend')
+        add.backend.process_cleanup.side_effect = KeyError()
+        trace(add, (2, 2), {}, eager=False)
+        add.backend.process_cleanup.assert_called_with()
+        add.backend.process_cleanup.side_effect = MemoryError()
+        with self.assertRaises(MemoryError):
+            trace(add, (2, 2), {}, eager=False)
+
+    def test_when_Ignore(self):
+
+        @self.app.task
+        def ignored():
+            raise Ignore()
+
+        retval, info = trace(ignored, (), {})
+        self.assertEqual(info.state, states.IGNORED)
+
     def test_trace_SystemExit(self):
         with self.assertRaises(SystemExit):
-            trace(raises, (SystemExit(), ), {})
+            trace(self.raises, (SystemExit(), ), {})
 
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError('foo', 'bar')
-        _, info = trace(raises, (exc, ), {})
+        _, info = trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.RETRY)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception(self):
         exc = KeyError('foo')
-        _, info = trace(raises, (exc, ), {})
+        _, info = trace(self.raises, (exc, ), {})
         self.assertEqual(info.state, states.FAILURE)
         self.assertIs(info.retval, exc)
 
     def test_trace_exception_propagate(self):
         with self.assertRaises(KeyError):
-            trace(raises, (KeyError('foo'), ), {}, propagate=True)
+            trace(self.raises, (KeyError('foo'), ), {}, propagate=True)
 
     @patch('celery.task.trace.build_tracer')
     @patch('celery.task.trace.report_internal_error')
@@ -63,7 +152,7 @@ class test_trace(Case):
         tracer.side_effect = KeyError('foo')
         build_tracer.return_value = tracer
 
-        @current_app.task
+        @self.app.task
         def xtask():
             pass
 
@@ -72,7 +161,7 @@ class test_trace(Case):
         self.assertIs(xtask.__trace__, tracer)
 
 
-class test_TraceInfo(Case):
+class test_TraceInfo(TraceCase):
 
     class TI(TraceInfo):
         __slots__ = TraceInfo.__slots__ + ('__dict__', )
@@ -80,8 +169,24 @@ class test_TraceInfo(Case):
     def test_handle_error_state(self):
         x = self.TI(states.FAILURE)
         x.handle_failure = Mock()
-        x.handle_error_state(add_cast)
+        x.handle_error_state(self.add_cast)
         x.handle_failure.assert_called_with(
-            add_cast,
-            store_errors=add_cast.store_errors_even_if_ignored,
+            self.add_cast,
+            store_errors=self.add_cast.store_errors_even_if_ignored,
         )
+
+
+class test_stackprotection(AppCase):
+
+    def test_stackprotection(self):
+        setup_worker_optimizations(self.app)
+        try:
+            @self.app.task(bind=True)
+            def foo(self, i):
+                if i:
+                    return foo(0)
+                return self.request
+
+            self.assertTrue(foo(1).called_directly)
+        finally:
+            reset_worker_optimizations()