Jelajahi Sumber

Perform a serialization roundtrip on eager apply_async. (#4456)

Simon Charette 6 tahun lalu
induk
melakukan
8dcc621a92
3 mengubah file dengan 23 tambahan dan 1 penghapusan
  1. 12 0
      celery/app/task.py
  2. 3 1
      t/unit/app/test_builtins.py
  3. 8 0
      t/unit/tasks/test_tasks.py

+ 12 - 0
celery/app/task.py

@@ -5,6 +5,7 @@ from __future__ import absolute_import, unicode_literals
 import sys
 
 from billiard.einfo import ExceptionInfo
+from kombu import serialization
 from kombu.exceptions import OperationalError
 from kombu.utils.uuid import uuid
 
@@ -514,6 +515,17 @@ class Task(object):
 
         app = self._get_app()
         if app.conf.task_always_eager:
+            with app.producer_or_acquire(producer) as eager_producer:
+                serializer = options.get(
+                    'serializer', eager_producer.serializer
+                )
+                body = args, kwargs
+                content_type, content_encoding, data = serialization.dumps(
+                    body, serializer
+                )
+                args, kwargs = serialization.loads(
+                    data, content_type, content_encoding
+                )
             with denied_join_result():
                 return self.apply(args, kwargs, task_id=task_id or uuid(),
                                   link=link, link_error=link_error, **options)

+ 3 - 1
t/unit/app/test_builtins.py

@@ -94,7 +94,9 @@ class test_group(BuiltinsCase):
         self.maybe_signature = self.patching('celery.canvas.maybe_signature')
         self.maybe_signature.side_effect = pass1
         self.app.producer_or_acquire = Mock()
-        self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
+        self.app.producer_or_acquire.attach_mock(
+            ContextMock(serializer='json'), 'return_value'
+        )
         self.app.conf.task_always_eager = True
         self.task = builtins.add_group_task(self.app)
         BuiltinsCase.setup(self)

+ 8 - 0
t/unit/tasks/test_tasks.py

@@ -7,6 +7,7 @@ from datetime import datetime, timedelta
 import pytest
 from case import ANY, ContextMock, MagicMock, Mock, patch
 from kombu import Queue
+from kombu.exceptions import EncodeError
 
 from celery import Task, group, uuid
 from celery.app.task import _reprtask
@@ -824,6 +825,13 @@ class test_apply_async(TasksCase):
             ignore_result=False
         )
 
+    def test_eager_serialization_failure(self):
+        @self.app.task
+        def task(*args, **kwargs):
+            pass
+        with pytest.raises(EncodeError):
+            task.apply_async((1, 2, 3, 4, {1}))
+
     def test_task_with_ignored_result(self):
         with patch.object(self.app, 'send_task') as send_task:
             self.task_with_ignored_result.apply_async()