Explorar o código

Perform a serialization roundtrip on eager apply_async. (#4456)

Simon Charette %!s(int64=6) %!d(string=hai) anos
pai
achega
8dcc621a92
Modificáronse 3 ficheiros con 23 adicións e 1 borrados
  1. 12 0
      celery/app/task.py
  2. 3 1
      t/unit/app/test_builtins.py
  3. 8 0
      t/unit/tasks/test_tasks.py

+ 12 - 0
celery/app/task.py

@@ -5,6 +5,7 @@ from __future__ import absolute_import, unicode_literals
 import sys
 import sys
 
 
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
+from kombu import serialization
 from kombu.exceptions import OperationalError
 from kombu.exceptions import OperationalError
 from kombu.utils.uuid import uuid
 from kombu.utils.uuid import uuid
 
 
@@ -514,6 +515,17 @@ class Task(object):
 
 
         app = self._get_app()
         app = self._get_app()
         if app.conf.task_always_eager:
         if app.conf.task_always_eager:
+            with app.producer_or_acquire(producer) as eager_producer:
+                serializer = options.get(
+                    'serializer', eager_producer.serializer
+                )
+                body = args, kwargs
+                content_type, content_encoding, data = serialization.dumps(
+                    body, serializer
+                )
+                args, kwargs = serialization.loads(
+                    data, content_type, content_encoding
+                )
             with denied_join_result():
             with denied_join_result():
                 return self.apply(args, kwargs, task_id=task_id or uuid(),
                 return self.apply(args, kwargs, task_id=task_id or uuid(),
                                   link=link, link_error=link_error, **options)
                                   link=link, link_error=link_error, **options)

+ 3 - 1
t/unit/app/test_builtins.py

@@ -94,7 +94,9 @@ class test_group(BuiltinsCase):
         self.maybe_signature = self.patching('celery.canvas.maybe_signature')
         self.maybe_signature = self.patching('celery.canvas.maybe_signature')
         self.maybe_signature.side_effect = pass1
         self.maybe_signature.side_effect = pass1
         self.app.producer_or_acquire = Mock()
         self.app.producer_or_acquire = Mock()
-        self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
+        self.app.producer_or_acquire.attach_mock(
+            ContextMock(serializer='json'), 'return_value'
+        )
         self.app.conf.task_always_eager = True
         self.app.conf.task_always_eager = True
         self.task = builtins.add_group_task(self.app)
         self.task = builtins.add_group_task(self.app)
         BuiltinsCase.setup(self)
         BuiltinsCase.setup(self)

+ 8 - 0
t/unit/tasks/test_tasks.py

@@ -7,6 +7,7 @@ from datetime import datetime, timedelta
 import pytest
 import pytest
 from case import ANY, ContextMock, MagicMock, Mock, patch
 from case import ANY, ContextMock, MagicMock, Mock, patch
 from kombu import Queue
 from kombu import Queue
+from kombu.exceptions import EncodeError
 
 
 from celery import Task, group, uuid
 from celery import Task, group, uuid
 from celery.app.task import _reprtask
 from celery.app.task import _reprtask
@@ -824,6 +825,13 @@ class test_apply_async(TasksCase):
             ignore_result=False
             ignore_result=False
         )
         )
 
 
+    def test_eager_serialization_failure(self):
+        @self.app.task
+        def task(*args, **kwargs):
+            pass
+        with pytest.raises(EncodeError):
+            task.apply_async((1, 2, 3, 4, {1}))
+
     def test_task_with_ignored_result(self):
     def test_task_with_ignored_result(self):
         with patch.object(self.app, 'send_task') as send_task:
         with patch.object(self.app, 'send_task') as send_task:
             self.task_with_ignored_result.apply_async()
             self.task_with_ignored_result.apply_async()