Browse Source

Complex LRUCache implementation

Ask Solem 14 năm trước cách đây
mục cha
commit
8f96b21689

+ 2 - 2
celery/__init__.py

@@ -24,10 +24,10 @@ def Celery(*args, **kwargs):
     return App(*args, **kwargs)
 
 if not os.environ.get("CELERY_NO_EVAL", False):
-    from celery.local import LocalProxy
+    from celery.local import Proxy
 
     def _get_current_app():
         from celery.app import current_app
         return current_app()
 
-    current_app = LocalProxy(_get_current_app)
+    current_app = Proxy(_get_current_app)

+ 2 - 2
celery/backends/__init__.py

@@ -1,5 +1,5 @@
 from celery import current_app
-from celery.local import LocalProxy
+from celery.local import Proxy
 from celery.utils import get_cls_by_name
 
 BACKEND_ALIASES = {
@@ -32,4 +32,4 @@ def get_backend_cls(backend=None, loader=None):
 
 
 # deprecate this.
-default_backend = LocalProxy(lambda: current_app.backend)
+default_backend = Proxy(lambda: current_app.backend)

+ 12 - 6
celery/backends/base.py

@@ -4,11 +4,11 @@ import time
 from datetime import timedelta
 
 from celery import states
+from celery.datastructures import LRUCache
 from celery.exceptions import TimeoutError, TaskRevokedError
 from celery.utils import timeutils
 from celery.utils.serialization import pickle, get_pickled_exception
 from celery.utils.serialization import get_pickleable_exception
-from celery.datastructures import LocalCache
 
 
 class BaseBackend(object):
@@ -175,7 +175,7 @@ class BaseDictBackend(BaseBackend):
 
     def __init__(self, *args, **kwargs):
         super(BaseDictBackend, self).__init__(*args, **kwargs)
-        self._cache = LocalCache(limit=kwargs.get("max_cached_results") or
+        self._cache = LRUCache(limit=kwargs.get("max_cached_results") or
                                  self.app.conf.CELERY_MAX_CACHED_RESULTS)
 
     def store_result(self, task_id, result, status, traceback=None, **kwargs):
@@ -208,8 +208,11 @@ class BaseDictBackend(BaseBackend):
             return meta["result"]
 
     def get_task_meta(self, task_id, cache=True):
-        if cache and task_id in self._cache:
-            return self._cache[task_id]
+        if cache:
+            try:
+                return self._cache[task_id]
+            except KeyError:
+                pass
 
         meta = self._get_task_meta_for(task_id)
         if cache and meta.get("status") == states.SUCCESS:
@@ -224,8 +227,11 @@ class BaseDictBackend(BaseBackend):
                                                         cache=False)
 
     def get_taskset_meta(self, taskset_id, cache=True):
-        if cache and taskset_id in self._cache:
-            return self._cache[taskset_id]
+        if cache:
+            try:
+                return self._cache[taskset_id]
+            except KeyError:
+                pass
 
         meta = self._restore_taskset(taskset_id)
         if cache and meta is not None:

+ 7 - 5
celery/backends/cache.py

@@ -2,7 +2,7 @@ from kombu.utils import cached_property
 
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
-from celery.datastructures import LocalCache
+from celery.datastructures import LRUCache
 
 _imp = [None]
 
@@ -36,7 +36,7 @@ def get_best_memcache(*args, **kwargs):
 class DummyClient(object):
 
     def __init__(self, *args, **kwargs):
-        self.cache = LocalCache(5000)
+        self.cache = LRUCache(limit=5000)
 
     def get(self, key, *args, **kwargs):
         return self.cache.get(key)
@@ -59,6 +59,7 @@ backends = {"memcache": lambda: get_best_memcache,
 
 
 class CacheBackend(KeyValueStoreBackend):
+    servers = None
 
     def __init__(self, expires=None, backend=None, options={}, **kwargs):
         super(CacheBackend, self).__init__(self, **kwargs)
@@ -66,10 +67,11 @@ class CacheBackend(KeyValueStoreBackend):
         self.options = dict(self.app.conf.CELERY_CACHE_BACKEND_OPTIONS,
                             **options)
 
-        backend = backend or self.app.conf.CELERY_CACHE_BACKEND
-        self.backend, _, servers = backend.partition("://")
+        self.backend = backend or self.app.conf.CELERY_CACHE_BACKEND
+        if self.backend:
+            self.backend, _, servers = self.backend.partition("://")
+            self.servers = servers.rstrip('/').split(";")
         self.expires = self.prepare_expires(expires, type=int)
-        self.servers = servers.rstrip('/').split(";")
         try:
             self.Client = backends[self.backend]()
         except KeyError:

+ 99 - 7
celery/datastructures.py

@@ -12,11 +12,12 @@ from __future__ import absolute_import
 
 import time
 import traceback
+import weakref
 
 from itertools import chain
 from Queue import Empty
 
-from celery.utils.compat import OrderedDict
+from celery.utils.compat import MutableMapping
 
 
 class AttributeDictMixin(object):
@@ -289,21 +290,112 @@ class LimitedSet(object):
         return self.chronologically[0]
 
 
-class LocalCache(OrderedDict):
+class DLL(object):
+    """Doubly Linked List."""
+    __slots__ = ("PREV", "NEXT", "value", "__weakref__")
+
+    def __init__(self, value=None):
+        self.PREV, self.NEXT, self.value = None, None, value
+
+    def __repr__(self):
+        return "<DLL: %r>" % (self.value, )
+
+    def iterate(self, sentinel=None):
+        node = self
+        while node is not sentinel:
+            yield node.value
+            node = node.NEXT
+
+
+class LRUCache(dict, MutableMapping):
     """Dictionary with a finite number of keys.
 
     Older items expires first.
 
     """
 
-    def __init__(self, limit=None):
-        super(LocalCache, self).__init__()
+    def __init__(self, items=None, limit=None):
+        dict.__init__(self)
         self.limit = limit
+        self._root = DLL()
+        self._map = {}
+        self.clear()
+        if items:
+            self.update(items)
+
+    def clear(self):
+        root = self._root
+        root.PREV = root.NEXT = root
+        self._map.clear()
+        dict.clear(self)
+
+    def _move_to_head(self, node):
+        root = self._root
+        prev = root.NEXT
+        if node is not prev:
+            nref = weakref.proxy(node)
+            node.NEXT, root.NEXT.PREV = root.NEXT, nref
+            root.NEXT, node.PREV = nref, root
+
+    def __getitem__(self, key):
+        value = self._get(key)
+        self._move_to_head(self._map[key])
+        return value
 
     def __setitem__(self, key, value):
-        while len(self) >= self.limit:
-            self.popitem(last=False)
-        super(LocalCache, self).__setitem__(key, value)
+        # remove least recently used key.
+        if self.limit and len(self) >= self.limit:
+            del(self[self._root.PREV.value])
+
+        if key in self._map:
+            node = self._map[key]
+            node.PREV.NEXT = node.NEXT
+            node.NEXT.PREV = node.PREV
+        else:
+            node = self._map[key] = DLL()
+            node.value = key
+        self._move_to_head(node)
+
+        dict.__setitem__(self, key, value)
+
+    def __delitem__(self, key):
+        dict.__delitem__(self, key)
+        node = self._map.pop(key)
+        node.PREV.NEXT = node.NEXT
+        node.NEXT.PREV = node.PREV
+
+    def _get(self, key):
+        return dict.__getitem__(self, key)
+
+    def __iter__(self):
+        return self._root.NEXT.iterate(self._root)
+
+    def keys(self):
+        return list(iter(self))
+
+    def iterkeys(self):
+        return iter(self)
+
+    def items(self):
+        return [(key, self._get(key)) for key in self]
+
+    def iteritems(self):
+        for key in self:
+            yield (key, self._get(key))
+
+    def __reduce__(self):
+        return (self.__class__, (self.items(), self.limit))
+
+    def __copy__(self):
+        fun, args = self.__reduce__()
+        return fun(*args)
+
+    get = MutableMapping.get
+    pop = MutableMapping.pop
+    update = MutableMapping.update
+    popitem = MutableMapping.popitem
+    setdefault = MutableMapping.setdefault
+
 
 
 class TokenBucket(object):

+ 2 - 2
celery/events/dumper.py

@@ -3,10 +3,10 @@ import sys
 from datetime import datetime
 
 from celery.app import app_or_default
-from celery.datastructures import LocalCache
+from celery.datastructures import LRUCache
 
 
-TASK_NAMES = LocalCache(0xFFF)
+TASK_NAMES = LRUCache(limit=0xFFF)
 
 HUMAN_TYPES = {"worker-offline": "shutdown",
                "worker-online": "started",

+ 3 - 3
celery/events/state.py

@@ -6,7 +6,7 @@ import heapq
 from threading import Lock
 
 from celery import states
-from celery.datastructures import AttributeDict, LocalCache
+from celery.datastructures import AttributeDict, LRUCache
 from celery.utils import kwdict
 
 #: Hartbeat expiry time in seconds.  The worker will be considered offline
@@ -169,8 +169,8 @@ class State(object):
 
     def __init__(self, callback=None,
             max_workers_in_memory=5000, max_tasks_in_memory=10000):
-        self.workers = LocalCache(max_workers_in_memory)
-        self.tasks = LocalCache(max_tasks_in_memory)
+        self.workers = LRUCache(limit=max_workers_in_memory)
+        self.tasks = LRUCache(limit=max_tasks_in_memory)
         self.event_callback = callback
         self.group_handlers = {"worker": self.worker_event,
                                "task": self.task_event}

+ 3 - 3
celery/local.py

@@ -6,12 +6,12 @@ def try_import(module):
         pass
 
 
-class LocalProxy(object):
-    """Code stolen from werkzeug.local.LocalProxy."""
+class Proxy(object):
+    """Code stolen from werkzeug.local.Proxy."""
     __slots__ = ('__local', '__dict__', '__name__')
 
     def __init__(self, local, name=None):
-        object.__setattr__(self, '_LocalProxy__local', local)
+        object.__setattr__(self, '_Proxy__local', local)
         object.__setattr__(self, '__name__', name)
 
     def _get_current_object(self):

+ 9 - 6
celery/log.py

@@ -14,6 +14,7 @@ except ImportError:
 
 from celery import signals
 from celery import current_app
+from celery.local import Proxy
 from celery.utils import LOG_LEVELS, isatty
 from celery.utils.compat import LoggerAdapter
 from celery.utils.compat import WatchedFileHandler
@@ -208,12 +209,14 @@ class Logging(object):
         return logger
 
 
-setup_logging_subsystem = current_app.log.setup_logging_subsystem
-get_default_logger = current_app.log.get_default_logger
-setup_logger = current_app.log.setup_logger
-setup_task_logger = current_app.log.setup_task_logger
-get_task_logger = current_app.log.get_task_logger
-redirect_stdouts_to_logger = current_app.log.redirect_stdouts_to_logger
+get_default_logger = Proxy(lambda: current_app.log.get_default_logger)
+setup_logger = Proxy(lambda: current_app.log.setup_logger)
+setup_task_logger = Proxy(lambda: current_app.log.setup_task_logger)
+get_task_logger = Proxy(lambda: current_app.log.get_task_logger)
+setup_logging_subsystem = Proxy(
+            lambda: current_app.log.setup_logging_subsystem)
+redirect_stdouts_to_logger = Proxy(
+            lambda: current_app.log.redirect_stdouts_to_logger)
 
 
 class LoggingProxy(object):

+ 7 - 6
celery/messaging.py

@@ -1,8 +1,9 @@
 from celery import current_app
+from celery.local import Proxy
 
-TaskPublisher = current_app.amqp.TaskPublisher
-ConsumerSet = current_app.amqp.ConsumerSet
-TaskConsumer = current_app.amqp.TaskConsumer
-establish_connection = current_app.broker_connection
-with_connection = current_app.with_default_connection
-get_consumer_set = current_app.amqp.get_task_consumer
+TaskPublisher = Proxy(lambda: current_app.amqp.TaskPublisher)
+ConsumerSet = Proxy(lambda: current_app.amqp.ConsumerSet)
+TaskConsumer = Proxy(lambda: current_app.amqp.TaskConsumer)
+establish_connection = Proxy(lambda: current_app.broker_connection)
+with_connection = Proxy(lambda: current_app.with_default_connection)
+get_consumer_set = Proxy(lambda: current_app.amqp.get_task_consumer)

+ 3 - 3
celery/tests/test_app/test_routes.py

@@ -124,14 +124,14 @@ class test_lookup_route(unittest.TestCase):
 class test_prepare(unittest.TestCase):
 
     def test_prepare(self):
-        from celery.datastructures import LocalCache
+        from celery.datastructures import LRUCache
         o = object()
         R = [{"foo": "bar"},
-                  "celery.datastructures.LocalCache",
+                  "celery.datastructures.LRUCache",
                   o]
         p = routes.prepare(R)
         self.assertIsInstance(p[0], routes.MapRoute)
-        self.assertIsInstance(maybe_promise(p[1]), LocalCache)
+        self.assertIsInstance(maybe_promise(p[1]), LRUCache)
         self.assertIs(p[2], o)
 
         self.assertEqual(routes.prepare(o), [o])

+ 2 - 0
celery/tests/test_backends/__init__.py

@@ -8,6 +8,8 @@ from celery.backends.cache import CacheBackend
 class TestBackends(unittest.TestCase):
 
     def test_get_backend_aliases(self):
+        from celery import current_app
+        print("CACHE BACKEND: %r" % (current_app.conf.CELERY_CACHE_BACKEND, ))
         expects = [("amqp", AMQPBackend),
                    ("cache", CacheBackend)]
         for expect_name, expect_cls in expects:

+ 11 - 11
celery/tests/test_task/test_result.py

@@ -6,7 +6,7 @@ from celery.result import AsyncResult, EagerResult, TaskSetResult
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 
-from celery.tests.utils import unittest
+from celery.tests.utils import AppCase
 from celery.tests.utils import skip_if_quick
 
 
@@ -33,9 +33,9 @@ def make_mock_taskset(size=10):
     return [AsyncResult(task["id"]) for task in tasks]
 
 
-class TestAsyncResult(unittest.TestCase):
+class TestAsyncResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.task1 = mock_task("task1", states.SUCCESS, "the")
         self.task2 = mock_task("task2", states.SUCCESS, "quick")
         self.task3 = mock_task("task3", states.FAILURE, KeyError("brown"))
@@ -182,9 +182,9 @@ class SimpleBackend(object):
             return ((id, {"result": i}) for i, id in enumerate(self.ids))
 
 
-class TestTaskSetResult(unittest.TestCase):
+class TestTaskSetResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.size = 10
         self.ts = TaskSetResult(gen_unique_id(), make_mock_taskset(self.size))
 
@@ -299,9 +299,9 @@ class TestTaskSetResult(unittest.TestCase):
         self.assertEqual(self.ts.completed_count(), self.ts.total)
 
 
-class TestPendingAsyncResult(unittest.TestCase):
+class TestPendingAsyncResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.task = AsyncResult(gen_unique_id())
 
     def test_result(self):
@@ -310,7 +310,7 @@ class TestPendingAsyncResult(unittest.TestCase):
 
 class TestFailedTaskSetResult(TestTaskSetResult):
 
-    def setUp(self):
+    def setup(self):
         self.size = 11
         subtasks = make_mock_taskset(10)
         failed = mock_task("ts11", states.FAILURE, KeyError("Baz"))
@@ -348,9 +348,9 @@ class TestFailedTaskSetResult(TestTaskSetResult):
         self.assertTrue(self.ts.failed())
 
 
-class TestTaskSetPending(unittest.TestCase):
+class TestTaskSetPending(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.ts = TaskSetResult(gen_unique_id(), [
                                         AsyncResult(gen_unique_id()),
                                         AsyncResult(gen_unique_id())])
@@ -378,7 +378,7 @@ class RaisingTask(Task):
         raise KeyError("xy")
 
 
-class TestEagerResult(unittest.TestCase):
+class TestEagerResult(AppCase):
 
     def test_wait_raises(self):
         res = RaisingTask.apply(args=[3, 3])

+ 4 - 4
celery/tests/test_utils/test_datastructures.py

@@ -2,7 +2,7 @@ import sys
 from celery.tests.utils import unittest
 from Queue import Queue
 
-from celery.datastructures import ExceptionInfo, LocalCache
+from celery.datastructures import ExceptionInfo, LRUCache
 from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
@@ -133,15 +133,15 @@ class test_LimitedSet(unittest.TestCase):
         self.assertIn("LimitedSet(", repr(s))
 
 
-class test_LocalCache(unittest.TestCase):
+class test_LRUCache(unittest.TestCase):
 
     def test_expires(self):
         limit = 100
-        x = LocalCache(limit=limit)
+        x = LRUCache(limit=limit)
         slots = list(range(limit * 2))
         for i in slots:
             x[i] = i
-        self.assertListEqual(x.keys(), slots[limit:])
+        self.assertListEqual(x.keys(), list(reversed(slots[limit:])))
 
 
 class test_AttributeDict(unittest.TestCase):

+ 3 - 1
celery/tests/utils.py

@@ -41,7 +41,9 @@ class AppCase(unittest.TestCase):
 
     def setUp(self):
         from celery.app import current_app
-        self.app = self._current_app = current_app()
+        app = self.app = self._current_app = current_app()
+        app.backend.client.cache.clear()
+        app.backend._cache.clear()
         self.setup()
 
     def tearDown(self):