Browse Source

Use imp.find_module to ensure the config module exists, so import errors in the config module is not silenced

Ask Solem 14 năm trước cách đây
mục cha
commit
3de4753768
2 tập tin đã thay đổi với 39 bổ sung11 xóa
  1. 3 1
      celery/loaders/default.py
  2. 36 10
      celery/utils/__init__.py

+ 3 - 1
celery/loaders/default.py

@@ -8,6 +8,7 @@ from importlib import import_module
 from celery.datastructures import AttributeDict
 from celery.exceptions import NotConfigured
 from celery.loaders.base import BaseLoader
+from celery.utils import find_module
 
 DEFAULT_CONFIG_MODULE = "celeryconfig"
 
@@ -24,13 +25,14 @@ class Loader(BaseLoader):
         configname = os.environ.get("CELERY_CONFIG_MODULE",
                                      DEFAULT_CONFIG_MODULE)
         try:
-            celeryconfig = self.import_from_cwd(configname)
+            find_module(configname)
         except ImportError:
             warnings.warn(NotConfigured(
                 "No %r module found! Please make sure it exists and "
                 "is available to Python." % (configname, )))
             return self.setup_settings({})
         else:
+            celeryconfig = self.import_from_cwd(configname)
             usercfg = dict((key, getattr(celeryconfig, key))
                             for key in dir(celeryconfig)
                                 if self.wanted_module_item(key))

+ 36 - 10
celery/utils/__init__.py

@@ -1,12 +1,16 @@
+from __future__ import absolute_import, with_statement
+
 import os
 import sys
 import operator
+import imp as _imp
 import importlib
 import logging
 import threading
 import traceback
 import warnings
 
+from contextlib import contextmanager
 from functools import partial, wraps
 from inspect import getargspec
 from itertools import islice
@@ -358,6 +362,37 @@ def textindent(t, indent=0):
         return "\n".join(" " * indent + p for p in t.split("\n"))
 
 
+@contextmanager
+def cwd_in_path():
+    cwd = os.getcwd()
+    if cwd in sys.path:
+        yield
+    else:
+        sys.path.insert(0, cwd)
+        try:
+            yield cwd
+        finally:
+            try:
+                sys.path.remove(cwd)
+            except ValueError:
+                pass
+
+
+def find_module(module, path=None, imp=None):
+    if imp is None:
+        imp = importlib.import_module
+    with cwd_in_path():
+        if "." in module:
+            last = None
+            parts = module.split(".")
+            for i, part in enumerate(parts[:-1]):
+                path = imp(part).__path__
+                last = _imp.find_module(parts[i+1], path)
+            return last
+        return _imp.find_module(module)
+
+
+
 def import_from_cwd(module, imp=None):
     """Import module, but make sure it finds modules
     located in the current directory.
@@ -367,17 +402,8 @@ def import_from_cwd(module, imp=None):
     """
     if imp is None:
         imp = importlib.import_module
-    cwd = os.getcwd()
-    if cwd in sys.path:
-        return imp(module)
-    sys.path.insert(0, cwd)
-    try:
+    with cwd_in_path():
         return imp(module)
-    finally:
-        try:
-            sys.path.remove(cwd)
-        except ValueError:
-            pass
 
 
 def cry():