bootsteps.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # -*- coding: utf-8 -*-
  2. """A directed acyclic graph of reusable components."""
  3. from collections import deque
  4. from threading import Event
  5. from kombu.common import ignore_errors
  6. from kombu.utils import symbol_by_name
  7. from kombu.utils.encoding import bytes_to_str
  8. from .utils.graph import DependencyGraph, GraphFormatter
  9. from .utils.imports import instantiate, qualname
  10. from .utils.log import get_logger
  11. try:
  12. from greenlet import GreenletExit
  13. except ImportError: # pragma: no cover
  14. IGNORE_ERRORS = ()
  15. else:
  16. IGNORE_ERRORS = (GreenletExit,)
  17. __all__ = ['Blueprint', 'Step', 'StartStopStep', 'ConsumerStep']
  18. #: States
  19. RUN = 0x1
  20. CLOSE = 0x2
  21. TERMINATE = 0x3
  22. logger = get_logger(__name__)
  23. def _pre(ns, fmt):
  24. return '| {0}: {1}'.format(ns.alias, fmt)
  25. def _label(s):
  26. return s.name.rsplit('.', 1)[-1]
  27. class StepFormatter(GraphFormatter):
  28. """Graph formatter for :class:`Blueprint`."""
  29. blueprint_prefix = '⧉'
  30. conditional_prefix = '∘'
  31. blueprint_scheme = {
  32. 'shape': 'parallelogram',
  33. 'color': 'slategray4',
  34. 'fillcolor': 'slategray3',
  35. }
  36. def label(self, step):
  37. return step and '{0}{1}'.format(
  38. self._get_prefix(step),
  39. bytes_to_str(
  40. (step.label or _label(step)).encode('utf-8', 'ignore')),
  41. )
  42. def _get_prefix(self, step):
  43. if step.last:
  44. return self.blueprint_prefix
  45. if step.conditional:
  46. return self.conditional_prefix
  47. return ''
  48. def node(self, obj, **attrs):
  49. scheme = self.blueprint_scheme if obj.last else self.node_scheme
  50. return self.draw_node(obj, scheme, attrs)
  51. def edge(self, a, b, **attrs):
  52. if a.last:
  53. attrs.update(arrowhead='none', color='darkseagreen3')
  54. return self.draw_edge(a, b, self.edge_scheme, attrs)
  55. class Blueprint:
  56. """Blueprint containing bootsteps that can be applied to objects.
  57. Arguments:
  58. steps Sequence[Union[str, Step]]: List of steps.
  59. name (str): Set explicit name for this blueprint.
  60. app (~@Celery): Set the Celery app for this blueprint.
  61. on_start (Callable): Optional callback applied after blueprint start.
  62. on_close (Callable): Optional callback applied before blueprint close.
  63. on_stopped (Callable): Optional callback applied after
  64. blueprint stopped.
  65. """
  66. GraphFormatter = StepFormatter
  67. name = None
  68. state = None
  69. started = 0
  70. default_steps = set()
  71. state_to_name = {
  72. 0: 'initializing',
  73. RUN: 'running',
  74. CLOSE: 'closing',
  75. TERMINATE: 'terminating',
  76. }
  77. def __init__(self, steps=None, name=None, app=None,
  78. on_start=None, on_close=None, on_stopped=None):
  79. self.app = app
  80. self.name = name or self.name or qualname(type(self))
  81. self.types = set(steps or []) | set(self.default_steps)
  82. self.on_start = on_start
  83. self.on_close = on_close
  84. self.on_stopped = on_stopped
  85. self.shutdown_complete = Event()
  86. self.steps = {}
  87. def start(self, parent):
  88. self.state = RUN
  89. if self.on_start:
  90. self.on_start()
  91. for i, step in enumerate(s for s in parent.steps if s is not None):
  92. self._debug('Starting %s', step.alias)
  93. self.started = i + 1
  94. step.start(parent)
  95. logger.debug('^-- substep ok')
  96. def human_state(self):
  97. return self.state_to_name[self.state or 0]
  98. def info(self, parent):
  99. info = {}
  100. for step in parent.steps:
  101. info.update(step.info(parent) or {})
  102. return info
  103. def close(self, parent):
  104. if self.on_close:
  105. self.on_close()
  106. self.send_all(parent, 'close', 'closing', reverse=False)
  107. def restart(self, parent, method='stop',
  108. description='restarting', propagate=False):
  109. self.send_all(parent, method, description, propagate=propagate)
  110. def send_all(self, parent, method,
  111. description=None, reverse=True, propagate=True, args=()):
  112. description = description or method.replace('_', ' ')
  113. steps = reversed(parent.steps) if reverse else parent.steps
  114. for step in steps:
  115. if step:
  116. fun = getattr(step, method, None)
  117. if fun is not None:
  118. self._debug('%s %s...',
  119. description.capitalize(), step.alias)
  120. try:
  121. fun(parent, *args)
  122. except Exception as exc:
  123. if propagate:
  124. raise
  125. logger.error(
  126. 'Error on %s %s: %r',
  127. description, step.alias, exc, exc_info=1,
  128. )
  129. def stop(self, parent, close=True, terminate=False):
  130. what = 'terminating' if terminate else 'stopping'
  131. if self.state in (CLOSE, TERMINATE):
  132. return
  133. if self.state != RUN or self.started != len(parent.steps):
  134. # Not fully started, can safely exit.
  135. self.state = TERMINATE
  136. self.shutdown_complete.set()
  137. return
  138. self.close(parent)
  139. self.state = CLOSE
  140. self.restart(
  141. parent, 'terminate' if terminate else 'stop',
  142. description=what, propagate=False,
  143. )
  144. if self.on_stopped:
  145. self.on_stopped()
  146. self.state = TERMINATE
  147. self.shutdown_complete.set()
  148. def join(self, timeout=None):
  149. try:
  150. # Will only get here if running green,
  151. # makes sure all greenthreads have exited.
  152. self.shutdown_complete.wait(timeout=timeout)
  153. except IGNORE_ERRORS:
  154. pass
  155. def apply(self, parent, **kwargs):
  156. """Apply the steps in this blueprint to an object.
  157. This will apply the ``__init__`` and ``include`` methods
  158. of each step, with the object as argument::
  159. step = Step(obj)
  160. ...
  161. step.include(obj)
  162. For :class:`StartStopStep` the services created
  163. will also be added to the objects ``steps`` attribute.
  164. """
  165. self._debug('Preparing bootsteps.')
  166. order = self.order = []
  167. steps = self.steps = self.claim_steps()
  168. self._debug('Building graph...')
  169. for S in self._finalize_steps(steps):
  170. step = S(parent, **kwargs)
  171. steps[step.name] = step
  172. order.append(step)
  173. self._debug('New boot order: {%s}',
  174. ', '.join(s.alias for s in self.order))
  175. for step in order:
  176. step.include(parent)
  177. return self
  178. def connect_with(self, other):
  179. self.graph.adjacent.update(other.graph.adjacent)
  180. self.graph.add_edge(type(other.order[0]), type(self.order[-1]))
  181. def __getitem__(self, name):
  182. return self.steps[name]
  183. def _find_last(self):
  184. return next((C for C in self.steps.values() if C.last), None)
  185. def _firstpass(self, steps):
  186. for step in steps.values():
  187. step.requires = [symbol_by_name(dep) for dep in step.requires]
  188. stream = deque(step.requires for step in steps.values())
  189. while stream:
  190. for node in stream.popleft():
  191. node = symbol_by_name(node)
  192. if node.name not in self.steps:
  193. steps[node.name] = node
  194. stream.append(node.requires)
  195. def _finalize_steps(self, steps):
  196. last = self._find_last()
  197. self._firstpass(steps)
  198. it = ((C, C.requires) for C in steps.values())
  199. G = self.graph = DependencyGraph(
  200. it, formatter=self.GraphFormatter(root=last),
  201. )
  202. if last:
  203. for obj in G:
  204. if obj != last:
  205. G.add_edge(last, obj)
  206. try:
  207. return G.topsort()
  208. except KeyError as exc:
  209. raise KeyError('unknown bootstep: %s' % exc)
  210. def claim_steps(self):
  211. return dict(self.load_step(step) for step in self._all_steps())
  212. def _all_steps(self):
  213. return self.types | self.app.steps[self.name.lower()]
  214. def load_step(self, step):
  215. step = symbol_by_name(step)
  216. return step.name, step
  217. def _debug(self, msg, *args):
  218. return logger.debug(_pre(self, msg), *args)
  219. @property
  220. def alias(self):
  221. return _label(self)
  222. class StepType(type):
  223. """Meta-class for steps."""
  224. def __new__(cls, name, bases, attrs):
  225. module = attrs.get('__module__')
  226. qname = '{0}.{1}'.format(module, name) if module else name
  227. attrs.update(
  228. __qualname__=qname,
  229. name=attrs.get('name') or qname,
  230. )
  231. return super().__new__(cls, name, bases, attrs)
  232. def __str__(self):
  233. return self.name
  234. def __repr__(self):
  235. return 'step:{0.name}{{{0.requires!r}}}'.format(self)
  236. class Step(metaclass=StepType):
  237. """A Bootstep.
  238. The :meth:`__init__` method is called when the step
  239. is bound to a parent object, and can as such be used
  240. to initialize attributes in the parent object at
  241. parent instantiation-time.
  242. """
  243. #: Optional step name, will use ``qualname`` if not specified.
  244. name = None
  245. #: Optional short name used for graph outputs and in logs.
  246. label = None
  247. #: Set this to true if the step is enabled based on some condition.
  248. conditional = False
  249. #: List of other steps that that must be started before this step.
  250. #: Note that all dependencies must be in the same blueprint.
  251. requires = ()
  252. #: This flag is reserved for the workers Consumer,
  253. #: since it is required to always be started last.
  254. #: There can only be one object marked last
  255. #: in every blueprint.
  256. last = False
  257. #: This provides the default for :meth:`include_if`.
  258. enabled = True
  259. def __init__(self, parent, **kwargs):
  260. pass
  261. def include_if(self, parent):
  262. """An optional predicate that decides whether this
  263. step should be created."""
  264. return self.enabled
  265. def instantiate(self, name, *args, **kwargs):
  266. return instantiate(name, *args, **kwargs)
  267. def _should_include(self, parent):
  268. if self.include_if(parent):
  269. return True, self.create(parent)
  270. return False, None
  271. def include(self, parent):
  272. return self._should_include(parent)[0]
  273. def create(self, parent):
  274. """Create the step."""
  275. pass
  276. def __repr__(self):
  277. return '<step: {0.alias}>'.format(self)
  278. @property
  279. def alias(self):
  280. return self.label or _label(self)
  281. def info(self, obj):
  282. pass
  283. class StartStopStep(Step):
  284. #: Optional obj created by the :meth:`create` method.
  285. #: This is used by :class:`StartStopStep` to keep the
  286. #: original service object.
  287. obj = None
  288. def start(self, parent):
  289. if self.obj:
  290. return self.obj.start()
  291. def stop(self, parent):
  292. if self.obj:
  293. return self.obj.stop()
  294. def close(self, parent):
  295. pass
  296. def terminate(self, parent):
  297. if self.obj:
  298. return getattr(self.obj, 'terminate', self.obj.stop)()
  299. def include(self, parent):
  300. inc, ret = self._should_include(parent)
  301. if inc:
  302. self.obj = ret
  303. parent.steps.append(self)
  304. return inc
  305. class ConsumerStep(StartStopStep):
  306. requires = ('celery.worker.consumer:Connection',)
  307. consumers = None
  308. def get_consumers(self, channel):
  309. raise NotImplementedError('missing get_consumers')
  310. def start(self, c):
  311. channel = c.connection.channel()
  312. self.consumers = self.get_consumers(channel)
  313. for consumer in self.consumers or []:
  314. consumer.consume()
  315. def stop(self, c):
  316. self._close(c, True)
  317. def shutdown(self, c):
  318. self._close(c, False)
  319. def _close(self, c, cancel_consumers=True):
  320. channels = set()
  321. for consumer in self.consumers or []:
  322. if cancel_consumers:
  323. ignore_errors(c.connection, consumer.cancel)
  324. if consumer.channel:
  325. channels.add(consumer.channel)
  326. for channel in channels:
  327. ignore_errors(c.connection, channel.close)