imports.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # -*- coding: utf-8 -*-
  2. """Utilities related to importing modules and symbols by name."""
  3. import imp as _imp
  4. import importlib
  5. import os
  6. import sys
  7. import warnings
  8. from contextlib import contextmanager
  9. from imp import reload
  10. from types import ModuleType
  11. from typing import Any, Callable, Iterator, Optional
  12. from kombu.utils.imports import symbol_by_name
  13. #: Billiard sets this when execv is enabled.
  14. #: We use it to find out the name of the original ``__main__``
  15. #: module, so that we can properly rewrite the name of the
  16. #: task to be that of ``App.main``.
  17. MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE')
  18. __all__ = [
  19. 'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name',
  20. 'cwd_in_path', 'find_module', 'import_from_cwd',
  21. 'reload_from_cwd', 'module_file', 'gen_task_name',
  22. ]
  23. class NotAPackage(Exception):
  24. """Raised when importing a package, but it's not a package."""
  25. def qualname(obj: Any) -> str:
  26. """Return object name."""
  27. if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
  28. obj = obj.__class__
  29. q = getattr(obj, '__qualname__', None)
  30. if '.' not in q:
  31. q = '.'.join((obj.__module__, q))
  32. return q
  33. def instantiate(name: Any, *args, **kwargs) -> Any:
  34. """Instantiate class by name.
  35. See Also:
  36. :func:`symbol_by_name`.
  37. """
  38. return symbol_by_name(name)(*args, **kwargs)
  39. @contextmanager
  40. def cwd_in_path() -> Iterator:
  41. """Context adding the current working directory to sys.path."""
  42. cwd = os.getcwd()
  43. if cwd in sys.path:
  44. yield
  45. else:
  46. sys.path.insert(0, cwd)
  47. try:
  48. yield cwd
  49. finally:
  50. try:
  51. sys.path.remove(cwd)
  52. except ValueError: # pragma: no cover
  53. pass
  54. def find_module(module: str,
  55. path: Optional[str]=None,
  56. imp: Optional[Callable]=None) -> ModuleType:
  57. """Version of :func:`imp.find_module` supporting dots."""
  58. if imp is None:
  59. imp = importlib.import_module
  60. with cwd_in_path():
  61. if '.' in module:
  62. last = None
  63. parts = module.split('.')
  64. for i, part in enumerate(parts[:-1]):
  65. mpart = imp('.'.join(parts[:i + 1]))
  66. try:
  67. path = mpart.__path__
  68. except AttributeError:
  69. raise NotAPackage(module)
  70. last = _imp.find_module(parts[i + 1], path)
  71. return last
  72. return _imp.find_module(module)
  73. def import_from_cwd(module: str,
  74. imp: Optional[Callable]=None,
  75. package: Optional[str]=None) -> ModuleType:
  76. """Import module, temporarily including modules in the current directory.
  77. Modules located in the current directory has
  78. precedence over modules located in `sys.path`.
  79. """
  80. if imp is None:
  81. imp = importlib.import_module
  82. with cwd_in_path():
  83. return imp(module, package=package)
  84. def reload_from_cwd(module: ModuleType,
  85. reloader: Optional[Callable]=None) -> Any:
  86. """Reload module (ensuring that CWD is in sys.path)."""
  87. if reloader is None:
  88. reloader = reload
  89. with cwd_in_path():
  90. return reloader(module)
  91. def module_file(module: ModuleType) -> str:
  92. """Return the correct original file name of a module."""
  93. name = module.__file__
  94. return name[:-1] if name.endswith('.pyc') else name
  95. def gen_task_name(app: Any, name: str, module_name: str) -> str:
  96. """Generate task name from name/module pair."""
  97. module_name = module_name or '__main__'
  98. try:
  99. module = sys.modules[module_name]
  100. except KeyError:
  101. # Fix for manage.py shell_plus (Issue #366)
  102. module = None
  103. if module is not None:
  104. module_name = module.__name__
  105. # - If the task module is used as the __main__ script
  106. # - we need to rewrite the module part of the task name
  107. # - to match App.main.
  108. if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE:
  109. # - see comment about :envvar:`MP_MAIN_FILE` above.
  110. module_name = '__main__'
  111. if module_name == '__main__' and app.main:
  112. return '.'.join([app.main, name])
  113. return '.'.join(p for p in (module_name, name) if p)
  114. def load_extension_class_names(namespace):
  115. try:
  116. from pkg_resources import iter_entry_points
  117. except ImportError: # pragma: no cover
  118. return
  119. for ep in iter_entry_points(namespace):
  120. yield ep.name, ':'.join([ep.module_name, ep.attrs[0]])
  121. def load_extension_classes(namespace):
  122. for name, class_name in load_extension_class_names(namespace):
  123. try:
  124. cls = symbol_by_name(class_name)
  125. except (ImportError, SyntaxError) as exc:
  126. warnings.warn(
  127. 'Cannot load {0} extension {1!r}: {2!r}'.format(
  128. namespace, class_name, exc))
  129. else:
  130. yield name, cls