Browse Source

Always use with form of assertRaises

Ask Solem 9 years ago
parent
commit
eb69ffa900

+ 12 - 6
celery/tests/security/test_certificate.py

@@ -16,11 +16,16 @@ class test_Certificate(SecurityCase):
         Certificate(CERT2)
 
     def test_invalid_certificate(self):
-        self.assertRaises((SecurityError, TypeError), Certificate, None)
-        self.assertRaises(SecurityError, Certificate, '')
-        self.assertRaises(SecurityError, Certificate, 'foo')
-        self.assertRaises(SecurityError, Certificate, CERT1[:20] + CERT1[21:])
-        self.assertRaises(SecurityError, Certificate, KEY1)
+        with self.assertRaises((SecurityError, TypeError)):
+            Certificate(None)
+        with self.assertRaises(SecurityError):
+            Certificate('')
+        with self.assertRaises(SecurityError):
+            Certificate('foo')
+        with self.assertRaises(SecurityError):
+            Certificate(CERT1[:20] + CERT1[21:])
+        with self.assertRaises(SecurityError):
+            Certificate(KEY1)
 
     def test_has_expired(self):
         raise SkipTest('cert expired')
@@ -49,7 +54,8 @@ class test_CertStore(SecurityCase):
         cert1 = Certificate(CERT1)
         certstore = CertStore()
         certstore.add_cert(cert1)
-        self.assertRaises(SecurityError, certstore.add_cert, cert1)
+        with self.assertRaises(SecurityError):
+            certstore.add_cert(cert1)
 
 
 class test_FSCertStore(SecurityCase):

+ 10 - 5
celery/tests/security/test_key.py

@@ -14,11 +14,16 @@ class test_PrivateKey(SecurityCase):
         PrivateKey(KEY2)
 
     def test_invalid_private_key(self):
-        self.assertRaises((SecurityError, TypeError), PrivateKey, None)
-        self.assertRaises(SecurityError, PrivateKey, '')
-        self.assertRaises(SecurityError, PrivateKey, 'foo')
-        self.assertRaises(SecurityError, PrivateKey, KEY1[:20] + KEY1[21:])
-        self.assertRaises(SecurityError, PrivateKey, CERT1)
+        with self.assertRaises((SecurityError, TypeError)):
+                PrivateKey(None)
+        with self.assertRaises(SecurityError):
+            PrivateKey('')
+        with self.assertRaises(SecurityError):
+            PrivateKey('foo')
+        with self.assertRaises(SecurityError):
+            PrivateKey(KEY1[:20] + KEY1[21:])
+        with self.assertRaises(SecurityError):
+            PrivateKey(CERT1)
 
     def test_sign(self):
         pkey = PrivateKey(KEY1)

+ 8 - 7
celery/tests/security/test_serialization.py

@@ -29,20 +29,21 @@ class test_SecureSerializer(SecurityCase):
 
     def test_deserialize(self):
         s = self._get_s(KEY1, CERT1, [CERT1])
-        self.assertRaises(SecurityError, s.deserialize, 'bad data')
+        with self.assertRaises(SecurityError):
+            s.deserialize('bad data')
 
     def test_unmatched_key_cert(self):
         s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
-        self.assertRaises(SecurityError,
-                          s.deserialize, s.serialize('foo'))
+        with self.assertRaises(SecurityError):
+            s.deserialize(s.serialize('foo'))
 
     def test_unknown_source(self):
         s1 = self._get_s(KEY1, CERT1, [CERT2])
         s2 = self._get_s(KEY1, CERT1, [])
-        self.assertRaises(SecurityError,
-                          s1.deserialize, s1.serialize('foo'))
-        self.assertRaises(SecurityError,
-                          s2.deserialize, s2.serialize('foo'))
+        with self.assertRaises(SecurityError):
+            s1.deserialize(s1.serialize('foo'))
+        with self.assertRaises(SecurityError):
+            s2.deserialize(s2.serialize('foo'))
 
     def test_self_send(self):
         s1 = self._get_s(KEY1, CERT1, [CERT1])

+ 2 - 1
celery/tests/worker/test_worker.py

@@ -355,7 +355,8 @@ class test_Consumer(AppCase):
         l.pool = l.controller.pool = Mock()
 
         l.connection_errors = (KeyError,)
-        self.assertRaises(SyntaxError, l.start)
+        with self.assertRaises(SyntaxError):
+            l.start()
         l.timer.stop()
 
     def test_loop_ignores_socket_timeout(self):