ソースを参照

100% coverage for .security.certificate

Ask Solem 13 年 前
コミット
539a10687b

+ 1 - 1
celery/tests/test_security/__init__.py

@@ -85,7 +85,7 @@ WWZybzzDZFncq1/N1C3Y/hrCBNDFO4TsnTLAhWtZ4c0vDAiacw==
 -----END CERTIFICATE-----"""
 
 
-class TestSecurity(SecurityCase):
+class test_security(SecurityCase):
 
     def tearDown(self):
         registry._disabled_content_types.clear()

+ 46 - 3
celery/tests/test_security/test_certificate.py

@@ -1,13 +1,15 @@
 from __future__ import absolute_import
 
 from celery.exceptions import SecurityError
-from celery.security.certificate import Certificate, CertStore
+from celery.security.certificate import Certificate, CertStore, FSCertStore
+
+from mock import Mock, patch
 
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 
 
-class TestCertificate(SecurityCase):
+class test_Certificate(SecurityCase):
 
     def test_valid_certificate(self):
         Certificate(CERT1)
@@ -24,7 +26,7 @@ class TestCertificate(SecurityCase):
         self.assertFalse(Certificate(CERT1).has_expired())
 
 
-class TestCertStore(SecurityCase):
+class test_CertStore(SecurityCase):
 
     def test_itercerts(self):
         cert1 = Certificate(CERT1)
@@ -42,3 +44,44 @@ class TestCertStore(SecurityCase):
         certstore = CertStore()
         certstore.add_cert(cert1)
         self.assertRaises(SecurityError, certstore.add_cert, cert1)
+
+
+class test_FSCertStore(SecurityCase):
+
+    @patch("os.path.isdir")
+    @patch("glob.glob")
+    @patch("celery.security.certificate.Certificate")
+    @patch("__builtin__.open")
+    def test_init(self, open_, Certificate, glob, isdir):
+        cert = Certificate.return_value = Mock()
+        cert.has_expired.return_value = False
+        isdir.return_value = True
+        glob.return_value = ["foo.cert"]
+        op = open_.return_value = Mock()
+        op.__enter__ = Mock()
+        def on_exit(*x):
+            if x[0]:
+                print(x)
+                raise x[0], x[1], x[2]
+        op.__exit__ = Mock()
+        op.__exit__.side_effect = on_exit
+        cert.get_id.return_value = 1
+        x = FSCertStore("/var/certs")
+        self.assertIn(1, x._certs)
+        glob.assert_called_with("/var/certs/*")
+        op.__enter__.assert_called_with()
+        op.__exit__.assert_called_with(None, None, None)
+
+        # they both end up with the same id
+        glob.return_value = ["foo.cert", "bar.cert"]
+        with self.assertRaises(SecurityError):
+            x = FSCertStore("/var/certs")
+        glob.return_value = ["foo.cert"]
+
+        cert.has_expired.return_value = True
+        with self.assertRaises(SecurityError):
+            x = FSCertStore("/var/certs")
+
+        isdir.return_value = False
+        with self.assertRaises(SecurityError):
+            x = FSCertStore("/var/certs")

+ 1 - 1
celery/tests/test_security/test_key.py

@@ -7,7 +7,7 @@ from . import CERT1, KEY1, KEY2
 from .case import SecurityCase
 
 
-class TestKey(SecurityCase):
+class test_PrivateKey(SecurityCase):
 
     def test_valid_private_key(self):
         PrivateKey(KEY1)

+ 1 - 1
celery/tests/test_security/test_serialization.py

@@ -11,7 +11,7 @@ from . import CERT1, CERT2, KEY1, KEY2
 from .case import SecurityCase
 
 
-class TestSecureSerializer(SecurityCase):
+class test_SecureSerializer(SecurityCase):
 
     def _get_s(self, key, cert, certs):
         store = CertStore()