Browse Source

Fixes failing security test case. Closes #3624

Ask Solem 8 years ago
parent
commit
c22e683958
3 changed files with 6 additions and 8 deletions
  1. 2 5
      celery/security/serialization.py
  2. 3 2
      t/unit/security/test_key.py
  3. 1 1
      t/unit/worker/test_consumer.py

+ 2 - 5
celery/security/serialization.py

@@ -2,11 +2,10 @@
 """Secure serializer."""
 from __future__ import absolute_import, unicode_literals
 
-import sys
-
 from kombu.serialization import registry, dumps, loads
 from kombu.utils.encoding import bytes_to_str, str_to_bytes, ensure_bytes
 
+from celery.five import bytes_if_py2
 from celery.utils.serialization import b64encode, b64decode
 
 from .certificate import Certificate, FSCertStore
@@ -15,8 +14,6 @@ from .utils import reraise_errors
 
 __all__ = ['SecureSerializer', 'register_auth']
 
-PY3 = sys.version_info[0] == 3
-
 
 class SecureSerializer(object):
     """Signed serializer."""
@@ -26,7 +23,7 @@ class SecureSerializer(object):
         self._key = key
         self._cert = cert
         self._cert_store = cert_store
-        self._digest = str_to_bytes(digest) if not PY3 else digest
+        self._digest = bytes_if_py2(digest)
         self._serializer = serializer
 
     def serialize(self, data):

+ 3 - 2
t/unit/security/test_key.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import, unicode_literals
 import pytest
 from celery.exceptions import SecurityError
+from celery.five import bytes_if_py2
 from celery.security.key import PrivateKey
 from . import CERT1, KEY1, KEY2
 from .case import SecurityCase
@@ -26,6 +27,6 @@ class test_PrivateKey(SecurityCase):
 
     def test_sign(self):
         pkey = PrivateKey(KEY1)
-        pkey.sign('test', b'sha1')
+        pkey.sign('test', bytes_if_py2('sha1'))
         with pytest.raises(ValueError):
-            pkey.sign('test', b'unknown')
+            pkey.sign('test', bytes_if_py2('unknown'))

+ 1 - 1
t/unit/worker/test_consumer.py

@@ -49,7 +49,7 @@ class test_Consumer:
     def test_dump_body_buffer(self):
         msg = Mock()
         msg.body = 'str'
-        assert dump_body(msg, buffer(msg.body))
+        assert dump_body(msg, buffer(msg.body))  # noqa: F821
 
     def test_sets_heartbeat(self):
         c = self.get_consumer(amqheartbeat=10)