app.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from __future__ import absolute_import, unicode_literals
  2. import weakref
  3. from contextlib import contextmanager
  4. from copy import deepcopy
  5. from kombu.utils.imports import symbol_by_name
  6. from celery import Celery
  7. from celery import _state
  8. DEFAULT_TEST_CONFIG = {
  9. 'worker_hijack_root_logger': False,
  10. 'worker_log_color': False,
  11. 'accept_content': {'json'},
  12. 'enable_utc': True,
  13. 'timezone': 'UTC',
  14. 'broker_url': 'memory://',
  15. 'result_backend': 'cache+memory://'
  16. }
  17. class Trap(object):
  18. """Trap that pretends to be an app but raises an exception instead.
  19. This to protect from code that does not properly pass app instances,
  20. then falls back to the current_app.
  21. """
  22. def __getattr__(self, name):
  23. raise RuntimeError('Test depends on current_app')
  24. class UnitLogging(symbol_by_name(Celery.log_cls)):
  25. """Sets up logging for the test application."""
  26. def __init__(self, *args, **kwargs):
  27. super(UnitLogging, self).__init__(*args, **kwargs)
  28. self.already_setup = True
  29. def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
  30. log=UnitLogging, backend=None, broker=None, **kwargs):
  31. """App used for testing."""
  32. from . import tasks # noqa
  33. config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
  34. if broker is not None:
  35. config.pop('broker_url', None)
  36. if backend is not None:
  37. config.pop('result_backend', None)
  38. log = None if enable_logging else log
  39. test_app = Celery(
  40. name or 'celery.tests',
  41. set_as_current=set_as_current,
  42. log=log,
  43. broker=broker,
  44. backend=backend,
  45. **kwargs)
  46. test_app.add_defaults(config)
  47. return test_app
  48. @contextmanager
  49. def set_trap(app):
  50. trap = Trap()
  51. prev_tls = _state._tls
  52. _state.set_default_app(trap)
  53. class NonTLS(object):
  54. current_app = trap
  55. _state._tls = NonTLS()
  56. yield
  57. _state._tls = prev_tls
  58. @contextmanager
  59. def setup_default_app(app, use_trap=False):
  60. prev_current_app = _state.get_current_app()
  61. prev_default_app = _state.default_app
  62. prev_finalizers = set(_state._on_app_finalizers)
  63. prev_apps = weakref.WeakSet(_state._apps)
  64. if use_trap:
  65. with set_trap(app):
  66. yield
  67. else:
  68. yield
  69. _state.set_default_app(prev_default_app)
  70. _state._tls.current_app = prev_current_app
  71. if app is not prev_current_app:
  72. app.close()
  73. _state._on_app_finalizers = prev_finalizers
  74. _state._apps = prev_apps