Browse Source

AsyncResult now has its own cache so MAX_CACHED_RESULTS can be completely disabled

Ask Solem 11 years ago
parent
commit
633f80dc68
4 changed files with 65 additions and 26 deletions
  1. 4 1
      celery/backends/amqp.py
  2. 48 12
      celery/result.py
  3. 7 6
      celery/tests/backends/test_amqp.py
  4. 6 7
      celery/tests/tasks/test_result.py

+ 4 - 1
celery/backends/amqp.py

@@ -169,15 +169,18 @@ class AMQPBackend(BaseBackend):
 
             prev = latest = acc = None
             for i in range(backlog_limit):  # spool ffwd
-                prev, latest, acc = latest, acc, binding.get(
+                acc = binding.get(
                     accept=self.accept, no_ack=False,
                 )
                 if not acc:  # no more messages
                     break
+                if acc.payload['task_id'] == task_id:
+                    prev, latest = latest, acc
                 if prev:
                     # backends are not expected to keep history,
                     # so we delete everything except the most recent state.
                     prev.ack()
+                    prev = None
             else:
                 raise self.BacklogLimitExceeded(task_id)
 

+ 48 - 12
celery/result.py

@@ -87,6 +87,7 @@ class AsyncResult(ResultBase):
         self.backend = backend or self.app.backend
         self.task_name = task_name
         self.parent = parent
+        self._cache = None
 
     def as_tuple(self):
         parent = self.parent
@@ -153,13 +154,21 @@ class AsyncResult(ResultBase):
             on_interval = self._maybe_reraise_parent_error
             on_interval()
 
-        return self.backend.wait_for(
-            self.id, timeout=timeout,
-            propagate=propagate,
-            interval=interval,
-            on_interval=on_interval,
-            no_ack=no_ack,
-        )
+        if self._cache:
+            if propagate:
+                self.maybe_reraise()
+            return self.result
+
+        try:
+            return self.backend.wait_for(
+                self.id, timeout=timeout,
+                propagate=propagate,
+                interval=interval,
+                on_interval=on_interval,
+                no_ack=no_ack,
+            )
+        finally:
+            self._get_task_meta()  # update self._cache
     wait = get  # deprecated alias to :meth:`get`.
 
     def _maybe_reraise_parent_error(self):
@@ -298,6 +307,9 @@ class AsyncResult(ResultBase):
     def __reduce_args__(self):
         return self.id, self.backend, self.task_name, None, self.parent
 
+    def __del__(self):
+        self._cache = None
+
     @cached_property
     def graph(self):
         return self.build_graph()
@@ -308,22 +320,42 @@ class AsyncResult(ResultBase):
 
     @property
     def children(self):
-        children = self.backend.get_children(self.id)
+        return self._get_task_meta().get('children')
+
+    def _get_task_meta(self):
+        if self._cache is None:
+            meta = self.backend.get_task_meta(self.id)
+            if meta:
+                state = meta['status']
+                if state == states.SUCCESS or state in states.PROPAGATE_STATES:
+                    self._set_cache(meta)
+                    return self._set_cache(meta)
+            return meta
+        return self._cache
+
+    def _set_cache(self, d):
+        state, children = d['status'], d.get('children')
+        if state in states.EXCEPTION_STATES:
+            d['result'] = self.backend.exception_to_python(d['result'])
         if children:
-            return [result_from_tuple(child, self.app) for child in children]
+            d['children'] = [
+                result_from_tuple(child, self.app) for child in children
+            ]
+        self._cache = d
+        return d
 
     @property
     def result(self):
         """When the task has been executed, this contains the return value.
         If the task raised an exception, this will be the exception
         instance."""
-        return self.backend.get_result(self.id)
+        return self._get_task_meta()['result']
     info = result
 
     @property
     def traceback(self):
         """Get the traceback of a failed task."""
-        return self.backend.get_traceback(self.id)
+        return self._get_task_meta().get('traceback')
 
     @property
     def state(self):
@@ -355,7 +387,7 @@ class AsyncResult(ResultBase):
                 then contains the tasks return value.
 
         """
-        return self.backend.get_status(self.id)
+        return self._get_task_meta()['status']
     status = state
 
     @property
@@ -802,6 +834,10 @@ class EagerResult(AsyncResult):
         self._state = state
         self._traceback = traceback
 
+    def _get_task_meta(self):
+        return {'task_id': self.id, 'result': self._result, 'status':
+                self._state, 'traceback': self._traceback}
+
     def __reduce__(self):
         return self.__class__, self.__reduce_args__()
 

+ 7 - 6
celery/tests/backends/test_amqp.py

@@ -183,29 +183,30 @@ class test_AMQPBackend(AppCase):
     def test_backlog_limit_exceeded(self):
         with self._result_context() as (results, backend, Message):
             for i in range(1001):
-                results.put(Message(status=states.RECEIVED))
+                results.put(Message(task_id='id', status=states.RECEIVED))
             with self.assertRaises(backend.BacklogLimitExceeded):
                 backend.get_task_meta('id')
 
     def test_poll_result(self):
         with self._result_context() as (results, backend, Message):
+            tid = uuid()
             # FFWD's to the latest state.
             state_messages = [
-                Message(status=states.RECEIVED, seq=1),
-                Message(status=states.STARTED, seq=2),
-                Message(status=states.FAILURE, seq=3),
+                Message(task_id=tid, status=states.RECEIVED, seq=1),
+                Message(task_id=tid, status=states.STARTED, seq=2),
+                Message(task_id=tid, status=states.FAILURE, seq=3),
             ]
             for state_message in state_messages:
                 results.put(state_message)
-            r1 = backend.get_task_meta(uuid())
+            r1 = backend.get_task_meta(tid)
             self.assertDictContainsSubset(
                 {'status': states.FAILURE, 'seq': 3}, r1,
                 'FFWDs to the last state',
             )
 
             # Caches last known state.
-            results.put(Message())
             tid = uuid()
+            results.put(Message(task_id=tid))
             backend.get_task_meta(tid)
             self.assertIn(tid, backend._cache, 'Caches last known state')
 

+ 6 - 7
celery/tests/tasks/test_result.py

@@ -66,15 +66,15 @@ class test_AsyncResult(AppCase):
     def test_children(self):
         x = self.app.AsyncResult('1')
         children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x._cache = {'children': children, 'status': states.SUCCESS}
         x.backend = Mock()
-        x.backend.get_children.return_value = children
-        x.backend.READY_STATES = states.READY_STATES
         self.assertTrue(x.children)
         self.assertEqual(len(x.children), 3)
 
     def test_propagates_for_parent(self):
         x = self.app.AsyncResult(uuid())
         x.backend = Mock()
+        x.backend.get_task_meta.return_value = {}
         x.parent = EagerResult(uuid(), KeyError('foo'), states.FAILURE)
         with self.assertRaises(KeyError):
             x.get(propagate=True)
@@ -89,10 +89,11 @@ class test_AsyncResult(AppCase):
         x = self.app.AsyncResult(tid)
         child = [self.app.AsyncResult(uuid()).as_tuple()
                  for i in range(10)]
-        x.backend._cache[tid] = {'children': child}
+        x._cache = {'children': child}
         self.assertTrue(x.children)
         self.assertEqual(len(x.children), 10)
 
+        x._cache = {'status': states.SUCCESS}
         x.backend._cache[tid] = {'result': None}
         self.assertIsNone(x.children)
 
@@ -122,13 +123,11 @@ class test_AsyncResult(AppCase):
 
     def test_iterdeps(self):
         x = self.app.AsyncResult('1')
-        x.backend._cache['1'] = {'status': states.SUCCESS, 'result': None}
         c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+        x._cache = {'status': states.SUCCESS, 'result': None, 'children': c}
         for child in c:
             child.backend = Mock()
             child.backend.get_children.return_value = []
-        x.backend.get_children = Mock()
-        x.backend.get_children.return_value = c
         it = x.iterdeps()
         self.assertListEqual(list(it), [
             (None, x),
@@ -136,7 +135,7 @@ class test_AsyncResult(AppCase):
             (x, c[1]),
             (x, c[2]),
         ])
-        x.backend._cache.pop('1')
+        x._cache = None
         x.ready = Mock()
         x.ready.return_value = False
         with self.assertRaises(IncompleteStream):