app.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """Create Celery app instances used for testing."""
  2. import weakref
  3. from contextlib import contextmanager
  4. from copy import deepcopy
  5. from kombu.utils.imports import symbol_by_name
  6. from celery import Celery
  7. from celery import _state
  8. #: Contains the default configuration values for the test app.
  9. DEFAULT_TEST_CONFIG = {
  10. 'worker_hijack_root_logger': False,
  11. 'worker_log_color': False,
  12. 'accept_content': {'json'},
  13. 'enable_utc': True,
  14. 'timezone': 'UTC',
  15. 'broker_url': 'memory://',
  16. 'result_backend': 'cache+memory://'
  17. }
  18. class Trap(object):
  19. """Trap that pretends to be an app but raises an exception instead.
  20. This to protect from code that does not properly pass app instances,
  21. then falls back to the current_app.
  22. """
  23. def __getattr__(self, name):
  24. raise RuntimeError('Test depends on current_app')
  25. class UnitLogging(symbol_by_name(Celery.log_cls)):
  26. """Sets up logging for the test application."""
  27. def __init__(self, *args, **kwargs):
  28. super(UnitLogging, self).__init__(*args, **kwargs)
  29. self.already_setup = True
  30. def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
  31. log=UnitLogging, backend=None, broker=None, **kwargs):
  32. """App used for testing."""
  33. from . import tasks # noqa
  34. config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
  35. if broker is not None:
  36. config.pop('broker_url', None)
  37. if backend is not None:
  38. config.pop('result_backend', None)
  39. log = None if enable_logging else log
  40. test_app = Celery(
  41. name or 'celery.tests',
  42. set_as_current=set_as_current,
  43. log=log,
  44. broker=broker,
  45. backend=backend,
  46. **kwargs)
  47. test_app.add_defaults(config)
  48. return test_app
  49. @contextmanager
  50. def set_trap(app):
  51. """Contextmanager that installs the trap app.
  52. The trap means that anything trying to use the current or default app
  53. will raise an exception.
  54. """
  55. trap = Trap()
  56. prev_tls = _state._tls
  57. _state.set_default_app(trap)
  58. class NonTLS(object):
  59. current_app = trap
  60. _state._tls = NonTLS()
  61. yield
  62. _state._tls = prev_tls
  63. @contextmanager
  64. def setup_default_app(app, use_trap=False):
  65. """Setup default app for testing.
  66. Ensures state is clean after the test returns.
  67. """
  68. prev_current_app = _state.get_current_app()
  69. prev_default_app = _state.default_app
  70. prev_finalizers = set(_state._on_app_finalizers)
  71. prev_apps = weakref.WeakSet(_state._apps)
  72. if use_trap:
  73. with set_trap(app):
  74. yield
  75. else:
  76. yield
  77. _state.set_default_app(prev_default_app)
  78. _state._tls.current_app = prev_current_app
  79. if app is not prev_current_app:
  80. app.close()
  81. _state._on_app_finalizers = prev_finalizers
  82. _state._apps = prev_apps