Browse Source

Pool: Start the TimeoutHandler thread on demand

Ask Solem 14 years ago
parent
commit
15c34855a6

+ 4 - 2
celery/concurrency/base.py

@@ -66,7 +66,7 @@ class BasePool(object):
 
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
             errbacks=None, accept_callback=None, timeout_callback=None,
-            **compat):
+            soft_timeout=None, timeout=None, **compat):
         """Equivalent of the :func:`apply` built-in function.
 
         All `callbacks` and `errbacks` should complete immediately since
@@ -89,7 +89,9 @@ class BasePool(object):
                              accept_callback=accept_callback,
                              timeout_callback=timeout_callback,
                              error_callback=on_worker_error,
-                             waitforslot=self.putlocks)
+                             waitforslot=self.putlocks,
+                             soft_timeout=soft_timeout,
+                             timeout=timeout)
 
     def on_ready(self, callbacks, errbacks, ret_value):
         """What to do when a worker task is ready and its return value has

+ 0 - 54
celery/concurrency/processes/__init__.py

@@ -50,60 +50,6 @@ class TaskPool(BasePool):
             self._pool.terminate()
             self._pool = None
 
-    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
-            errbacks=None, accept_callback=None, timeout_callback=None,
-            soft_timeout=None, timeout=None, **compat):
-        """Equivalent of the :func:``apply`` built-in function.
-
-        All ``callbacks`` and ``errbacks`` should complete immediately since
-        otherwise the thread which handles the result will get blocked.
-
-        """
-        args = args or []
-        kwargs = kwargs or {}
-        callbacks = callbacks or []
-        errbacks = errbacks or []
-
-        on_ready = curry(self.on_ready, callbacks, errbacks)
-        on_worker_error = curry(self.on_worker_error, errbacks)
-
-        self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
-            target, args, kwargs))
-
-        return self._pool.apply_async(target, args, kwargs,
-                                      callback=on_ready,
-                                      accept_callback=accept_callback,
-                                      timeout_callback=timeout_callback,
-                                      error_callback=on_worker_error,
-                                      waitforslot=self.putlocks,
-                                      soft_timeout=soft_timeout,
-                                      timeout=timeout)
-
-    def on_worker_error(self, errbacks, exc):
-        einfo = ExceptionInfo((exc.__class__, exc, None))
-        [errback(einfo) for errback in errbacks]
-
-    def on_ready(self, callbacks, errbacks, ret_value):
-        """What to do when a worker task is ready and its return value has
-        been collected."""
-
-        if isinstance(ret_value, ExceptionInfo):
-            if isinstance(ret_value.exception, (
-                    SystemExit, KeyboardInterrupt)):
-                raise ret_value.exception
-            [self.safe_apply_callback(errback, ret_value)
-                    for errback in errbacks]
-        else:
-            [self.safe_apply_callback(callback, ret_value)
-                    for callback in callbacks]
-
-    def safe_apply_callback(self, fun, *args):
-        try:
-            fun(*args)
-        except:
-            self.logger.error("Pool callback raised exception: %s" % (
-                traceback.format_exc(), ))
-
     def terminate_job(self, pid, signal=None):
         os.kill(pid, signal or _signal.SIGTERM)
 

+ 29 - 11
celery/concurrency/processes/pool.py

@@ -21,6 +21,7 @@ import itertools
 import collections
 import time
 import signal
+import warnings
 
 from multiprocessing import Process, cpu_count, TimeoutError
 from multiprocessing import util
@@ -267,7 +268,7 @@ class TaskHandler(PoolThread):
 
 class TimeoutHandler(PoolThread):
 
-    def __init__(self, processes, cache, t_soft, t_hard, putlock):
+    def __init__(self, processes, cache, t_soft, t_hard):
         self.processes = processes
         self.cache = cache
         self.t_soft = t_soft
@@ -278,7 +279,6 @@ class TimeoutHandler(PoolThread):
     def body(self):
         processes = self.processes
         cache = self.cache
-        putlock = self.putlock
         t_hard, t_soft = self.t_hard, self.t_soft
         dirty = set()
 
@@ -492,9 +492,10 @@ class Pool(object):
         self._initializer = initializer
         self._initargs = initargs
 
-        if self.soft_timeout and SIG_SOFT_TIMEOUT is None:
-            raise NotImplementedError("Soft timeouts not supported: "
-                    "Your platform does not have the SIGUSR1 signal.")
+        if soft_timeout and SIG_SOFT_TIMEOUT is None:
+            warnings.warn(UserWarning("Soft timeouts are not supported: "
+                    "on this platform: It does not have the SIGUSR1 signal."))
+            soft_timeout = None
 
         if processes is None:
             try:
@@ -521,13 +522,10 @@ class Pool(object):
         self._task_handler.start()
 
         # Thread killing timedout jobs.
+        self._timeout_handler = None
+        self._timeout_handler_mutex = threading.Lock()
         if self.timeout is not None or self.soft_timeout is not None:
-            self._timeout_handler = self.TimeoutHandler(
-                    self._pool, self._cache,
-                    self.soft_timeout, self.timeout, self._putlock)
-            self._timeout_handler.start()
-        else:
-            self._timeout_handler = None
+            self._start_timeout_handler()
 
         # Thread processing results in the outqueue.
         self._result_handler = self.ResultHandler(self._outqueue,
@@ -665,6 +663,19 @@ class Pool(object):
             return False, None
         self._poll_result = _poll_result
 
+    def _start_timeout_handler(self):
+        # ensure more than one thread does not start the timeout handler
+        # thread at once.
+        self._timeout_handler_mutex.acquire()
+        try:
+            if self._timeout_handler is None:
+                self._timeout_handler = self.TimeoutHandler(
+                        self._pool, self._cache,
+                        self.soft_timeout, self.timeout)
+                self._timeout_handler.start()
+        finally:
+            self._timeout_handler_mutex.release()
+
     def apply(self, func, args=(), kwds={}):
         '''
         Equivalent of `apply()` builtin
@@ -736,6 +747,10 @@ class Pool(object):
 
         '''
         assert self._state == RUN
+        if soft_timeout and SIG_SOFT_TIMEOUT is None:
+            warnings.warn(UserWarning("Soft timeouts are not supported: "
+                    "on this platform: It does not have the SIGUSR1 signal."))
+            soft_timeout = None
         result = ApplyResult(self._cache, callback,
                              accept_callback, timeout_callback,
                              error_callback, soft_timeout, timeout)
@@ -744,6 +759,9 @@ class Pool(object):
             self._putlock.acquire()
             if self._state != RUN:
                 return
+        if timeout or soft_timeout:
+            # start the timeout handler thread when required.
+            self._start_timeout_handler()
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         return result
 

+ 6 - 5
celery/tests/test_worker/test_worker_job.py

@@ -388,10 +388,11 @@ class test_TaskRequest(unittest.TestCase):
 
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.logger = MockLogger()
-        tw.on_timeout(soft=True)
-        self.assertIn("Soft time limit exceeded", tw.logger.warnings[0])
-        tw.on_timeout(soft=False)
-        self.assertIn("Hard time limit exceeded", tw.logger.errors[0])
+        tw.on_timeout(soft=True, timeout=1337)
+        self.assertIn("Soft time limit (1337s) exceeded",
+                      tw.logger.warnings[0])
+        tw.on_timeout(soft=False, timeout=1337)
+        self.assertIn("Hard time limit (1337s) exceeded", tw.logger.errors[0])
         self.assertEqual(mytask.backend.get_status(tw.task_id),
                          states.FAILURE)
 
@@ -401,7 +402,7 @@ class test_TaskRequest(unittest.TestCase):
             tw.logger = MockLogger()
         finally:
             mytask.ignore_result = False
-            tw.on_timeout(soft=True)
+            tw.on_timeout(soft=True, timeout=1336)
             self.assertEqual(mytask.backend.get_status(tw.task_id),
                              states.PENDING)
 

+ 1 - 0
celery/worker/__init__.py

@@ -300,6 +300,7 @@ class WorkController(object):
                 stop = getattr(component, "terminate", stop)
             stop()
 
+        self.priority_timer.stop()
         self.consumer.close_connection()
         self._state = self.TERMINATE
 

+ 2 - 2
celery/worker/job.py

@@ -455,11 +455,11 @@ class TaskRequest(object):
         """Handler called if the task times out."""
         state.task_ready(self)
         if soft:
-            self.logger.warning("Soft time limit (%s) exceeded for %s[%s]" % (
+            self.logger.warning("Soft time limit (%ss) exceeded for %s[%s]" % (
                 timeout, self.task_name, self.task_id))
             exc = SoftTimeLimitExceeded(timeout)
         else:
-            self.logger.error("Hard time limit (%s) exceeded for %s[%s]" % (
+            self.logger.error("Hard time limit (%ss) exceeded for %s[%s]" % (
                 timeout, self.task_name, self.task_id))
             exc = TimeLimitExceeded(timeout)