imports.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.utils.import
  4. ~~~~~~~~~~~~~~~~~~~
  5. Utilities related to importing modules and symbols by name.
  6. """
  7. from __future__ import absolute_import
  8. import imp as _imp
  9. import importlib
  10. import os
  11. import sys
  12. from contextlib import contextmanager
  13. from kombu.utils import symbol_by_name
  14. from celery.five import reload
  15. __all__ = [
  16. 'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name', 'cwd_in_path',
  17. 'find_module', 'import_from_cwd', 'reload_from_cwd', 'module_file',
  18. ]
  19. class NotAPackage(Exception):
  20. pass
  21. if sys.version_info > (3, 3): # pragma: no cover
  22. def qualname(obj):
  23. if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
  24. obj = obj.__class__
  25. q = getattr(obj, '__qualname__', None)
  26. if '.' not in q:
  27. q = '.'.join((obj.__module__, q))
  28. return q
  29. else:
  30. def qualname(obj): # noqa
  31. if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
  32. obj = obj.__class__
  33. return '.'.join((obj.__module__, obj.__name__))
  34. def instantiate(name, *args, **kwargs):
  35. """Instantiate class by name.
  36. See :func:`symbol_by_name`.
  37. """
  38. return symbol_by_name(name)(*args, **kwargs)
  39. @contextmanager
  40. def cwd_in_path():
  41. cwd = os.getcwd()
  42. if cwd in sys.path:
  43. yield
  44. else:
  45. sys.path.insert(0, cwd)
  46. try:
  47. yield cwd
  48. finally:
  49. try:
  50. sys.path.remove(cwd)
  51. except ValueError: # pragma: no cover
  52. pass
  53. def find_module(module, path=None, imp=None):
  54. """Version of :func:`imp.find_module` supporting dots."""
  55. if imp is None:
  56. imp = importlib.import_module
  57. with cwd_in_path():
  58. if '.' in module:
  59. last = None
  60. parts = module.split('.')
  61. for i, part in enumerate(parts[:-1]):
  62. mpart = imp('.'.join(parts[:i + 1]))
  63. try:
  64. path = mpart.__path__
  65. except AttributeError:
  66. raise NotAPackage(module)
  67. last = _imp.find_module(parts[i + 1], path)
  68. return last
  69. return _imp.find_module(module)
  70. def import_from_cwd(module, imp=None, package=None):
  71. """Import module, but make sure it finds modules
  72. located in the current directory.
  73. Modules located in the current directory has
  74. precedence over modules located in `sys.path`.
  75. """
  76. if imp is None:
  77. imp = importlib.import_module
  78. with cwd_in_path():
  79. return imp(module, package=package)
  80. def reload_from_cwd(module, reloader=None):
  81. if reloader is None:
  82. reloader = reload
  83. with cwd_in_path():
  84. return reloader(module)
  85. def module_file(module):
  86. """Return the correct original file name of a module."""
  87. name = module.__file__
  88. return name[:-1] if name.endswith('.pyc') else name