test_serialization.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from __future__ import absolute_import, unicode_literals
  2. import base64
  3. import os
  4. import pytest
  5. from kombu.serialization import registry
  6. from kombu.utils.encoding import bytes_to_str
  7. from celery.exceptions import SecurityError
  8. from celery.security.certificate import Certificate, CertStore
  9. from celery.security.key import PrivateKey
  10. from celery.security.serialization import SecureSerializer, register_auth
  11. from . import CERT1, CERT2, KEY1, KEY2
  12. from .case import SecurityCase
  13. class test_SecureSerializer(SecurityCase):
  14. def _get_s(self, key, cert, certs):
  15. store = CertStore()
  16. for c in certs:
  17. store.add_cert(Certificate(c))
  18. return SecureSerializer(PrivateKey(key), Certificate(cert), store)
  19. def test_serialize(self):
  20. s = self._get_s(KEY1, CERT1, [CERT1])
  21. assert s.deserialize(s.serialize('foo')) == 'foo'
  22. def test_deserialize(self):
  23. s = self._get_s(KEY1, CERT1, [CERT1])
  24. with pytest.raises(SecurityError):
  25. s.deserialize('bad data')
  26. def test_unmatched_key_cert(self):
  27. s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
  28. with pytest.raises(SecurityError):
  29. s.deserialize(s.serialize('foo'))
  30. def test_unknown_source(self):
  31. s1 = self._get_s(KEY1, CERT1, [CERT2])
  32. s2 = self._get_s(KEY1, CERT1, [])
  33. with pytest.raises(SecurityError):
  34. s1.deserialize(s1.serialize('foo'))
  35. with pytest.raises(SecurityError):
  36. s2.deserialize(s2.serialize('foo'))
  37. def test_self_send(self):
  38. s1 = self._get_s(KEY1, CERT1, [CERT1])
  39. s2 = self._get_s(KEY1, CERT1, [CERT1])
  40. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  41. def test_separate_ends(self):
  42. s1 = self._get_s(KEY1, CERT1, [CERT2])
  43. s2 = self._get_s(KEY2, CERT2, [CERT1])
  44. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  45. def test_register_auth(self):
  46. register_auth(KEY1, CERT1, '')
  47. assert 'application/data' in registry._decoders
  48. def test_lots_of_sign(self):
  49. for i in range(1000):
  50. rdata = bytes_to_str(base64.urlsafe_b64encode(os.urandom(265)))
  51. s = self._get_s(KEY1, CERT1, [CERT1])
  52. assert s.deserialize(s.serialize(rdata)) == rdata