123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- # -*- coding: utf-8 -*-
- """Utilities related to importing modules and symbols by name."""
- import imp as _imp
- import importlib
- import os
- import sys
- import warnings
- from contextlib import contextmanager
- from imp import reload
- from types import ModuleType
- from typing import Any, Callable, Iterator, Optional
- from kombu.utils.imports import symbol_by_name
- #: Billiard sets this when execv is enabled.
- #: We use it to find out the name of the original ``__main__``
- #: module, so that we can properly rewrite the name of the
- #: task to be that of ``App.main``.
- MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE')
- __all__ = [
- 'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name',
- 'cwd_in_path', 'find_module', 'import_from_cwd',
- 'reload_from_cwd', 'module_file', 'gen_task_name',
- ]
- class NotAPackage(Exception):
- """Raised when importing a package, but it's not a package."""
- def qualname(obj: Any) -> str:
- """Return object name."""
- if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
- obj = obj.__class__
- q = getattr(obj, '__qualname__', None)
- if '.' not in q:
- q = '.'.join((obj.__module__, q))
- return q
- def instantiate(name: Any, *args, **kwargs) -> Any:
- """Instantiate class by name.
- See Also:
- :func:`symbol_by_name`.
- """
- return symbol_by_name(name)(*args, **kwargs)
- @contextmanager
- def cwd_in_path() -> Iterator:
- """Context adding the current working directory to sys.path."""
- cwd = os.getcwd()
- if cwd in sys.path:
- yield
- else:
- sys.path.insert(0, cwd)
- try:
- yield cwd
- finally:
- try:
- sys.path.remove(cwd)
- except ValueError: # pragma: no cover
- pass
- def find_module(module: str,
- path: Optional[str]=None,
- imp: Optional[Callable]=None) -> ModuleType:
- """Version of :func:`imp.find_module` supporting dots."""
- if imp is None:
- imp = importlib.import_module
- with cwd_in_path():
- if '.' in module:
- last = None
- parts = module.split('.')
- for i, part in enumerate(parts[:-1]):
- mpart = imp('.'.join(parts[:i + 1]))
- try:
- path = mpart.__path__
- except AttributeError:
- raise NotAPackage(module)
- last = _imp.find_module(parts[i + 1], path)
- return last
- return _imp.find_module(module)
- def import_from_cwd(module: str,
- imp: Optional[Callable]=None,
- package: Optional[str]=None) -> ModuleType:
- """Import module, temporarily including modules in the current directory.
- Modules located in the current directory has
- precedence over modules located in `sys.path`.
- """
- if imp is None:
- imp = importlib.import_module
- with cwd_in_path():
- return imp(module, package=package)
- def reload_from_cwd(module: ModuleType,
- reloader: Optional[Callable]=None) -> Any:
- """Reload module (ensuring that CWD is in sys.path)."""
- if reloader is None:
- reloader = reload
- with cwd_in_path():
- return reloader(module)
- def module_file(module: ModuleType) -> str:
- """Return the correct original file name of a module."""
- name = module.__file__
- return name[:-1] if name.endswith('.pyc') else name
- def gen_task_name(app: Any, name: str, module_name: str) -> str:
- """Generate task name from name/module pair."""
- module_name = module_name or '__main__'
- try:
- module = sys.modules[module_name]
- except KeyError:
- # Fix for manage.py shell_plus (Issue #366)
- module = None
- if module is not None:
- module_name = module.__name__
- # - If the task module is used as the __main__ script
- # - we need to rewrite the module part of the task name
- # - to match App.main.
- if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE:
- # - see comment about :envvar:`MP_MAIN_FILE` above.
- module_name = '__main__'
- if module_name == '__main__' and app.main:
- return '.'.join([app.main, name])
- return '.'.join(p for p in (module_name, name) if p)
- def load_extension_class_names(namespace):
- try:
- from pkg_resources import iter_entry_points
- except ImportError: # pragma: no cover
- return
- for ep in iter_entry_points(namespace):
- yield ep.name, ':'.join([ep.module_name, ep.attrs[0]])
- def load_extension_classes(namespace):
- for name, class_name in load_extension_class_names(namespace):
- try:
- cls = symbol_by_name(class_name)
- except (ImportError, SyntaxError) as exc:
- warnings.warn(
- 'Cannot load {0} extension {1!r}: {2!r}'.format(
- namespace, class_name, exc))
- else:
- yield name, cls
|