Selaa lähdekoodia

Fixes Rdb tests not resetting sys.stdout

Ask Solem 9 vuotta sitten
vanhempi
commit
9919825837
3 muutettua tiedostoa jossa 58 lisäystä ja 43 poistoa
  1. 15 5
      celery/contrib/rdb.py
  2. 4 4
      celery/tests/case.py
  3. 39 34
      celery/tests/contrib/test_rdb.py

+ 15 - 5
celery/contrib/rdb.py

@@ -132,13 +132,23 @@ class Rdb(Pdb):
     def say(self, m):
         print(m, file=self.out)
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *exc_info):
+        self._close_session()
+
     def _close_session(self):
         self.stdin, self.stdout = sys.stdin, sys.stdout = self._prev_handles
-        self._handle.close()
-        self._client.close()
-        self._sock.close()
-        self.active = False
-        self.say(SESSION_ENDED.format(self=self))
+        if self.active:
+            if self._handle is not None:
+                self._handle.close()
+            if self._client is not None:
+                self._client.close()
+            if self._sock is not None:
+                self._sock.close()
+            self.active = False
+            self.say(SESSION_ENDED.format(self=self))
 
     def do_continue(self, arg):
         self._close_session()

+ 4 - 4
celery/tests/case.py

@@ -452,11 +452,11 @@ class AppCase(Case):
         assert sys.__stdout__
         assert sys.__stderr__
         this = self._get_test_name()
-        if isinstance(sys.stdout, LoggingProxy) or \
-                isinstance(sys.__stdout__, LoggingProxy):
+        if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stdout__, (LoggingProxy, Mock)):
             raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
-        if isinstance(sys.stderr, LoggingProxy) or \
-                isinstance(sys.__stderr__, LoggingProxy):
+        if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stderr__, (LoggingProxy, Mock)):
             raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
         backend = self.app.__dict__.get('backend')
         if backend is not None:

+ 39 - 34
celery/tests/contrib/test_rdb.py

@@ -8,14 +8,14 @@ from celery.contrib.rdb import (
     debugger,
     set_trace,
 )
-from celery.tests.case import Case, Mock, WhateverIO, patch, skip_if_pypy
+from celery.tests.case import AppCase, Mock, WhateverIO, patch, skip_if_pypy
 
 
 class SockErr(socket.error):
     errno = None
 
 
-class test_Rdb(Case):
+class test_Rdb(AppCase):
 
     @patch('celery.contrib.rdb.Rdb')
     def test_debugger(self, Rdb):
@@ -37,56 +37,60 @@ class test_Rdb(Case):
         get_avail_port.return_value = (sock, 8000)
         sock.accept.return_value = (Mock(), ['helu'])
         out = WhateverIO()
-        rdb = Rdb(out=out)
-        self.assertTrue(get_avail_port.called)
-        self.assertIn('helu', out.getvalue())
-
-        # set_quit
-        with patch('sys.settrace') as settrace:
-            rdb.set_quit()
-            settrace.assert_called_with(None)
-
-        # set_trace
-        with patch('celery.contrib.rdb.Pdb.set_trace') as pset:
-            with patch('celery.contrib.rdb._frame'):
-                rdb.set_trace()
-                rdb.set_trace(Mock())
-                pset.side_effect = SockErr
-                pset.side_effect.errno = errno.ENOENT
-                with self.assertRaises(SockErr):
+        with Rdb(out=out) as rdb:
+            self.assertTrue(get_avail_port.called)
+            self.assertIn('helu', out.getvalue())
+
+            # set_quit
+            with patch('sys.settrace') as settrace:
+                rdb.set_quit()
+                settrace.assert_called_with(None)
+
+            # set_trace
+            with patch('celery.contrib.rdb.Pdb.set_trace') as pset:
+                with patch('celery.contrib.rdb._frame'):
                     rdb.set_trace()
+                    rdb.set_trace(Mock())
+                    pset.side_effect = SockErr
+                    pset.side_effect.errno = errno.ENOENT
+                    with self.assertRaises(SockErr):
+                        rdb.set_trace()
 
-        # _close_session
-        rdb._close_session()
+            # _close_session
+            rdb._close_session()
 
-        # do_continue
-        rdb.set_continue = Mock()
-        rdb.do_continue(Mock())
-        rdb.set_continue.assert_called_with()
+            # do_continue
+            rdb.set_continue = Mock()
+            rdb.do_continue(Mock())
+            rdb.set_continue.assert_called_with()
 
-        # do_quit
-        rdb.set_quit = Mock()
-        rdb.do_quit(Mock())
-        rdb.set_quit.assert_called_with()
+            # do_quit
+            rdb.set_quit = Mock()
+            rdb.do_quit(Mock())
+            rdb.set_quit.assert_called_with()
 
     @patch('socket.socket')
     @skip_if_pypy
     def test_get_avail_port(self, sock):
         out = WhateverIO()
         sock.return_value.accept.return_value = (Mock(), ['helu'])
-        Rdb(out=out)
+        with Rdb(out=out) as rdb:
+            pass
 
         with patch('celery.contrib.rdb.current_process') as curproc:
             curproc.return_value.name = 'PoolWorker-10'
-            Rdb(out=out)
+            with Rdb(out=out) as rdb:
+                pass
 
         err = sock.return_value.bind.side_effect = SockErr()
         err.errno = errno.ENOENT
         with self.assertRaises(SockErr):
-            Rdb(out=out)
+            with Rdb(out=out) as rdb:
+                pass
         err.errno = errno.EADDRINUSE
         with self.assertRaises(Exception):
-            Rdb(out=out)
+            with Rdb(out=out) as rdb:
+                pass
         called = [0]
 
         def effect(*a, **kw):
@@ -97,4 +101,5 @@ class test_Rdb(Case):
             finally:
                 called[0] += 1
         sock.return_value.bind.side_effect = effect
-        Rdb(out=out)
+        with Rdb(out=out) as rdb:
+            pass