test_serialization.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import base64
  2. import os
  3. import pytest
  4. from kombu.serialization import registry
  5. from kombu.utils.encoding import bytes_to_str
  6. from celery.exceptions import SecurityError
  7. from celery.security.serialization import SecureSerializer, register_auth
  8. from celery.security.certificate import Certificate, CertStore
  9. from celery.security.key import PrivateKey
  10. from . import CERT1, CERT2, KEY1, KEY2
  11. from .case import SecurityCase
  12. class test_SecureSerializer(SecurityCase):
  13. def _get_s(self, key, cert, certs):
  14. store = CertStore()
  15. for c in certs:
  16. store.add_cert(Certificate(c))
  17. return SecureSerializer(PrivateKey(key), Certificate(cert), store)
  18. def test_serialize(self):
  19. s = self._get_s(KEY1, CERT1, [CERT1])
  20. assert s.deserialize(s.serialize('foo')) == 'foo'
  21. def test_deserialize(self):
  22. s = self._get_s(KEY1, CERT1, [CERT1])
  23. with pytest.raises(SecurityError):
  24. s.deserialize('bad data')
  25. def test_unmatched_key_cert(self):
  26. s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
  27. with pytest.raises(SecurityError):
  28. s.deserialize(s.serialize('foo'))
  29. def test_unknown_source(self):
  30. s1 = self._get_s(KEY1, CERT1, [CERT2])
  31. s2 = self._get_s(KEY1, CERT1, [])
  32. with pytest.raises(SecurityError):
  33. s1.deserialize(s1.serialize('foo'))
  34. with pytest.raises(SecurityError):
  35. s2.deserialize(s2.serialize('foo'))
  36. def test_self_send(self):
  37. s1 = self._get_s(KEY1, CERT1, [CERT1])
  38. s2 = self._get_s(KEY1, CERT1, [CERT1])
  39. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  40. def test_separate_ends(self):
  41. s1 = self._get_s(KEY1, CERT1, [CERT2])
  42. s2 = self._get_s(KEY2, CERT2, [CERT1])
  43. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  44. def test_register_auth(self):
  45. register_auth(KEY1, CERT1, '')
  46. assert 'application/data' in registry._decoders
  47. def test_lots_of_sign(self):
  48. for i in range(1000):
  49. rdata = bytes_to_str(base64.urlsafe_b64encode(os.urandom(265)))
  50. s = self._get_s(KEY1, CERT1, [CERT1])
  51. assert s.deserialize(s.serialize(rdata)) == rdata