|
- from __future__ import absolute_import, print_function, unicode_literals
- import os
- import socket
- import sys
- from collections import deque
- from datetime import datetime, timedelta
- from functools import partial
- from threading import Event
- import pytest
- from amqp import ChannelError
- from case import Mock, patch, skip
- from kombu import Connection
- from kombu.common import QoS, ignore_errors
- from kombu.transport.base import Message
- from kombu.transport.memory import Transport
- from kombu.utils.uuid import uuid
- from celery.bootsteps import CLOSE, RUN, TERMINATE, StartStopStep
- from celery.concurrency.base import BasePool
- from celery.exceptions import (ImproperlyConfigured, InvalidTaskError,
- TaskRevokedError, WorkerShutdown,
- WorkerTerminate)
- from celery.five import Empty
- from celery.five import Queue as FastQueue
- from celery.five import range
- from celery.platforms import EX_FAILURE
- from celery.utils.nodenames import worker_direct
- from celery.utils.serialization import pickle
- from celery.utils.timer2 import Timer
- from celery.worker import components, consumer, state
- from celery.worker import worker as worker_module
- from celery.worker.consumer import Consumer
- from celery.worker.pidbox import gPidbox
- from celery.worker.request import Request
- def MockStep(step=None):
- if step is None:
- step = Mock(name='step')
- else:
- step.blueprint = Mock(name='step.blueprint')
- step.blueprint.name = 'MockNS'
- step.name = 'MockStep(%s)' % (id(step),)
- return step
- def mock_event_dispatcher():
- evd = Mock(name='event_dispatcher')
- evd.groups = ['worker']
- evd._outbound_buffer = deque()
- return evd
- def find_step(obj, typ):
- return obj.blueprint.steps[typ.name]
- def create_message(channel, **data):
- data.setdefault('id', uuid())
- m = Message(body=pickle.dumps(dict(**data)),
- channel=channel,
- content_type='application/x-python-serialize',
- content_encoding='binary',
- delivery_info={'consumer_tag': 'mock'})
- m.accept = ['application/x-python-serialize']
- return m
- class ConsumerCase:
- def create_task_message(self, channel, *args, **kwargs):
- m = self.TaskMessage(*args, **kwargs)
- m.channel = channel
- m.delivery_info = {'consumer_tag': 'mock'}
- return m
- class test_Consumer(ConsumerCase):
- def setup(self):
- self.buffer = FastQueue()
- self.timer = Timer()
- @self.app.task(shared=False)
- def foo_task(x, y, z):
- return x * y * z
- self.foo_task = foo_task
- def teardown(self):
- self.timer.stop()
- def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
- without_mingle=True, without_gossip=True,
- without_heartbeat=True, **kwargs):
- if controller is None:
- controller = Mock(name='.controller')
- buffer = buffer if buffer is not None else self.buffer.put
- timer = timer if timer is not None else self.timer
- app = app if app is not None else self.app
- c = Consumer(
- buffer,
- timer=timer,
- app=app,
- controller=controller,
- without_mingle=without_mingle,
- without_gossip=without_gossip,
- without_heartbeat=without_heartbeat,
- **kwargs
- )
- c.task_consumer = Mock(name='.task_consumer')
- c.qos = QoS(c.task_consumer.qos, 10)
- c.connection = Mock(name='.connection')
- c.controller = c.app.WorkController()
- c.heart = Mock(name='.heart')
- c.controller.consumer = c
- c.pool = c.controller.pool = Mock(name='.controller.pool')
- c.node = Mock(name='.node')
- c.event_dispatcher = mock_event_dispatcher()
- return c
- def NoopConsumer(self, *args, **kwargs):
- c = self.LoopConsumer(*args, **kwargs)
- c.loop = Mock(name='.loop')
- return c
- def test_info(self):
- c = self.NoopConsumer()
- c.connection.info.return_value = {'foo': 'bar'}
- c.controller.pool.info.return_value = [Mock(), Mock()]
- info = c.controller.stats()
- assert info['prefetch_count'] == 10
- assert info['broker']
- def test_start_when_closed(self):
- c = self.NoopConsumer()
- c.blueprint.state = CLOSE
- c.start()
- def test_connection(self):
- c = self.NoopConsumer()
- c.blueprint.start(c)
- assert isinstance(c.connection, Connection)
- c.blueprint.state = RUN
- c.event_dispatcher = None
- c.blueprint.restart(c)
- assert c.connection
- c.blueprint.state = RUN
- c.shutdown()
- assert c.connection is None
- assert c.task_consumer is None
- c.blueprint.start(c)
- assert isinstance(c.connection, Connection)
- c.blueprint.restart(c)
- c.stop()
- c.shutdown()
- assert c.connection is None
- assert c.task_consumer is None
- def test_close_connection(self):
- c = self.NoopConsumer()
- c.blueprint.state = RUN
- step = find_step(c, consumer.Connection)
- connection = c.connection
- step.shutdown(c)
- connection.close.assert_called()
- assert c.connection is None
- def test_close_connection__heart_shutdown(self):
- c = self.NoopConsumer()
- event_dispatcher = c.event_dispatcher
- heart = c.heart
- c.event_dispatcher.enabled = True
- c.blueprint.state = RUN
- Events = find_step(c, consumer.Events)
- Events.shutdown(c)
- Heart = find_step(c, consumer.Heart)
- Heart.shutdown(c)
- event_dispatcher.close.assert_called()
- heart.stop.assert_called_with()
- @patch('celery.worker.consumer.consumer.warn')
- def test_receive_message_unknown(self, warn):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- channel = Mock(name='.channeol')
- m = create_message(channel, unknown={'baz': '!!!'})
- callback = self._get_on_message(c)
- callback(m)
- warn.assert_called()
- @patch('celery.worker.strategy.to_timestamp')
- def test_receive_message_eta_OverflowError(self, to_timestamp):
- to_timestamp.side_effect = OverflowError()
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- m = self.create_task_message(
- Mock(), self.foo_task.name,
- args=('2, 2'), kwargs={},
- eta=datetime.now().isoformat(),
- )
- c.update_strategies()
- callback = self._get_on_message(c)
- callback(m)
- assert m.acknowledged
- @patch('celery.worker.consumer.consumer.error')
- def test_receive_message_InvalidTaskError(self, error):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- m = self.create_task_message(
- Mock(), self.foo_task.name,
- args=(1, 2), kwargs='foobarbaz', id=1)
- c.update_strategies()
- strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
- strat.side_effect = InvalidTaskError()
- callback = self._get_on_message(c)
- callback(m)
- error.assert_called()
- assert 'Received invalid task message' in error.call_args[0][0]
- @patch('celery.worker.consumer.consumer.crit')
- def test_on_decode_error(self, crit):
- c = self.LoopConsumer()
- class MockMessage(Mock):
- content_type = 'application/x-msgpack'
- content_encoding = 'binary'
- body = 'foobarbaz'
- message = MockMessage()
- c.on_decode_error(message, KeyError('foo'))
- assert message.ack.call_count
- assert "Can't decode message body" in crit.call_args[0][0]
- def _get_on_message(self, c):
- if c.qos is None:
- c.qos = Mock()
- c.task_consumer = Mock()
- c.event_dispatcher = mock_event_dispatcher()
- c.connection = Mock(name='.connection')
- c.connection.get_heartbeat_interval.return_value = 0
- c.connection.drain_events.side_effect = WorkerShutdown()
- with pytest.raises(WorkerShutdown):
- c.loop(*c.loop_args())
- assert c.task_consumer.on_message
- return c.task_consumer.on_message
- def test_receieve_message(self):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- m = self.create_task_message(
- Mock(), self.foo_task.name,
- args=[2, 4, 8], kwargs={},
- )
- c.update_strategies()
- callback = self._get_on_message(c)
- callback(m)
- in_bucket = self.buffer.get_nowait()
- assert isinstance(in_bucket, Request)
- assert in_bucket.name == self.foo_task.name
- assert in_bucket.execute() == 2 * 4 * 8
- assert self.timer.empty()
- def test_start_channel_error(self):
- c = self.NoopConsumer(task_events=False, pool=BasePool())
- c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
- c.channel_errors = (KeyError,)
- try:
- with pytest.raises(KeyError):
- c.start()
- finally:
- c.timer and c.timer.stop()
- def test_start_connection_error(self):
- c = self.NoopConsumer(task_events=False, pool=BasePool())
- c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
- c.connection_errors = (KeyError,)
- try:
- with pytest.raises(SyntaxError):
- c.start()
- finally:
- c.timer and c.timer.stop()
- def test_loop_ignores_socket_timeout(self):
- class Connection(self.app.connection_for_read().__class__):
- obj = None
- def drain_events(self, **kwargs):
- self.obj.connection = None
- raise socket.timeout(10)
- c = self.NoopConsumer()
- c.connection = Connection(self.app.conf.broker_url)
- c.connection.obj = c
- c.qos = QoS(c.task_consumer.qos, 10)
- c.loop(*c.loop_args())
- def test_loop_when_socket_error(self):
- class Connection(self.app.connection_for_read().__class__):
- obj = None
- def drain_events(self, **kwargs):
- self.obj.connection = None
- raise socket.error('foo')
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- conn = c.connection = Connection(self.app.conf.broker_url)
- c.connection.obj = c
- c.qos = QoS(c.task_consumer.qos, 10)
- with pytest.raises(socket.error):
- c.loop(*c.loop_args())
- c.blueprint.state = CLOSE
- c.connection = conn
- c.loop(*c.loop_args())
- def test_loop(self):
- class Connection(self.app.connection_for_read().__class__):
- obj = None
- def drain_events(self, **kwargs):
- self.obj.connection = None
- @property
- def supports_heartbeats(self):
- return False
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.connection = Connection(self.app.conf.broker_url)
- c.connection.obj = c
- c.connection.get_heartbeat_interval = Mock(return_value=None)
- c.qos = QoS(c.task_consumer.qos, 10)
- c.loop(*c.loop_args())
- c.loop(*c.loop_args())
- assert c.task_consumer.consume.call_count
- c.task_consumer.qos.assert_called_with(prefetch_count=10)
- assert c.qos.value == 10
- c.qos.decrement_eventually()
- assert c.qos.value == 9
- c.qos.update()
- assert c.qos.value == 9
- c.task_consumer.qos.assert_called_with(prefetch_count=9)
- def test_ignore_errors(self):
- c = self.NoopConsumer()
- c.connection_errors = (AttributeError, KeyError,)
- c.channel_errors = (SyntaxError,)
- ignore_errors(c, Mock(side_effect=AttributeError('foo')))
- ignore_errors(c, Mock(side_effect=KeyError('foo')))
- ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
- with pytest.raises(IndexError):
- ignore_errors(c, Mock(side_effect=IndexError('foo')))
- def test_apply_eta_task(self):
- c = self.NoopConsumer()
- c.qos = QoS(None, 10)
- task = Mock(name='task', id='1234213')
- qos = c.qos.value
- c.apply_eta_task(task)
- assert task in state.reserved_requests
- assert c.qos.value == qos - 1
- assert self.buffer.get_nowait() is task
- def test_receieve_message_eta_isoformat(self):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- m = self.create_task_message(
- Mock(), self.foo_task.name,
- eta=(datetime.now() + timedelta(days=1)).isoformat(),
- args=[2, 4, 8], kwargs={},
- )
- c.qos = QoS(c.task_consumer.qos, 1)
- current_pcount = c.qos.value
- c.event_dispatcher.enabled = False
- c.update_strategies()
- callback = self._get_on_message(c)
- callback(m)
- c.timer.stop()
- c.timer.join(1)
- items = [entry[2] for entry in self.timer.queue]
- found = 0
- for item in items:
- if item.args[0].name == self.foo_task.name:
- found = True
- assert found
- assert c.qos.value > current_pcount
- c.timer.stop()
- def test_pidbox_callback(self):
- c = self.NoopConsumer()
- con = find_step(c, consumer.Control).box
- con.node = Mock()
- con.reset = Mock()
- con.on_message('foo', 'bar')
- con.node.handle_message.assert_called_with('foo', 'bar')
- con.node = Mock()
- con.node.handle_message.side_effect = KeyError('foo')
- con.on_message('foo', 'bar')
- con.node.handle_message.assert_called_with('foo', 'bar')
- con.node = Mock()
- con.node.handle_message.side_effect = ValueError('foo')
- con.on_message('foo', 'bar')
- con.node.handle_message.assert_called_with('foo', 'bar')
- con.reset.assert_called()
- def test_revoke(self):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- channel = Mock(name='channel')
- id = uuid()
- t = self.create_task_message(
- channel, self.foo_task.name,
- args=[2, 4, 8], kwargs={}, id=id,
- )
- state.revoked.add(id)
- callback = self._get_on_message(c)
- callback(t)
- assert self.buffer.empty()
- def test_receieve_message_not_registered(self):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- c.steps.pop()
- channel = Mock(name='channel')
- m = self.create_task_message(
- channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
- )
- callback = self._get_on_message(c)
- assert not callback(m)
- with pytest.raises(Empty):
- self.buffer.get_nowait()
- assert self.timer.empty()
- @patch('celery.worker.consumer.consumer.warn')
- @patch('celery.worker.consumer.consumer.logger')
- def test_receieve_message_ack_raises(self, logger, warn):
- c = self.LoopConsumer()
- c.blueprint.state = RUN
- channel = Mock(name='channel')
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[2, 4, 8], kwargs={},
- )
- m.headers = None
- c.update_strategies()
- c.connection_errors = (socket.error,)
- m.reject = Mock()
- m.reject.side_effect = socket.error('foo')
- callback = self._get_on_message(c)
- assert not callback(m)
- warn.assert_called()
- with pytest.raises(Empty):
- self.buffer.get_nowait()
- assert self.timer.empty()
- m.reject_log_error.assert_called_with(logger, c.connection_errors)
- def test_receive_message_eta(self):
- if os.environ.get('C_DEBUG_TEST'):
- pp = partial(print, file=sys.__stderr__)
- else:
- def pp(*args, **kwargs):
- pass
- pp('TEST RECEIVE MESSAGE ETA')
- pp('+CREATE MYKOMBUCONSUMER')
- c = self.LoopConsumer()
- pp('-CREATE MYKOMBUCONSUMER')
- c.steps.pop()
- channel = Mock(name='channel')
- pp('+ CREATE MESSAGE')
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[2, 4, 8], kwargs={},
- eta=(datetime.now() + timedelta(days=1)).isoformat(),
- )
- pp('- CREATE MESSAGE')
- try:
- pp('+ BLUEPRINT START 1')
- c.blueprint.start(c)
- pp('- BLUEPRINT START 1')
- p = c.app.conf.broker_connection_retry
- c.app.conf.broker_connection_retry = False
- pp('+ BLUEPRINT START 2')
- c.blueprint.start(c)
- pp('- BLUEPRINT START 2')
- c.app.conf.broker_connection_retry = p
- pp('+ BLUEPRINT RESTART')
- c.blueprint.restart(c)
- pp('- BLUEPRINT RESTART')
- pp('+ GET ON MESSAGE')
- callback = self._get_on_message(c)
- pp('- GET ON MESSAGE')
- pp('+ CALLBACK')
- callback(m)
- pp('- CALLBACK')
- finally:
- pp('+ STOP TIMER')
- c.timer.stop()
- pp('- STOP TIMER')
- try:
- pp('+ JOIN TIMER')
- c.timer.join()
- pp('- JOIN TIMER')
- except RuntimeError:
- pass
- in_hold = c.timer.queue[0]
- assert len(in_hold) == 3
- eta, priority, entry = in_hold
- task = entry.args[0]
- assert isinstance(task, Request)
- assert task.name == self.foo_task.name
- assert task.execute() == 2 * 4 * 8
- with pytest.raises(Empty):
- self.buffer.get_nowait()
- def test_reset_pidbox_node(self):
- c = self.NoopConsumer()
- con = find_step(c, consumer.Control).box
- con.node = Mock()
- chan = con.node.channel = Mock()
- chan.close.side_effect = socket.error('foo')
- c.connection_errors = (socket.error,)
- con.reset()
- chan.close.assert_called_with()
- def test_reset_pidbox_node_green(self):
- c = self.NoopConsumer(pool=Mock(is_green=True))
- con = find_step(c, consumer.Control)
- assert isinstance(con.box, gPidbox)
- con.start(c)
- c.pool.spawn_n.assert_called_with(con.box.loop, c)
- def test_green_pidbox_node(self):
- pool = Mock()
- pool.is_green = True
- c = self.NoopConsumer(pool=Mock(is_green=True))
- controller = find_step(c, consumer.Control)
- class BConsumer(Mock):
- def __enter__(self):
- self.consume()
- return self
- def __exit__(self, *exc_info):
- self.cancel()
- controller.box.node.listen = BConsumer()
- connections = []
- class Connection(object):
- calls = 0
- def __init__(self, obj):
- connections.append(self)
- self.obj = obj
- self.default_channel = self.channel()
- self.closed = False
- def __enter__(self):
- return self
- def __exit__(self, *exc_info):
- self.close()
- def channel(self):
- return Mock()
- def as_uri(self):
- return 'dummy://'
- def drain_events(self, **kwargs):
- if not self.calls:
- self.calls += 1
- raise socket.timeout()
- self.obj.connection = None
- controller.box._node_shutdown.set()
- def close(self):
- self.closed = True
- c.connection_for_read = lambda: Connection(obj=c)
- controller = find_step(c, consumer.Control)
- controller.box.loop(c)
- controller.box.node.listen.assert_called()
- assert controller.box.consumer
- controller.box.consumer.consume.assert_called_with()
- assert c.connection is None
- assert connections[0].closed
- @patch('kombu.connection.Connection._establish_connection')
- @patch('kombu.utils.functional.sleep')
- def test_connect_errback(self, sleep, connect):
- c = self.NoopConsumer()
- Transport.connection_errors = (ChannelError,)
- connect.on_nth_call_do(ChannelError('error'), n=1)
- c.connect()
- connect.assert_called_with()
- def test_stop_pidbox_node(self):
- c = self.NoopConsumer()
- cont = find_step(c, consumer.Control)
- cont._node_stopped = Event()
- cont._node_shutdown = Event()
- cont._node_stopped.set()
- cont.stop(c)
- def test_start__loop(self):
- class _QoS(object):
- prev = 3
- value = 4
- def update(self):
- self.prev = self.value
- init_callback = Mock(name='init_callback')
- c = self.NoopConsumer(init_callback=init_callback)
- c.qos = _QoS()
- c.connection = Connection(self.app.conf.broker_url)
- c.connection.get_heartbeat_interval = Mock(return_value=None)
- c.iterations = 0
- def raises_KeyError(*args, **kwargs):
- c.iterations += 1
- if c.qos.prev != c.qos.value:
- c.qos.update()
- if c.iterations >= 2:
- raise KeyError('foo')
- c.loop = raises_KeyError
- with pytest.raises(KeyError):
- c.start()
- assert c.iterations == 2
- assert c.qos.prev == c.qos.value
- init_callback.reset_mock()
- c = self.NoopConsumer(task_events=False, init_callback=init_callback)
- c.qos = _QoS()
- c.connection = Connection(self.app.conf.broker_url)
- c.connection.get_heartbeat_interval = Mock(return_value=None)
- c.loop = Mock(side_effect=socket.error('foo'))
- with pytest.raises(socket.error):
- c.start()
- c.loop.assert_called()
- def test_reset_connection_with_no_node(self):
- c = self.NoopConsumer()
- c.steps.pop()
- c.blueprint.start(c)
- class test_WorkController(ConsumerCase):
- def setup(self):
- self.worker = self.create_worker()
- self._logger = worker_module.logger
- self._comp_logger = components.logger
- self.logger = worker_module.logger = Mock()
- self.comp_logger = components.logger = Mock()
- @self.app.task(shared=False)
- def foo_task(x, y, z):
- return x * y * z
- self.foo_task = foo_task
- def teardown(self):
- worker_module.logger = self._logger
- components.logger = self._comp_logger
- def create_worker(self, **kw):
- worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
- worker.blueprint.shutdown_complete.set()
- return worker
- def test_on_consumer_ready(self):
- self.worker.on_consumer_ready(Mock())
- def test_setup_queues_worker_direct(self):
- self.app.conf.worker_direct = True
- self.app.amqp.__dict__['queues'] = Mock()
- self.worker.setup_queues({})
- self.app.amqp.queues.select_add.assert_called_with(
- worker_direct(self.worker.hostname),
- )
- def test_setup_queues__missing_queue(self):
- self.app.amqp.queues.select = Mock(name='select')
- self.app.amqp.queues.deselect = Mock(name='deselect')
- self.app.amqp.queues.select.side_effect = KeyError()
- self.app.amqp.queues.deselect.side_effect = KeyError()
- with pytest.raises(ImproperlyConfigured):
- self.worker.setup_queues('x,y', exclude='foo,bar')
- self.app.amqp.queues.select = Mock(name='select')
- with pytest.raises(ImproperlyConfigured):
- self.worker.setup_queues('x,y', exclude='foo,bar')
- def test_send_worker_shutdown(self):
- with patch('celery.signals.worker_shutdown') as ws:
- self.worker._send_worker_shutdown()
- ws.send.assert_called_with(sender=self.worker)
- @skip.todo('unstable test')
- def test_process_shutdown_on_worker_shutdown(self):
- from celery.concurrency.prefork import process_destructor
- from celery.concurrency.asynpool import Worker
- with patch('celery.signals.worker_process_shutdown') as ws:
- with patch('os._exit') as _exit:
- worker = Worker(None, None, on_exit=process_destructor)
- worker._do_exit(22, 3.1415926)
- ws.send.assert_called_with(
- sender=None, pid=22, exitcode=3.1415926,
- )
- _exit.assert_called_with(3.1415926)
- def test_process_task_revoked_release_semaphore(self):
- self.worker._quick_release = Mock()
- req = Mock()
- req.execute_using_pool.side_effect = TaskRevokedError
- self.worker._process_task(req)
- self.worker._quick_release.assert_called_with()
- delattr(self.worker, '_quick_release')
- self.worker._process_task(req)
- def test_shutdown_no_blueprint(self):
- self.worker.blueprint = None
- self.worker._shutdown()
- @patch('celery.worker.worker.create_pidlock')
- def test_use_pidfile(self, create_pidlock):
- create_pidlock.return_value = Mock()
- worker = self.create_worker(pidfile='pidfilelockfilepid')
- worker.steps = []
- worker.start()
- create_pidlock.assert_called()
- worker.stop()
- worker.pidlock.release.assert_called()
- def test_attrs(self):
- worker = self.worker
- assert worker.timer is not None
- assert isinstance(worker.timer, Timer)
- assert worker.pool is not None
- assert worker.consumer is not None
- assert worker.steps
- def test_with_embedded_beat(self):
- worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
- assert worker.beat
- assert worker.beat in [w.obj for w in worker.steps]
- def test_with_autoscaler(self):
- worker = self.create_worker(
- autoscale=[10, 3], send_events=False,
- timer_cls='celery.utils.timer2.Timer',
- )
- assert worker.autoscaler
- def test_dont_stop_or_terminate(self):
- worker = self.app.WorkController(concurrency=1, loglevel=0)
- worker.stop()
- assert worker.blueprint.state != CLOSE
- worker.terminate()
- assert worker.blueprint.state != CLOSE
- sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
- try:
- worker.blueprint.state = RUN
- worker.stop(in_sighandler=True)
- assert worker.blueprint.state != CLOSE
- worker.terminate(in_sighandler=True)
- assert worker.blueprint.state != CLOSE
- finally:
- worker.pool.signal_safe = sigsafe
- def test_on_timer_error(self):
- worker = self.app.WorkController(concurrency=1, loglevel=0)
- try:
- raise KeyError('foo')
- except KeyError as exc:
- components.Timer(worker).on_timer_error(exc)
- msg, args = self.comp_logger.error.call_args[0]
- assert 'KeyError' in msg % args
- def test_on_timer_tick(self):
- worker = self.app.WorkController(concurrency=1, loglevel=10)
- components.Timer(worker).on_timer_tick(30.0)
- xargs = self.comp_logger.debug.call_args[0]
- fmt, arg = xargs[0], xargs[1]
- assert arg == 30.0
- assert 'Next ETA %s secs' in fmt
- def test_process_task(self):
- worker = self.worker
- worker.pool = Mock()
- channel = Mock()
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[4, 8, 10], kwargs={},
- )
- task = Request(m, app=self.app)
- worker._process_task(task)
- assert worker.pool.apply_async.call_count == 1
- worker.pool.stop()
- def test_process_task_raise_base(self):
- worker = self.worker
- worker.pool = Mock()
- worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
- channel = Mock()
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[4, 8, 10], kwargs={},
- )
- task = Request(m, app=self.app)
- worker.steps = []
- worker.blueprint.state = RUN
- with pytest.raises(KeyboardInterrupt):
- worker._process_task(task)
- def test_process_task_raise_WorkerTerminate(self):
- worker = self.worker
- worker.pool = Mock()
- worker.pool.apply_async.side_effect = WorkerTerminate()
- channel = Mock()
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[4, 8, 10], kwargs={},
- )
- task = Request(m, app=self.app)
- worker.steps = []
- worker.blueprint.state = RUN
- with pytest.raises(SystemExit):
- worker._process_task(task)
- def test_process_task_raise_regular(self):
- worker = self.worker
- worker.pool = Mock()
- worker.pool.apply_async.side_effect = KeyError('some exception')
- channel = Mock()
- m = self.create_task_message(
- channel, self.foo_task.name,
- args=[4, 8, 10], kwargs={},
- )
- task = Request(m, app=self.app)
- with pytest.raises(KeyError):
- worker._process_task(task)
- worker.pool.stop()
- def test_start_catches_base_exceptions(self):
- worker1 = self.create_worker()
- worker1.blueprint.state = RUN
- stc = MockStep()
- stc.start.side_effect = WorkerTerminate()
- worker1.steps = [stc]
- worker1.start()
- stc.start.assert_called_with(worker1)
- assert stc.terminate.call_count
- worker2 = self.create_worker()
- worker2.blueprint.state = RUN
- sec = MockStep()
- sec.start.side_effect = WorkerShutdown()
- sec.terminate = None
- worker2.steps = [sec]
- worker2.start()
- assert sec.stop.call_count
- def test_statedb(self):
- from celery.worker import state
- Persistent = state.Persistent
- state.Persistent = Mock()
- try:
- worker = self.create_worker(statedb='statefilename')
- assert worker._persistence
- finally:
- state.Persistent = Persistent
- def test_process_task_sem(self):
- worker = self.worker
- worker._quick_acquire = Mock()
- req = Mock()
- worker._process_task_sem(req)
- worker._quick_acquire.assert_called_with(worker._process_task, req)
- def test_signal_consumer_close(self):
- worker = self.worker
- worker.consumer = Mock()
- worker.signal_consumer_close()
- worker.consumer.close.assert_called_with()
- worker.consumer.close.side_effect = AttributeError()
- worker.signal_consumer_close()
- def test_rusage__no_resource(self):
- from celery.worker import worker
- prev, worker.resource = worker.resource, None
- try:
- self.worker.pool = Mock(name='pool')
- with pytest.raises(NotImplementedError):
- self.worker.rusage()
- self.worker.stats()
- finally:
- worker.resource = prev
- def test_repr(self):
- assert repr(self.worker)
- def test_str(self):
- assert str(self.worker) == self.worker.hostname
- def test_start__stop(self):
- worker = self.worker
- worker.blueprint.shutdown_complete.set()
- worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
- worker.blueprint.state = RUN
- worker.blueprint.started = 4
- for w in worker.steps:
- w.start = Mock()
- w.close = Mock()
- w.stop = Mock()
- worker.start()
- for w in worker.steps:
- w.start.assert_called()
- worker.consumer = Mock()
- worker.stop(exitcode=3)
- for stopstep in worker.steps:
- stopstep.close.assert_called()
- stopstep.stop.assert_called()
- # Doesn't close pool if no pool.
- worker.start()
- worker.pool = None
- worker.stop()
- # test that stop of None is not attempted
- worker.steps[-1] = None
- worker.start()
- worker.stop()
- def test_start__KeyboardInterrupt(self):
- worker = self.worker
- worker.blueprint = Mock(name='blueprint')
- worker.blueprint.start.side_effect = KeyboardInterrupt()
- worker.stop = Mock(name='stop')
- worker.start()
- worker.stop.assert_called_with(exitcode=EX_FAILURE)
- def test_register_with_event_loop(self):
- worker = self.worker
- hub = Mock(name='hub')
- worker.blueprint = Mock(name='blueprint')
- worker.register_with_event_loop(hub)
- worker.blueprint.send_all.assert_called_with(
- worker, 'register_with_event_loop', args=(hub,),
- description='hub.register',
- )
- def test_step_raises(self):
- worker = self.worker
- step = Mock()
- worker.steps = [step]
- step.start.side_effect = TypeError()
- worker.stop = Mock()
- worker.start()
- worker.stop.assert_called_with(exitcode=EX_FAILURE)
- def test_state(self):
- assert self.worker.state
- def test_start__terminate(self):
- worker = self.worker
- worker.blueprint.shutdown_complete.set()
- worker.blueprint.started = 5
- worker.blueprint.state = RUN
- worker.steps = [MockStep() for _ in range(5)]
- worker.start()
- for w in worker.steps[:3]:
- w.start.assert_called()
- assert worker.blueprint.started == len(worker.steps)
- assert worker.blueprint.state == RUN
- worker.terminate()
- for step in worker.steps:
- step.terminate.assert_called()
- worker.blueprint.state = TERMINATE
- worker.terminate()
- def test_Hub_create(self):
- w = Mock()
- x = components.Hub(w)
- x.create(w)
- assert w.timer.max_interval
- def test_Pool_create_threaded(self):
- w = Mock()
- w._conninfo.connection_errors = w._conninfo.channel_errors = ()
- w.pool_cls = Mock()
- w.use_eventloop = False
- pool = components.Pool(w)
- pool.create(w)
- def test_Pool_pool_no_sem(self):
- w = Mock()
- w.pool_cls.uses_semaphore = False
- components.Pool(w).create(w)
- assert w.process_task is w._process_task
- def test_Pool_create(self):
- from kombu.asynchronous.semaphore import LaxBoundedSemaphore
- w = Mock()
- w._conninfo.connection_errors = w._conninfo.channel_errors = ()
- w.hub = Mock()
- PoolImp = Mock()
- poolimp = PoolImp.return_value = Mock()
- poolimp._pool = [Mock(), Mock()]
- poolimp._cache = {}
- poolimp._fileno_to_inq = {}
- poolimp._fileno_to_outq = {}
- from celery.concurrency.prefork import TaskPool as _TaskPool
- class MockTaskPool(_TaskPool):
- Pool = PoolImp
- @property
- def timers(self):
- return {Mock(): 30}
- w.pool_cls = MockTaskPool
- w.use_eventloop = True
- w.consumer.restart_count = -1
- pool = components.Pool(w)
- pool.create(w)
- pool.register_with_event_loop(w, w.hub)
- if sys.platform != 'win32':
- assert isinstance(w.semaphore, LaxBoundedSemaphore)
- P = w.pool
- P.start()
|