base.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.loaders.base
  4. ~~~~~~~~~~~~~~~~~~~
  5. Loader base class.
  6. """
  7. from __future__ import absolute_import
  8. import anyjson
  9. import imp
  10. import importlib
  11. import os
  12. import re
  13. from datetime import datetime
  14. from itertools import imap
  15. from kombu.utils import cached_property
  16. from kombu.utils.encoding import safe_str
  17. from celery.datastructures import DictAttribute
  18. from celery.exceptions import ImproperlyConfigured
  19. from celery.utils.imports import import_from_cwd, symbol_by_name
  20. from celery.utils.functional import maybe_list
  21. BUILTIN_MODULES = frozenset()
  22. ERROR_ENVVAR_NOT_SET = (
  23. """The environment variable {0!r} is not set,
  24. and as such the configuration could not be loaded.
  25. Please set this variable and make it point to
  26. a configuration module.""")
  27. _RACE_PROTECTION = False
  28. class BaseLoader(object):
  29. """The base class for loaders.
  30. Loaders handles,
  31. * Reading celery client/worker configurations.
  32. * What happens when a task starts?
  33. See :meth:`on_task_init`.
  34. * What happens when the worker starts?
  35. See :meth:`on_worker_init`.
  36. * What happens when the worker shuts down?
  37. See :meth:`on_worker_shutdown`.
  38. * What modules are imported to find tasks?
  39. """
  40. builtin_modules = BUILTIN_MODULES
  41. configured = False
  42. error_envvar_not_set = ERROR_ENVVAR_NOT_SET
  43. override_backends = {}
  44. worker_initialized = False
  45. _conf = None
  46. def __init__(self, app=None, **kwargs):
  47. from celery.app import app_or_default
  48. self.app = app_or_default(app)
  49. self.task_modules = set()
  50. def now(self, utc=True):
  51. if utc:
  52. return datetime.utcnow()
  53. return datetime.now()
  54. def on_task_init(self, task_id, task):
  55. """This method is called before a task is executed."""
  56. pass
  57. def on_process_cleanup(self):
  58. """This method is called after a task is executed."""
  59. pass
  60. def on_worker_init(self):
  61. """This method is called when the worker (:program:`celery worker`)
  62. starts."""
  63. pass
  64. def on_worker_shutdown(self):
  65. """This method is called when the worker (:program:`celery worker`)
  66. shuts down."""
  67. pass
  68. def on_worker_process_init(self):
  69. """This method is called when a child process starts."""
  70. pass
  71. def import_task_module(self, module):
  72. self.task_modules.add(module)
  73. return self.import_from_cwd(module)
  74. def import_module(self, module, package=None):
  75. return importlib.import_module(module, package=package)
  76. def import_from_cwd(self, module, imp=None, package=None):
  77. return import_from_cwd(module,
  78. self.import_module if imp is None else imp,
  79. package=package)
  80. def import_default_modules(self):
  81. return [self.import_task_module(m)
  82. for m in set(maybe_list(self.app.conf.CELERY_IMPORTS))
  83. | set(maybe_list(self.app.conf.CELERY_INCLUDE))
  84. | self.builtin_modules]
  85. def init_worker(self):
  86. if not self.worker_initialized:
  87. self.worker_initialized = True
  88. self.import_default_modules()
  89. self.on_worker_init()
  90. def shutdown_worker(self):
  91. self.on_worker_shutdown()
  92. def init_worker_process(self):
  93. self.on_worker_process_init()
  94. def config_from_envvar(self, variable_name, silent=False):
  95. module_name = os.environ.get(variable_name)
  96. if not module_name:
  97. if silent:
  98. return False
  99. raise ImproperlyConfigured(
  100. self.error_envvar_not_set.format(module_name))
  101. return self.config_from_object(module_name, silent=silent)
  102. def config_from_object(self, obj, silent=False):
  103. if isinstance(obj, basestring):
  104. try:
  105. if '.' in obj:
  106. obj = symbol_by_name(obj, imp=self.import_from_cwd)
  107. else:
  108. obj = self.import_from_cwd(obj)
  109. except (ImportError, AttributeError):
  110. if silent:
  111. return False
  112. raise
  113. if not hasattr(obj, '__getitem__'):
  114. obj = DictAttribute(obj)
  115. self._conf = obj
  116. return True
  117. def cmdline_config_parser(self, args, namespace='celery',
  118. re_type=re.compile(r'\((\w+)\)'),
  119. extra_types={'json': anyjson.loads},
  120. override_types={'tuple': 'json',
  121. 'list': 'json',
  122. 'dict': 'json'}):
  123. from celery.app.defaults import Option, NAMESPACES
  124. namespace = namespace.upper()
  125. typemap = dict(Option.typemap, **extra_types)
  126. def getarg(arg):
  127. """Parse a single configuration definition from
  128. the command-line."""
  129. ## find key/value
  130. # ns.key=value|ns_key=value (case insensitive)
  131. key, value = arg.split('=', 1)
  132. key = key.upper().replace('.', '_')
  133. ## find namespace.
  134. # .key=value|_key=value expands to default namespace.
  135. if key[0] == '_':
  136. ns, key = namespace, key[1:]
  137. else:
  138. # find namespace part of key
  139. ns, key = key.split('_', 1)
  140. ns_key = (ns and ns + '_' or '') + key
  141. # (type)value makes cast to custom type.
  142. cast = re_type.match(value)
  143. if cast:
  144. type_ = cast.groups()[0]
  145. type_ = override_types.get(type_, type_)
  146. value = value[len(cast.group()):]
  147. value = typemap[type_](value)
  148. else:
  149. try:
  150. value = NAMESPACES[ns][key].to_python(value)
  151. except ValueError as exc:
  152. # display key name in error message.
  153. raise ValueError('{0!r}: {1}'.format(ns_key, exc))
  154. return ns_key, value
  155. return dict(imap(getarg, args))
  156. def mail_admins(self, subject, body, fail_silently=False,
  157. sender=None, to=None, host=None, port=None,
  158. user=None, password=None, timeout=None,
  159. use_ssl=False, use_tls=False):
  160. message = self.mail.Message(sender=sender, to=to,
  161. subject=safe_str(subject),
  162. body=safe_str(body))
  163. mailer = self.mail.Mailer(host=host, port=port,
  164. user=user, password=password,
  165. timeout=timeout, use_ssl=use_ssl,
  166. use_tls=use_tls)
  167. mailer.send(message, fail_silently=fail_silently)
  168. def read_configuration(self):
  169. return {}
  170. def autodiscover_tasks(self, packages, related_name='tasks'):
  171. self.task_modules.update(mod.__name__
  172. for mod in autodiscover_tasks(packages, related_name) if mod
  173. )
  174. @property
  175. def conf(self):
  176. """Loader configuration."""
  177. if self._conf is None:
  178. self._conf = self.read_configuration()
  179. return self._conf
  180. @cached_property
  181. def mail(self):
  182. return self.import_module('celery.utils.mail')
  183. def autodiscover_tasks(packages, related_name='tasks'):
  184. global _RACE_PROTECTION
  185. if _RACE_PROTECTION:
  186. return
  187. _RACE_PROTECTION = True
  188. try:
  189. return [find_related_module(pkg, related_name) for pkg in packages]
  190. finally:
  191. _RACE_PROTECTION = False
  192. def find_related_module(package, related_name):
  193. """Given a package name and a module name, tries to find that
  194. module."""
  195. try:
  196. pkg_path = importlib.import_module(package).__path__
  197. except AttributeError:
  198. return
  199. try:
  200. imp.find_module(related_name, pkg_path)
  201. except ImportError:
  202. return
  203. return importlib.import_module('{0}.{1}'.format(package, related_name))