Sfoglia il codice sorgente

Tests passing now without enable_insecure_serializers()

Ask Solem 11 anni fa
parent
commit
3ccae08eae

+ 5 - 4
celery/backends/amqp.py

@@ -62,8 +62,7 @@ class AMQPBackend(BaseBackend):
     }
 
     def __init__(self, app, connection=None, exchange=None, exchange_type=None,
-                 persistent=None, serializer=None, auto_delete=True,
-                 accept=None, **kwargs):
+                 persistent=None, serializer=None, auto_delete=True, **kwargs):
         super(AMQPBackend, self).__init__(app, **kwargs)
         conf = self.app.conf
         self._connection = connection
@@ -74,7 +73,6 @@ class AMQPBackend(BaseBackend):
         self.exchange = self._create_exchange(exchange, exchange_type,
                                               self.persistent)
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
-        self.accept = conf.CELERY_ACCEPT_CONTENT if accept is None else accept
         self.auto_delete = auto_delete
 
         self.expires = None
@@ -152,8 +150,11 @@ class AMQPBackend(BaseBackend):
             binding.declare()
 
             prev = latest = acc = None
+            print('binding.get: %r' % (binding.get, ))
             for i in range(backlog_limit):  # spool ffwd
-                prev, latest, acc = latest, acc, binding.get(no_ack=False)
+                prev, latest, acc = latest, acc, binding.get(
+                    accept=self.accept, no_ack=False,
+                )
                 if not acc:  # no more messages
                     break
                 if prev:

+ 14 - 7
celery/backends/base.py

@@ -19,7 +19,10 @@ import sys
 from datetime import timedelta
 
 from billiard.einfo import ExceptionInfo
-from kombu import serialization
+from kombu.serialization import (
+    encode, decode, prepare_accept_encoding,
+    registry as serializer_registry,
+)
 from kombu.utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 
 from celery import states
@@ -66,16 +69,19 @@ class BaseBackend(object):
     supports_autoexpire = False
 
     def __init__(self, app, serializer=None,
-                 max_cached_results=None, **kwargs):
+                 max_cached_results=None, accept=None, **kwargs):
         self.app = app
         conf = self.app.conf
         self.serializer = serializer or conf.CELERY_RESULT_SERIALIZER
         (self.content_type,
          self.content_encoding,
-         self.encoder) = serialization.registry._encoders[self.serializer]
+         self.encoder) = serializer_registry._encoders[self.serializer]
         self._cache = LRUCache(
             limit=max_cached_results or conf.CELERY_MAX_CACHED_RESULTS,
         )
+        self.accept = prepare_accept_encoding(
+            conf.CELERY_ACCEPT_CONTENT if accept is None else accept,
+        )
 
     def mark_as_started(self, task_id, **meta):
         """Mark a task as started"""
@@ -130,14 +136,15 @@ class BaseBackend(object):
         return result
 
     def encode(self, data):
-        _, _, payload = serialization.encode(data, serializer=self.serializer)
+        _, _, payload = encode(data, serializer=self.serializer)
         return payload
 
     def decode(self, payload):
         payload = PY3 and payload or str(payload)
-        return serialization.decode(payload,
-                                    content_type=self.content_type,
-                                    content_encoding=self.content_encoding)
+        return decode(payload,
+                      content_type=self.content_type,
+                      content_encoding=self.content_encoding,
+                      accept=self.accept)
 
     def wait_for(self, task_id, timeout=None, propagate=True, interval=0.5):
         """Wait for task and return its result.

+ 2 - 0
celery/security/__init__.py

@@ -35,6 +35,8 @@ Please see the configuration reference for more information.
 def disable_untrusted_serializers(whitelist=None):
     for name in set(registry._decoders) - set(whitelist or []):
         registry.disable(name)
+    for name in whitelist or []:
+        registry.enable(name)
 
 
 def setup_security(allowed_serializers=None, key=None, cert=None, store=None,

+ 12 - 8
celery/tests/app/test_amqp.py

@@ -58,15 +58,19 @@ class test_TaskConsumer(AppCase):
 
     def test_accept_content(self):
         with self.app.pool.acquire(block=True) as conn:
+            prev = self.app.conf.CELERY_ACCEPT_CONTENT
             self.app.conf.CELERY_ACCEPT_CONTENT = ['application/json']
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn).accept,
-                set(['application/json'])
-            )
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn, accept=['json']).accept,
-                set(['application/json']),
-            )
+            try:
+                self.assertEqual(
+                    self.app.amqp.TaskConsumer(conn).accept,
+                    set(['application/json'])
+                )
+                self.assertEqual(
+                    self.app.amqp.TaskConsumer(conn, accept=['json']).accept,
+                    set(['application/json']),
+                )
+            finally:
+                self.app.conf.CELERY_ACCEPT_CONTENT = prev
 
 
 class test_compat_TaskPublisher(AppCase):

+ 5 - 2
celery/tests/backends/test_amqp.py

@@ -158,9 +158,12 @@ class test_AMQPBackend(AppCase):
             def declare(self):
                 pass
 
-            def get(self, no_ack=False):
+            def get(self, no_ack=False, accept=None):
                 try:
-                    return results.get(block=False)
+                    m = results.get(block=False)
+                    if m:
+                        m.accept = accept
+                    return m
                 except Empty:
                     pass
 

+ 1 - 0
celery/tests/backends/test_redis.py

@@ -106,6 +106,7 @@ class test_RedisBackend(AppCase):
     def test_conf_raises_KeyError(self):
         conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
                               'CELERY_MAX_CACHED_RESULTS': 1,
+                              'CELERY_ACCEPT_CONTENT': ['json'],
                               'CELERY_TASK_RESULT_EXPIRES': None})
         prev, self.app.conf = self.app.conf, conf
         try:

+ 1 - 1
celery/tests/security/test_security.py

@@ -36,7 +36,7 @@ class test_security(SecurityCase):
 
     def test_disable_untrusted_serializers(self):
         disabled = registry._disabled_content_types
-        self.assertEqual(0, len(disabled))
+        self.assertTrue(disabled)
 
         disable_untrusted_serializers(
             ['application/json', 'application/x-python-serialize'])

+ 2 - 2
celery/tests/worker/test_control.py

@@ -441,10 +441,10 @@ class test_ControlPanel(AppCase):
             r = control.revoke(Mock(), tid, terminate=True)
             self.assertIn(tid, revoked)
             self.assertTrue(request.terminate.call_count)
-            self.assertIn('terminating', r['ok'])
+            self.assertIn('terminate:', r['ok'])
             # unknown task id only revokes
             r = control.revoke(Mock(), uuid(), terminate=True)
-            self.assertIn('not found', r['ok'])
+            self.assertIn('tasks unknown', r['ok'])
         finally:
             worker_state.reserved_requests.discard(request)
 

+ 6 - 4
celery/tests/worker/test_worker.py

@@ -121,10 +121,12 @@ def foo_periodic_task():
 def create_message(channel, **data):
     data.setdefault('id', uuid())
     channel.no_ack_consumers = set()
-    return Message(channel, body=pickle.dumps(dict(**data)),
-                   content_type='application/x-python-serialize',
-                   content_encoding='binary',
-                   delivery_info={'consumer_tag': 'mock'})
+    m = Message(channel, body=pickle.dumps(dict(**data)),
+                content_type='application/x-python-serialize',
+                content_encoding='binary',
+                delivery_info={'consumer_tag': 'mock'})
+    m.accept = ['application/x-python-serialize']
+    return m
 
 
 class test_QoS(Case):