| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080 | 
							- from __future__ import absolute_import, print_function, unicode_literals
 
- import os
 
- import socket
 
- import sys
 
- from collections import deque
 
- from datetime import datetime, timedelta
 
- from functools import partial
 
- from threading import Event
 
- import pytest
 
- from amqp import ChannelError
 
- from case import Mock, patch, skip
 
- from kombu import Connection
 
- from kombu.common import QoS, ignore_errors
 
- from kombu.transport.base import Message
 
- from kombu.transport.memory import Transport
 
- from kombu.utils.uuid import uuid
 
- from celery.bootsteps import CLOSE, RUN, TERMINATE, StartStopStep
 
- from celery.concurrency.base import BasePool
 
- from celery.exceptions import (ImproperlyConfigured, InvalidTaskError,
 
-                                TaskRevokedError, WorkerShutdown,
 
-                                WorkerTerminate)
 
- from celery.five import Empty
 
- from celery.five import Queue as FastQueue
 
- from celery.five import range
 
- from celery.platforms import EX_FAILURE
 
- from celery.utils.nodenames import worker_direct
 
- from celery.utils.serialization import pickle
 
- from celery.utils.timer2 import Timer
 
- from celery.worker import components, consumer, state
 
- from celery.worker import worker as worker_module
 
- from celery.worker.consumer import Consumer
 
- from celery.worker.pidbox import gPidbox
 
- from celery.worker.request import Request
 
- def MockStep(step=None):
 
-     if step is None:
 
-         step = Mock(name='step')
 
-     else:
 
-         step.blueprint = Mock(name='step.blueprint')
 
-     step.blueprint.name = 'MockNS'
 
-     step.name = 'MockStep(%s)' % (id(step),)
 
-     return step
 
- def mock_event_dispatcher():
 
-     evd = Mock(name='event_dispatcher')
 
-     evd.groups = ['worker']
 
-     evd._outbound_buffer = deque()
 
-     return evd
 
- def find_step(obj, typ):
 
-     return obj.blueprint.steps[typ.name]
 
- def create_message(channel, **data):
 
-     data.setdefault('id', uuid())
 
-     m = Message(body=pickle.dumps(dict(**data)),
 
-                 channel=channel,
 
-                 content_type='application/x-python-serialize',
 
-                 content_encoding='binary',
 
-                 delivery_info={'consumer_tag': 'mock'})
 
-     m.accept = ['application/x-python-serialize']
 
-     return m
 
- class ConsumerCase:
 
-     def create_task_message(self, channel, *args, **kwargs):
 
-         m = self.TaskMessage(*args, **kwargs)
 
-         m.channel = channel
 
-         m.delivery_info = {'consumer_tag': 'mock'}
 
-         return m
 
- class test_Consumer(ConsumerCase):
 
-     def setup(self):
 
-         self.buffer = FastQueue()
 
-         self.timer = Timer()
 
-         @self.app.task(shared=False)
 
-         def foo_task(x, y, z):
 
-             return x * y * z
 
-         self.foo_task = foo_task
 
-     def teardown(self):
 
-         self.timer.stop()
 
-     def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None,
 
-                      without_mingle=True, without_gossip=True,
 
-                      without_heartbeat=True, **kwargs):
 
-         if controller is None:
 
-             controller = Mock(name='.controller')
 
-         buffer = buffer if buffer is not None else self.buffer.put
 
-         timer = timer if timer is not None else self.timer
 
-         app = app if app is not None else self.app
 
-         c = Consumer(
 
-             buffer,
 
-             timer=timer,
 
-             app=app,
 
-             controller=controller,
 
-             without_mingle=without_mingle,
 
-             without_gossip=without_gossip,
 
-             without_heartbeat=without_heartbeat,
 
-             **kwargs
 
-         )
 
-         c.task_consumer = Mock(name='.task_consumer')
 
-         c.qos = QoS(c.task_consumer.qos, 10)
 
-         c.connection = Mock(name='.connection')
 
-         c.controller = c.app.WorkController()
 
-         c.heart = Mock(name='.heart')
 
-         c.controller.consumer = c
 
-         c.pool = c.controller.pool = Mock(name='.controller.pool')
 
-         c.node = Mock(name='.node')
 
-         c.event_dispatcher = mock_event_dispatcher()
 
-         return c
 
-     def NoopConsumer(self, *args, **kwargs):
 
-         c = self.LoopConsumer(*args, **kwargs)
 
-         c.loop = Mock(name='.loop')
 
-         return c
 
-     def test_info(self):
 
-         c = self.NoopConsumer()
 
-         c.connection.info.return_value = {'foo': 'bar'}
 
-         c.controller.pool.info.return_value = [Mock(), Mock()]
 
-         info = c.controller.stats()
 
-         assert info['prefetch_count'] == 10
 
-         assert info['broker']
 
-     def test_start_when_closed(self):
 
-         c = self.NoopConsumer()
 
-         c.blueprint.state = CLOSE
 
-         c.start()
 
-     def test_connection(self):
 
-         c = self.NoopConsumer()
 
-         c.blueprint.start(c)
 
-         assert isinstance(c.connection, Connection)
 
-         c.blueprint.state = RUN
 
-         c.event_dispatcher = None
 
-         c.blueprint.restart(c)
 
-         assert c.connection
 
-         c.blueprint.state = RUN
 
-         c.shutdown()
 
-         assert c.connection is None
 
-         assert c.task_consumer is None
 
-         c.blueprint.start(c)
 
-         assert isinstance(c.connection, Connection)
 
-         c.blueprint.restart(c)
 
-         c.stop()
 
-         c.shutdown()
 
-         assert c.connection is None
 
-         assert c.task_consumer is None
 
-     def test_close_connection(self):
 
-         c = self.NoopConsumer()
 
-         c.blueprint.state = RUN
 
-         step = find_step(c, consumer.Connection)
 
-         connection = c.connection
 
-         step.shutdown(c)
 
-         connection.close.assert_called()
 
-         assert c.connection is None
 
-     def test_close_connection__heart_shutdown(self):
 
-         c = self.NoopConsumer()
 
-         event_dispatcher = c.event_dispatcher
 
-         heart = c.heart
 
-         c.event_dispatcher.enabled = True
 
-         c.blueprint.state = RUN
 
-         Events = find_step(c, consumer.Events)
 
-         Events.shutdown(c)
 
-         Heart = find_step(c, consumer.Heart)
 
-         Heart.shutdown(c)
 
-         event_dispatcher.close.assert_called()
 
-         heart.stop.assert_called_with()
 
-     @patch('celery.worker.consumer.consumer.warn')
 
-     def test_receive_message_unknown(self, warn):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         channel = Mock(name='.channeol')
 
-         m = create_message(channel, unknown={'baz': '!!!'})
 
-         callback = self._get_on_message(c)
 
-         callback(m)
 
-         warn.assert_called()
 
-     @patch('celery.worker.strategy.to_timestamp')
 
-     def test_receive_message_eta_OverflowError(self, to_timestamp):
 
-         to_timestamp.side_effect = OverflowError()
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         m = self.create_task_message(
 
-             Mock(), self.foo_task.name,
 
-             args=('2, 2'), kwargs={},
 
-             eta=datetime.now().isoformat(),
 
-         )
 
-         c.update_strategies()
 
-         callback = self._get_on_message(c)
 
-         callback(m)
 
-         assert m.acknowledged
 
-     @patch('celery.worker.consumer.consumer.error')
 
-     def test_receive_message_InvalidTaskError(self, error):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         m = self.create_task_message(
 
-             Mock(), self.foo_task.name,
 
-             args=(1, 2), kwargs='foobarbaz', id=1)
 
-         c.update_strategies()
 
-         strat = c.strategies[self.foo_task.name] = Mock(name='strategy')
 
-         strat.side_effect = InvalidTaskError()
 
-         callback = self._get_on_message(c)
 
-         callback(m)
 
-         error.assert_called()
 
-         assert 'Received invalid task message' in error.call_args[0][0]
 
-     @patch('celery.worker.consumer.consumer.crit')
 
-     def test_on_decode_error(self, crit):
 
-         c = self.LoopConsumer()
 
-         class MockMessage(Mock):
 
-             content_type = 'application/x-msgpack'
 
-             content_encoding = 'binary'
 
-             body = 'foobarbaz'
 
-         message = MockMessage()
 
-         c.on_decode_error(message, KeyError('foo'))
 
-         assert message.ack.call_count
 
-         assert "Can't decode message body" in crit.call_args[0][0]
 
-     def _get_on_message(self, c):
 
-         if c.qos is None:
 
-             c.qos = Mock()
 
-         c.task_consumer = Mock()
 
-         c.event_dispatcher = mock_event_dispatcher()
 
-         c.connection = Mock(name='.connection')
 
-         c.connection.get_heartbeat_interval.return_value = 0
 
-         c.connection.drain_events.side_effect = WorkerShutdown()
 
-         with pytest.raises(WorkerShutdown):
 
-             c.loop(*c.loop_args())
 
-         assert c.task_consumer.on_message
 
-         return c.task_consumer.on_message
 
-     def test_receieve_message(self):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         m = self.create_task_message(
 
-             Mock(), self.foo_task.name,
 
-             args=[2, 4, 8], kwargs={},
 
-         )
 
-         c.update_strategies()
 
-         callback = self._get_on_message(c)
 
-         callback(m)
 
-         in_bucket = self.buffer.get_nowait()
 
-         assert isinstance(in_bucket, Request)
 
-         assert in_bucket.name == self.foo_task.name
 
-         assert in_bucket.execute() == 2 * 4 * 8
 
-         assert self.timer.empty()
 
-     def test_start_channel_error(self):
 
-         c = self.NoopConsumer(task_events=False, pool=BasePool())
 
-         c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
 
-         c.channel_errors = (KeyError,)
 
-         try:
 
-             with pytest.raises(KeyError):
 
-                 c.start()
 
-         finally:
 
-             c.timer and c.timer.stop()
 
-     def test_start_connection_error(self):
 
-         c = self.NoopConsumer(task_events=False, pool=BasePool())
 
-         c.loop.on_nth_call_do_raise(KeyError('foo'), SyntaxError('bar'))
 
-         c.connection_errors = (KeyError,)
 
-         try:
 
-             with pytest.raises(SyntaxError):
 
-                 c.start()
 
-         finally:
 
-             c.timer and c.timer.stop()
 
-     def test_loop_ignores_socket_timeout(self):
 
-         class Connection(self.app.connection_for_read().__class__):
 
-             obj = None
 
-             def drain_events(self, **kwargs):
 
-                 self.obj.connection = None
 
-                 raise socket.timeout(10)
 
-         c = self.NoopConsumer()
 
-         c.connection = Connection(self.app.conf.broker_url)
 
-         c.connection.obj = c
 
-         c.qos = QoS(c.task_consumer.qos, 10)
 
-         c.loop(*c.loop_args())
 
-     def test_loop_when_socket_error(self):
 
-         class Connection(self.app.connection_for_read().__class__):
 
-             obj = None
 
-             def drain_events(self, **kwargs):
 
-                 self.obj.connection = None
 
-                 raise socket.error('foo')
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         conn = c.connection = Connection(self.app.conf.broker_url)
 
-         c.connection.obj = c
 
-         c.qos = QoS(c.task_consumer.qos, 10)
 
-         with pytest.raises(socket.error):
 
-             c.loop(*c.loop_args())
 
-         c.blueprint.state = CLOSE
 
-         c.connection = conn
 
-         c.loop(*c.loop_args())
 
-     def test_loop(self):
 
-         class Connection(self.app.connection_for_read().__class__):
 
-             obj = None
 
-             def drain_events(self, **kwargs):
 
-                 self.obj.connection = None
 
-             @property
 
-             def supports_heartbeats(self):
 
-                 return False
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.connection = Connection(self.app.conf.broker_url)
 
-         c.connection.obj = c
 
-         c.connection.get_heartbeat_interval = Mock(return_value=None)
 
-         c.qos = QoS(c.task_consumer.qos, 10)
 
-         c.loop(*c.loop_args())
 
-         c.loop(*c.loop_args())
 
-         assert c.task_consumer.consume.call_count
 
-         c.task_consumer.qos.assert_called_with(prefetch_count=10)
 
-         assert c.qos.value == 10
 
-         c.qos.decrement_eventually()
 
-         assert c.qos.value == 9
 
-         c.qos.update()
 
-         assert c.qos.value == 9
 
-         c.task_consumer.qos.assert_called_with(prefetch_count=9)
 
-     def test_ignore_errors(self):
 
-         c = self.NoopConsumer()
 
-         c.connection_errors = (AttributeError, KeyError,)
 
-         c.channel_errors = (SyntaxError,)
 
-         ignore_errors(c, Mock(side_effect=AttributeError('foo')))
 
-         ignore_errors(c, Mock(side_effect=KeyError('foo')))
 
-         ignore_errors(c, Mock(side_effect=SyntaxError('foo')))
 
-         with pytest.raises(IndexError):
 
-             ignore_errors(c, Mock(side_effect=IndexError('foo')))
 
-     def test_apply_eta_task(self):
 
-         c = self.NoopConsumer()
 
-         c.qos = QoS(None, 10)
 
-         task = Mock(name='task', id='1234213')
 
-         qos = c.qos.value
 
-         c.apply_eta_task(task)
 
-         assert task in state.reserved_requests
 
-         assert c.qos.value == qos - 1
 
-         assert self.buffer.get_nowait() is task
 
-     def test_receieve_message_eta_isoformat(self):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         m = self.create_task_message(
 
-             Mock(), self.foo_task.name,
 
-             eta=(datetime.now() + timedelta(days=1)).isoformat(),
 
-             args=[2, 4, 8], kwargs={},
 
-         )
 
-         c.qos = QoS(c.task_consumer.qos, 1)
 
-         current_pcount = c.qos.value
 
-         c.event_dispatcher.enabled = False
 
-         c.update_strategies()
 
-         callback = self._get_on_message(c)
 
-         callback(m)
 
-         c.timer.stop()
 
-         c.timer.join(1)
 
-         items = [entry[2] for entry in self.timer.queue]
 
-         found = 0
 
-         for item in items:
 
-             if item.args[0].name == self.foo_task.name:
 
-                 found = True
 
-         assert found
 
-         assert c.qos.value > current_pcount
 
-         c.timer.stop()
 
-     def test_pidbox_callback(self):
 
-         c = self.NoopConsumer()
 
-         con = find_step(c, consumer.Control).box
 
-         con.node = Mock()
 
-         con.reset = Mock()
 
-         con.on_message('foo', 'bar')
 
-         con.node.handle_message.assert_called_with('foo', 'bar')
 
-         con.node = Mock()
 
-         con.node.handle_message.side_effect = KeyError('foo')
 
-         con.on_message('foo', 'bar')
 
-         con.node.handle_message.assert_called_with('foo', 'bar')
 
-         con.node = Mock()
 
-         con.node.handle_message.side_effect = ValueError('foo')
 
-         con.on_message('foo', 'bar')
 
-         con.node.handle_message.assert_called_with('foo', 'bar')
 
-         con.reset.assert_called()
 
-     def test_revoke(self):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         channel = Mock(name='channel')
 
-         id = uuid()
 
-         t = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[2, 4, 8], kwargs={}, id=id,
 
-         )
 
-         state.revoked.add(id)
 
-         callback = self._get_on_message(c)
 
-         callback(t)
 
-         assert self.buffer.empty()
 
-     def test_receieve_message_not_registered(self):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         c.steps.pop()
 
-         channel = Mock(name='channel')
 
-         m = self.create_task_message(
 
-             channel, 'x.X.31x', args=[2, 4, 8], kwargs={},
 
-         )
 
-         callback = self._get_on_message(c)
 
-         assert not callback(m)
 
-         with pytest.raises(Empty):
 
-             self.buffer.get_nowait()
 
-         assert self.timer.empty()
 
-     @patch('celery.worker.consumer.consumer.warn')
 
-     @patch('celery.worker.consumer.consumer.logger')
 
-     def test_receieve_message_ack_raises(self, logger, warn):
 
-         c = self.LoopConsumer()
 
-         c.blueprint.state = RUN
 
-         channel = Mock(name='channel')
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[2, 4, 8], kwargs={},
 
-         )
 
-         m.headers = None
 
-         c.update_strategies()
 
-         c.connection_errors = (socket.error,)
 
-         m.reject = Mock()
 
-         m.reject.side_effect = socket.error('foo')
 
-         callback = self._get_on_message(c)
 
-         assert not callback(m)
 
-         warn.assert_called()
 
-         with pytest.raises(Empty):
 
-             self.buffer.get_nowait()
 
-         assert self.timer.empty()
 
-         m.reject_log_error.assert_called_with(logger, c.connection_errors)
 
-     def test_receive_message_eta(self):
 
-         if os.environ.get('C_DEBUG_TEST'):
 
-             pp = partial(print, file=sys.__stderr__)
 
-         else:
 
-             def pp(*args, **kwargs):
 
-                 pass
 
-         pp('TEST RECEIVE MESSAGE ETA')
 
-         pp('+CREATE MYKOMBUCONSUMER')
 
-         c = self.LoopConsumer()
 
-         pp('-CREATE MYKOMBUCONSUMER')
 
-         c.steps.pop()
 
-         channel = Mock(name='channel')
 
-         pp('+ CREATE MESSAGE')
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[2, 4, 8], kwargs={},
 
-             eta=(datetime.now() + timedelta(days=1)).isoformat(),
 
-         )
 
-         pp('- CREATE MESSAGE')
 
-         try:
 
-             pp('+ BLUEPRINT START 1')
 
-             c.blueprint.start(c)
 
-             pp('- BLUEPRINT START 1')
 
-             p = c.app.conf.broker_connection_retry
 
-             c.app.conf.broker_connection_retry = False
 
-             pp('+ BLUEPRINT START 2')
 
-             c.blueprint.start(c)
 
-             pp('- BLUEPRINT START 2')
 
-             c.app.conf.broker_connection_retry = p
 
-             pp('+ BLUEPRINT RESTART')
 
-             c.blueprint.restart(c)
 
-             pp('- BLUEPRINT RESTART')
 
-             pp('+ GET ON MESSAGE')
 
-             callback = self._get_on_message(c)
 
-             pp('- GET ON MESSAGE')
 
-             pp('+ CALLBACK')
 
-             callback(m)
 
-             pp('- CALLBACK')
 
-         finally:
 
-             pp('+ STOP TIMER')
 
-             c.timer.stop()
 
-             pp('- STOP TIMER')
 
-             try:
 
-                 pp('+ JOIN TIMER')
 
-                 c.timer.join()
 
-                 pp('- JOIN TIMER')
 
-             except RuntimeError:
 
-                 pass
 
-         in_hold = c.timer.queue[0]
 
-         assert len(in_hold) == 3
 
-         eta, priority, entry = in_hold
 
-         task = entry.args[0]
 
-         assert isinstance(task, Request)
 
-         assert task.name == self.foo_task.name
 
-         assert task.execute() == 2 * 4 * 8
 
-         with pytest.raises(Empty):
 
-             self.buffer.get_nowait()
 
-     def test_reset_pidbox_node(self):
 
-         c = self.NoopConsumer()
 
-         con = find_step(c, consumer.Control).box
 
-         con.node = Mock()
 
-         chan = con.node.channel = Mock()
 
-         chan.close.side_effect = socket.error('foo')
 
-         c.connection_errors = (socket.error,)
 
-         con.reset()
 
-         chan.close.assert_called_with()
 
-     def test_reset_pidbox_node_green(self):
 
-         c = self.NoopConsumer(pool=Mock(is_green=True))
 
-         con = find_step(c, consumer.Control)
 
-         assert isinstance(con.box, gPidbox)
 
-         con.start(c)
 
-         c.pool.spawn_n.assert_called_with(con.box.loop, c)
 
-     def test_green_pidbox_node(self):
 
-         pool = Mock()
 
-         pool.is_green = True
 
-         c = self.NoopConsumer(pool=Mock(is_green=True))
 
-         controller = find_step(c, consumer.Control)
 
-         class BConsumer(Mock):
 
-             def __enter__(self):
 
-                 self.consume()
 
-                 return self
 
-             def __exit__(self, *exc_info):
 
-                 self.cancel()
 
-         controller.box.node.listen = BConsumer()
 
-         connections = []
 
-         class Connection(object):
 
-             calls = 0
 
-             def __init__(self, obj):
 
-                 connections.append(self)
 
-                 self.obj = obj
 
-                 self.default_channel = self.channel()
 
-                 self.closed = False
 
-             def __enter__(self):
 
-                 return self
 
-             def __exit__(self, *exc_info):
 
-                 self.close()
 
-             def channel(self):
 
-                 return Mock()
 
-             def as_uri(self):
 
-                 return 'dummy://'
 
-             def drain_events(self, **kwargs):
 
-                 if not self.calls:
 
-                     self.calls += 1
 
-                     raise socket.timeout()
 
-                 self.obj.connection = None
 
-                 controller.box._node_shutdown.set()
 
-             def close(self):
 
-                 self.closed = True
 
-         c.connection_for_read = lambda: Connection(obj=c)
 
-         controller = find_step(c, consumer.Control)
 
-         controller.box.loop(c)
 
-         controller.box.node.listen.assert_called()
 
-         assert controller.box.consumer
 
-         controller.box.consumer.consume.assert_called_with()
 
-         assert c.connection is None
 
-         assert connections[0].closed
 
-     @patch('kombu.connection.Connection._establish_connection')
 
-     @patch('kombu.utils.functional.sleep')
 
-     def test_connect_errback(self, sleep, connect):
 
-         c = self.NoopConsumer()
 
-         Transport.connection_errors = (ChannelError,)
 
-         connect.on_nth_call_do(ChannelError('error'), n=1)
 
-         c.connect()
 
-         connect.assert_called_with()
 
-     def test_stop_pidbox_node(self):
 
-         c = self.NoopConsumer()
 
-         cont = find_step(c, consumer.Control)
 
-         cont._node_stopped = Event()
 
-         cont._node_shutdown = Event()
 
-         cont._node_stopped.set()
 
-         cont.stop(c)
 
-     def test_start__loop(self):
 
-         class _QoS(object):
 
-             prev = 3
 
-             value = 4
 
-             def update(self):
 
-                 self.prev = self.value
 
-         init_callback = Mock(name='init_callback')
 
-         c = self.NoopConsumer(init_callback=init_callback)
 
-         c.qos = _QoS()
 
-         c.connection = Connection(self.app.conf.broker_url)
 
-         c.connection.get_heartbeat_interval = Mock(return_value=None)
 
-         c.iterations = 0
 
-         def raises_KeyError(*args, **kwargs):
 
-             c.iterations += 1
 
-             if c.qos.prev != c.qos.value:
 
-                 c.qos.update()
 
-             if c.iterations >= 2:
 
-                 raise KeyError('foo')
 
-         c.loop = raises_KeyError
 
-         with pytest.raises(KeyError):
 
-             c.start()
 
-         assert c.iterations == 2
 
-         assert c.qos.prev == c.qos.value
 
-         init_callback.reset_mock()
 
-         c = self.NoopConsumer(task_events=False, init_callback=init_callback)
 
-         c.qos = _QoS()
 
-         c.connection = Connection(self.app.conf.broker_url)
 
-         c.connection.get_heartbeat_interval = Mock(return_value=None)
 
-         c.loop = Mock(side_effect=socket.error('foo'))
 
-         with pytest.raises(socket.error):
 
-             c.start()
 
-         c.loop.assert_called()
 
-     def test_reset_connection_with_no_node(self):
 
-         c = self.NoopConsumer()
 
-         c.steps.pop()
 
-         c.blueprint.start(c)
 
- class test_WorkController(ConsumerCase):
 
-     def setup(self):
 
-         self.worker = self.create_worker()
 
-         self._logger = worker_module.logger
 
-         self._comp_logger = components.logger
 
-         self.logger = worker_module.logger = Mock()
 
-         self.comp_logger = components.logger = Mock()
 
-         @self.app.task(shared=False)
 
-         def foo_task(x, y, z):
 
-             return x * y * z
 
-         self.foo_task = foo_task
 
-     def teardown(self):
 
-         worker_module.logger = self._logger
 
-         components.logger = self._comp_logger
 
-     def create_worker(self, **kw):
 
-         worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
 
-         worker.blueprint.shutdown_complete.set()
 
-         return worker
 
-     def test_on_consumer_ready(self):
 
-         self.worker.on_consumer_ready(Mock())
 
-     def test_setup_queues_worker_direct(self):
 
-         self.app.conf.worker_direct = True
 
-         self.app.amqp.__dict__['queues'] = Mock()
 
-         self.worker.setup_queues({})
 
-         self.app.amqp.queues.select_add.assert_called_with(
 
-             worker_direct(self.worker.hostname),
 
-         )
 
-     def test_setup_queues__missing_queue(self):
 
-         self.app.amqp.queues.select = Mock(name='select')
 
-         self.app.amqp.queues.deselect = Mock(name='deselect')
 
-         self.app.amqp.queues.select.side_effect = KeyError()
 
-         self.app.amqp.queues.deselect.side_effect = KeyError()
 
-         with pytest.raises(ImproperlyConfigured):
 
-             self.worker.setup_queues('x,y', exclude='foo,bar')
 
-         self.app.amqp.queues.select = Mock(name='select')
 
-         with pytest.raises(ImproperlyConfigured):
 
-             self.worker.setup_queues('x,y', exclude='foo,bar')
 
-     def test_send_worker_shutdown(self):
 
-         with patch('celery.signals.worker_shutdown') as ws:
 
-             self.worker._send_worker_shutdown()
 
-             ws.send.assert_called_with(sender=self.worker)
 
-     @skip.todo('unstable test')
 
-     def test_process_shutdown_on_worker_shutdown(self):
 
-         from celery.concurrency.prefork import process_destructor
 
-         from celery.concurrency.asynpool import Worker
 
-         with patch('celery.signals.worker_process_shutdown') as ws:
 
-             with patch('os._exit') as _exit:
 
-                 worker = Worker(None, None, on_exit=process_destructor)
 
-                 worker._do_exit(22, 3.1415926)
 
-                 ws.send.assert_called_with(
 
-                     sender=None, pid=22, exitcode=3.1415926,
 
-                 )
 
-                 _exit.assert_called_with(3.1415926)
 
-     def test_process_task_revoked_release_semaphore(self):
 
-         self.worker._quick_release = Mock()
 
-         req = Mock()
 
-         req.execute_using_pool.side_effect = TaskRevokedError
 
-         self.worker._process_task(req)
 
-         self.worker._quick_release.assert_called_with()
 
-         delattr(self.worker, '_quick_release')
 
-         self.worker._process_task(req)
 
-     def test_shutdown_no_blueprint(self):
 
-         self.worker.blueprint = None
 
-         self.worker._shutdown()
 
-     @patch('celery.worker.worker.create_pidlock')
 
-     def test_use_pidfile(self, create_pidlock):
 
-         create_pidlock.return_value = Mock()
 
-         worker = self.create_worker(pidfile='pidfilelockfilepid')
 
-         worker.steps = []
 
-         worker.start()
 
-         create_pidlock.assert_called()
 
-         worker.stop()
 
-         worker.pidlock.release.assert_called()
 
-     def test_attrs(self):
 
-         worker = self.worker
 
-         assert worker.timer is not None
 
-         assert isinstance(worker.timer, Timer)
 
-         assert worker.pool is not None
 
-         assert worker.consumer is not None
 
-         assert worker.steps
 
-     def test_with_embedded_beat(self):
 
-         worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
 
-         assert worker.beat
 
-         assert worker.beat in [w.obj for w in worker.steps]
 
-     def test_with_autoscaler(self):
 
-         worker = self.create_worker(
 
-             autoscale=[10, 3], send_events=False,
 
-             timer_cls='celery.utils.timer2.Timer',
 
-         )
 
-         assert worker.autoscaler
 
-     def test_dont_stop_or_terminate(self):
 
-         worker = self.app.WorkController(concurrency=1, loglevel=0)
 
-         worker.stop()
 
-         assert worker.blueprint.state != CLOSE
 
-         worker.terminate()
 
-         assert worker.blueprint.state != CLOSE
 
-         sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
 
-         try:
 
-             worker.blueprint.state = RUN
 
-             worker.stop(in_sighandler=True)
 
-             assert worker.blueprint.state != CLOSE
 
-             worker.terminate(in_sighandler=True)
 
-             assert worker.blueprint.state != CLOSE
 
-         finally:
 
-             worker.pool.signal_safe = sigsafe
 
-     def test_on_timer_error(self):
 
-         worker = self.app.WorkController(concurrency=1, loglevel=0)
 
-         try:
 
-             raise KeyError('foo')
 
-         except KeyError as exc:
 
-             components.Timer(worker).on_timer_error(exc)
 
-             msg, args = self.comp_logger.error.call_args[0]
 
-             assert 'KeyError' in msg % args
 
-     def test_on_timer_tick(self):
 
-         worker = self.app.WorkController(concurrency=1, loglevel=10)
 
-         components.Timer(worker).on_timer_tick(30.0)
 
-         xargs = self.comp_logger.debug.call_args[0]
 
-         fmt, arg = xargs[0], xargs[1]
 
-         assert arg == 30.0
 
-         assert 'Next ETA %s secs' in fmt
 
-     def test_process_task(self):
 
-         worker = self.worker
 
-         worker.pool = Mock()
 
-         channel = Mock()
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[4, 8, 10], kwargs={},
 
-         )
 
-         task = Request(m, app=self.app)
 
-         worker._process_task(task)
 
-         assert worker.pool.apply_async.call_count == 1
 
-         worker.pool.stop()
 
-     def test_process_task_raise_base(self):
 
-         worker = self.worker
 
-         worker.pool = Mock()
 
-         worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
 
-         channel = Mock()
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[4, 8, 10], kwargs={},
 
-         )
 
-         task = Request(m, app=self.app)
 
-         worker.steps = []
 
-         worker.blueprint.state = RUN
 
-         with pytest.raises(KeyboardInterrupt):
 
-             worker._process_task(task)
 
-     def test_process_task_raise_WorkerTerminate(self):
 
-         worker = self.worker
 
-         worker.pool = Mock()
 
-         worker.pool.apply_async.side_effect = WorkerTerminate()
 
-         channel = Mock()
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[4, 8, 10], kwargs={},
 
-         )
 
-         task = Request(m, app=self.app)
 
-         worker.steps = []
 
-         worker.blueprint.state = RUN
 
-         with pytest.raises(SystemExit):
 
-             worker._process_task(task)
 
-     def test_process_task_raise_regular(self):
 
-         worker = self.worker
 
-         worker.pool = Mock()
 
-         worker.pool.apply_async.side_effect = KeyError('some exception')
 
-         channel = Mock()
 
-         m = self.create_task_message(
 
-             channel, self.foo_task.name,
 
-             args=[4, 8, 10], kwargs={},
 
-         )
 
-         task = Request(m, app=self.app)
 
-         with pytest.raises(KeyError):
 
-             worker._process_task(task)
 
-         worker.pool.stop()
 
-     def test_start_catches_base_exceptions(self):
 
-         worker1 = self.create_worker()
 
-         worker1.blueprint.state = RUN
 
-         stc = MockStep()
 
-         stc.start.side_effect = WorkerTerminate()
 
-         worker1.steps = [stc]
 
-         worker1.start()
 
-         stc.start.assert_called_with(worker1)
 
-         assert stc.terminate.call_count
 
-         worker2 = self.create_worker()
 
-         worker2.blueprint.state = RUN
 
-         sec = MockStep()
 
-         sec.start.side_effect = WorkerShutdown()
 
-         sec.terminate = None
 
-         worker2.steps = [sec]
 
-         worker2.start()
 
-         assert sec.stop.call_count
 
-     def test_statedb(self):
 
-         from celery.worker import state
 
-         Persistent = state.Persistent
 
-         state.Persistent = Mock()
 
-         try:
 
-             worker = self.create_worker(statedb='statefilename')
 
-             assert worker._persistence
 
-         finally:
 
-             state.Persistent = Persistent
 
-     def test_process_task_sem(self):
 
-         worker = self.worker
 
-         worker._quick_acquire = Mock()
 
-         req = Mock()
 
-         worker._process_task_sem(req)
 
-         worker._quick_acquire.assert_called_with(worker._process_task, req)
 
-     def test_signal_consumer_close(self):
 
-         worker = self.worker
 
-         worker.consumer = Mock()
 
-         worker.signal_consumer_close()
 
-         worker.consumer.close.assert_called_with()
 
-         worker.consumer.close.side_effect = AttributeError()
 
-         worker.signal_consumer_close()
 
-     def test_rusage__no_resource(self):
 
-         from celery.worker import worker
 
-         prev, worker.resource = worker.resource, None
 
-         try:
 
-             self.worker.pool = Mock(name='pool')
 
-             with pytest.raises(NotImplementedError):
 
-                 self.worker.rusage()
 
-             self.worker.stats()
 
-         finally:
 
-             worker.resource = prev
 
-     def test_repr(self):
 
-         assert repr(self.worker)
 
-     def test_str(self):
 
-         assert str(self.worker) == self.worker.hostname
 
-     def test_start__stop(self):
 
-         worker = self.worker
 
-         worker.blueprint.shutdown_complete.set()
 
-         worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
 
-         worker.blueprint.state = RUN
 
-         worker.blueprint.started = 4
 
-         for w in worker.steps:
 
-             w.start = Mock()
 
-             w.close = Mock()
 
-             w.stop = Mock()
 
-         worker.start()
 
-         for w in worker.steps:
 
-             w.start.assert_called()
 
-         worker.consumer = Mock()
 
-         worker.stop(exitcode=3)
 
-         for stopstep in worker.steps:
 
-             stopstep.close.assert_called()
 
-             stopstep.stop.assert_called()
 
-         # Doesn't close pool if no pool.
 
-         worker.start()
 
-         worker.pool = None
 
-         worker.stop()
 
-         # test that stop of None is not attempted
 
-         worker.steps[-1] = None
 
-         worker.start()
 
-         worker.stop()
 
-     def test_start__KeyboardInterrupt(self):
 
-         worker = self.worker
 
-         worker.blueprint = Mock(name='blueprint')
 
-         worker.blueprint.start.side_effect = KeyboardInterrupt()
 
-         worker.stop = Mock(name='stop')
 
-         worker.start()
 
-         worker.stop.assert_called_with(exitcode=EX_FAILURE)
 
-     def test_register_with_event_loop(self):
 
-         worker = self.worker
 
-         hub = Mock(name='hub')
 
-         worker.blueprint = Mock(name='blueprint')
 
-         worker.register_with_event_loop(hub)
 
-         worker.blueprint.send_all.assert_called_with(
 
-             worker, 'register_with_event_loop', args=(hub,),
 
-             description='hub.register',
 
-         )
 
-     def test_step_raises(self):
 
-         worker = self.worker
 
-         step = Mock()
 
-         worker.steps = [step]
 
-         step.start.side_effect = TypeError()
 
-         worker.stop = Mock()
 
-         worker.start()
 
-         worker.stop.assert_called_with(exitcode=EX_FAILURE)
 
-     def test_state(self):
 
-         assert self.worker.state
 
-     def test_start__terminate(self):
 
-         worker = self.worker
 
-         worker.blueprint.shutdown_complete.set()
 
-         worker.blueprint.started = 5
 
-         worker.blueprint.state = RUN
 
-         worker.steps = [MockStep() for _ in range(5)]
 
-         worker.start()
 
-         for w in worker.steps[:3]:
 
-             w.start.assert_called()
 
-         assert worker.blueprint.started == len(worker.steps)
 
-         assert worker.blueprint.state == RUN
 
-         worker.terminate()
 
-         for step in worker.steps:
 
-             step.terminate.assert_called()
 
-         worker.blueprint.state = TERMINATE
 
-         worker.terminate()
 
-     def test_Hub_create(self):
 
-         w = Mock()
 
-         x = components.Hub(w)
 
-         x.create(w)
 
-         assert w.timer.max_interval
 
-     def test_Pool_create_threaded(self):
 
-         w = Mock()
 
-         w._conninfo.connection_errors = w._conninfo.channel_errors = ()
 
-         w.pool_cls = Mock()
 
-         w.use_eventloop = False
 
-         pool = components.Pool(w)
 
-         pool.create(w)
 
-     def test_Pool_pool_no_sem(self):
 
-         w = Mock()
 
-         w.pool_cls.uses_semaphore = False
 
-         components.Pool(w).create(w)
 
-         assert w.process_task is w._process_task
 
-     def test_Pool_create(self):
 
-         from kombu.asynchronous.semaphore import LaxBoundedSemaphore
 
-         w = Mock()
 
-         w._conninfo.connection_errors = w._conninfo.channel_errors = ()
 
-         w.hub = Mock()
 
-         PoolImp = Mock()
 
-         poolimp = PoolImp.return_value = Mock()
 
-         poolimp._pool = [Mock(), Mock()]
 
-         poolimp._cache = {}
 
-         poolimp._fileno_to_inq = {}
 
-         poolimp._fileno_to_outq = {}
 
-         from celery.concurrency.prefork import TaskPool as _TaskPool
 
-         class MockTaskPool(_TaskPool):
 
-             Pool = PoolImp
 
-             @property
 
-             def timers(self):
 
-                 return {Mock(): 30}
 
-         w.pool_cls = MockTaskPool
 
-         w.use_eventloop = True
 
-         w.consumer.restart_count = -1
 
-         pool = components.Pool(w)
 
-         pool.create(w)
 
-         pool.register_with_event_loop(w, w.hub)
 
-         if sys.platform != 'win32':
 
-             assert isinstance(w.semaphore, LaxBoundedSemaphore)
 
-             P = w.pool
 
-             P.start()
 
 
  |