Przeglądaj źródła

Added celery.utils.get_cls_by_name, and refactored the backend loader to use it.

Ask Solem 15 lat temu
rodzic
commit
c90a4f3919
2 zmienionych plików z 17 dodań i 16 usunięć
  1. 2 16
      celery/backends/__init__.py
  2. 15 0
      celery/utils/__init__.py

+ 2 - 16
celery/backends/__init__.py

@@ -1,9 +1,7 @@
-import importlib
-
 from billiard.utils.functional import curry
-from carrot.utils import rpartition
 
 from celery import conf
+from celery.utils import get_cls_by_name
 
 BACKEND_ALIASES = {
     "amqp": "celery.backends.amqp.AMQPBackend",
@@ -18,22 +16,10 @@ BACKEND_ALIASES = {
 _backend_cache = {}
 
 
-def resolve_backend(backend):
-    backend = BACKEND_ALIASES.get(backend, backend)
-    backend_module_name, _, backend_cls_name = rpartition(backend, ".")
-    return backend_module_name, backend_cls_name
-
-
-def _get_backend_cls(backend):
-    backend_module_name, backend_cls_name = resolve_backend(backend)
-    backend_module = importlib.import_module(backend_module_name)
-    return getattr(backend_module, backend_cls_name)
-
-
 def get_backend_cls(backend):
     """Get backend class by name/alias"""
     if backend not in _backend_cache:
-        _backend_cache[backend] = _get_backend_cls(backend)
+        _backend_cache[backend] = get_cls_by_name(backend, BACKEND_ALIASES)
     return _backend_cache[backend]
 
 

+ 15 - 0
celery/utils/__init__.py

@@ -10,10 +10,12 @@ try:
     import ctypes
 except ImportError:
     ctypes = None
+import importlib
 from uuid import UUID, uuid4, _uuid_generate_random
 from inspect import getargspec
 from itertools import islice
 
+from carrot.utils import rpartition
 from billiard.utils.functional import curry
 
 from celery.utils.compat import all, any, defaultdict
@@ -188,3 +190,16 @@ def timedelta_seconds(delta):
     if delta.days < 0:
         return 0
     return delta.days * 86400 + delta.seconds + (delta.microseconds / 10e5)
+
+
+def get_cls_by_name(name, aliases={}):
+    name = aliases.get(name) or name
+    module_name, _, cls_name = rpartition(name, ".")
+    module = importlib.import_module(module_name)
+    return getattr(module, cls_name)
+
+
+def instantiate(name, aliases={}, *args, **kwargs):
+    return _get_cls_by_name(name, aliases)(*args, **kwargs)
+
+