|
- from __future__ import absolute_import, unicode_literals
- from contextlib import contextmanager
- from mock import patch
- from amqp import ChannelError
- from kombu import Connection, Producer, Queue, Exchange
- from kombu.transport.virtual import QoS
- from celery.contrib.migrate import (
- StopFiltering,
- State,
- migrate_task,
- migrate_tasks,
- filter_callback,
- _maybe_queue,
- filter_status,
- move_by_taskmap,
- move_by_idmap,
- move_task_by_id,
- start_filter,
- task_id_in,
- task_id_eq,
- expand_dest,
- move,
- )
- from celery.utils.encoding import bytes_t, ensure_bytes
- from celery.tests.case import AppCase, Mock, override_stdouts
- # hack to ignore error at shutdown
- QoS.restore_at_shutdown = False
- def Message(body, exchange='exchange', routing_key='rkey',
- compression=None, content_type='application/json',
- content_encoding='utf-8'):
- return Mock(
- attrs={
- 'body': body,
- 'delivery_info': {
- 'exchange': exchange,
- 'routing_key': routing_key,
- },
- 'headers': {
- 'compression': compression,
- },
- 'content_type': content_type,
- 'content_encoding': content_encoding,
- 'properties': {}
- },
- )
- class test_State(AppCase):
- def test_strtotal(self):
- x = State()
- self.assertEqual(x.strtotal, '?')
- x.total_apx = 100
- self.assertEqual(x.strtotal, '100')
- def test_repr(self):
- x = State()
- self.assertTrue(repr(x))
- x.filtered = 'foo'
- self.assertTrue(repr(x))
- class test_move(AppCase):
- @contextmanager
- def move_context(self, **kwargs):
- with patch('celery.contrib.migrate.start_filter') as start:
- with patch('celery.contrib.migrate.republish') as republish:
- pred = Mock(name='predicate')
- move(pred, app=self.app,
- connection=self.app.connection(), **kwargs)
- self.assertTrue(start.called)
- callback = start.call_args[0][2]
- yield callback, pred, republish
- def msgpair(self, **kwargs):
- body = dict({'task': 'add', 'id': 'id'}, **kwargs)
- return body, Message(body)
- def test_move(self):
- with self.move_context() as (callback, pred, republish):
- pred.return_value = None
- body, message = self.msgpair()
- callback(body, message)
- self.assertFalse(message.ack.called)
- self.assertFalse(republish.called)
- pred.return_value = 'foo'
- callback(body, message)
- message.ack.assert_called_with()
- self.assertTrue(republish.called)
- def test_move_transform(self):
- trans = Mock(name='transform')
- trans.return_value = Queue('bar')
- with self.move_context(transform=trans) as (callback, pred, republish):
- pred.return_value = 'foo'
- body, message = self.msgpair()
- with patch('celery.contrib.migrate.maybe_declare') as maybed:
- callback(body, message)
- trans.assert_called_with('foo')
- self.assertTrue(maybed.called)
- self.assertTrue(republish.called)
- def test_limit(self):
- with self.move_context(limit=1) as (callback, pred, republish):
- pred.return_value = 'foo'
- body, message = self.msgpair()
- with self.assertRaises(StopFiltering):
- callback(body, message)
- self.assertTrue(republish.called)
- def test_callback(self):
- cb = Mock()
- with self.move_context(callback=cb) as (callback, pred, republish):
- pred.return_value = 'foo'
- body, message = self.msgpair()
- callback(body, message)
- self.assertTrue(republish.called)
- self.assertTrue(cb.called)
- class test_start_filter(AppCase):
- def test_start(self):
- with patch('celery.contrib.migrate.eventloop') as evloop:
- app = Mock()
- filt = Mock(name='filter')
- conn = Connection('memory://')
- evloop.side_effect = StopFiltering()
- app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
- consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
- consumer.queues = list(app.amqp.queues.values())
- consumer.channel = conn.default_channel
- consumer.__enter__ = Mock(name='consumer.__enter__')
- consumer.__exit__ = Mock(name='consumer.__exit__')
- consumer.callbacks = []
- def register_callback(x):
- consumer.callbacks.append(x)
- consumer.register_callback = register_callback
- start_filter(app, conn, filt,
- queues='foo,bar', ack_messages=True)
- body = {'task': 'add', 'id': 'id'}
- for callback in consumer.callbacks:
- callback(body, Message(body))
- consumer.callbacks[:] = []
- cb = Mock(name='callback=')
- start_filter(app, conn, filt, tasks='add,mul', callback=cb)
- for callback in consumer.callbacks:
- callback(body, Message(body))
- self.assertTrue(cb.called)
- on_declare_queue = Mock()
- start_filter(app, conn, filt, tasks='add,mul', queues='foo',
- on_declare_queue=on_declare_queue)
- self.assertTrue(on_declare_queue.called)
- start_filter(app, conn, filt, queues=['foo', 'bar'])
- consumer.callbacks[:] = []
- state = State()
- start_filter(app, conn, filt,
- tasks='add,mul', callback=cb, state=state, limit=1)
- stop_filtering_raised = False
- for callback in consumer.callbacks:
- try:
- callback(body, Message(body))
- except StopFiltering:
- stop_filtering_raised = True
- self.assertTrue(state.count)
- self.assertTrue(stop_filtering_raised)
- class test_filter_callback(AppCase):
- def test_filter(self):
- callback = Mock()
- filt = filter_callback(callback, ['add', 'mul'])
- t1 = {'task': 'add'}
- t2 = {'task': 'div'}
- message = Mock()
- filt(t2, message)
- self.assertFalse(callback.called)
- filt(t1, message)
- callback.assert_called_with(t1, message)
- class test_utils(AppCase):
- def test_task_id_in(self):
- self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
- self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
- def test_task_id_eq(self):
- self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
- self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
- def test_expand_dest(self):
- self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
- self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
- def test_maybe_queue(self):
- app = Mock()
- app.amqp.queues = {'foo': 313}
- self.assertEqual(_maybe_queue(app, 'foo'), 313)
- self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
- def test_filter_status(self):
- with override_stdouts() as (stdout, stderr):
- filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
- self.assertTrue(stdout.getvalue())
- def test_move_by_taskmap(self):
- with patch('celery.contrib.migrate.move') as move:
- move_by_taskmap({'add': Queue('foo')})
- self.assertTrue(move.called)
- cb = move.call_args[0][0]
- self.assertTrue(cb({'task': 'add'}, Mock()))
- def test_move_by_idmap(self):
- with patch('celery.contrib.migrate.move') as move:
- move_by_idmap({'123f': Queue('foo')})
- self.assertTrue(move.called)
- cb = move.call_args[0][0]
- self.assertTrue(cb({'id': '123f'}, Mock()))
- def test_move_task_by_id(self):
- with patch('celery.contrib.migrate.move') as move:
- move_task_by_id('123f', Queue('foo'))
- self.assertTrue(move.called)
- cb = move.call_args[0][0]
- self.assertEqual(
- cb({'id': '123f'}, Mock()),
- Queue('foo'),
- )
- class test_migrate_task(AppCase):
- def test_removes_compression_header(self):
- x = Message('foo', compression='zlib')
- producer = Mock()
- migrate_task(producer, x.body, x)
- self.assertTrue(producer.publish.called)
- args, kwargs = producer.publish.call_args
- self.assertIsInstance(args[0], bytes_t)
- self.assertNotIn('compression', kwargs['headers'])
- self.assertEqual(kwargs['compression'], 'zlib')
- self.assertEqual(kwargs['content_type'], 'application/json')
- self.assertEqual(kwargs['content_encoding'], 'utf-8')
- self.assertEqual(kwargs['exchange'], 'exchange')
- self.assertEqual(kwargs['routing_key'], 'rkey')
- class test_migrate_tasks(AppCase):
- def test_migrate(self, name='testcelery'):
- x = Connection('memory://foo')
- y = Connection('memory://foo')
- # use separate state
- x.default_channel.queues = {}
- y.default_channel.queues = {}
- ex = Exchange(name, 'direct')
- q = Queue(name, exchange=ex, routing_key=name)
- q(x.default_channel).declare()
- Producer(x).publish('foo', exchange=name, routing_key=name)
- Producer(x).publish('bar', exchange=name, routing_key=name)
- Producer(x).publish('baz', exchange=name, routing_key=name)
- self.assertTrue(x.default_channel.queues)
- self.assertFalse(y.default_channel.queues)
- migrate_tasks(x, y, accept=['text/plain'], app=self.app)
- yq = q(y.default_channel)
- self.assertEqual(yq.get().body, ensure_bytes('foo'))
- self.assertEqual(yq.get().body, ensure_bytes('bar'))
- self.assertEqual(yq.get().body, ensure_bytes('baz'))
- Producer(x).publish('foo', exchange=name, routing_key=name)
- callback = Mock()
- migrate_tasks(x, y,
- callback=callback, accept=['text/plain'], app=self.app)
- self.assertTrue(callback.called)
- migrate = Mock()
- Producer(x).publish('baz', exchange=name, routing_key=name)
- migrate_tasks(x, y, callback=callback,
- migrate=migrate, accept=['text/plain'], app=self.app)
- self.assertTrue(migrate.called)
- with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
- def effect(*args, **kwargs):
- if kwargs.get('passive'):
- raise ChannelError('some channel error')
- return 0, 3, 0
- qd.side_effect = effect
- migrate_tasks(x, y, app=self.app)
- x = Connection('memory://')
- x.default_channel.queues = {}
- y.default_channel.queues = {}
- callback = Mock()
- migrate_tasks(x, y,
- callback=callback, accept=['text/plain'], app=self.app)
- self.assertFalse(callback.called)
|