Browse Source

Now at 94% coverage! + Tests and fixes for the supervisor.

Ask Solem 16 years ago
parent
commit
f044601f1a
2 changed files with 94 additions and 8 deletions
  1. 14 8
      celery/supervisor.py
  2. 80 0
      celery/tests/test_supervisor.py

+ 14 - 8
celery/supervisor.py

@@ -1,6 +1,7 @@
-from multiprocessing import Process, TimeoutError
+import multiprocessing
 import threading
 import time
+from multiprocessing import TimeoutError
 
 PING_TIMEOUT = 30 # seconds
 JOIN_TIMEOUT = 2
@@ -9,6 +10,10 @@ MAX_RESTART_FREQ = 3
 MAX_RESTART_FREQ_TIME = 10
 
 
+class MaxRestartsExceededError(Exception):
+    """Restarts exceeded the maximum restart frequency."""
+
+
 def raise_ping_timeout(msg):
     """Raises :exc:`multiprocessing.TimeoutError`, for use in
     :class:`threading.Timer` callbacks."""
@@ -68,6 +73,7 @@ class OFASupervisor(object):
         The time in seconds, between process pings.
 
     """
+    Process = multiprocessing.Process
 
     def __init__(self, target, args=None, kwargs=None,
             ping_timeout=PING_TIMEOUT, join_timeout=JOIN_TIMEOUT,
@@ -91,36 +97,36 @@ class OFASupervisor(object):
 
         def _start_supervised_process():
             """Start the :attr:`target` in a new process."""
-            process = Process(target=target,
-                              args=self.args, kwargs=self.kwargs)
+            process = self.Process(target=target,
+                                   args=self.args, kwargs=self.kwargs)
             process.start()
             return process
 
-        def _restart(self, process):
+        def _restart(process):
             """Terminate the process and restart."""
             process.join(timeout=self.join_timeout)
             process.terminate()
             self.restarts_in_frame += 1
             process = _start_supervised_process()
 
+        process = _start_supervised_process()
         try:
-            process = _start_supervised_process()
             restart_frame = 0
             while True:
                 if restart_frame > self.max_restart_freq_time:
                     if self.restarts_in_frame >= self.max_restart_freq:
-                        raise Exception(
+                        raise MaxRestartsExceededError(
                                 "Supervised: Max restart frequency reached")
                 restart_frame = 0
                 self.restarts_in_frame = 0
 
                 try:
-                    proc_is_alive = self.is_alive(process)
+                    proc_is_alive = self._is_alive(process)
                 except TimeoutError:
                     proc_is_alive = False
 
                 if not proc_is_alive:
-                    self._restart()
+                    _restart(process)
 
                 time.sleep(self.check_interval)
                 restart_frame += self.check_interval

+ 80 - 0
celery/tests/test_supervisor.py

@@ -0,0 +1,80 @@
+import unittest
+from celery.supervisor import raise_ping_timeout, OFASupervisor
+from celery.supervisor import TimeoutError, MaxRestartsExceededError
+
+
+def target_one(x, y, z):
+    return x * y * z
+
+
+class MockProcess(object):
+    _started = False
+    _stopped = False
+    _terminated = False
+    _joined = False
+    alive = True
+    timeout_on_is_alive = False
+
+    def __init__(self, target, args, kwargs):
+        self.target = target
+        self.args = args
+        self.kwargs = kwargs
+
+    def start(self):
+        self._stopped = False
+        self._started = True
+
+    def stop(self):
+        self._stopped = True
+        self._started = False
+
+    def terminate(self):
+        self._terminated = False
+
+    def is_alive(self):
+        if self._started and self.alive:
+            if self.timeout_on_is_alive:
+                raise TimeoutError("Supervised: timed out.")
+            return True
+        return False
+
+    def join(self, timeout=None):
+        self._joined = True
+
+class TestDiv(unittest.TestCase):
+
+    def test_raise_ping_timeout(self):
+        self.assertRaises(TimeoutError, raise_ping_timeout, "timed out")
+
+
+class TestOFASupervisor(unittest.TestCase):
+
+    def test_init(self):
+        s = OFASupervisor(target=target_one, args=[2, 4, 8], kwargs={})
+        s.Process = MockProcess
+    
+    def test__is_alive(self):
+        s = OFASupervisor(target=target_one, args=[2, 4, 8], kwargs={})
+        s.Process = MockProcess
+        proc = MockProcess(target_one, [2, 4, 8], {})
+        proc.start()
+        self.assertTrue(s._is_alive(proc))
+        proc.alive = False
+        self.assertFalse(s._is_alive(proc))
+
+    def test_start(self):
+        MockProcess.alive = False
+        s = OFASupervisor(target=target_one, args=[2, 4, 8], kwargs={},
+                          max_restart_freq=0, max_restart_freq_time=0)
+        s.Process = MockProcess
+        self.assertRaises(MaxRestartsExceededError, s.start)
+        MockProcess.alive = True
+    
+    def test_start_is_alive_timeout(self):
+        MockProcess.alive = True
+        MockProcess.timeout_on_is_alive = True
+        s = OFASupervisor(target=target_one, args=[2, 4, 8], kwargs={},
+                          max_restart_freq=0, max_restart_freq_time=0)
+        s.Process = MockProcess
+        self.assertRaises(MaxRestartsExceededError, s.start)
+        MockProcess.timeout_on_is_alive = False