Browse Source

Tests passing

Ask Solem 13 years ago
parent
commit
73ea354d3d

+ 33 - 0
celery/app/task.py

@@ -778,6 +778,21 @@ class BaseTask(object):
             task_id = self.request.id
         self.backend.store_result(task_id, meta, state)
 
+    def on_success(self, retval, task_id, args, kwargs):
+        """Success handler.
+
+        Run by the worker if the task executes successfully.
+
+        :param retval: The return value of the task.
+        :param task_id: Unique id of the executed task.
+        :param args: Original arguments for the executed task.
+        :param kwargs: Original keyword arguments for the executed task.
+
+        The return value of this handler is ignored.
+
+        """
+        pass
+
     def on_retry(self, exc, task_id, args, kwargs, einfo):
         """Retry handler.
 
@@ -815,6 +830,24 @@ class BaseTask(object):
         """
         pass
 
+    def after_return(self, status, retval, task_id, args, kwargs, einfo):
+        """Handler called after the task returns.
+
+        :param status: Current task state.
+        :param retval: Task return value/exception.
+        :param task_id: Unique id of the task.
+        :param args: Original arguments for the task that failed.
+        :param kwargs: Original keyword arguments for the task
+                       that failed.
+
+        :keyword einfo: :class:`~celery.datastructures.ExceptionInfo`
+                        instance, containing the traceback (if any).
+
+        The return value of this handler is ignored.
+
+        """
+        pass
+
     def send_error_email(self, context, exc, **kwargs):
         if self.send_error_emails and not self.disable_error_emails:
             self.ErrorMail(self, **kwargs).send(context, exc)

+ 10 - 6
celery/task/trace.py

@@ -65,10 +65,10 @@ def mro_lookup(cls, attr, stop=()):
             return node
 
 
-def defines_custom_call(task):
+def task_has_custom(task, attr):
     """Returns true if the task or one of its bases
-    defines __call__ (excluding the one in BaseTask)."""
-    return mro_lookup(task.__class__, "__call__", stop=(BaseTask, object))
+    defines ``attr`` (excluding the one in BaseTask)."""
+    return mro_lookup(task.__class__, attr, stop=(BaseTask, object))
 
 
 class TraceInfo(object):
@@ -157,7 +157,7 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     # If the task doesn't define a custom __call__ method
     # we optimize it away by simply calling the run method directly,
     # saving the extra method call and a line less in the stack trace.
-    fun = task if defines_custom_call(task) else task.run
+    fun = task if task_has_custom(task, "__call__") else task.run
 
     loader = loader or current_app.loader
     backend = task.backend
@@ -170,8 +170,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
     loader_task_init = loader.on_task_init
     loader_cleanup = loader.on_process_cleanup
 
-    task_on_success = getattr(task, "on_success", None)
-    task_after_return = getattr(task, "after_return", None)
+    task_on_success = None
+    task_after_return = None
+    if task_has_custom(task, "on_success"):
+        task_on_success = task.on_success
+    if task_has_custom(task, "after_return"):
+        task_after_return = task.after_return
 
     store_result = backend.store_result
     backend_cleanup = backend.process_cleanup

+ 128 - 50
celery/tests/bin/test_celeryd.py

@@ -19,7 +19,7 @@ from celery import signals
 from celery import current_app
 from celery.apps import worker as cd
 from celery.bin.celeryd import WorkerCommand, main as celeryd_main
-from celery.exceptions import ImproperlyConfigured
+from celery.exceptions import ImproperlyConfigured, SystemTerminate
 from celery.utils.log import ensure_process_aware_logger
 from celery.worker import state
 
@@ -32,12 +32,17 @@ def disable_stdouts(fun):
 
     @wraps(fun)
     def disable(*args, **kwargs):
-        sys.stdout, sys.stderr = WhateverIO(), WhateverIO()
+        prev_out, prev_err = sys.stdout, sys.stderr
+        prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
+        sys.stdout = sys.__stdout__ = WhateverIO()
+        sys.stderr = sys.__stderr__ = WhateverIO()
         try:
             return fun(*args, **kwargs)
         finally:
-            sys.stdout = sys.__stdout__
-            sys.stderr = sys.__stderr__
+            sys.stdout = prev_out
+            sys.stderr = prev_err
+            sys.__stdout__ = prev_rout
+            sys.__stderr__ = prev_rerr
 
     return disable
 
@@ -58,6 +63,9 @@ class Worker(cd.Worker):
 class test_Worker(AppCase):
     Worker = Worker
 
+    def teardown(self):
+        self.app.conf.CELERY_INCLUDE = ()
+
     @disable_stdouts
     def test_queues_string(self):
         celery = Celery(set_as_current=False)
@@ -402,19 +410,33 @@ class test_signal_handlers(AppCase):
             def __setitem__(self, sig, handler):
                 next_handlers[sig] = handler
 
-        p, platforms.signals = platforms.signals, Signals()
-        try:
-            handlers["SIGINT"]("SIGINT", object())
-            self.assertTrue(state.should_stop)
-        finally:
-            platforms.signals = p
-            state.should_stop = False
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 3
+            p, platforms.signals = platforms.signals, Signals()
+            try:
+                handlers["SIGINT"]("SIGINT", object())
+                self.assertTrue(state.should_stop)
+            finally:
+                platforms.signals = p
+                state.should_stop = False
 
-        try:
-            next_handlers["SIGINT"]("SIGINT", object())
-            self.assertTrue(state.should_terminate)
-        finally:
-            state.should_terminate = False
+            try:
+                next_handlers["SIGINT"]("SIGINT", object())
+                self.assertTrue(state.should_terminate)
+            finally:
+                state.should_terminate = False
+
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 1
+            p, platforms.signals = platforms.signals, Signals()
+            try:
+                with self.assertRaises(SystemExit):
+                    handlers["SIGINT"]("SIGINT", object())
+            finally:
+                platforms.signals = p
+
+            with self.assertRaises(SystemTerminate):
+                next_handlers["SIGINT"]("SIGINT", object())
 
     @disable_stdouts
     def test_worker_int_handler_only_stop_MainProcess(self):
@@ -424,14 +446,27 @@ class test_signal_handlers(AppCase):
             raise SkipTest("only relevant for multiprocessing")
         process = current_process()
         name, process.name = process.name, "OtherProcess"
-        try:
-            worker = self._Worker()
-            handlers = self.psig(cd.install_worker_int_handler, worker)
-            handlers["SIGINT"]("SIGINT", object())
-            self.assertTrue(state.should_stop)
-        finally:
-            process.name = name
-            state.should_stop = False
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 3
+            try:
+                worker = self._Worker()
+                handlers = self.psig(cd.install_worker_int_handler, worker)
+                handlers["SIGINT"]("SIGINT", object())
+                self.assertTrue(state.should_stop)
+            finally:
+                process.name = name
+                state.should_stop = False
+
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 1
+            try:
+                worker = self._Worker()
+                handlers = self.psig(cd.install_worker_int_handler, worker)
+                with self.assertRaises(SystemExit):
+                    handlers["SIGINT"]("SIGINT", object())
+            finally:
+                process.name = name
+                state.should_stop = False
 
     @disable_stdouts
     def test_install_HUP_not_supported_handler(self):
@@ -448,25 +483,49 @@ class test_signal_handlers(AppCase):
         process = current_process()
         name, process.name = process.name, "OtherProcess"
         try:
+            with patch("celery.apps.worker.active_thread_count") as c:
+                c.return_value = 3
+                worker = self._Worker()
+                handlers = self.psig(
+                        cd.install_worker_term_hard_handler, worker)
+                try:
+                    handlers["SIGQUIT"]("SIGQUIT", object())
+                    self.assertTrue(state.should_terminate)
+                finally:
+                    state.should_terminate = False
+            with patch("celery.apps.worker.active_thread_count") as c:
+                c.return_value = 1
+                worker = self._Worker()
+                handlers = self.psig(
+                        cd.install_worker_term_hard_handler, worker)
+                with self.assertRaises(SystemTerminate):
+                    handlers["SIGQUIT"]("SIGQUIT", object())
+        finally:
+            process.name = name
+
+    @disable_stdouts
+    def test_worker_term_handler_when_threads(self):
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 3
             worker = self._Worker()
-            handlers = self.psig(cd.install_worker_term_hard_handler, worker)
+            handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
-                handlers["SIGQUIT"]("SIGQUIT", object())
-                self.assertTrue(state.should_terminate)
+                handlers["SIGTERM"]("SIGTERM", object())
+                self.assertTrue(state.should_stop)
             finally:
-                state.should_terminate = False
-        finally:
-            process.name = name
+                state.should_stop = False
 
     @disable_stdouts
-    def test_worker_term_handler(self):
-        worker = self._Worker()
-        handlers = self.psig(cd.install_worker_term_handler, worker)
-        try:
-            handlers["SIGTERM"]("SIGTERM", object())
-            self.assertTrue(state.should_stop)
-        finally:
-            state.should_stop = False
+    def test_worker_term_handler_when_single_thread(self):
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 1
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_term_handler, worker)
+            try:
+                with self.assertRaises(SystemExit):
+                    handlers["SIGTERM"]("SIGTERM", object())
+            finally:
+                state.should_stop = False
 
     @patch("sys.__stderr__")
     def test_worker_cry_handler(self, stderr):
@@ -490,10 +549,18 @@ class test_signal_handlers(AppCase):
         process = current_process()
         name, process.name = process.name, "OtherProcess"
         try:
-            worker = self._Worker()
-            handlers = self.psig(cd.install_worker_term_handler, worker)
-            handlers["SIGTERM"]("SIGTERM", object())
-            self.assertTrue(state.should_stop)
+            with patch("celery.apps.worker.active_thread_count") as c:
+                c.return_value = 3
+                worker = self._Worker()
+                handlers = self.psig(cd.install_worker_term_handler, worker)
+                handlers["SIGTERM"]("SIGTERM", object())
+                self.assertTrue(state.should_stop)
+            with patch("celery.apps.worker.active_thread_count") as c:
+                c.return_value = 1
+                worker = self._Worker()
+                handlers = self.psig(cd.install_worker_term_handler, worker)
+                with self.assertRaises(SystemExit):
+                    handlers["SIGTERM"]("SIGTERM", object())
         finally:
             process.name = name
             state.should_stop = False
@@ -521,11 +588,22 @@ class test_signal_handlers(AppCase):
             state.should_stop = False
 
     @disable_stdouts
-    def test_worker_term_hard_handler(self):
-        worker = self._Worker()
-        handlers = self.psig(cd.install_worker_term_hard_handler, worker)
-        try:
-            handlers["SIGQUIT"]("SIGQUIT", object())
-            self.assertTrue(state.should_terminate)
-        finally:
-            state.should_terminate = False
+    def test_worker_term_hard_handler_when_threaded(self):
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 3
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_term_hard_handler, worker)
+            try:
+                handlers["SIGQUIT"]("SIGQUIT", object())
+                self.assertTrue(state.should_terminate)
+            finally:
+                state.should_terminate = False
+
+    @disable_stdouts
+    def test_worker_term_hard_handler_when_single_threaded(self):
+        with patch("celery.apps.worker.active_thread_count") as c:
+            c.return_value = 1
+            worker = self._Worker()
+            handlers = self.psig(cd.install_worker_term_hard_handler, worker)
+            with self.assertRaises(SystemTerminate):
+                handlers["SIGQUIT"]("SIGQUIT", object())

+ 27 - 23
celery/tests/utilities/test_timer2.py

@@ -68,12 +68,12 @@ class test_Timer(Case):
     @skip_if_quick
     def test_enter_after(self):
         t = timer2.Timer()
-        done = [False]
+        try:
+            done = [False]
 
-        def set_done():
-            done[0] = True
+            def set_done():
+                done[0] = True
 
-        try:
             t.apply_after(300, set_done)
             while not done[0]:
                 time.sleep(0.1)
@@ -88,25 +88,29 @@ class test_Timer(Case):
 
     def test_apply_interval(self):
         t = timer2.Timer()
-        t.schedule.enter_after = Mock()
-
-        myfun = Mock()
-        t.apply_interval(30, myfun)
-
-        self.assertEqual(t.schedule.enter_after.call_count, 1)
-        args1, _ = t.schedule.enter_after.call_args_list[0]
-        msec1, tref1, _ = args1
-        self.assertEqual(msec1, 30)
-        tref1()
-
-        self.assertEqual(t.schedule.enter_after.call_count, 2)
-        args2, _ = t.schedule.enter_after.call_args_list[1]
-        msec2, tref2, _ = args2
-        self.assertEqual(msec2, 30)
-        tref2.cancelled = True
-        tref2()
-
-        self.assertEqual(t.schedule.enter_after.call_count, 2)
+        try:
+            t.schedule.enter_after = Mock()
+
+            myfun = Mock()
+            myfun.__name__ = "myfun"
+            t.apply_interval(30, myfun)
+
+            self.assertEqual(t.schedule.enter_after.call_count, 1)
+            args1, _ = t.schedule.enter_after.call_args_list[0]
+            msec1, tref1, _ = args1
+            self.assertEqual(msec1, 30)
+            tref1()
+
+            self.assertEqual(t.schedule.enter_after.call_count, 2)
+            args2, _ = t.schedule.enter_after.call_args_list[1]
+            msec2, tref2, _ = args2
+            self.assertEqual(msec2, 30)
+            tref2.cancelled = True
+            tref2()
+
+            self.assertEqual(t.schedule.enter_after.call_count, 2)
+        finally:
+            t.stop()
 
     @patch("celery.utils.timer2.logger")
     def test_apply_entry_error_handled(self, logger):

+ 8 - 0
celery/utils/timer2.py

@@ -33,6 +33,7 @@ __homepage__ = "http://github.com/ask/timer2/"
 __docformat__ = "restructuredtext"
 
 DEFAULT_MAX_INTERVAL = 2
+TIMER_DEBUG = os.environ.get("TIMER_DEBUG")
 
 logger = get_logger("timer2")
 
@@ -215,6 +216,13 @@ class Timer(Thread):
     on_tick = None
     _timer_count = count(1).next
 
+    if TIMER_DEBUG:
+        def start(self, *args, **kwargs):
+            import traceback
+            print("TIMER START")
+            traceback.print_stack()
+            super(Timer, self).start(*args, **kwargs)
+
     def __init__(self, schedule=None, on_error=None, on_tick=None,
             max_interval=None, **kwargs):
         self.schedule = schedule or self.Schedule(on_error=on_error,

+ 1 - 1
celery/worker/job.py

@@ -419,7 +419,7 @@ class Request(object):
                                    "hostname": self.hostname,
                                    "internal": internal}})
 
-        self.task.send_error_email(context, exc_info.exception)
+        self.task.send_error_email(context, einfo.exception)
 
     def acknowledge(self):
         """Acknowledge task."""