| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 | from __future__ import absolute_import, unicode_literalsimport pytestfrom contextlib import contextmanagerfrom amqp import ChannelErrorfrom case import Mock, mock, patchfrom kombu import Connection, Producer, Queue, Exchangefrom kombu.transport.virtual import QoSfrom celery.contrib.migrate import (    StopFiltering,    State,    migrate_task,    migrate_tasks,    filter_callback,    _maybe_queue,    filter_status,    move_by_taskmap,    move_by_idmap,    move_task_by_id,    start_filter,    task_id_in,    task_id_eq,    expand_dest,    move,)from celery.utils.encoding import bytes_t, ensure_bytes# hack to ignore error at shutdownQoS.restore_at_shutdown = Falsedef Message(body, exchange='exchange', routing_key='rkey',            compression=None, content_type='application/json',            content_encoding='utf-8'):    return Mock(        attrs={            'body': body,            'delivery_info': {                'exchange': exchange,                'routing_key': routing_key,            },            'headers': {                'compression': compression,            },            'content_type': content_type,            'content_encoding': content_encoding,            'properties': {}        },    )class test_State:    def test_strtotal(self):        x = State()        assert x.strtotal == '?'        x.total_apx = 100        assert x.strtotal == '100'    def test_repr(self):        x = State()        assert repr(x)        x.filtered = 'foo'        assert repr(x)class test_move:    @contextmanager    def move_context(self, **kwargs):        with patch('celery.contrib.migrate.start_filter') as start:            with patch('celery.contrib.migrate.republish') as republish:                pred = Mock(name='predicate')                move(pred, app=self.app,                     connection=self.app.connection(), **kwargs)                start.assert_called()                callback = start.call_args[0][2]                yield callback, pred, republish    def msgpair(self, **kwargs):        body = dict({'task': 'add', 'id': 'id'}, **kwargs)        return body, Message(body)    def test_move(self):        with self.move_context() as (callback, pred, republish):            pred.return_value = None            body, message = self.msgpair()            callback(body, message)            message.ack.assert_not_called()            republish.assert_not_called()            pred.return_value = 'foo'            callback(body, message)            message.ack.assert_called_with()            republish.assert_called()    def test_move_transform(self):        trans = Mock(name='transform')        trans.return_value = Queue('bar')        with self.move_context(transform=trans) as (callback, pred, republish):            pred.return_value = 'foo'            body, message = self.msgpair()            with patch('celery.contrib.migrate.maybe_declare') as maybed:                callback(body, message)                trans.assert_called_with('foo')                maybed.assert_called()                republish.assert_called()    def test_limit(self):        with self.move_context(limit=1) as (callback, pred, republish):            pred.return_value = 'foo'            body, message = self.msgpair()            with pytest.raises(StopFiltering):                callback(body, message)            republish.assert_called()    def test_callback(self):        cb = Mock()        with self.move_context(callback=cb) as (callback, pred, republish):            pred.return_value = 'foo'            body, message = self.msgpair()            callback(body, message)            republish.assert_called()            cb.assert_called()class test_start_filter:    def test_start(self):        with patch('celery.contrib.migrate.eventloop') as evloop:            app = Mock()            filt = Mock(name='filter')            conn = Connection('memory://')            evloop.side_effect = StopFiltering()            app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}            consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')            consumer.queues = list(app.amqp.queues.values())            consumer.channel = conn.default_channel            consumer.__enter__ = Mock(name='consumer.__enter__')            consumer.__exit__ = Mock(name='consumer.__exit__')            consumer.callbacks = []            def register_callback(x):                consumer.callbacks.append(x)            consumer.register_callback = register_callback            start_filter(app, conn, filt,                         queues='foo,bar', ack_messages=True)            body = {'task': 'add', 'id': 'id'}            for callback in consumer.callbacks:                callback(body, Message(body))            consumer.callbacks[:] = []            cb = Mock(name='callback=')            start_filter(app, conn, filt, tasks='add,mul', callback=cb)            for callback in consumer.callbacks:                callback(body, Message(body))            cb.assert_called()            on_declare_queue = Mock()            start_filter(app, conn, filt, tasks='add,mul', queues='foo',                         on_declare_queue=on_declare_queue)            on_declare_queue.assert_called()            start_filter(app, conn, filt, queues=['foo', 'bar'])            consumer.callbacks[:] = []            state = State()            start_filter(app, conn, filt,                         tasks='add,mul', callback=cb, state=state, limit=1)            stop_filtering_raised = False            for callback in consumer.callbacks:                try:                    callback(body, Message(body))                except StopFiltering:                    stop_filtering_raised = True            assert state.count            assert stop_filtering_raisedclass test_filter_callback:    def test_filter(self):        callback = Mock()        filt = filter_callback(callback, ['add', 'mul'])        t1 = {'task': 'add'}        t2 = {'task': 'div'}        message = Mock()        filt(t2, message)        callback.assert_not_called()        filt(t1, message)        callback.assert_called_with(t1, message)def test_task_id_in():    assert task_id_in(['A'], {'id': 'A'}, Mock())    assert not task_id_in(['A'], {'id': 'B'}, Mock())def test_task_id_eq():    assert task_id_eq('A', {'id': 'A'}, Mock())    assert not task_id_eq('A', {'id': 'B'}, Mock())def test_expand_dest():    assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')    assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')def test_maybe_queue():    app = Mock()    app.amqp.queues = {'foo': 313}    assert _maybe_queue(app, 'foo') == 313    assert _maybe_queue(app, Queue('foo')) == Queue('foo')def test_filter_status():    with mock.stdouts() as (stdout, stderr):        filter_status(State(), {'id': '1', 'task': 'add'}, Mock())        assert stdout.getvalue()def test_move_by_taskmap():    with patch('celery.contrib.migrate.move') as move:        move_by_taskmap({'add': Queue('foo')})        move.assert_called()        cb = move.call_args[0][0]        assert cb({'task': 'add'}, Mock())def test_move_by_idmap():    with patch('celery.contrib.migrate.move') as move:        move_by_idmap({'123f': Queue('foo')})        move.assert_called()        cb = move.call_args[0][0]        assert cb({'id': '123f'}, Mock())def test_move_task_by_id():    with patch('celery.contrib.migrate.move') as move:        move_task_by_id('123f', Queue('foo'))        move.assert_called()        cb = move.call_args[0][0]        assert cb({'id': '123f'}, Mock()) == Queue('foo')class test_migrate_task:    def test_removes_compression_header(self):        x = Message('foo', compression='zlib')        producer = Mock()        migrate_task(producer, x.body, x)        producer.publish.assert_called()        args, kwargs = producer.publish.call_args        assert isinstance(args[0], bytes_t)        assert 'compression' not in kwargs['headers']        assert kwargs['compression'] == 'zlib'        assert kwargs['content_type'] == 'application/json'        assert kwargs['content_encoding'] == 'utf-8'        assert kwargs['exchange'] == 'exchange'        assert kwargs['routing_key'] == 'rkey'class test_migrate_tasks:    def test_migrate(self, app, name='testcelery'):        x = Connection('memory://foo')        y = Connection('memory://foo')        # use separate state        x.default_channel.queues = {}        y.default_channel.queues = {}        ex = Exchange(name, 'direct')        q = Queue(name, exchange=ex, routing_key=name)        q(x.default_channel).declare()        Producer(x).publish('foo', exchange=name, routing_key=name)        Producer(x).publish('bar', exchange=name, routing_key=name)        Producer(x).publish('baz', exchange=name, routing_key=name)        assert x.default_channel.queues        assert not y.default_channel.queues        migrate_tasks(x, y, accept=['text/plain'], app=app)        yq = q(y.default_channel)        assert yq.get().body == ensure_bytes('foo')        assert yq.get().body == ensure_bytes('bar')        assert yq.get().body == ensure_bytes('baz')        Producer(x).publish('foo', exchange=name, routing_key=name)        callback = Mock()        migrate_tasks(x, y,                      callback=callback, accept=['text/plain'], app=app)        callback.assert_called()        migrate = Mock()        Producer(x).publish('baz', exchange=name, routing_key=name)        migrate_tasks(x, y, callback=callback,                      migrate=migrate, accept=['text/plain'], app=app)        migrate.assert_called()        with patch('kombu.transport.virtual.Channel.queue_declare') as qd:            def effect(*args, **kwargs):                if kwargs.get('passive'):                    raise ChannelError('some channel error')                return 0, 3, 0            qd.side_effect = effect            migrate_tasks(x, y, app=app)        x = Connection('memory://')        x.default_channel.queues = {}        y.default_channel.queues = {}        callback = Mock()        migrate_tasks(x, y,                      callback=callback, accept=['text/plain'], app=app)        callback.assert_not_called()
 |