__init__.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import absolute_import
  2. import logging
  3. import os
  4. import sys
  5. import warnings
  6. from importlib import import_module
  7. try:
  8. WindowsError = WindowsError # noqa
  9. except NameError:
  10. class WindowsError(Exception):
  11. pass
  12. config_module = os.environ.setdefault(
  13. 'CELERY_TEST_CONFIG_MODULE', 'celery.tests.config',
  14. )
  15. os.environ.setdefault('CELERY_CONFIG_MODULE', config_module)
  16. os.environ['CELERY_LOADER'] = 'default'
  17. os.environ['EVENTLET_NOPATCH'] = 'yes'
  18. os.environ['GEVENT_NOPATCH'] = 'yes'
  19. os.environ['KOMBU_DISABLE_LIMIT_PROTECTION'] = 'yes'
  20. os.environ['CELERY_BROKER_URL'] = 'memory://'
  21. def setup():
  22. if os.environ.get('COVER_ALL_MODULES') or '--with-coverage3' in sys.argv:
  23. from celery.tests.utils import catch_warnings
  24. with catch_warnings(record=True):
  25. import_all_modules()
  26. warnings.resetwarnings()
  27. def teardown():
  28. # Don't want SUBDEBUG log messages at finalization.
  29. try:
  30. from multiprocessing.util import get_logger
  31. except ImportError:
  32. pass
  33. else:
  34. get_logger().setLevel(logging.WARNING)
  35. # Make sure test database is removed.
  36. import os
  37. if os.path.exists('test.db'):
  38. try:
  39. os.remove('test.db')
  40. except WindowsError:
  41. pass
  42. # Make sure there are no remaining threads at shutdown.
  43. import threading
  44. remaining_threads = [thread for thread in threading.enumerate()
  45. if thread.getName() != 'MainThread']
  46. if remaining_threads:
  47. sys.stderr.write(
  48. '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
  49. remaining_threads))
  50. def find_distribution_modules(name=__name__, file=__file__):
  51. current_dist_depth = len(name.split('.')) - 1
  52. current_dist = os.path.join(os.path.dirname(file),
  53. *([os.pardir] * current_dist_depth))
  54. abs = os.path.abspath(current_dist)
  55. dist_name = os.path.basename(abs)
  56. for dirpath, dirnames, filenames in os.walk(abs):
  57. package = (dist_name + dirpath[len(abs):]).replace('/', '.')
  58. if '__init__.py' in filenames:
  59. yield package
  60. for filename in filenames:
  61. if filename.endswith('.py') and filename != '__init__.py':
  62. yield '.'.join([package, filename])[:-3]
  63. def import_all_modules(name=__name__, file=__file__,
  64. skip=['celery.decorators', 'celery.contrib.batches']):
  65. for module in find_distribution_modules(name, file):
  66. if module not in skip:
  67. try:
  68. import_module(module)
  69. except ImportError:
  70. pass