Browse Source

Adds tests for the security package

Mher Movsisyan 13 years ago
parent
commit
e33fff606d

+ 59 - 1
celery/tests/test_security/__init__.py

@@ -1,4 +1,3 @@
-from __future__ import absolute_import
 """
 Keys and certificates for tests (KEY1 is a private key of CERT1, etc.)
 
@@ -13,6 +12,17 @@ Generated with::
     $ rm key1.key.org cert1.csr
 
 """
+from __future__ import absolute_import
+
+import __builtin__
+
+from celery import current_app
+from celery.exceptions import ImproperlyConfigured
+from celery.security import setup_security, disable_untrusted_serializers
+from kombu.serialization import registry
+
+from .case import SecurityCase
+
 
 KEY1 = """-----BEGIN RSA PRIVATE KEY-----
 MIICXgIBAAKBgQDCsmLC+eqL4z6bhtv0nzbcnNXuQrZUoh827jGfDI3kxNZ2LbEy
@@ -73,3 +83,51 @@ AAOBgQBzaZ5vBkzksPhnWb2oobuy6Ne/LMEtdQ//qeVY4sKl2tOJUCSdWRen9fqP
 e+zYdEdkFCd8rp568Eiwkq/553uy4rlE927/AEqs/+KGYmAtibk/9vmi+/+iZXyS
 WWZybzzDZFncq1/N1C3Y/hrCBNDFO4TsnTLAhWtZ4c0vDAiacw==
 -----END CERTIFICATE-----"""
+
+
+class TestSecurity(SecurityCase):
+
+    def tearDown(self):
+        registry._disabled_content_types.clear()
+
+    def test_disable_untrusted_serializers(self):
+        disabled = registry._disabled_content_types
+        self.assertEqual(0, len(disabled))
+
+        disable_untrusted_serializers(
+                ['application/json', 'application/x-python-serialize'])
+        self.assertIn('application/x-yaml', disabled)
+        self.assertNotIn('application/json', disabled)
+        self.assertNotIn('application/x-python-serialize', disabled)
+        disabled.clear()
+
+        disable_untrusted_serializers()
+        self.assertIn('application/x-yaml', disabled)
+        self.assertIn('application/json', disabled)
+        self.assertIn('application/x-python-serialize', disabled)
+
+    def test_setup_security(self):
+        disabled = registry._disabled_content_types
+        self.assertEqual(0, len(disabled))
+
+        current_app.conf.CELERY_TASK_SERIALIZER = 'json'
+
+        setup_security()
+        self.assertIn('application/x-python-serialize', disabled)
+        disabled.clear()
+
+    def test_security_conf(self):
+        current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
+
+        self.assertRaises(ImproperlyConfigured, setup_security)
+
+        _import = __builtin__.__import__
+
+        def import_hook(name, *args, **kwargs):
+            if name == 'OpenSSL':
+                raise ImportError
+            return _import(name, *args, **kwargs)
+
+        __builtin__.__import__ = import_hook
+        self.assertRaises(ImproperlyConfigured, setup_security)
+        __builtin__.__import__ = _import

+ 3 - 0
celery/tests/test_security/test_certificate.py

@@ -20,6 +20,9 @@ class TestCertificate(SecurityCase):
         self.assertRaises(SecurityError, Certificate, CERT1[:20] + CERT1[21:])
         self.assertRaises(SecurityError, Certificate, KEY1)
 
+    def test_has_expired(self):
+        self.assertFalse(Certificate(CERT1).has_expired())
+
 
 class TestCertStore(SecurityCase):
 

+ 5 - 0
celery/tests/test_security/test_key.py

@@ -19,3 +19,8 @@ class TestKey(SecurityCase):
         self.assertRaises(SecurityError, PrivateKey, "foo")
         self.assertRaises(SecurityError, PrivateKey, KEY1[:20] + KEY1[21:])
         self.assertRaises(SecurityError, PrivateKey, CERT1)
+
+    def test_sign(self):
+        pkey = PrivateKey(KEY1)
+        pkey.sign('test', 'sha1')
+        self.assertRaises(ValueError, pkey.sign, 'test', 'unknown')

+ 6 - 1
celery/tests/test_security/test_serialization.py

@@ -2,9 +2,10 @@ from __future__ import absolute_import
 
 from celery.exceptions import SecurityError
 
-from celery.security.serialization import SecureSerializer
+from celery.security.serialization import SecureSerializer, register_auth
 from celery.security.certificate import Certificate, CertStore
 from celery.security.key import PrivateKey
+from kombu.serialization import registry
 
 from . import CERT1, CERT2, KEY1, KEY2
 from .case import SecurityCase
@@ -48,3 +49,7 @@ class TestSecureSerializer(SecurityCase):
         s1 = self._get_s(KEY1, CERT1, [CERT2])
         s2 = self._get_s(KEY2, CERT2, [CERT1])
         self.assertEqual(s2.deserialize(s1.serialize("foo")), "foo")
+
+    def test_register_auth(self):
+        register_auth(KEY1, CERT1, '')
+        self.assertIn('application/data', registry._decoders)