Browse Source

Merge pull request #2691 from sukrit007/patch_json-backend-decoding-issue

Fix de-serialization of exception (when used with json serializer)
Dmitry Malinovsky 9 years ago
parent
commit
30c9cb2cc1
2 changed files with 59 additions and 7 deletions
  1. 2 1
      celery/backends/amqp.py
  2. 57 6
      celery/tests/backends/test_amqp.py

+ 2 - 1
celery/backends/amqp.py

@@ -180,7 +180,8 @@ class AMQPBackend(BaseBackend):
                 raise self.BacklogLimitExceeded(task_id)
 
             if latest:
-                payload = self._cache[task_id] = latest.payload
+                payload = self._cache[task_id] = \
+                    self.meta_from_decoded(latest.payload)
                 latest.requeue()
                 return payload
             else:

+ 57 - 6
celery/tests/backends/test_amqp.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import
 
+import json
 import pickle
 import socket
 
@@ -129,7 +130,7 @@ class test_AMQPBackend(AppCase):
         self.assertState(b.get_task_meta(uuid()), states.PENDING)
 
     @contextmanager
-    def _result_context(self):
+    def _result_context(self, serializer='pickle'):
         results = Queue()
 
         class Message(object):
@@ -139,9 +140,13 @@ class test_AMQPBackend(AppCase):
             def __init__(self, **merge):
                 self.payload = dict({'status': states.STARTED,
                                      'result': None}, **merge)
-                self.body = pickle.dumps(self.payload)
-                self.content_type = 'application/x-python-serialize'
-                self.content_encoding = 'binary'
+                if serializer == 'json':
+                    self.body = json.dumps(self.payload)
+                    self.content_type = 'application/json'
+                else:
+                    self.body = pickle.dumps(self.payload)
+                    self.content_type = 'application/x-python-serialize'
+                    self.content_encoding = 'binary'
 
             def ack(self, *args, **kwargs):
                 self.acked += 1
@@ -176,6 +181,7 @@ class test_AMQPBackend(AppCase):
             Queue = MockBinding
 
         backend = MockBackend(self.app, max_cached_results=100)
+        backend.serializer = serializer
         backend._republish = Mock()
 
         yield results, backend, Message
@@ -200,8 +206,10 @@ class test_AMQPBackend(AppCase):
                 results.put(state_message)
             r1 = backend.get_task_meta(tid)
             self.assertDictContainsSubset(
-                {'status': states.FAILURE, 'seq': 3}, r1,
-                'FFWDs to the last state',
+                {
+                    'status': states.FAILURE,
+                    'seq': 3
+                }, r1, 'FFWDs to the last state',
             )
 
             # Caches last known state.
@@ -221,6 +229,49 @@ class test_AMQPBackend(AppCase):
                 'Returns cache if no new states',
             )
 
+    def test_poll_result_for_json_serializer(self):
+        with self._result_context(serializer='json') as \
+                (results, backend, Message):
+            tid = uuid()
+            # FFWD's to the latest state.
+            state_messages = [
+                Message(task_id=tid, status=states.RECEIVED, seq=1),
+                Message(task_id=tid, status=states.STARTED, seq=2),
+                Message(task_id=tid, status=states.FAILURE, seq=3,
+                        result={
+                            'exc_type': 'RuntimeError',
+                            'exc_message': 'Mock'
+                        }),
+                ]
+            for state_message in state_messages:
+                results.put(state_message)
+            r1 = backend.get_task_meta(tid)
+            self.assertDictContainsSubset(
+                {
+                    'status': states.FAILURE,
+                    'seq': 3
+                }, r1, 'FFWDs to the last state',
+            )
+            self.assertEquals(type(r1['result']).__name__, 'RuntimeError')
+            self.assertEqual(str(r1['result']), 'Mock')
+
+            # Caches last known state.
+            tid = uuid()
+            results.put(Message(task_id=tid))
+            backend.get_task_meta(tid)
+            self.assertIn(tid, backend._cache, 'Caches last known state')
+
+            self.assertTrue(state_messages[-1].requeued)
+
+            # Returns cache if no new states.
+            results.queue.clear()
+            assert not results.qsize()
+            backend._cache[tid] = 'hello'
+            self.assertEqual(
+                backend.get_task_meta(tid), 'hello',
+                'Returns cache if no new states',
+                )
+
     def test_wait_for(self):
         b = self.create_backend()