Browse Source

Fix Request to pass Context to backend store_result functions. (#5068)

* Fix Request to pass Context to backend store_result functions.

* Insert a blank line.

* Update the test to expect new values.

* Rename Request.request to Request._context.

* Fix test_request.py to follow the function renaming.

* Add a docstring for Request._context.
Kiyohiro Yamaguchi 6 years ago
parent
commit
8d3c694ef8
2 changed files with 28 additions and 10 deletions
  1. 22 4
      celery/worker/request.py
  2. 6 6
      t/unit/worker/test_request.py

+ 22 - 4
celery/worker/request.py

@@ -17,6 +17,7 @@ from kombu.utils.encoding import safe_repr, safe_str
 from kombu.utils.objects import cached_property
 
 from celery import signals
+from celery.app.task import Context
 from celery.app.trace import trace_task, trace_task_ret
 from celery.exceptions import (Ignore, InvalidTaskError, Reject, Retry,
                                TaskRevokedError, Terminated,
@@ -260,11 +261,12 @@ class Request(object):
         self.send_event('task-revoked',
                         terminated=terminated, signum=signum, expired=expired)
         self.task.backend.mark_as_revoked(
-            self.id, reason, request=self, store_result=self.store_errors,
+            self.id, reason, request=self._context,
+            store_result=self.store_errors,
         )
         self.acknowledge()
         self._already_revoked = True
-        send_revoked(self.task, request=self,
+        send_revoked(self.task, request=self._context,
                      terminated=terminated, signum=signum, expired=expired)
 
     def revoked(self):
@@ -312,7 +314,8 @@ class Request(object):
             exc = TimeLimitExceeded(timeout)
 
             self.task.backend.mark_as_failure(
-                self.id, exc, request=self, store_result=self.store_errors,
+                self.id, exc, request=self._context,
+                store_result=self.store_errors,
             )
 
             if self.task.acks_late and self.task.acks_on_failure_or_timeout:
@@ -364,7 +367,8 @@ class Request(object):
             send_failed_event = False  # already sent revoked event
         elif isinstance(exc, WorkerLostError) or not return_ok:
             self.task.backend.mark_as_failure(
-                self.id, exc, request=self, store_result=self.store_errors,
+                self.id, exc, request=self._context,
+                store_result=self.store_errors,
             )
         # (acks_late) acknowledge after result stored.
         if self.task.acks_late:
@@ -502,6 +506,20 @@ class Request(object):
         # by parent process
         return self.request_dict.get('group')
 
+    @cached_property
+    def _context(self):
+        """Context (:class:`~celery.app.task.Context`) of this task."""
+        request = self.request_dict
+        # pylint: disable=unpacking-non-sequence
+        #    payload is a property, so pylint doesn't think it's a tuple.
+        args, kwargs, embed = self._payload
+        request.update({
+            'hostname': self.hostname,
+            'args': args,
+            'kwargs': kwargs
+        }, **embed or {})
+        return Context(request)
+
 
 def create_request_cls(base, task, pool, hostname, eventer,
                        ref=ref, revoked_tasks=revoked_tasks,

+ 6 - 6
t/unit/worker/test_request.py

@@ -410,7 +410,7 @@ class test_Request(RequestCase):
         job = self.get_request(self.mytask.s(1, f='x'))
         job._apply_result = Mock(name='_apply_result')
         with self.assert_signal_called(
-                task_revoked, sender=job.task, request=job,
+                task_revoked, sender=job.task, request=job._context,
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
             job.worker_pid = 314
@@ -426,7 +426,7 @@ class test_Request(RequestCase):
         signum = signal.SIGTERM
         job = self.get_request(self.mytask.s(1, f='x'))
         with self.assert_signal_called(
-                task_revoked, sender=job.task, request=job,
+                task_revoked, sender=job.task, request=job._context,
                 terminated=True, expired=False, signum=signum):
             job.time_start = monotonic()
             job.worker_pid = 313
@@ -447,7 +447,7 @@ class test_Request(RequestCase):
             expires=datetime.utcnow() - timedelta(days=1)
         ))
         with self.assert_signal_called(
-                task_revoked, sender=job.task, request=job,
+                task_revoked, sender=job.task, request=job._context,
                 terminated=False, expired=True, signum=None):
             job.revoked()
             assert job.id in revoked
@@ -479,7 +479,7 @@ class test_Request(RequestCase):
     def test_revoked(self):
         job = self.xRequest()
         with self.assert_signal_called(
-                task_revoked, sender=job.task, request=job,
+                task_revoked, sender=job.task, request=job._context,
                 terminated=False, expired=False, signum=None):
             revoked.add(job.id)
             assert job.revoked()
@@ -528,7 +528,7 @@ class test_Request(RequestCase):
         pool = Mock()
         job = self.xRequest()
         with self.assert_signal_called(
-                task_revoked, sender=job.task, request=job,
+                task_revoked, sender=job.task, request=job._context,
                 terminated=True, expired=False, signum=signum):
             job.terminate(pool, signal='TERM')
             assert not pool.terminate_job.call_count
@@ -933,7 +933,7 @@ class test_Request(RequestCase):
         exc = WorkerLostError()
         job = self._test_on_failure(exc)
         job.task.backend.mark_as_failure.assert_called_with(
-            job.id, exc, request=job, store_result=True,
+            job.id, exc, request=job._context, store_result=True,
         )
 
     def test_on_failure__return_ok(self):