imports.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from __future__ import absolute_import
  2. from __future__ import with_statement
  3. import imp as _imp
  4. import importlib
  5. import os
  6. import sys
  7. from contextlib import contextmanager
  8. from .compat import reload
  9. class NotAPackage(Exception):
  10. pass
  11. if sys.version_info >= (3, 3): # pragma: no cover
  12. def qualname(obj):
  13. return obj.__qualname__
  14. else:
  15. def qualname(obj): # noqa
  16. if not hasattr(obj, "__name__") and hasattr(obj, "__class__"):
  17. return qualname(obj.__class__)
  18. return '.'.join([obj.__module__, obj.__name__])
  19. def symbol_by_name(name, aliases={}, imp=None, package=None,
  20. sep='.', default=None, **kwargs):
  21. """Get symbol by qualified name.
  22. The name should be the full dot-separated path to the class::
  23. modulename.ClassName
  24. Example::
  25. celery.concurrency.processes.TaskPool
  26. ^- class name
  27. or using ':' to separate module and symbol::
  28. celery.concurrency.processes:TaskPool
  29. If `aliases` is provided, a dict containing short name/long name
  30. mappings, the name is looked up in the aliases first.
  31. Examples:
  32. >>> symbol_by_name("celery.concurrency.processes.TaskPool")
  33. <class 'celery.concurrency.processes.TaskPool'>
  34. >>> symbol_by_name("default", {
  35. ... "default": "celery.concurrency.processes.TaskPool"})
  36. <class 'celery.concurrency.processes.TaskPool'>
  37. # Does not try to look up non-string names.
  38. >>> from celery.concurrency.processes import TaskPool
  39. >>> symbol_by_name(TaskPool) is TaskPool
  40. True
  41. """
  42. if imp is None:
  43. imp = importlib.import_module
  44. if not isinstance(name, basestring):
  45. return name # already a class
  46. name = aliases.get(name) or name
  47. sep = ':' if ':' in name else sep
  48. module_name, _, cls_name = name.rpartition(sep)
  49. if not module_name and package:
  50. module_name = package
  51. try:
  52. try:
  53. module = imp(module_name, package=package, **kwargs)
  54. except ValueError, exc:
  55. raise ValueError, ValueError(
  56. "Couldn't import %r: %s" % (name, exc)), sys.exc_info()[2]
  57. return getattr(module, cls_name)
  58. except (ImportError, AttributeError):
  59. if default is None:
  60. raise
  61. return default
  62. def instantiate(name, *args, **kwargs):
  63. """Instantiate class by name.
  64. See :func:`symbol_by_name`.
  65. """
  66. return symbol_by_name(name)(*args, **kwargs)
  67. @contextmanager
  68. def cwd_in_path():
  69. cwd = os.getcwd()
  70. if cwd in sys.path:
  71. yield
  72. else:
  73. sys.path.insert(0, cwd)
  74. try:
  75. yield cwd
  76. finally:
  77. try:
  78. sys.path.remove(cwd)
  79. except ValueError: # pragma: no cover
  80. pass
  81. def find_module(module, path=None, imp=None):
  82. """Version of :func:`imp.find_module` supporting dots."""
  83. if imp is None:
  84. imp = importlib.import_module
  85. with cwd_in_path():
  86. if "." in module:
  87. last = None
  88. parts = module.split(".")
  89. for i, part in enumerate(parts[:-1]):
  90. mpart = imp(".".join(parts[:i + 1]))
  91. try:
  92. path = mpart.__path__
  93. except AttributeError:
  94. raise NotAPackage(module)
  95. last = _imp.find_module(parts[i + 1], path)
  96. return last
  97. return _imp.find_module(module)
  98. def import_from_cwd(module, imp=None, package=None):
  99. """Import module, but make sure it finds modules
  100. located in the current directory.
  101. Modules located in the current directory has
  102. precedence over modules located in `sys.path`.
  103. """
  104. if imp is None:
  105. imp = importlib.import_module
  106. with cwd_in_path():
  107. return imp(module, package=package)
  108. def reload_from_cwd(module, reloader=None):
  109. if reloader is None:
  110. reloader = reload
  111. with cwd_in_path():
  112. return reloader(module)
  113. def module_file(module):
  114. name = module.__file__
  115. return name[:-1] if name.endswith(".pyc") else name