Explorar o código

Use @cached_property

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
331166fa22

+ 68 - 21
celery/app/amqp.py

@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
 """
 celery.app.amqp
 ===============
@@ -10,21 +11,27 @@ AMQ related functionality.
 """
 from datetime import datetime, timedelta
 
+from kombu import BrokerConnection
+from kombu import compat as messaging
+
 from celery import routes
 from celery import signals
-from celery.utils import gen_unique_id, textindent
+from celery.utils import gen_unique_id, textindent, cached_property
 from celery.utils.compat import UserDict
 
-from kombu import compat as messaging
-from kombu import BrokerConnection
-
+#: List of known options to a Kombu producers send method.
+#: Used to extract the message related options out of any `dict`.
 MSG_OPTIONS = ("mandatory", "priority", "immediate",
                "routing_key", "serializer", "delivery_mode",
                "compression")
+
+#: Human readable queue declaration.
 QUEUE_FORMAT = """
 . %(name)s -> exchange:%(exchange)s (%(exchange_type)s) \
 binding:%(binding_key)s
 """
+
+#: Broker connection info -> URI
 BROKER_FORMAT = """\
 %(transport)s://%(userid)s@%(hostname)s%(port)s%(virtual_host)s\
 """
@@ -43,6 +50,14 @@ def extract_msg_options(options, keep=MSG_OPTIONS):
 
 
 class Queues(UserDict):
+    """Queue name⇒ declaration mapping.
+
+    Celery will consult this mapping to find the options
+    for any queue by name.
+
+    :param queues: Initial mapping.
+
+    """
 
     def __init__(self, queues):
         self.data = {}
@@ -51,12 +66,23 @@ class Queues(UserDict):
 
     def add(self, queue, exchange=None, routing_key=None,
             exchange_type="direct", **options):
+        """Add new queue.
+
+        :param queue: Name of the queue.
+        :keyword exchange: Name of the exchange.
+        :keyword routing_key: Binding key.
+        :keyword exchange_type: Type of exchange.
+        :keyword \*\*options: Additional declaration options.
+
+        """
         q = self[queue] = self.options(exchange, routing_key,
                                        exchange_type, **options)
         return q
 
     def options(self, exchange, routing_key,
             exchange_type="direct", **options):
+        """Creates new option mapping for queue, with required
+        keys present."""
         return dict(options, routing_key=routing_key,
                              binding_key=routing_key,
                              exchange=exchange,
@@ -64,12 +90,23 @@ class Queues(UserDict):
 
     def format(self, indent=0):
         """Format routing table into string for log dumps."""
-        format = lambda **queue: QUEUE_FORMAT.strip() % queue
-        info = "\n".join(format(name=name, **config)
+        info = "\n".join(QUEUE_FORMAT.strip() % dict(name=name, **config)
                                 for name, config in self.items())
         return textindent(info, indent=indent)
 
     def select_subset(self, wanted, create_missing=True):
+        """Select subset of the currently defined queues.
+
+        Does not return anything: queues not in `wanted` will
+        be discarded in-place.
+
+        :param wanted: List of wanted queue names.
+        :keyword create_missing: By default any unknown queues will be
+                                 added automatically, but if disabled
+                                 the occurrence of unknown queues
+                                 in `wanted` will raise :exc:`KeyError`.
+
+        """
         acc = {}
         for queue in wanted:
             try:
@@ -84,7 +121,8 @@ class Queues(UserDict):
 
     @classmethod
     def with_defaults(cls, queues, default_exchange, default_exchange_type):
-
+        """Alternate constructor that adds default exchange and
+        exchange type information to queues that does not have any."""
         for opts in queues.values():
             opts.setdefault("exchange", default_exchange),
             opts.setdefault("exchange_type", default_exchange_type)
@@ -105,7 +143,7 @@ class TaskPublisher(messaging.Publisher):
             countdown=None, eta=None, task_id=None, taskset_id=None,
             expires=None, exchange=None, exchange_type=None,
             event_dispatcher=None, **kwargs):
-        """Delay task for execution by the celery nodes."""
+        """Send task message."""
 
         task_id = task_id or gen_unique_id()
         task_args = task_args or []
@@ -167,20 +205,20 @@ class AMQP(object):
     BrokerConnection = BrokerConnection
     Publisher = messaging.Publisher
     Consumer = messaging.Consumer
-    _queues = None
+    ConsumerSet = messaging.ConsumerSet
 
     def __init__(self, app):
         self.app = app
 
-    def ConsumerSet(self, *args, **kwargs):
-        return messaging.ConsumerSet(*args, **kwargs)
-
     def Queues(self, queues):
+        """Create new :class:`Queues` instance, using queue defaults
+        from the current configuration."""
         return Queues.with_defaults(queues,
                                     self.app.conf.CELERY_DEFAULT_EXCHANGE,
                                     self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
 
     def Router(self, queues=None, create_missing=None):
+        """Returns the current task router."""
         return routes.Router(self.app.conf.CELERY_ROUTES,
                              queues or self.app.conf.CELERY_QUEUES,
                              self.app.either("CELERY_CREATE_MISSING_QUEUES",
@@ -188,6 +226,7 @@ class AMQP(object):
                              app=self.app)
 
     def TaskConsumer(self, *args, **kwargs):
+        """Returns consumer for a single task queue."""
         default_queue_name, default_queue = self.get_default_queue()
         defaults = dict({"queue": default_queue_name}, **default_queue)
         defaults["routing_key"] = defaults.pop("binding_key", None)
@@ -195,6 +234,11 @@ class AMQP(object):
                              **self.app.merge(defaults, kwargs))
 
     def TaskPublisher(self, *args, **kwargs):
+        """Returns publisher used to send tasks.
+
+        You should use `app.send_task` instead.
+
+        """
         _, default_queue = self.get_default_queue()
         defaults = {"exchange": default_queue["exchange"],
                     "exchange_type": default_queue["exchange_type"],
@@ -213,14 +257,20 @@ class AMQP(object):
         return publisher
 
     def get_task_consumer(self, connection, queues=None, **kwargs):
+        """Return consumer configured to consume from all known task
+        queues."""
         return self.ConsumerSet(connection, from_dict=queues or self.queues,
                                 **kwargs)
 
     def get_default_queue(self):
+        """Returns `(queue_name, queue_options)` tuple for the queue
+        configured to be default (:setting:`CELERY_DEFAULT_QUEUE`)."""
         q = self.app.conf.CELERY_DEFAULT_QUEUE
         return q, self.queues[q]
 
     def get_broker_info(self, broker_connection=None):
+        """Returns information about the current broker connection
+        as a `dict`."""
         if broker_connection is None:
             broker_connection = self.app.broker_connection()
         info = broker_connection.info()
@@ -236,13 +286,10 @@ class AMQP(object):
         """Get message broker connection info string for log dumps."""
         return BROKER_FORMAT % self.get_broker_info()
 
-    def _get_queues(self):
-        if self._queues is None:
-            c = self.app.conf
-            self._queues = self.Queues(c.CELERY_QUEUES)
-        return self._queues
-
-    def _set_queues(self, queues):
-        self._queues = self.Queues(queues)
+    @cached_property
+    def queues(self):
+        return self.Queues(self.app.conf.CELERY_QUEUES)
 
-    queues = property(_get_queues, _set_queues)
+    @queues.setter
+    def queues(self, value):
+        return self.Queues(value)

+ 22 - 43
celery/app/base.py

@@ -16,7 +16,7 @@ from datetime import timedelta
 from celery import routes
 from celery.app.defaults import DEFAULTS
 from celery.datastructures import ConfigurationView
-from celery.utils import noop, isatty
+from celery.utils import noop, isatty, cached_property
 from celery.utils.functional import wraps
 
 
@@ -31,13 +31,6 @@ class BaseApp(object):
         self.main = main
         self.loader_cls = loader or "app"
         self.backend_cls = backend
-        self._amqp = None
-        self._backend = None
-        self._conf = None
-        self._control = None
-        self._loader = None
-        self._log = None
-        self._events = None
         self.set_as_current = set_as_current
         self.on_init()
 
@@ -55,7 +48,7 @@ class BaseApp(object):
             >>> celery.config_from_object(celeryconfig)
 
         """
-        self._conf = None
+        del(self.conf)
         return self.loader.config_from_object(obj, silent=silent)
 
     def config_from_envvar(self, variable_name, silent=False):
@@ -68,7 +61,7 @@ class BaseApp(object):
             >>> celery.config_from_envvar("CELERY_CONFIG_MODULE")
 
         """
-        self._conf = None
+        del(self.conf)
         return self.loader.config_from_envvar(variable_name, silent=silent)
 
     def config_from_cmdline(self, argv, namespace="celery"):
@@ -271,71 +264,57 @@ class BaseApp(object):
         return self.post_config_merge(ConfigurationView(
                     self.pre_config_merge(self.loader.conf), DEFAULTS))
 
-    @property
+    @cached_property
     def amqp(self):
         """Sending/receiving messages.
 
         See :class:`~celery.app.amqp.AMQP`.
 
         """
-        if self._amqp is None:
-            from celery.app.amqp import AMQP
-            self._amqp = AMQP(self)
-        return self._amqp
+        from celery.app.amqp import AMQP
+        return AMQP(self)
 
-    @property
+    @cached_property
     def backend(self):
         """Storing/retreiving task state.
 
         See :class:`~celery.backend.base.BaseBackend`.
 
         """
-        if self._backend is None:
-            self._backend = self._get_backend()
-        return self._backend
+        return self._get_backend()
 
-    @property
+    @cached_property
     def loader(self):
         """Current loader."""
-        if self._loader is None:
-            from celery.loaders import get_loader_cls
-            self._loader = get_loader_cls(self.loader_cls)(app=self)
-        return self._loader
+        from celery.loaders import get_loader_cls
+        return get_loader_cls(self.loader_cls)(app=self)
 
-    @property
+    @cached_property
     def conf(self):
         """Current configuration (dict and attribute access)."""
-        if self._conf is None:
-            self._conf = self._get_config()
-        return self._conf
+        return self._get_config()
 
-    @property
+    @cached_property
     def control(self):
         """Controlling worker nodes.
 
         See :class:`~celery.task.control.Control`.
 
         """
-        if self._control is None:
-            from celery.task.control import Control
-            self._control = Control(app=self)
-        return self._control
+        from celery.task.control import Control
+        return Control(app=self)
 
-    @property
+    @cached_property
     def log(self):
         """Logging utilities.
 
         See :class:`~celery.log.Logging`.
 
         """
-        if self._log is None:
-            from celery.log import Logging
-            self._log = Logging(app=self)
-        return self._log
+        from celery.log import Logging
+        return Logging(app=self)
 
-    @property
+    @cached_property
     def events(self):
-        if self._events is None:
-            from celery.events import Events
-            self._events = Events(app=self)
-        return self._events
+        from celery.events import Events
+        return Events(app=self)

+ 5 - 6
celery/backends/amqp.py

@@ -12,6 +12,7 @@ from celery import states
 from celery.backends.base import BaseDictBackend
 from celery.exceptions import TimeoutError
 from celery.utils import timeutils
+from celery.utils import cached_property
 
 
 def repair_uuid(s):
@@ -248,12 +249,6 @@ class AMQPBackend(BaseDictBackend):
             self._pool.close()
             self._pool = None
 
-    @property
-    def pool(self):
-        if not self._pool:
-            self._pool = self.app.broker_connection().Pool(self.connection_max)
-        return self._pool
-
     def reload_task_result(self, task_id):
         raise NotImplementedError(
                 "reload_task_result is not supported by this backend.")
@@ -272,3 +267,7 @@ class AMQPBackend(BaseDictBackend):
         """Get the result of a taskset."""
         raise NotImplementedError(
                 "restore_taskset is not supported by this backend.")
+
+    @cached_property
+    def pool(self):
+        return self.app.broker_connection().Pool(self.connection_max)

+ 3 - 5
celery/backends/cache.py

@@ -4,6 +4,7 @@ from kombu.utils import partition
 
 from celery.backends.base import KeyValueStoreBackend
 from celery.exceptions import ImproperlyConfigured
+from celery.utils import cached_property
 from celery.utils import timeutils
 from celery.datastructures import LocalCache
 
@@ -48,7 +49,6 @@ backends = {"memcache": get_best_memcache,
 
 
 class CacheBackend(KeyValueStoreBackend):
-    _client = None
 
     def __init__(self, expires=None, backend=None, options={}, **kwargs):
         super(CacheBackend, self).__init__(self, **kwargs)
@@ -80,8 +80,6 @@ class CacheBackend(KeyValueStoreBackend):
     def delete(self, key):
         return self.client.delete(key)
 
-    @property
+    @cached_property
     def client(self):
-        if self._client is None:
-            self._client = self.Client(self.servers, **self.options)
-        return self._client
+        return self.Client(self.servers, **self.options)

+ 12 - 21
celery/beat.py

@@ -9,6 +9,7 @@ import sys
 import threading
 import traceback
 import multiprocessing
+
 from datetime import datetime
 
 from celery import platforms
@@ -16,7 +17,7 @@ from celery import registry
 from celery.app import app_or_default
 from celery.log import SilenceRepeated
 from celery.schedules import maybe_schedule
-from celery.utils import instantiate
+from celery.utils import cached_property, instantiate
 from celery.utils.compat import UserDict
 from celery.utils.timeutils import humanize_seconds
 
@@ -132,9 +133,6 @@ class Scheduler(UserDict):
     """
     Entry = ScheduleEntry
 
-    _connection = None
-    _publisher = None
-
     def __init__(self, schedule=None, logger=None, max_interval=None,
             app=None, Publisher=None, lazy=False, **kwargs):
         UserDict.__init__(self)
@@ -165,18 +163,6 @@ class Scheduler(UserDict):
                                                        result.task_id))
         return next_time_to_run
 
-    @property
-    def connection(self):
-        if self._connection is None:
-            self._connection = self.app.broker_connection()
-        return self._connection
-
-    @property
-    def publisher(self):
-        if self._publisher is None:
-            self._publisher = self.Publisher(connection=self.connection)
-        return self._publisher
-
     def tick(self):
         """Run a tick, that is one iteration of the scheduler.
 
@@ -262,6 +248,14 @@ class Scheduler(UserDict):
     def get_schedule(self):
         return self.data
 
+    @cached_property
+    def connection(self):
+        return self.app.broker_connection()
+
+    @cached_property
+    def publisher(self):
+        return self.Publisher(connection=self.connection)
+
     @property
     def schedule(self):
         return self.get_schedule()
@@ -317,7 +311,6 @@ class Service(object):
         self.schedule_filename = schedule_filename or \
                                     self.app.conf.CELERYBEAT_SCHEDULE_FILENAME
 
-        self._scheduler = None
         self._shutdown = threading.Event()
         self._stopped = threading.Event()
         silence = self.max_interval < 60 and 10 or 1
@@ -366,11 +359,9 @@ class Service(object):
             scheduler.update_from_dict(self.schedule)
         return scheduler
 
-    @property
+    @cached_property
     def scheduler(self):
-        if self._scheduler is None:
-            self._scheduler = self.get_scheduler()
-        return self._scheduler
+        return self.get_scheduler()
 
 
 class _Threaded(threading.Thread):

+ 3 - 6
celery/loaders/base.py

@@ -4,7 +4,7 @@ import warnings
 
 import anyjson
 
-from celery.utils import import_from_cwd as _import_from_cwd
+from celery.utils import cached_property, import_from_cwd as _import_from_cwd
 
 BUILTIN_MODULES = ["celery.task"]
 
@@ -25,7 +25,6 @@ class BaseLoader(object):
         * What modules are imported to find tasks?
 
     """
-    _conf_cache = None
     worker_initialized = False
     override_backends = {}
     configured = False
@@ -130,9 +129,7 @@ class BaseLoader(object):
                 "Mail could not be sent: %r %r" % (
                     exc, {"To": to, "Subject": subject})))
 
-    @property
+    @cached_property
     def conf(self):
         """Loader configuration."""
-        if not self._conf_cache:
-            self._conf_cache = self.read_configuration()
-        return self._conf_cache
+        return self.read_configuration()

+ 1 - 9
celery/task/http.py

@@ -77,7 +77,7 @@ class MutableURL(object):
     """
     def __init__(self, url):
         self.parts = urlparse(url)
-        self._query = dict(parse_qsl(self.parts[4]))
+        self.query = dict(parse_qsl(self.parts[4]))
 
     def __str__(self):
         scheme, netloc, path, params, query, fragment = self.parts
@@ -93,14 +93,6 @@ class MutableURL(object):
     def __repr__(self):
         return "<%s: %s>" % (self.__class__.__name__, str(self))
 
-    def _get_query(self):
-        return self._query
-
-    def _set_query(self, query):
-        self._query = query
-
-    query = property(_get_query, _set_query)
-
 
 class HttpDispatch(object):
     """Make task HTTP request and collect the task result.

+ 1 - 1
celery/tests/test_loaders.py

@@ -43,7 +43,7 @@ class TestLoaderBase(unittest.TestCase):
 
     def test_conf_property(self):
         self.assertEqual(self.loader.conf["foo"], "bar")
-        self.assertEqual(self.loader._conf_cache["foo"], "bar")
+        self.assertEqual(self.loader.__dict__["conf"]["foo"], "bar")
         self.assertEqual(self.loader.conf["foo"], "bar")
 
     def test_import_default_modules(self):

+ 2 - 2
celery/tests/test_task_control.py

@@ -104,10 +104,10 @@ class test_Broadcast(unittest.TestCase):
     def setUp(self):
         self.app = app_or_default()
         self.control = Control(app=self.app)
-        self.app._control = self.control
+        self.app.control = self.control
 
     def tearDown(self):
-        self.app._control = None
+        del(self.app.control)
 
     def test_discard_all(self):
         self.control.discard_all()

+ 63 - 0
celery/utils/__init__.py

@@ -342,3 +342,66 @@ def import_from_cwd(module, imp=None):
             sys.path.remove(cwd)
         except ValueError:
             pass
+
+
+class cached_property(object):
+    """Property descriptor that caches the return value
+    of the get function.
+
+    *Examples*
+
+    .. code-block:: python
+
+        @cached_property
+        def connection(self):
+            return Connection()
+
+        @connection.setter  # Prepares stored value
+        def connection(self, value):
+            if value is None:
+                raise TypeError("Connection must be a connection")
+            return value
+
+        @connection.deleter
+        def connection(self):
+            # Additional action to do at del(self.attr)
+            print("Next access will give a new connection")
+
+    """
+
+    def __init__(self, fget=None, fset=None, fdel=None, doc=None):
+        self.__get = fget
+        self.__set = fset
+        self.__del = fdel
+        self.__doc__ = doc or fget.__doc__
+        self.__name__ = fget.__name__
+        self.__module__ = fget.__module__
+
+    def __get__(self, obj, type=None):
+        if not obj:
+            return self
+        try:
+            return obj.__dict__[self.__name__]
+        except KeyError:
+            value = obj.__dict__[self.__name__] = self.__get(obj)
+            return value
+
+    def __set__(self, obj, value):
+        if not obj:
+            return self
+        if self.__set is not None:
+            value = self.__set(obj, value)
+        obj.__dict__[self.__name__] = value
+
+    def __delete__(self, obj):
+        if not obj:
+            return self
+        if self.__del is not None:
+            self.__del(obj)
+        del(obj.__dict__[self.__name__])
+
+    def setter(self, fset):
+        return self.__class__(self.__get, fset, self.__del)
+
+    def deleter(self, fdel):
+        return self.__class__(self.__get, self.__set, fdel)

+ 8 - 8
celery/worker/state.py

@@ -1,5 +1,6 @@
 import shelve
 
+from celery.utils import cached_property
 from celery.utils.compat import defaultdict
 from celery.datastructures import LimitedSet
 
@@ -42,7 +43,7 @@ def task_ready(request):
 
 class Persistent(object):
     storage = shelve
-    _open = None
+    _is_open = False
 
     def __init__(self, filename):
         self.filename = filename
@@ -66,16 +67,15 @@ class Persistent(object):
         return self.storage.open(self.filename)
 
     def close(self):
-        if self._open:
-            self._open.close()
-            self._open = None
+        if self._is_open:
+            self.db.close()
+            self._is_open = False
 
     def _load(self):
         self.merge(self.db)
         self.close()
 
-    @property
+    @cached_property
     def db(self):
-        if self._open is None:
-            self._open = self.open()
-        return self._open
+        self._is_open = True
+        return self.open()