|
@@ -12,10 +12,13 @@ from collections import defaultdict
|
|
|
from importlib import import_module
|
|
|
from threading import Event
|
|
|
|
|
|
-from celery.datastructures import DependencyGraph
|
|
|
-from celery.utils.imports import instantiate
|
|
|
-from celery.utils.log import get_logger
|
|
|
-from celery.utils.threads import default_socket_timeout
|
|
|
+from kombu.common import ignore_errors
|
|
|
+from kombu.utils import symbol_by_name
|
|
|
+
|
|
|
+from .datastructures import DependencyGraph
|
|
|
+from .utils.imports import instantiate
|
|
|
+from .utils.log import get_logger
|
|
|
+from .utils.threads import default_socket_timeout
|
|
|
|
|
|
try:
|
|
|
from greenlet import GreenletExit
|
|
@@ -32,10 +35,11 @@ CLOSE = 0x2
|
|
|
TERMINATE = 0x3
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
+debug = logger.debug
|
|
|
|
|
|
|
|
|
-def qualname(c):
|
|
|
- return '.'.join([c.namespace.name, c.name.capitalize()])
|
|
|
+def _pre(ns, fmt):
|
|
|
+ return '| {0}: {1}'.format(ns.name, fmt)
|
|
|
|
|
|
|
|
|
class Namespace(object):
|
|
@@ -54,8 +58,7 @@ class Namespace(object):
|
|
|
name = None
|
|
|
state = None
|
|
|
started = 0
|
|
|
-
|
|
|
- _unclaimed = defaultdict(dict)
|
|
|
+ default_steps = set()
|
|
|
|
|
|
def __init__(self, name=None, app=None, on_start=None,
|
|
|
on_close=None, on_stopped=None):
|
|
@@ -71,13 +74,11 @@ class Namespace(object):
|
|
|
self.state = RUN
|
|
|
if self.on_start:
|
|
|
self.on_start()
|
|
|
- for i, step in enumerate(parent.steps):
|
|
|
- if step:
|
|
|
- logger.debug('Starting %s...', qualname(step))
|
|
|
- self.started = i + 1
|
|
|
- print('STARTING: %r' % (step.start, ))
|
|
|
- step.start(parent)
|
|
|
- logger.debug('%s OK!', qualname(step))
|
|
|
+ for i, step in enumerate(filter(None, parent.steps)):
|
|
|
+ self._debug('Starting %s', step.name)
|
|
|
+ self.started = i + 1
|
|
|
+ step.start(parent)
|
|
|
+ debug('^-- substep ok')
|
|
|
|
|
|
def close(self, parent):
|
|
|
if self.on_close:
|
|
@@ -91,7 +92,7 @@ class Namespace(object):
|
|
|
with default_socket_timeout(SHUTDOWN_SOCKET_TIMEOUT): # Issue 975
|
|
|
for step in reversed(parent.steps):
|
|
|
if step:
|
|
|
- logger.debug('%s %s...', description, qualname(step))
|
|
|
+ self._debug('%s %s...', description, step.name)
|
|
|
fun = getattr(step, attr, None)
|
|
|
if fun:
|
|
|
fun(parent)
|
|
@@ -124,16 +125,6 @@ class Namespace(object):
|
|
|
except IGNORE_ERRORS:
|
|
|
pass
|
|
|
|
|
|
- def modules(self):
|
|
|
- """Subclasses can override this to return a
|
|
|
- list of modules to import before steps are claimed."""
|
|
|
- return []
|
|
|
-
|
|
|
- def load_modules(self):
|
|
|
- """Will load the steps modules this namespace depends on."""
|
|
|
- for m in self.modules():
|
|
|
- self.import_module(m)
|
|
|
-
|
|
|
def apply(self, parent, **kwargs):
|
|
|
"""Apply the steps in this namespace to an object.
|
|
|
|
|
@@ -144,11 +135,9 @@ class Namespace(object):
|
|
|
will also be added the the objects ``steps`` attribute.
|
|
|
|
|
|
"""
|
|
|
- self._debug('Loading modules.')
|
|
|
- self.load_modules()
|
|
|
- self._debug('Claiming steps.')
|
|
|
- self.steps = self._claim()
|
|
|
- self._debug('Building boot step graph.')
|
|
|
+ self._debug('Loading boot-steps.')
|
|
|
+ self.steps = self.claim_steps()
|
|
|
+ self._debug('Building graph.')
|
|
|
self.boot_steps = [self.bind_step(name, parent, **kwargs)
|
|
|
for name in self._finalize_boot_steps()]
|
|
|
self._debug('New boot order: {%s}',
|
|
@@ -183,14 +172,23 @@ class Namespace(object):
|
|
|
for obj in G:
|
|
|
if obj != last.name:
|
|
|
G.add_edge(last.name, obj)
|
|
|
- return G.topsort()
|
|
|
+ try:
|
|
|
+ return G.topsort()
|
|
|
+ except KeyError as exc:
|
|
|
+ raise KeyError('unknown boot-step: %s' % exc)
|
|
|
|
|
|
- def _claim(self):
|
|
|
- return self._unclaimed[self.name]
|
|
|
+ def claim_steps(self):
|
|
|
+ return dict(self.load_step(step) for step in self._unclaimed_steps())
|
|
|
+
|
|
|
+ def _unclaimed_steps(self):
|
|
|
+ return set(self.default_steps) | self.app.steps[self.name]
|
|
|
+
|
|
|
+ def load_step(self, step):
|
|
|
+ step = symbol_by_name(step)
|
|
|
+ return step.name, step
|
|
|
|
|
|
def _debug(self, msg, *args):
|
|
|
- return logger.debug('[%s] ' + msg,
|
|
|
- *(self.name.capitalize(), ) + args)
|
|
|
+ return debug(_pre(self, msg), *args)
|
|
|
|
|
|
|
|
|
def _prepare_requires(req):
|
|
@@ -203,21 +201,12 @@ class StepType(type):
|
|
|
"""Metaclass for steps."""
|
|
|
|
|
|
def __new__(cls, name, bases, attrs):
|
|
|
- abstract = attrs.pop('abstract', False)
|
|
|
- if not abstract:
|
|
|
- try:
|
|
|
- cname = attrs['name']
|
|
|
- except KeyError:
|
|
|
- raise NotImplementedError('Steps must be named')
|
|
|
- namespace = attrs.get('namespace', None)
|
|
|
- if not namespace:
|
|
|
- attrs['namespace'], _, attrs['name'] = cname.partition('.')
|
|
|
+ module = attrs.get('__module__')
|
|
|
+ qname = '.'.join([module, name]) if module else name
|
|
|
+ attrs['name'] = attrs.get('name') or qname
|
|
|
attrs['requires'] = tuple(_prepare_requires(req)
|
|
|
for req in attrs.get('requires', ()))
|
|
|
- cls = super(StepType, cls).__new__(cls, name, bases, attrs)
|
|
|
- if not abstract:
|
|
|
- Namespace._unclaimed[cls.namespace][cls.name] = cls
|
|
|
- return cls
|
|
|
+ return super(StepType, cls).__new__(cls, name, bases, attrs)
|
|
|
|
|
|
|
|
|
class Step(object):
|
|
@@ -273,8 +262,8 @@ class Step(object):
|
|
|
step should be created."""
|
|
|
return self.enabled
|
|
|
|
|
|
- def instantiate(self, qualname, *args, **kwargs):
|
|
|
- return instantiate(qualname, *args, **kwargs)
|
|
|
+ def instantiate(self, name, *args, **kwargs):
|
|
|
+ return instantiate(name, *args, **kwargs)
|
|
|
|
|
|
def include(self, parent):
|
|
|
if self.include_if(parent):
|
|
@@ -302,3 +291,27 @@ class StartStopStep(Step):
|
|
|
def include(self, parent):
|
|
|
if super(StartStopStep, self).include(parent):
|
|
|
parent.steps.append(self)
|
|
|
+
|
|
|
+
|
|
|
+class ConsumerStep(StartStopStep):
|
|
|
+ abstract = True
|
|
|
+ requires = ('Connection', )
|
|
|
+ consumers = None
|
|
|
+
|
|
|
+ def get_consumers(self, channel):
|
|
|
+ raise NotImplementedError('missing get_consumers')
|
|
|
+
|
|
|
+ def start(self, c):
|
|
|
+ self.consumers = self.get_consumers(c.connection)
|
|
|
+ for consumer in self.consumers or []:
|
|
|
+ consumer.consume()
|
|
|
+
|
|
|
+ def stop(self, c):
|
|
|
+ for consumer in self.consumers or []:
|
|
|
+ ignore_errors(c.connection, consumer.cancel)
|
|
|
+
|
|
|
+ def shutdown(self, c):
|
|
|
+ self.stop(c)
|
|
|
+ for consumer in self.consumers or []:
|
|
|
+ if consumer.channel:
|
|
|
+ ignore_errors(c.connection, consumer.channel.close)
|