فهرست منبع

Use imp.find_module to ensure the config module exists, so import errors in the config module is not silenced

Ask Solem 14 سال پیش
والد
کامیت
3de4753768
2فایلهای تغییر یافته به همراه39 افزوده شده و 11 حذف شده
  1. 3 1
      celery/loaders/default.py
  2. 36 10
      celery/utils/__init__.py

+ 3 - 1
celery/loaders/default.py

@@ -8,6 +8,7 @@ from importlib import import_module
 from celery.datastructures import AttributeDict
 from celery.exceptions import NotConfigured
 from celery.loaders.base import BaseLoader
+from celery.utils import find_module
 
 DEFAULT_CONFIG_MODULE = "celeryconfig"
 
@@ -24,13 +25,14 @@ class Loader(BaseLoader):
         configname = os.environ.get("CELERY_CONFIG_MODULE",
                                      DEFAULT_CONFIG_MODULE)
         try:
-            celeryconfig = self.import_from_cwd(configname)
+            find_module(configname)
         except ImportError:
             warnings.warn(NotConfigured(
                 "No %r module found! Please make sure it exists and "
                 "is available to Python." % (configname, )))
             return self.setup_settings({})
         else:
+            celeryconfig = self.import_from_cwd(configname)
             usercfg = dict((key, getattr(celeryconfig, key))
                             for key in dir(celeryconfig)
                                 if self.wanted_module_item(key))

+ 36 - 10
celery/utils/__init__.py

@@ -1,12 +1,16 @@
+from __future__ import absolute_import, with_statement
+
 import os
 import sys
 import operator
+import imp as _imp
 import importlib
 import logging
 import threading
 import traceback
 import warnings
 
+from contextlib import contextmanager
 from functools import partial, wraps
 from inspect import getargspec
 from itertools import islice
@@ -358,6 +362,37 @@ def textindent(t, indent=0):
         return "\n".join(" " * indent + p for p in t.split("\n"))
 
 
+@contextmanager
+def cwd_in_path():
+    cwd = os.getcwd()
+    if cwd in sys.path:
+        yield
+    else:
+        sys.path.insert(0, cwd)
+        try:
+            yield cwd
+        finally:
+            try:
+                sys.path.remove(cwd)
+            except ValueError:
+                pass
+
+
+def find_module(module, path=None, imp=None):
+    if imp is None:
+        imp = importlib.import_module
+    with cwd_in_path():
+        if "." in module:
+            last = None
+            parts = module.split(".")
+            for i, part in enumerate(parts[:-1]):
+                path = imp(part).__path__
+                last = _imp.find_module(parts[i+1], path)
+            return last
+        return _imp.find_module(module)
+
+
+
 def import_from_cwd(module, imp=None):
     """Import module, but make sure it finds modules
     located in the current directory.
@@ -367,17 +402,8 @@ def import_from_cwd(module, imp=None):
     """
     if imp is None:
         imp = importlib.import_module
-    cwd = os.getcwd()
-    if cwd in sys.path:
-        return imp(module)
-    sys.path.insert(0, cwd)
-    try:
+    with cwd_in_path():
         return imp(module)
-    finally:
-        try:
-            sys.path.remove(cwd)
-        except ValueError:
-            pass
 
 
 def cry():