Browse Source

Do proper cleanup after closing pool.

Ionel Cristian Mărieș 11 years ago
parent
commit
9002feec88
2 changed files with 71 additions and 63 deletions
  1. 5 2
      celery/concurrency/workhorse.py
  2. 66 61
      celery/tests/concurrency/test_workhorse.py

+ 5 - 2
celery/concurrency/workhorse.py

@@ -157,6 +157,10 @@ class TaskPool(BasePool):
     def register_with_event_loop(self, hub):
         hub.add_reader(self.sigfd, self.on_sigchld)
 
+    def on_close(self):
+        self.sigfh.close()
+        signalfd.sigprocmask(signalfd.SIG_BLOCK, [])
+
     def on_sigchld(self):
         pending = {}
 
@@ -236,7 +240,6 @@ class TaskPool(BasePool):
         except:
             self.terminate()
             raise
-    on_close = on_stop
 
     def terminate(self, timeout=5):
         for pid in self.workers:
@@ -244,7 +247,7 @@ class TaskPool(BasePool):
 
         while self.workers and timeout > 0:
             self.on_sigchld()
-            time.sleep(1)
+            time.sleep(0.2)
             timeout -= 1
 
         for pid in self.workers:

+ 66 - 61
celery/tests/concurrency/test_workhorse.py

@@ -6,6 +6,7 @@ import signal
 import socket
 import time
 from itertools import cycle
+from contextlib import closing
 
 from celery.concurrency.workhorse import TaskPool
 from celery.five import items
@@ -26,76 +27,80 @@ class test_Workhorse(AppCase):
 
     def test_run_stop_pool(self):
         pool = TaskPool(semaphore=LaxBoundedSemaphore(10))
-        pool.start()
+        with closing(pool):
+            pool.start()
 
-        pids = []
-        accept_callback = lambda pid, ts: pids.append(pid)
-        success_callback = Mock()
-        error_callback = Mock()
-        pool.apply_async(
-            lambda x: x,
-            (2, ),
-            {},
-            accept_callback=accept_callback,
-            correlation_id='asdf-1234',
-            error_callback=error_callback,
-            callback=success_callback,
-        )
-        self.assertTrue(pool.workers)
-        self.assertEqual(pids, list(pool.workers))
-        pool.stop()
-        success_callback.assert_called_with(None)
-        self.assertFalse(error_callback.called)
+
+            pids = []
+            accept_callback = lambda pid, ts: pids.append(pid)
+            success_callback = Mock()
+            error_callback = Mock()
+            pool.apply_async(
+                lambda x: x,
+                (2, ),
+                {},
+                accept_callback=accept_callback,
+                correlation_id='asdf-1234',
+                error_callback=error_callback,
+                callback=success_callback,
+            )
+            self.assertTrue(pool.workers)
+            self.assertEqual(pids, list(pool.workers))
+            pool.stop()
+            success_callback.assert_called_with(None)
+            self.assertFalse(error_callback.called)
 
     def test_terminate_pool(self):
         pool = TaskPool(semaphore=LaxBoundedSemaphore(10))
-        pool.start()
-        pids = []
-        accept_callback = lambda pid, ts: pids.append(pid)
-        success_callback = Mock()
-        error_callback = Mock()
-        pool.apply_async(
-            lambda x: time.sleep(x),
-            (5, ),
-            {},
-            accept_callback=accept_callback,
-            correlation_id='asdf-1234',
-            error_callback=error_callback,
-            callback=success_callback,
-        )
-        self.assertTrue(pool.workers)
-        self.assertEqual(len(pids), 1)
-        self.assertEqual(pids, list(pool.workers))
-        pool.terminate()
-        self.assertFalse(success_callback.called)
-        self.assertTrue(error_callback.called)
-
+        with closing(pool):
+            pool.start()
+            pids = []
+            accept_callback = lambda pid, ts: pids.append(pid)
+            success_callback = Mock()
+            error_callback = Mock()
+            pool.apply_async(
+                lambda x: time.sleep(x),
+                (2, ),
+                {},
+                accept_callback=accept_callback,
+                correlation_id='asdf-1234',
+                error_callback=error_callback,
+                callback=success_callback,
+            )
+            self.assertTrue(pool.workers)
+            self.assertEqual(len(pids), 1)
+            self.assertEqual(pids, list(pool.workers))
+            pool.terminate()
+            self.assertFalse(success_callback.called)
+            pool.stop()
+            self.assertTrue(error_callback.called)
 
     def test_release_sem(self):
         semaphore = LaxBoundedSemaphore(1)
         pool = TaskPool(semaphore=semaphore)
-        pool.start()
-        pids = []
-        accept_callback = lambda pid, ts: pids.append(pid)
-        success_callback = Mock()
-        error_callback = Mock()
-        for i in range(3):
-            semaphore.acquire(
-                lambda args, kwargs: pool.apply_async(*args, **kwargs),
-                (lambda x: time.sleep(x), (2, ), {}),
-                dict(
-                    accept_callback=accept_callback,
-                    correlation_id='asdf-1234-%s' % i,
-                    error_callback=error_callback,
-                    callback=success_callback,
+        with closing(pool):
+            pool.start()
+            pids = []
+            accept_callback = lambda pid, ts: pids.append(pid)
+            success_callback = Mock()
+            error_callback = Mock()
+            for i in range(3):
+                semaphore.acquire(
+                    lambda args, kwargs: pool.apply_async(*args, **kwargs),
+                    (lambda x: time.sleep(x), (2, ), {}),
+                    dict(
+                        accept_callback=accept_callback,
+                        correlation_id='asdf-1234-%s' % i,
+                        error_callback=error_callback,
+                        callback=success_callback,
+                    )
                 )
-            )
-        self.assertEqual(semaphore.value, 0)
-        self.assertTrue(pool.workers)
-        self.assertEqual(pids, list(pool.workers))
-        self.assertEqual(len(pids), 1)
-        pool.terminate_job(pids[0], signal.SIGTERM)
-        pool.terminate()
+            self.assertEqual(semaphore.value, 0)
+            self.assertTrue(pool.workers)
+            self.assertEqual(pids, list(pool.workers))
+            self.assertEqual(len(pids), 1)
+            pool.terminate_job(pids[0], signal.SIGTERM)
+            pool.terminate()
         # TODO TODO TODO TODO TODO TODO
         #pool.grow()
         #self.assertFalse(success_callback.called)