test_serialization.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import absolute_import, unicode_literals
  2. import base64
  3. import os
  4. import pytest
  5. from celery.exceptions import SecurityError
  6. from celery.security.certificate import Certificate, CertStore
  7. from celery.security.key import PrivateKey
  8. from celery.security.serialization import SecureSerializer, register_auth
  9. from kombu.serialization import registry
  10. from kombu.utils.encoding import bytes_to_str
  11. from . import CERT1, CERT2, KEY1, KEY2
  12. from .case import SecurityCase
  13. class test_SecureSerializer(SecurityCase):
  14. def _get_s(self, key, cert, certs):
  15. store = CertStore()
  16. for c in certs:
  17. store.add_cert(Certificate(c))
  18. return SecureSerializer(PrivateKey(key), Certificate(cert), store)
  19. def test_serialize(self):
  20. s = self._get_s(KEY1, CERT1, [CERT1])
  21. assert s.deserialize(s.serialize('foo')) == 'foo'
  22. def test_deserialize(self):
  23. s = self._get_s(KEY1, CERT1, [CERT1])
  24. with pytest.raises(SecurityError):
  25. s.deserialize('bad data')
  26. def test_unmatched_key_cert(self):
  27. s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
  28. with pytest.raises(SecurityError):
  29. s.deserialize(s.serialize('foo'))
  30. def test_unknown_source(self):
  31. s1 = self._get_s(KEY1, CERT1, [CERT2])
  32. s2 = self._get_s(KEY1, CERT1, [])
  33. with pytest.raises(SecurityError):
  34. s1.deserialize(s1.serialize('foo'))
  35. with pytest.raises(SecurityError):
  36. s2.deserialize(s2.serialize('foo'))
  37. def test_self_send(self):
  38. s1 = self._get_s(KEY1, CERT1, [CERT1])
  39. s2 = self._get_s(KEY1, CERT1, [CERT1])
  40. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  41. def test_separate_ends(self):
  42. s1 = self._get_s(KEY1, CERT1, [CERT2])
  43. s2 = self._get_s(KEY2, CERT2, [CERT1])
  44. assert s2.deserialize(s1.serialize('foo')) == 'foo'
  45. def test_register_auth(self):
  46. register_auth(KEY1, CERT1, '')
  47. assert 'application/data' in registry._decoders
  48. def test_lots_of_sign(self):
  49. for i in range(1000):
  50. rdata = bytes_to_str(base64.urlsafe_b64encode(os.urandom(265)))
  51. s = self._get_s(KEY1, CERT1, [CERT1])
  52. assert s.deserialize(s.serialize(rdata)) == rdata