소스 검색

Tests passing

Ask Solem 9 년 전
부모
커밋
072ad1937f
5개의 변경된 파일35개의 추가작업 그리고 156개의 파일을 삭제
  1. 4 0
      celery/backends/async.py
  2. 4 0
      celery/backends/base.py
  3. 17 6
      celery/result.py
  4. 0 141
      celery/tests/backends/test_amqp.py
  5. 10 9
      celery/tests/tasks/test_result.py

+ 4 - 0
celery/backends/async.py

@@ -140,6 +140,10 @@ class AsyncBackendMixin(object):
             callback=callback, on_message=on_message, propagate=propagate,
         )
 
+    @property
+    def is_async(self):
+        return True
+
 
 class BaseResultConsumer(object):
 

+ 4 - 0
celery/backends/base.py

@@ -427,6 +427,10 @@ class SyncBackendMixin(object):
     def remove_pending_result(self, result):
         return result
 
+    @property
+    def is_async(self):
+        return False
+
 
 class BaseBackend(Backend, SyncBackendMixin):
     pass

+ 17 - 6
celery/result.py

@@ -168,7 +168,7 @@ class AsyncResult(ResultBase):
 
         if self._cache:
             if propagate:
-                self.maybe_throw()
+                self.maybe_throw(callback=callback)
             return self.result
 
         self.backend.add_pending_result(self)
@@ -178,6 +178,7 @@ class AsyncResult(ResultBase):
             on_interval=_on_interval,
             no_ack=no_ack,
             propagate=propagate,
+            callback=callback,
         )
     wait = get  # deprecated alias to :meth:`get`.
 
@@ -436,9 +437,10 @@ class ResultSet(ResultBase):
         self._app = app
         self._cache = None
         self.results = results
-        self._on_full = ready_barrier or barrier(self.results)
-        self._on_full.then(promise(self._on_ready))
         self.on_ready = promise()
+        self._on_full = ready_barrier
+        if self._on_full:
+            self._on_full.then(promise(self.on_ready))
 
     def add(self, result):
         """Add :class:`AsyncResult` as a new member of the set.
@@ -448,12 +450,14 @@ class ResultSet(ResultBase):
         """
         if result not in self.results:
             self.results.append(result)
-            self.ready.add(result)
+            if self._on_full:
+                self._on_full.add(result)
 
     def _on_ready(self):
         self.backend.remove_pending_result(self)
-        self._cache = [r.get() for r in self.results]
-        self.on_ready(self)
+        if self.backend.is_async:
+            self._cache = [r.get() for r in self.results]
+            self.on_ready(self)
 
     def remove(self, result):
         """Remove result from the set; it must be a member.
@@ -867,9 +871,16 @@ class EagerResult(AsyncResult):
         return self.on_ready.then(callback, on_error)
 
     def _get_task_meta(self):
+        return self._cache
+
+    @property
+    def _cache(self):
         return {'task_id': self.id, 'result': self._result, 'status':
                 self._state, 'traceback': self._traceback}
 
+    def __del__(self):
+        pass
+
     def __reduce__(self):
         return self.__class__, self.__reduce_args__()
 

+ 0 - 141
celery/tests/backends/test_amqp.py

@@ -239,31 +239,6 @@ class test_AMQPBackend(AppCase):
                 'Returns cache if no new states',
             )
 
-    def test_wait_for(self):
-        b = self.create_backend()
-
-        tid = uuid()
-        with self.assertRaises(TimeoutError):
-            b.wait_for(tid, timeout=0.1)
-        b.store_result(tid, None, states.STARTED)
-        with self.assertRaises(TimeoutError):
-            b.wait_for(tid, timeout=0.1)
-        b.store_result(tid, None, states.RETRY)
-        with self.assertRaises(TimeoutError):
-            b.wait_for(tid, timeout=0.1)
-        b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42)
-        b.store_result(tid, 56, states.SUCCESS)
-        self.assertEqual(b.wait_for(tid, timeout=1)['result'], 42,
-                         'result is cached')
-        self.assertEqual(b.wait_for(tid, timeout=1, cache=False)['result'], 56)
-        b.store_result(tid, KeyError('foo'), states.FAILURE)
-        res = b.wait_for(tid, timeout=1, cache=False)
-        self.assertEqual(res['status'], states.FAILURE)
-        b.store_result(tid, KeyError('foo'), states.PENDING)
-        with self.assertRaises(TimeoutError):
-            b.wait_for(tid, timeout=0.01, cache=False)
-
     def test_drain_events_decodes_exceptions_in_meta(self):
         tid = uuid()
         b = self.create_backend(serializer="json")
@@ -276,122 +251,6 @@ class test_AMQPBackend(AppCase):
         self.assertEqual(cm.exception.__class__.__name__, "RuntimeError")
         self.assertEqual(str(cm.exception), "aap")
 
-    def test_drain_events_remaining_timeouts(self):
-        class Connection(object):
-            def drain_events(self, timeout=None):
-                pass
-
-        b = self.create_backend()
-        with self.app.pool.acquire_channel(block=False) as (_, channel):
-            binding = b._create_binding(uuid())
-            consumer = b.Consumer(channel, binding, no_ack=True)
-            callback = Mock()
-            with self.assertRaises(socket.timeout):
-                b.drain_events(Connection(), consumer, timeout=0.1,
-                               on_interval=callback)
-                callback.assert_called_with()
-
-    def test_get_many(self):
-        b = self.create_backend(max_cached_results=10)
-
-        tids = []
-        for i in range(10):
-            tid = uuid()
-            b.store_result(tid, i, states.SUCCESS)
-            tids.append(tid)
-
-        res = list(b.get_many(tids, timeout=1))
-        expected_results = [
-            (_tid, {'status': states.SUCCESS,
-                    'result': i,
-                    'traceback': None,
-                    'task_id': _tid,
-                    'children': None})
-            for i, _tid in enumerate(tids)
-        ]
-        self.assertEqual(sorted(res), sorted(expected_results))
-        self.assertDictEqual(b._cache[res[0][0]], res[0][1])
-        cached_res = list(b.get_many(tids, timeout=1))
-        self.assertEqual(sorted(cached_res), sorted(expected_results))
-
-        # times out when not ready in cache (this shouldn't happen)
-        b._cache[res[0][0]]['status'] = states.RETRY
-        with self.assertRaises(socket.timeout):
-            list(b.get_many(tids, timeout=0.01))
-
-        # times out when result not yet ready
-        with self.assertRaises(socket.timeout):
-            tids = [uuid()]
-            b.store_result(tids[0], i, states.PENDING)
-            list(b.get_many(tids, timeout=0.01))
-
-    def test_get_many_on_message(self):
-        b = self.create_backend(max_cached_results=10)
-
-        tids = []
-        for i in range(10):
-            tid = uuid()
-            b.store_result(tid, '', states.PENDING)
-            b.store_result(tid, 'comment_%i_1' % i, states.STARTED)
-            b.store_result(tid, 'comment_%i_2' % i, states.STARTED)
-            b.store_result(tid, 'final result %i' % i, states.SUCCESS)
-            tids.append(tid)
-
-        expected_messages = {}
-        for i, _tid in enumerate(tids):
-            expected_messages[_tid] = []
-            expected_messages[_tid].append((states.PENDING, ''))
-            expected_messages[_tid].append(
-                (states.STARTED, 'comment_%i_1' % i),
-            )
-            expected_messages[_tid].append(
-                (states.STARTED, 'comment_%i_2' % i),
-            )
-            expected_messages[_tid].append(
-                (states.SUCCESS, 'final result %i' % i),
-            )
-
-        on_message_results = {}
-
-        def on_message(body):
-            if not body['task_id'] in on_message_results:
-                on_message_results[body['task_id']] = []
-            on_message_results[body['task_id']].append(
-                (body['status'], body['result']),
-            )
-
-        list(b.get_many(tids, timeout=1, on_message=on_message))
-        self.assertEqual(sorted(on_message_results), sorted(expected_messages))
-
-    def test_get_many_raises_outer_block(self):
-
-        class Backend(AMQPBackend):
-
-            def Consumer(*args, **kwargs):
-                raise KeyError('foo')
-
-        b = Backend(self.app)
-        with self.assertRaises(KeyError):
-            next(b.get_many(['id1']))
-
-    def test_get_many_raises_inner_block(self):
-        with patch('kombu.connection.Connection.drain_events') as drain:
-            drain.side_effect = KeyError('foo')
-            b = AMQPBackend(self.app)
-            with self.assertRaises(KeyError):
-                next(b.get_many(['id1']))
-
-    def test_consume_raises_inner_block(self):
-        with patch('kombu.connection.Connection.drain_events') as drain:
-
-            def se(*args, **kwargs):
-                drain.side_effect = ValueError()
-                raise KeyError('foo')
-            drain.side_effect = se
-            b = AMQPBackend(self.app)
-            with self.assertRaises(ValueError):
-                next(b.consume('id1'))
-
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         app = self.app

+ 10 - 9
celery/tests/tasks/test_result.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 from contextlib import contextmanager
 
 from celery import states
+from celery.backends.base import SyncBackendMixin
 from celery.exceptions import (
     ImproperlyConfigured, IncompleteStream, TimeoutError,
 )
@@ -100,17 +101,15 @@ class test_AsyncResult(AppCase):
         x = self.app.AsyncResult(uuid())
         x.backend = Mock(name='backend')
         x.backend.get_task_meta.return_value = {}
-        x.backend.wait_for.return_value = {
-            'status': states.SUCCESS, 'result': 84,
-        }
+        x.backend.wait_for_pending.return_value = 84
         x.parent = EagerResult(uuid(), KeyError('foo'), states.FAILURE)
         with self.assertRaises(KeyError):
             x.get(propagate=True)
-        self.assertFalse(x.backend.wait_for.called)
+        self.assertFalse(x.backend.wait_for_pending.called)
 
         x.parent = EagerResult(uuid(), 42, states.SUCCESS)
         self.assertEqual(x.get(propagate=True), 84)
-        self.assertTrue(x.backend.wait_for.called)
+        self.assertTrue(x.backend.wait_for_pending.called)
 
     def test_get_children(self):
         tid = uuid()
@@ -477,7 +476,7 @@ class MockAsyncResultSuccess(AsyncResult):
         return self.result
 
 
-class SimpleBackend(object):
+class SimpleBackend(SyncBackendMixin):
         ids = []
 
         def __init__(self, ids=[]):
@@ -676,10 +675,12 @@ class test_GroupResult(AppCase):
     def test_failed(self):
         self.assertFalse(self.ts.failed())
 
-    def test_maybe_reraise(self):
+    def test_maybe_throw(self):
         self.ts.results = [Mock(name='r1')]
-        self.ts.maybe_reraise()
-        self.ts.results[0].maybe_reraise.assert_called_with()
+        self.ts.maybe_throw()
+        self.ts.results[0].maybe_throw.assert_called_with(
+            callback=None, propagate=True,
+        )
 
     def test_join__on_message(self):
         with self.assertRaises(ImproperlyConfigured):