|
@@ -19,9 +19,10 @@ from __future__ import absolute_import
|
|
|
from mock import Mock, patch
|
|
|
|
|
|
from celery import current_app
|
|
|
-from celery.exceptions import ImproperlyConfigured
|
|
|
+from celery.exceptions import ImproperlyConfigured, SecurityError
|
|
|
from celery.five import builtins
|
|
|
from celery.security import setup_security, disable_untrusted_serializers
|
|
|
+from celery.security.utils import reraise_errors
|
|
|
from kombu.serialization import registry
|
|
|
|
|
|
from .case import SecurityCase
|
|
@@ -74,10 +75,12 @@ class test_security(SecurityCase):
|
|
|
calls[0] += 1
|
|
|
|
|
|
with mock_open(side_effect=effect):
|
|
|
- store = Mock()
|
|
|
- setup_security(['json'], key, cert, store)
|
|
|
- dis.assert_called_with(['json'])
|
|
|
- reg.assert_called_with('A', 'B', store, 'sha1', 'json')
|
|
|
+ with patch('celery.security.registry') as registry:
|
|
|
+ store = Mock()
|
|
|
+ setup_security(['json'], key, cert, store)
|
|
|
+ dis.assert_called_with(['json'])
|
|
|
+ reg.assert_called_with('A', 'B', store, 'sha1', 'json')
|
|
|
+ registry._set_default_serializer.assert_called_with('auth')
|
|
|
|
|
|
def test_security_conf(self):
|
|
|
current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
|
|
@@ -94,3 +97,11 @@ class test_security(SecurityCase):
|
|
|
builtins.__import__ = import_hook
|
|
|
self.assertRaises(ImproperlyConfigured, setup_security)
|
|
|
builtins.__import__ = _import
|
|
|
+
|
|
|
+ def test_reraise_errors(self):
|
|
|
+ with self.assertRaises(SecurityError):
|
|
|
+ with reraise_errors(errors=(KeyError, )):
|
|
|
+ raise KeyError('foo')
|
|
|
+ with self.assertRaises(KeyError):
|
|
|
+ with reraise_errors(errors=(ValueError, )):
|
|
|
+ raise KeyError('bar')
|