Forráskód Böngészése

Result: Async: Buffer results received out of band

Ask Solem 8 éve
szülő
commit
7a61ab73ed

+ 3 - 1
celery/backends/amqp.py

@@ -135,7 +135,9 @@ class AMQPBackend(base.Backend, AsyncBackendMixin):
         })
         self.result_consumer = self.ResultConsumer(
             self, self.app, self.accept,
-            self._pending_results, self._weak_pending_results)
+            self._pending_results, self._weak_pending_results,
+            self._pending_messages,
+        )
         if register_after_fork is not None:
             register_after_fork(self, _on_after_fork_cleanup_backend)
 

+ 17 - 6
celery/backends/async.py

@@ -18,7 +18,7 @@ from kombu.utils import cached_property
 
 from celery import states
 from celery.exceptions import TimeoutError
-from celery.five import monotonic
+from celery.five import Empty, monotonic
 
 drainers = {}
 
@@ -131,6 +131,13 @@ class AsyncBackendMixin(object):
             yield node.id, node._cache
 
     def add_pending_result(self, result, weak=False):
+        try:
+            meta = self._pending_messages.pop(result.id)
+        except Empty:
+            pass
+        else:
+            result._maybe_set_cache(meta)
+            return result
         if weak:
             dest, alt = self._weak_pending_results, self._pending_results
         else:
@@ -175,13 +182,15 @@ class AsyncBackendMixin(object):
 
 class BaseResultConsumer(object):
 
-    def __init__(self, backend, app, accept, pending_results,
-                 weak_pending_results):
+    def __init__(self, backend, app, accept,
+                 pending_results, weak_pending_results,
+                 pending_messages):
         self.backend = backend
         self.app = app
         self.accept = accept
         self._pending_results = pending_results
         self._weak_pending_results = weak_pending_results
+        self._pending_messages = pending_messages
         self.on_message = None
         self.buckets = WeakKeyDictionary()
         self.drainer = drainers[detect_environment()](self)
@@ -240,13 +249,15 @@ class BaseResultConsumer(object):
         if self.on_message:
             self.on_message(meta)
         if meta['status'] in states.READY_STATES:
+            task_id = meta['task_id']
             try:
-                result = self._pending_results[meta['task_id']]
+                result = self._pending_results[task_id]
             except KeyError:
                 try:
-                    result = self._weak_pending_results[meta['task_id']]
+                    result = self._weak_pending_results[task_id]
                 except KeyError:
-                    return
+                    # send to BufferMapping
+                    self._pending_messages.append(task_id, meta)
             result._maybe_set_cache(meta)
             buckets = self.buckets
             try:

+ 4 - 0
celery/backends/base.py

@@ -30,6 +30,7 @@ from kombu.utils.url import maybe_sanitize_url
 from celery import states
 from celery import current_app, group, maybe_signature
 from celery.app import current_task
+from celery.datastructures import BufferMapping
 from celery.exceptions import ChordError, TimeoutError, TaskRevokedError
 from celery.five import items
 from celery.result import (
@@ -50,6 +51,8 @@ PY3 = sys.version_info >= (3, 0)
 
 logger = get_logger(__name__)
 
+MESSAGE_BUFFER_MAX = 8192
+
 
 def unpickle_backend(cls, args, kwargs):
     """Return an unpickled backend."""
@@ -111,6 +114,7 @@ class Backend(object):
         )
         self._pending_results = {}
         self._weak_pending_results = WeakValueDictionary()
+        self._pending_messages = BufferMapping(MESSAGE_BUFFER_MAX)
         self.url = url
 
     def as_uri(self, include_password=False):

+ 3 - 1
celery/backends/redis.py

@@ -147,7 +147,9 @@ class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
             else ((), ()))
         self.result_consumer = self.ResultConsumer(
             self, self.app, self.accept,
-            self._pending_results, self._weak_pending_results)
+            self._pending_results, self._weak_pending_results,
+            self._pending_messages,
+        )
 
     def _params_from_url(self, url, defaults):
         scheme, host, port, user, password, path, query = _parse_url(url)

+ 208 - 3
celery/datastructures.py

@@ -12,16 +12,17 @@ import sys
 import time
 
 from collections import (
-    Callable, Mapping, MutableMapping, MutableSet, defaultdict,
+    Callable, Mapping, MutableMapping, MutableSet, Sequence,
+    OrderedDict as _OrderedDict, defaultdict, deque,
 )
 from heapq import heapify, heappush, heappop
-from itertools import chain
+from itertools import chain, count
 
 from billiard.einfo import ExceptionInfo  # noqa
 from kombu.utils.encoding import safe_str, bytes_to_str
 from kombu.utils.limits import TokenBucket  # noqa
 
-from celery.five import items, python_2_unicode_compatible, values
+from celery.five import Empty, items, python_2_unicode_compatible, values
 from celery.utils.functional import LRUCache, first, uniq  # noqa
 from celery.utils.text import match_case
 
@@ -815,3 +816,207 @@ class LimitedSet(object):
         """Compute how much is heap bigger than data [percents]."""
         return len(self._heap) * 100 / max(len(self._data), 1) - 100
 MutableSet.register(LimitedSet)
+
+
+if not hasattr(_OrderedDict, 'move_to_end'):
+
+    class OrderedDict(_OrderedDict):
+        def move_to_end(self, key, last=True):
+            link = self._OrderedDict__map[key]
+            link_prev = link[0]
+            link_next = link[1]
+            link_prev[1] = link_next
+            link_next[0] = link_prev
+            root = self._OrderedDict__root
+            if last:
+                last = root[0]
+                link[0] = last
+                link[1] = root
+                last[1] = root[0] = link
+            else:
+                first = root[1]
+                link[0] = root
+                link[1] = first
+                root[1] = first[0] = link
+
+        def _LRUkey(self):
+            return self._OrderedDict__root[1][2]
+
+else:  # pragma: no cover
+
+    class OrderedDict(object):  # noqa
+
+        def _LRUkey(self):
+            return self._OrderedDict__root.next.key
+
+
+class Evictable(object):
+
+    Empty = Empty
+
+    def evict(self):
+        """Force evict until maxsize is enforced."""
+        self._evict(range=count)
+
+    def _evict(self, limit=100, range=range):
+        try:
+            [self._evict1() for _ in range(limit)]
+        except IndexError:
+            pass
+
+    def _evict1(self):
+        if self._evictcount <= self.maxsize:
+            raise IndexError()
+        try:
+            self._pop_to_evict()
+        except self.Empty:
+            raise IndexError()
+
+
+class Messagebuffer(Evictable):
+
+    Empty = Empty
+
+    def __init__(self, maxsize, iterable=None, deque=deque):
+        self.maxsize = maxsize
+        self.data = deque(iterable or [])
+        self._append = self.data.append
+        self._pop = self.data.popleft
+        self._len = self.data.__len__
+        self._extend = self.data.extend
+
+    def append(self, item):
+        self._append(item)
+        self.maxsize and self._evict()
+
+    def extend(self, it):
+        self._extend(it)
+        self.maxsize and self._evict()
+
+    def pop(self, *default):
+        try:
+            return self._pop()
+        except IndexError:
+            if default:
+                return default[0]
+            raise self.Empty()
+
+    def _pop_to_evict(self):
+        return self.pop()
+
+    def __repr__(self):
+        return '<{0}: {1}/{2}>'.format(
+            type(self).__name__, len(self), self.maxsize,
+        )
+
+    def __iter__(self):
+        while 1:
+            try:
+                yield self._pop()
+            except IndexError:
+                break
+
+    def __len__(self):
+        return self._len()
+
+    def __contains__(self, item):
+        return item in self.data
+
+    def __reversed__(self):
+        return reversed(self.data)
+
+    def __getitem__(self, index):
+        return self.data[index]
+
+    @property
+    def _evictcount(self):
+        return len(self)
+Sequence.register(Messagebuffer)
+
+
+class BufferMapping(OrderedDict, Evictable):
+
+    Buffer = Messagebuffer
+    Empty = Empty
+
+    maxsize = None
+
+    def __init__(self, maxsize, iterable=None, bufmaxsize=1000):
+        self.maxsize = maxsize
+        self.bufmaxsize = 1000
+        super(BufferMapping, self).__init__(iterable or ())
+        self.total = sum(len(buf) for buf in values(self))
+
+    def pop(self, key, *default):
+        item, throw = None, False
+        try:
+            buf = self[key]
+        except KeyError:
+            throw = True
+        else:
+            try:
+                item = buf.pop()
+                self.total -= 1
+            except self.Empty:
+                throw = True
+            else:
+                self.move_to_end(key)  # least recently used.
+
+        if throw:
+            if default:
+                return default[0]
+            raise self.Empty()
+        return item
+
+    def get_or_create(self, key):
+        try:
+            return self[key]
+        except KeyError:
+            buf = self[key] = self.Buffer(maxsize=self.bufmaxsize)
+            return buf
+
+    def discard(self, key, *default):
+        super(BufferMapping, self).pop(key, *default)
+
+    def _LRUpop(self, *default):
+        return self[self._LRUkey()].pop(*default)
+
+    def _pop_to_evict(self):
+        for i in range(100):
+            key = self._LRUkey()
+            buf = self[key]
+            try:
+                buf.pop()
+            except (IndexError, self.Empty):
+                # buffer empty, remove it from mapping.
+                self.discard(key)
+            else:
+                # we removed one item
+                self.total -= 1
+                # if buffer is empty now, remove it from mapping.
+                if not len(buf):
+                    self.discard(key)
+                else:
+                    # move to least recently used.
+                    self.move_to_end(key)
+                break
+
+    def append(self, key, item):
+        self.get_or_create(key).append(item)
+        self.total += 1
+        self.move_to_end(key)   # least recently used.
+        self.maxsize and self._evict()
+
+    def extend(self, key, it):
+        self.get_or_create(key).extend(it)
+        self.total += len(it)
+        self.maxsize and self._evict()
+
+    def __repr__(self):
+        return '<{0}: {1}/{2}>'.format(
+            type(self).__name__, self.total, self.maxsize,
+        )
+
+    @property
+    def _evictcount(self):
+        return self.total

+ 117 - 2
celery/tests/utils/test_datastructures.py

@@ -9,11 +9,13 @@ from billiard.einfo import ExceptionInfo
 from time import time
 
 from celery.datastructures import (
-    LimitedSet,
     AttributeDict,
-    DictAttribute,
+    BufferMapping,
     ConfigurationView,
     DependencyGraph,
+    DictAttribute,
+    LimitedSet,
+    Messagebuffer,
 )
 from celery.five import WhateverIO, items
 from celery.utils.objects import Bunch
@@ -419,3 +421,116 @@ class test_DependencyGraph(Case):
         s = WhateverIO()
         self.graph1().to_dot(s)
         self.assertTrue(s.getvalue())
+
+
+class test_Messagebuffer(Case):
+
+    def assert_size_and_first(self, buf, size, expected_first_item):
+        self.assertEqual(len(buf), size)
+        self.assertEqual(buf.pop(), expected_first_item)
+
+    def test_append_limited(self):
+        b = Messagebuffer(10)
+        for i in range(20):
+            b.append(i)
+        self.assert_size_and_first(b, 10, 10)
+
+    def test_append_unlimited(self):
+        b = Messagebuffer(None)
+        for i in range(20):
+            b.append(i)
+        self.assert_size_and_first(b, 20, 0)
+
+    def test_extend_limited(self):
+        b = Messagebuffer(10)
+        b.extend(list(range(20)))
+        self.assert_size_and_first(b, 10, 10)
+
+    def test_extend_unlimited(self):
+        b = Messagebuffer(None)
+        b.extend(list(range(20)))
+        self.assert_size_and_first(b, 20, 0)
+
+    def test_extend_eviction_time_limited(self):
+        b = Messagebuffer(3000)
+        b.extend(range(10000))
+        self.assertGreater(len(b), 3000)
+        b.evict()
+        self.assertEqual(len(b), 3000)
+
+    def test_pop_empty_with_default(self):
+        b = Messagebuffer(10)
+        sentinel = object()
+        self.assertIs(b.pop(sentinel), sentinel)
+
+    def test_pop_empty_no_default(self):
+        b = Messagebuffer(10)
+        with self.assertRaises(b.Empty):
+            b.pop()
+
+    def test_repr(self):
+        self.assertTrue(repr(Messagebuffer(10, [1, 2, 3])))
+
+    def test_iter(self):
+        b = Messagebuffer(10, list(range(10)))
+        self.assertEqual(len(b), 10)
+        for i, item in enumerate(b):
+            self.assertEqual(item, i)
+        self.assertEqual(len(b), 0)
+
+    def test_contains(self):
+        b = Messagebuffer(10, list(range(10)))
+        self.assertIn(5, b)
+
+    def test_reversed(self):
+        self.assertEqual(
+            list(reversed(Messagebuffer(10, list(range(10))))),
+            list(reversed(range(10))),
+        )
+
+    def test_getitem(self):
+        b = Messagebuffer(10, list(range(10)))
+        for i in range(10):
+            self.assertEqual(b[i], i)
+
+
+class test_BufferMapping(Case):
+
+    def test_append_limited(self):
+        b = BufferMapping(10)
+        for i in range(20):
+            b.append(i, i)
+        self.assert_size_and_first(b, 10, 10)
+
+    def assert_size_and_first(self, buf, size, expected_first_item):
+        self.assertEqual(buf.total, size)
+        self.assertEqual(buf._LRUpop(), expected_first_item)
+
+    def test_append_unlimited(self):
+        b = BufferMapping(None)
+        for i in range(20):
+            b.append(i, i)
+        self.assert_size_and_first(b, 20, 0)
+
+    def test_extend_limited(self):
+        b = BufferMapping(10)
+        b.extend(1, list(range(20)))
+        self.assert_size_and_first(b, 10, 10)
+
+    def test_extend_unlimited(self):
+        b = BufferMapping(None)
+        b.extend(1, list(range(20)))
+        self.assert_size_and_first(b, 20, 0)
+
+    def test_pop_empty_with_default(self):
+        b = BufferMapping(10)
+        sentinel = object()
+        self.assertIs(b.pop(1, sentinel), sentinel)
+
+    def test_pop_empty_no_default(self):
+        b = BufferMapping(10)
+        with self.assertRaises(b.Empty):
+            b.pop(1)
+
+    def test_repr(self):
+        self.assertTrue(repr(Messagebuffer(10, [1, 2, 3])))