Browse Source

Merge branch 'lrucache'

Conflicts:
	celery/__init__.py
	celery/backends/__init__.py
	celery/backends/base.py
	celery/backends/cache.py
	celery/datastructures.py
	celery/events/dumper.py
	celery/events/state.py
	celery/log.py
	celery/messaging.py
	celery/tests/test_task/test_result.py
	celery/tests/utils.py
Ask Solem 13 years ago
parent
commit
0d0dbae0ce

+ 2 - 2
celery/__init__.py

@@ -26,10 +26,10 @@ def Celery(*args, **kwargs):
     return App(*args, **kwargs)
 
 if not os.environ.get("CELERY_NO_EVAL", False):
-    from .local import LocalProxy
+    from .local import Proxy
 
     def _get_current_app():
         from .app import current_app
         return current_app()
 
-    current_app = LocalProxy(_get_current_app)
+    current_app = Proxy(_get_current_app)

+ 10 - 13
celery/backends/__init__.py

@@ -1,8 +1,9 @@
 from __future__ import absolute_import
 
 from .. import current_app
-from ..local import LocalProxy
+from ..local import Proxy
 from ..utils import get_cls_by_name
+from ..utils.functional import memoize
 
 BACKEND_ALIASES = {
     "amqp": "celery.backends.amqp.AMQPBackend",
@@ -15,23 +16,19 @@ BACKEND_ALIASES = {
     "disabled": "celery.backends.base.DisabledBackend",
 }
 
-_backend_cache = {}
-
 
+@memoize(100)
 def get_backend_cls(backend=None, loader=None):
     """Get backend class by name/alias"""
     backend = backend or "disabled"
     loader = loader or current_app.loader
-    if backend not in _backend_cache:
-        aliases = dict(BACKEND_ALIASES, **loader.override_backends)
-        try:
-            _backend_cache[backend] = get_cls_by_name(backend, aliases)
-        except ValueError, exc:
-            raise ValueError("Unknown result backend: %r.  "
-                             "Did you spell it correctly?  (%s)" % (backend,
-                                                                    exc))
-    return _backend_cache[backend]
+    aliases = dict(BACKEND_ALIASES, **loader.override_backends)
+    try:
+        return get_cls_by_name(backend, aliases)
+    except ValueError, exc:
+        raise ValueError("Unknown result backend: %r.  "
+                         "Did you spell it correctly?  (%s)" % (backend, exc))
 
 
 # deprecate this.
-default_backend = LocalProxy(lambda: current_app.backend)
+default_backend = Proxy(lambda: current_app.backend)

+ 7 - 6
celery/backends/base.py

@@ -9,7 +9,7 @@ from datetime import timedelta
 from kombu import serialization
 
 from .. import states
-from ..datastructures import LocalCache
+from ..datastructures import LRUCache
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..utils import timeutils
 from ..utils.serialization import (get_pickled_exception,
@@ -212,7 +212,7 @@ class BaseDictBackend(BaseBackend):
 
     def __init__(self, *args, **kwargs):
         super(BaseDictBackend, self).__init__(*args, **kwargs)
-        self._cache = LocalCache(limit=kwargs.get("max_cached_results") or
+        self._cache = LRUCache(limit=kwargs.get("max_cached_results") or
                                  self.app.conf.CELERY_MAX_CACHED_RESULTS)
 
     def store_result(self, task_id, result, status, traceback=None, **kwargs):
@@ -245,11 +245,11 @@ class BaseDictBackend(BaseBackend):
             return meta["result"]
 
     def get_task_meta(self, task_id, cache=True):
-        if cache and task_id in self._cache:
+        if cache:
             try:
                 return self._cache[task_id]
             except KeyError:
-                pass   # backend emptied in the meantime
+                pass
 
         meta = self._get_task_meta_for(task_id)
         if cache and meta.get("status") == states.SUCCESS:
@@ -264,11 +264,11 @@ class BaseDictBackend(BaseBackend):
                                                         cache=False)
 
     def get_taskset_meta(self, taskset_id, cache=True):
-        if cache and taskset_id in self._cache:
+        if cache:
             try:
                 return self._cache[taskset_id]
             except KeyError:
-                pass  # backend emptied in the meantime
+                pass
 
         meta = self._restore_taskset(taskset_id)
         if cache and meta is not None:
@@ -387,6 +387,7 @@ class KeyValueStoreBackend(BaseDictBackend):
 
 
 class DisabledBackend(BaseBackend):
+    _cache = {}   # need this attribute to reset cache in tests.
 
     def store_result(self, *args, **kwargs):
         pass

+ 7 - 5
celery/backends/cache.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import
 
-from ..datastructures import LocalCache
+from ..datastructures import LRUCache
 from ..exceptions import ImproperlyConfigured
 from ..utils import cached_property
 
@@ -38,7 +38,7 @@ def get_best_memcache(*args, **kwargs):
 class DummyClient(object):
 
     def __init__(self, *args, **kwargs):
-        self.cache = LocalCache(5000)
+        self.cache = LRUCache(limit=5000)
 
     def get(self, key, *args, **kwargs):
         return self.cache.get(key)
@@ -61,6 +61,7 @@ backends = {"memcache": lambda: get_best_memcache,
 
 
 class CacheBackend(KeyValueStoreBackend):
+    servers = None
 
     def __init__(self, expires=None, backend=None, options={}, **kwargs):
         super(CacheBackend, self).__init__(self, **kwargs)
@@ -68,10 +69,11 @@ class CacheBackend(KeyValueStoreBackend):
         self.options = dict(self.app.conf.CELERY_CACHE_BACKEND_OPTIONS,
                             **options)
 
-        backend = backend or self.app.conf.CELERY_CACHE_BACKEND
-        self.backend, _, servers = backend.partition("://")
+        self.backend = backend or self.app.conf.CELERY_CACHE_BACKEND
+        if self.backend:
+            self.backend, _, servers = self.backend.partition("://")
+            self.servers = servers.rstrip('/').split(";")
         self.expires = self.prepare_expires(expires, type=int)
-        self.servers = servers.rstrip('/').split(";")
         try:
             self.Client = backends[self.backend]()
         except KeyError:

+ 22 - 14
celery/datastructures.py

@@ -18,7 +18,7 @@ from itertools import chain
 from Queue import Empty
 from threading import RLock
 
-from .utils.compat import OrderedDict
+from .utils.compat import UserDict, OrderedDict
 
 
 class AttributeDictMixin(object):
@@ -291,27 +291,35 @@ class LimitedSet(object):
         return self.chronologically[0]
 
 
-class LocalCache(OrderedDict):
-    """Dictionary with a finite number of keys.
+class LRUCache(UserDict):
+    """LRU Cache implementation using a doubly linked list to track access.
 
-    Older items expires first.
+    :keyword limit: The maximum number of keys to keep in the cache.
+        When a new key is inserted and the limit has been exceeded,
+        the *Least Recently Used* key will be discarded from the
+        cache.
 
     """
 
     def __init__(self, limit=None):
-        super(LocalCache, self).__init__()
         self.limit = limit
-        self.lock = RLock()
+        self.mutex = RLock()
+        self.data = OrderedDict()
+
+    def __getitem__(self, key):
+        with self.mutex:
+            value = self[key] = self.data.pop(key)
+            return value
 
     def __setitem__(self, key, value):
-        with self.lock:
-            while len(self) >= self.limit:
-                self.popitem(last=False)
-            super(LocalCache, self).__setitem__(key, value)
-
-    def pop(self, key, *args):
-        with self.lock:
-            super(LocalCache, self).pop(key, *args)
+        # remove least recently used key.
+        with self.mutex:
+            if self.limit and len(self.data) >= self.limit:
+                self.data.pop(iter(self.data).next())
+            self.data[key] = value
+
+    def __iter__(self):
+        return self.data.iterkeys()
 
 
 class TokenBucket(object):

+ 2 - 2
celery/events/dumper.py

@@ -5,10 +5,10 @@ import sys
 from datetime import datetime
 
 from ..app import app_or_default
-from ..datastructures import LocalCache
+from ..datastructures import LRUCache
 
 
-TASK_NAMES = LocalCache(0xFFF)
+TASK_NAMES = LRUCache(limit=0xFFF)
 
 HUMAN_TYPES = {"worker-offline": "shutdown",
                "worker-online": "started",

+ 3 - 3
celery/events/state.py

@@ -7,7 +7,7 @@ import heapq
 from threading import Lock
 
 from .. import states
-from ..datastructures import AttributeDict, LocalCache
+from ..datastructures import AttributeDict, LRUCache
 from ..utils import kwdict
 
 #: Hartbeat expiry time in seconds.  The worker will be considered offline
@@ -173,8 +173,8 @@ class State(object):
 
     def __init__(self, callback=None,
             max_workers_in_memory=5000, max_tasks_in_memory=10000):
-        self.workers = LocalCache(max_workers_in_memory)
-        self.tasks = LocalCache(max_tasks_in_memory)
+        self.workers = LRUCache(limit=max_workers_in_memory)
+        self.tasks = LRUCache(limit=max_tasks_in_memory)
         self.event_callback = callback
         self.group_handlers = {"worker": self.worker_event,
                                "task": self.task_event}

+ 3 - 3
celery/local.py

@@ -6,12 +6,12 @@ def try_import(module):
         pass
 
 
-class LocalProxy(object):
-    """Code stolen from werkzeug.local.LocalProxy."""
+class Proxy(object):
+    """Code stolen from werkzeug.local.Proxy."""
     __slots__ = ('__local', '__dict__', '__name__')
 
     def __init__(self, local, name=None):
-        object.__setattr__(self, '_LocalProxy__local', local)
+        object.__setattr__(self, '_Proxy__local', local)
         object.__setattr__(self, '__name__', name)
 
     def _get_current_object(self):

+ 9 - 6
celery/log.py

@@ -14,6 +14,7 @@ except ImportError:
 
 from . import current_app
 from . import signals
+from .local import Proxy
 from .utils import LOG_LEVELS, isatty
 from .utils.compat import LoggerAdapter, WatchedFileHandler
 from .utils.encoding import safe_str
@@ -219,12 +220,14 @@ class Logging(object):
         return logger
 
 
-setup_logging_subsystem = current_app.log.setup_logging_subsystem
-get_default_logger = current_app.log.get_default_logger
-setup_logger = current_app.log.setup_logger
-setup_task_logger = current_app.log.setup_task_logger
-get_task_logger = current_app.log.get_task_logger
-redirect_stdouts_to_logger = current_app.log.redirect_stdouts_to_logger
+get_default_logger = Proxy(lambda: current_app.log.get_default_logger)
+setup_logger = Proxy(lambda: current_app.log.setup_logger)
+setup_task_logger = Proxy(lambda: current_app.log.setup_task_logger)
+get_task_logger = Proxy(lambda: current_app.log.get_task_logger)
+setup_logging_subsystem = Proxy(
+            lambda: current_app.log.setup_logging_subsystem)
+redirect_stdouts_to_logger = Proxy(
+            lambda: current_app.log.redirect_stdouts_to_logger)
 
 
 class LoggingProxy(object):

+ 7 - 6
celery/messaging.py

@@ -1,8 +1,9 @@
 from . import current_app
+from .local import Proxy
 
-TaskPublisher = current_app.amqp.TaskPublisher
-ConsumerSet = current_app.amqp.ConsumerSet
-TaskConsumer = current_app.amqp.TaskConsumer
-establish_connection = current_app.broker_connection
-with_connection = current_app.with_default_connection
-get_consumer_set = current_app.amqp.get_task_consumer
+TaskPublisher = Proxy(lambda: current_app.amqp.TaskPublisher)
+ConsumerSet = Proxy(lambda: current_app.amqp.ConsumerSet)
+TaskConsumer = Proxy(lambda: current_app.amqp.TaskConsumer)
+establish_connection = Proxy(lambda: current_app.broker_connection)
+with_connection = Proxy(lambda: current_app.with_default_connection)
+get_consumer_set = Proxy(lambda: current_app.amqp.get_task_consumer)

+ 3 - 3
celery/tests/test_app/test_routes.py

@@ -124,14 +124,14 @@ class test_lookup_route(unittest.TestCase):
 class test_prepare(unittest.TestCase):
 
     def test_prepare(self):
-        from celery.datastructures import LocalCache
+        from celery.datastructures import LRUCache
         o = object()
         R = [{"foo": "bar"},
-                  "celery.datastructures.LocalCache",
+                  "celery.datastructures.LRUCache",
                   o]
         p = routes.prepare(R)
         self.assertIsInstance(p[0], routes.MapRoute)
-        self.assertIsInstance(maybe_promise(p[1]), LocalCache)
+        self.assertIsInstance(maybe_promise(p[1]), LRUCache)
         self.assertIs(p[2], o)
 
         self.assertEqual(routes.prepare(o), [o])

+ 6 - 3
celery/tests/test_backends/__init__.py

@@ -10,6 +10,7 @@ from celery.backends.cache import CacheBackend
 class TestBackends(unittest.TestCase):
 
     def test_get_backend_aliases(self):
+        from celery import current_app
         expects = [("amqp", AMQPBackend),
                    ("cache", CacheBackend)]
         for expect_name, expect_cls in expects:
@@ -17,11 +18,13 @@ class TestBackends(unittest.TestCase):
                                   expect_cls)
 
     def test_get_backend_cache(self):
-        backends._backend_cache = {}
+        backends.get_backend_cls.clear()
+        hits = backends.get_backend_cls.hits
+        misses = backends.get_backend_cls.misses
         backends.get_backend_cls("amqp")
-        self.assertIn("amqp", backends._backend_cache)
+        self.assertEqual(backends.get_backend_cls.misses, misses + 1)
         amqp_backend = backends.get_backend_cls("amqp")
-        self.assertIs(amqp_backend, backends._backend_cache["amqp"])
+        self.assertEqual(backends.get_backend_cls.hits, hits + 1)
 
     def test_unknown_backend(self):
         with self.assertRaises(ValueError):

+ 12 - 12
celery/tests/test_task/test_result.py

@@ -6,7 +6,7 @@ from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
 from celery.exceptions import TimeoutError
 from celery.task.base import Task
 
-from celery.tests.utils import unittest
+from celery.tests.utils import AppCase
 from celery.tests.utils import skip_if_quick
 
 
@@ -33,9 +33,9 @@ def make_mock_taskset(size=10):
     return [AsyncResult(task["id"]) for task in tasks]
 
 
-class TestAsyncResult(unittest.TestCase):
+class TestAsyncResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.task1 = mock_task("task1", states.SUCCESS, "the")
         self.task2 = mock_task("task2", states.SUCCESS, "quick")
         self.task3 = mock_task("task3", states.FAILURE, KeyError("brown"))
@@ -144,7 +144,7 @@ class TestAsyncResult(unittest.TestCase):
         self.assertFalse(AsyncResult(uuid()).ready())
 
 
-class test_ResultSet(unittest.TestCase):
+class test_ResultSet(AppCase):
 
     def test_add_discard(self):
         x = ResultSet([])
@@ -208,9 +208,9 @@ class SimpleBackend(object):
             return ((id, {"result": i}) for i, id in enumerate(self.ids))
 
 
-class TestTaskSetResult(unittest.TestCase):
+class TestTaskSetResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.size = 10
         self.ts = TaskSetResult(uuid(), make_mock_taskset(self.size))
 
@@ -325,9 +325,9 @@ class TestTaskSetResult(unittest.TestCase):
         self.assertEqual(self.ts.completed_count(), self.ts.total)
 
 
-class TestPendingAsyncResult(unittest.TestCase):
+class TestPendingAsyncResult(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.task = AsyncResult(uuid())
 
     def test_result(self):
@@ -336,7 +336,7 @@ class TestPendingAsyncResult(unittest.TestCase):
 
 class TestFailedTaskSetResult(TestTaskSetResult):
 
-    def setUp(self):
+    def setup(self):
         self.size = 11
         subtasks = make_mock_taskset(10)
         failed = mock_task("ts11", states.FAILURE, KeyError("Baz"))
@@ -374,9 +374,9 @@ class TestFailedTaskSetResult(TestTaskSetResult):
         self.assertTrue(self.ts.failed())
 
 
-class TestTaskSetPending(unittest.TestCase):
+class TestTaskSetPending(AppCase):
 
-    def setUp(self):
+    def setup(self):
         self.ts = TaskSetResult(uuid(), [
                                         AsyncResult(uuid()),
                                         AsyncResult(uuid())])
@@ -404,7 +404,7 @@ class RaisingTask(Task):
         raise KeyError("xy")
 
 
-class TestEagerResult(unittest.TestCase):
+class TestEagerResult(AppCase):
 
     def test_wait_raises(self):
         res = RaisingTask.apply(args=[3, 3])

+ 22 - 5
celery/tests/test_utils/test_datastructures.py

@@ -2,7 +2,7 @@ import sys
 from celery.tests.utils import unittest
 from Queue import Queue
 
-from celery.datastructures import ExceptionInfo, LocalCache
+from celery.datastructures import ExceptionInfo, LRUCache
 from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
@@ -133,15 +133,32 @@ class test_LimitedSet(unittest.TestCase):
         self.assertIn("LimitedSet(", repr(s))
 
 
-class test_LocalCache(unittest.TestCase):
+class test_LRUCache(unittest.TestCase):
 
     def test_expires(self):
         limit = 100
-        x = LocalCache(limit=limit)
-        slots = list(range(limit * 2))
+        x = LRUCache(limit=limit)
+        slots = list(xrange(limit * 2))
         for i in slots:
             x[i] = i
-        self.assertListEqual(x.keys(), slots[limit:])
+        self.assertListEqual(x.keys(), list(slots[limit:]))
+
+    def test_least_recently_used(self):
+        x = LRUCache(3)
+
+        x[1], x[2], x[3] = 1, 2, 3
+        self.assertEqual(x.keys(), [1, 2, 3])
+
+        x[4], x[5] = 4, 5
+        self.assertEqual(x.keys(), [3, 4, 5])
+
+        # access 3, which makes it the last used key.
+        x[3]
+        x[6] = 6
+        self.assertEqual(x.keys(), [5, 3, 6])
+
+        x[7] = 7
+        self.assertEqual(x.keys(), [3, 6, 7])
 
 
 class test_AttributeDict(unittest.TestCase):

+ 6 - 1
celery/tests/utils.py

@@ -58,7 +58,12 @@ class AppCase(unittest.TestCase):
 
     def setUp(self):
         from ..app import current_app
-        self.app = self._current_app = current_app()
+        from ..backends.cache import CacheBackend, DummyClient
+        app = self.app = self._current_app = current_app()
+        if isinstance(app.backend, CacheBackend):
+            if isinstance(app.backend.client, DummyClient):
+                app.backend.client.cache.clear()
+        app.backend._cache.clear()
         self.setup()
 
     def tearDown(self):

+ 42 - 0
celery/utils/functional.py

@@ -0,0 +1,42 @@
+from __future__ import absolute_import, with_statement
+
+from functools import wraps
+from threading import Lock
+
+from celery.datastructures import LRUCache
+
+KEYWORD_MARK = object()
+
+
+def memoize(maxsize=None, Cache=LRUCache):
+
+    def _memoize(fun):
+        mutex = Lock()
+        cache = Cache(limit=maxsize)
+
+        @wraps(fun)
+        def _M(*args, **kwargs):
+            key = args + (KEYWORD_MARK, ) + tuple(sorted(kwargs.iteritems()))
+            try:
+                with mutex:
+                    value = cache[key]
+            except KeyError:
+                value = fun(*args, **kwargs)
+                _M.misses += 1
+                with mutex:
+                    cache[key] = value
+            else:
+                _M.hits += 1
+            return value
+
+        def clear():
+            """Clear the cache and reset cache statistics."""
+            cache.clear()
+            _M.hits = _M.misses = 0
+
+        _M.hits = _M.misses = 0
+        _M.clear = clear
+        _M.original_func = fun
+        return _M
+
+    return _memoize

+ 2 - 1
celery/worker/__init__.py

@@ -255,13 +255,14 @@ class WorkController(object):
             raise
         except SystemExit:
             self.stop()
+            raise
         except:
             self.stop()
             try:
                 raise
             except TypeError:
                 # eventlet borks here saying that the exception is None(?)
-                pass
+                sys.exit()
 
     def process_task(self, request):
         """Process task by sending it to the pool of workers."""