Explorar o código

Result: .get() should not call get_task_meta after receiving the result. Closes #2245

Ask Solem %!s(int64=10) %!d(string=hai) anos
pai
achega
28771544e1
Modificáronse 4 ficheiros con 39 adicións e 42 borrados
  1. 3 8
      celery/backends/amqp.py
  2. 4 10
      celery/backends/base.py
  3. 27 18
      celery/result.py
  4. 5 6
      celery/tests/backends/test_amqp.py

+ 3 - 8
celery/backends/amqp.py

@@ -140,7 +140,7 @@ class AMQPBackend(BaseBackend):
     def on_reply_declare(self, task_id):
         return [self._create_binding(task_id)]
 
-    def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
+    def wait_for(self, task_id, timeout=None, cache=True,
                  no_ack=True, on_interval=None,
                  READY_STATES=states.READY_STATES,
                  PROPAGATE_STATES=states.PROPAGATE_STATES,
@@ -148,19 +148,14 @@ class AMQPBackend(BaseBackend):
         cached_meta = self._cache.get(task_id)
         if cache and cached_meta and \
                 cached_meta['status'] in READY_STATES:
-            meta = cached_meta
+            return cached_meta
         else:
             try:
-                meta = self.consume(task_id, timeout=timeout, no_ack=no_ack,
+                return self.consume(task_id, timeout=timeout, no_ack=no_ack,
                                     on_interval=on_interval)
             except socket.timeout:
                 raise TimeoutError('The operation timed out.')
 
-        if meta['status'] in PROPAGATE_STATES and propagate:
-            raise self.exception_to_python(meta['result'])
-        # consume() always returns READY_STATE.
-        return meta['result']
-
     def get_task_meta(self, task_id, backlog_limit=1000):
         # Polling and using basic_get
         with self.app.pool.acquire_channel(block=True) as (_, channel):

+ 4 - 10
celery/backends/base.py

@@ -189,8 +189,7 @@ class BaseBackend(object):
                      accept=self.accept)
 
     def wait_for(self, task_id,
-                 timeout=None, propagate=True, interval=0.5, no_ack=True,
-                 on_interval=None):
+                 timeout=None, interval=0.5, no_ack=True, on_interval=None):
         """Wait for task and return its result.
 
         If the task raises an exception, this exception
@@ -205,14 +204,9 @@ class BaseBackend(object):
         time_elapsed = 0.0
 
         while 1:
-            status = self.get_status(task_id)
-            if status == states.SUCCESS:
-                return self.get_result(task_id)
-            elif status in states.PROPAGATE_STATES:
-                result = self.get_result(task_id)
-                if propagate:
-                    raise result
-                return result
+            meta = self.get_task_meta(task_id)
+            if meta['status'] in states.READY_STATES:
+                return meta
             if on_interval:
                 on_interval()
             # avoid hammering the CPU checking status.

+ 27 - 18
celery/result.py

@@ -119,8 +119,10 @@ class AsyncResult(ResultBase):
                                 terminate=terminate, signal=signal,
                                 reply=wait, timeout=timeout)
 
-    def get(self, timeout=None, propagate=True, interval=0.5, no_ack=True,
-            follow_parents=True):
+    def get(self, timeout=None, propagate=True, interval=0.5,
+            no_ack=True, follow_parents=True,
+            EXCEPTION_STATES=states.EXCEPTION_STATES,
+            PROPAGATE_STATES=states.PROPAGATE_STATES):
         """Wait until task is ready, and return its result.
 
         .. warning::
@@ -159,16 +161,21 @@ class AsyncResult(ResultBase):
                 self.maybe_reraise()
             return self.result
 
-        try:
-            return self.backend.wait_for(
-                self.id, timeout=timeout,
-                propagate=propagate,
-                interval=interval,
-                on_interval=on_interval,
-                no_ack=no_ack,
-            )
-        finally:
-            self._get_task_meta()  # update self._cache
+        meta = self.backend.wait_for(
+            self.id, timeout=timeout,
+            propagate=propagate,
+            interval=interval,
+            on_interval=on_interval,
+            no_ack=no_ack,
+        )
+        if meta:
+            self._maybe_set_cache(meta)
+            status = meta['status']
+            if status in EXCEPTION_STATES:
+                return self.backend.exception_to_python(meta['result'])
+            if status in PROPAGATE_STATES and propagate:
+                raise self.backend.exception_to_python(meta['result'])
+            return meta['result']
     wait = get  # deprecated alias to :meth:`get`.
 
     def _maybe_reraise_parent_error(self):
@@ -322,14 +329,16 @@ class AsyncResult(ResultBase):
     def children(self):
         return self._get_task_meta().get('children')
 
+    def _maybe_set_cache(self, meta):
+        if meta:
+            state = meta['status']
+            if state == states.SUCCESS or state in states.PROPAGATE_STATES:
+                return self._set_cache(meta)
+        return meta
+
     def _get_task_meta(self):
         if self._cache is None:
-            meta = self.backend.get_task_meta(self.id)
-            if meta:
-                state = meta['status']
-                if state == states.SUCCESS or state in states.PROPAGATE_STATES:
-                    return self._set_cache(meta)
-            return meta
+            return self._maybe_set_cache(self.backend.get_task_meta(self.id))
         return self._cache
 
     def _set_cache(self, d):

+ 5 - 6
celery/tests/backends/test_amqp.py

@@ -234,15 +234,14 @@ class test_AMQPBackend(AppCase):
         with self.assertRaises(TimeoutError):
             b.wait_for(tid, timeout=0.1)
         b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(b.wait_for(tid, timeout=1), 42)
+        self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42)
         b.store_result(tid, 56, states.SUCCESS)
-        self.assertEqual(b.wait_for(tid, timeout=1), 42,
+        self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42,
                          'result is cached')
-        self.assertEqual(b.wait_for(tid, timeout=1, cache=False), 56)
+        self.assertEqual(b.wait_for(tid, timeout=1, cache=False)['result'], 56)
         b.store_result(tid, KeyError('foo'), states.FAILURE)
-        with self.assertRaises(KeyError):
-            b.wait_for(tid, timeout=1, cache=False)
-        self.assertTrue(b.wait_for(tid, timeout=1, propagate=False))
+        res = b.wait_for(tid, timeout=1, cache=False)
+        self.assertEqual(res['status'], states.FAILURE)
         b.store_result(tid, KeyError('foo'), states.PENDING)
         with self.assertRaises(TimeoutError):
             b.wait_for(tid, timeout=0.01, cache=False)