Browse Source

100% coverage for .security

Ask Solem 13 years ago
parent
commit
8e7646c68e

+ 1 - 1
celery/security/__init__.py

@@ -71,7 +71,7 @@ def setup_security(allowed_serializers=None, key=None, cert=None, store=None,
     cert = cert or conf.CELERY_SECURITY_CERTIFICATE
     store = store or conf.CELERY_SECURITY_CERT_STORE
 
-    if any(not v for v in (key, cert, store)):
+    if not (key and cert and store):
         raise ImproperlyConfigured(SETTING_MISSING)
 
     with open(key) as kf:

+ 24 - 0
celery/tests/test_security/__init__.py

@@ -13,9 +13,12 @@ Generated with::
 
 """
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import __builtin__
 
+from mock import Mock, patch
+
 from celery import current_app
 from celery.exceptions import ImproperlyConfigured
 from celery.security import setup_security, disable_untrusted_serializers
@@ -23,6 +26,8 @@ from kombu.serialization import registry
 
 from .case import SecurityCase
 
+from celery.tests.utils import mock_open
+
 
 KEY1 = """-----BEGIN RSA PRIVATE KEY-----
 MIICXgIBAAKBgQDCsmLC+eqL4z6bhtv0nzbcnNXuQrZUoh827jGfDI3kxNZ2LbEy
@@ -116,6 +121,25 @@ class test_security(SecurityCase):
         self.assertIn('application/x-python-serialize', disabled)
         disabled.clear()
 
+    @patch("celery.security.register_auth")
+    @patch("celery.security.disable_untrusted_serializers")
+    def test_setup_registry_complete(self, dis, reg, key="KEY", cert="CERT"):
+        calls = [0]
+        def effect(*args):
+            try:
+                m = Mock()
+                m.read.return_value = "B" if calls[0] else "A"
+                return m
+            finally:
+                calls[0] += 1
+
+        with mock_open(side_effect=effect):
+            store = Mock()
+            setup_security(["json"], key, cert, store)
+            dis.assert_called_with(["json"])
+            reg.assert_called_with("A", "B", store)
+
+
     def test_security_conf(self):
         current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
 

+ 21 - 29
celery/tests/test_security/test_certificate.py

@@ -8,6 +8,8 @@ from mock import Mock, patch
 from . import CERT1, CERT2, KEY1
 from .case import SecurityCase
 
+from celery.tests.utils import mock_open
+
 
 class test_Certificate(SecurityCase):
 
@@ -51,37 +53,27 @@ class test_FSCertStore(SecurityCase):
     @patch("os.path.isdir")
     @patch("glob.glob")
     @patch("celery.security.certificate.Certificate")
-    @patch("__builtin__.open")
-    def test_init(self, open_, Certificate, glob, isdir):
+    def test_init(self, Certificate, glob, isdir):
         cert = Certificate.return_value = Mock()
         cert.has_expired.return_value = False
         isdir.return_value = True
         glob.return_value = ["foo.cert"]
-        op = open_.return_value = Mock()
-        op.__enter__ = Mock()
-        def on_exit(*x):
-            if x[0]:
-                print(x)
-                raise x[0], x[1], x[2]
-        op.__exit__ = Mock()
-        op.__exit__.side_effect = on_exit
-        cert.get_id.return_value = 1
-        x = FSCertStore("/var/certs")
-        self.assertIn(1, x._certs)
-        glob.assert_called_with("/var/certs/*")
-        op.__enter__.assert_called_with()
-        op.__exit__.assert_called_with(None, None, None)
-
-        # they both end up with the same id
-        glob.return_value = ["foo.cert", "bar.cert"]
-        with self.assertRaises(SecurityError):
-            x = FSCertStore("/var/certs")
-        glob.return_value = ["foo.cert"]
-
-        cert.has_expired.return_value = True
-        with self.assertRaises(SecurityError):
-            x = FSCertStore("/var/certs")
-
-        isdir.return_value = False
-        with self.assertRaises(SecurityError):
+        with mock_open():
+            cert.get_id.return_value = 1
             x = FSCertStore("/var/certs")
+            self.assertIn(1, x._certs)
+            glob.assert_called_with("/var/certs/*")
+
+            # they both end up with the same id
+            glob.return_value = ["foo.cert", "bar.cert"]
+            with self.assertRaises(SecurityError):
+                x = FSCertStore("/var/certs")
+            glob.return_value = ["foo.cert"]
+
+            cert.has_expired.return_value = True
+            with self.assertRaises(SecurityError):
+                x = FSCertStore("/var/certs")
+
+            isdir.return_value = False
+            with self.assertRaises(SecurityError):
+                x = FSCertStore("/var/certs")

+ 1 - 1
celery/tests/test_task/test_result.py

@@ -350,7 +350,7 @@ class test_pending_AsyncResult(AppCase):
         self.assertIsNone(self.task.result)
 
 
-class test_failed_AsyncResult(TestTaskSetResult):
+class test_failed_AsyncResult(test_TaskSetResult):
 
     def setup(self):
         self.size = 11

+ 10 - 14
celery/tests/test_worker/test_worker_autoreload.py

@@ -20,7 +20,7 @@ from celery.worker.autoreload import (
     Autoreloader,
 )
 
-from celery.tests.utils import AppCase, Case, WhateverIO
+from celery.tests.utils import AppCase, Case, WhateverIO, mock_open
 
 
 class test_WorkerComponent(AppCase):
@@ -37,19 +37,15 @@ class test_WorkerComponent(AppCase):
 
 class test_file_hash(Case):
 
-    @patch("__builtin__.open")
-    def test_hash(self, open_):
-        context = open_.return_value = Mock()
-        context.__enter__ = Mock()
-        context.__exit__ = Mock()
-        a = context.__enter__.return_value = WhateverIO()
-        a.write("the quick brown fox\n")
-        a.seek(0)
-        A = file_hash("foo")
-        b = context.__enter__.return_value = WhateverIO()
-        b.write("the quick brown bar\n")
-        b.seek(0)
-        B = file_hash("bar")
+    def test_hash(self):
+        with mock_open() as a:
+            a.write("the quick brown fox\n")
+            a.seek(0)
+            A = file_hash("foo")
+        with mock_open() as b:
+            b.write("the quick brown bar\n")
+            b.seek(0)
+            B = file_hash("bar")
         self.assertNotEqual(A, B)
 
 

+ 26 - 0
celery/tests/utils.py

@@ -498,3 +498,29 @@ def mock_module(*names):
     for name in names:
         if prev[name]:
             sys.modules[name] = prev[name]
+
+
+@contextmanager
+def mock_context(mock, typ=Mock):
+    context = mock.return_value = Mock()
+    context.__enter__ = typ()
+    context.__exit__ = typ()
+
+    def on_exit(*x):
+        if x[0]:
+            raise x[0], x[1], x[2]
+    context.__exit__.side_effect = on_exit
+    context.__enter__.return_value = context
+    yield context
+    context.reset()
+
+
+@contextmanager
+def mock_open(typ=WhateverIO, side_effect=None):
+    with mock.patch("__builtin__.open") as open_:
+        with mock_context(open_) as context:
+            if side_effect is not None:
+                context.__enter__.side_effect = side_effect
+            val = context.__enter__.return_value = typ()
+            yield val
+