test_migrate.py 11 KB


  1. from __future__ import absolute_import, unicode_literals
  2. from contextlib import contextmanager
  3. from mock import patch
  4. from amqp import ChannelError
  5. from kombu import Connection, Producer, Queue, Exchange
  6. from kombu.transport.virtual import QoS
  7. from celery.contrib.migrate import (
  8. StopFiltering,
  9. State,
  10. migrate_task,
  11. migrate_tasks,
  12. filter_callback,
  13. _maybe_queue,
  14. filter_status,
  15. move_by_taskmap,
  16. move_by_idmap,
  17. move_task_by_id,
  18. start_filter,
  19. task_id_in,
  20. task_id_eq,
  21. expand_dest,
  22. move,
  23. )
  24. from celery.utils.encoding import bytes_t, ensure_bytes
  25. from celery.tests.case import AppCase, Mock, override_stdouts
  26. # hack to ignore error at shutdown
  27. QoS.restore_at_shutdown = False
  28. def Message(body, exchange='exchange', routing_key='rkey',
  29. compression=None, content_type='application/json',
  30. content_encoding='utf-8'):
  31. return Mock(
  32. attrs={
  33. 'body': body,
  34. 'delivery_info': {
  35. 'exchange': exchange,
  36. 'routing_key': routing_key,
  37. },
  38. 'headers': {
  39. 'compression': compression,
  40. },
  41. 'content_type': content_type,
  42. 'content_encoding': content_encoding,
  43. 'properties': {}
  44. },
  45. )
  46. class test_State(AppCase):
  47. def test_strtotal(self):
  48. x = State()
  49. self.assertEqual(x.strtotal, '?')
  50. x.total_apx = 100
  51. self.assertEqual(x.strtotal, '100')
  52. def test_repr(self):
  53. x = State()
  54. self.assertTrue(repr(x))
  55. x.filtered = 'foo'
  56. self.assertTrue(repr(x))
  57. class test_move(AppCase):
  58. @contextmanager
  59. def move_context(self, **kwargs):
  60. with patch('celery.contrib.migrate.start_filter') as start:
  61. with patch('celery.contrib.migrate.republish') as republish:
  62. pred = Mock(name='predicate')
  63. move(pred, app=self.app,
  64. connection=self.app.connection(), **kwargs)
  65. self.assertTrue(start.called)
  66. callback = start.call_args[0][2]
  67. yield callback, pred, republish
  68. def msgpair(self, **kwargs):
  69. body = dict({'task': 'add', 'id': 'id'}, **kwargs)
  70. return body, Message(body)
  71. def test_move(self):
  72. with self.move_context() as (callback, pred, republish):
  73. pred.return_value = None
  74. body, message = self.msgpair()
  75. callback(body, message)
  76. self.assertFalse(message.ack.called)
  77. self.assertFalse(republish.called)
  78. pred.return_value = 'foo'
  79. callback(body, message)
  80. message.ack.assert_called_with()
  81. self.assertTrue(republish.called)
  82. def test_move_transform(self):
  83. trans = Mock(name='transform')
  84. trans.return_value = Queue('bar')
  85. with self.move_context(transform=trans) as (callback, pred, republish):
  86. pred.return_value = 'foo'
  87. body, message = self.msgpair()
  88. with patch('celery.contrib.migrate.maybe_declare') as maybed:
  89. callback(body, message)
  90. trans.assert_called_with('foo')
  91. self.assertTrue(maybed.called)
  92. self.assertTrue(republish.called)
  93. def test_limit(self):
  94. with self.move_context(limit=1) as (callback, pred, republish):
  95. pred.return_value = 'foo'
  96. body, message = self.msgpair()
  97. with self.assertRaises(StopFiltering):
  98. callback(body, message)
  99. self.assertTrue(republish.called)
  100. def test_callback(self):
  101. cb = Mock()
  102. with self.move_context(callback=cb) as (callback, pred, republish):
  103. pred.return_value = 'foo'
  104. body, message = self.msgpair()
  105. callback(body, message)
  106. self.assertTrue(republish.called)
  107. self.assertTrue(cb.called)
  108. class test_start_filter(AppCase):
  109. def test_start(self):
  110. with patch('celery.contrib.migrate.eventloop') as evloop:
  111. app = Mock()
  112. filt = Mock(name='filter')
  113. conn = Connection('memory://')
  114. evloop.side_effect = StopFiltering()
  115. app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
  116. consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
  117. consumer.queues = list(app.amqp.queues.values())
  118. consumer.channel = conn.default_channel
  119. consumer.__enter__ = Mock(name='consumer.__enter__')
  120. consumer.__exit__ = Mock(name='consumer.__exit__')
  121. consumer.callbacks = []
  122. def register_callback(x):
  123. consumer.callbacks.append(x)
  124. consumer.register_callback = register_callback
  125. start_filter(app, conn, filt,
  126. queues='foo,bar', ack_messages=True)
  127. body = {'task': 'add', 'id': 'id'}
  128. for callback in consumer.callbacks:
  129. callback(body, Message(body))
  130. consumer.callbacks[:] = []
  131. cb = Mock(name='callback=')
  132. start_filter(app, conn, filt, tasks='add,mul', callback=cb)
  133. for callback in consumer.callbacks:
  134. callback(body, Message(body))
  135. self.assertTrue(cb.called)
  136. on_declare_queue = Mock()
  137. start_filter(app, conn, filt, tasks='add,mul', queues='foo',
  138. on_declare_queue=on_declare_queue)
  139. self.assertTrue(on_declare_queue.called)
  140. start_filter(app, conn, filt, queues=['foo', 'bar'])
  141. consumer.callbacks[:] = []
  142. state = State()
  143. start_filter(app, conn, filt,
  144. tasks='add,mul', callback=cb, state=state, limit=1)
  145. stop_filtering_raised = False
  146. for callback in consumer.callbacks:
  147. try:
  148. callback(body, Message(body))
  149. except StopFiltering:
  150. stop_filtering_raised = True
  151. self.assertTrue(state.count)
  152. self.assertTrue(stop_filtering_raised)
  153. class test_filter_callback(AppCase):
  154. def test_filter(self):
  155. callback = Mock()
  156. filt = filter_callback(callback, ['add', 'mul'])
  157. t1 = {'task': 'add'}
  158. t2 = {'task': 'div'}
  159. message = Mock()
  160. filt(t2, message)
  161. self.assertFalse(callback.called)
  162. filt(t1, message)
  163. callback.assert_called_with(t1, message)
  164. class test_utils(AppCase):
  165. def test_task_id_in(self):
  166. self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
  167. self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
  168. def test_task_id_eq(self):
  169. self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
  170. self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
  171. def test_expand_dest(self):
  172. self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
  173. self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
  174. def test_maybe_queue(self):
  175. app = Mock()
  176. app.amqp.queues = {'foo': 313}
  177. self.assertEqual(_maybe_queue(app, 'foo'), 313)
  178. self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
  179. def test_filter_status(self):
  180. with override_stdouts() as (stdout, stderr):
  181. filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
  182. self.assertTrue(stdout.getvalue())
  183. def test_move_by_taskmap(self):
  184. with patch('celery.contrib.migrate.move') as move:
  185. move_by_taskmap({'add': Queue('foo')})
  186. self.assertTrue(move.called)
  187. cb = move.call_args[0][0]
  188. self.assertTrue(cb({'task': 'add'}, Mock()))
  189. def test_move_by_idmap(self):
  190. with patch('celery.contrib.migrate.move') as move:
  191. move_by_idmap({'123f': Queue('foo')})
  192. self.assertTrue(move.called)
  193. cb = move.call_args[0][0]
  194. self.assertTrue(cb({'id': '123f'}, Mock()))
  195. def test_move_task_by_id(self):
  196. with patch('celery.contrib.migrate.move') as move:
  197. move_task_by_id('123f', Queue('foo'))
  198. self.assertTrue(move.called)
  199. cb = move.call_args[0][0]
  200. self.assertEqual(
  201. cb({'id': '123f'}, Mock()),
  202. Queue('foo'),
  203. )
  204. class test_migrate_task(AppCase):
  205. def test_removes_compression_header(self):
  206. x = Message('foo', compression='zlib')
  207. producer = Mock()
  208. migrate_task(producer, x.body, x)
  209. self.assertTrue(producer.publish.called)
  210. args, kwargs = producer.publish.call_args
  211. self.assertIsInstance(args[0], bytes_t)
  212. self.assertNotIn('compression', kwargs['headers'])
  213. self.assertEqual(kwargs['compression'], 'zlib')
  214. self.assertEqual(kwargs['content_type'], 'application/json')
  215. self.assertEqual(kwargs['content_encoding'], 'utf-8')
  216. self.assertEqual(kwargs['exchange'], 'exchange')
  217. self.assertEqual(kwargs['routing_key'], 'rkey')
  218. class test_migrate_tasks(AppCase):
  219. def test_migrate(self, name='testcelery'):
  220. x = Connection('memory://foo')
  221. y = Connection('memory://foo')
  222. # use separate state
  223. x.default_channel.queues = {}
  224. y.default_channel.queues = {}
  225. ex = Exchange(name, 'direct')
  226. q = Queue(name, exchange=ex, routing_key=name)
  227. q(x.default_channel).declare()
  228. Producer(x).publish('foo', exchange=name, routing_key=name)
  229. Producer(x).publish('bar', exchange=name, routing_key=name)
  230. Producer(x).publish('baz', exchange=name, routing_key=name)
  231. self.assertTrue(x.default_channel.queues)
  232. self.assertFalse(y.default_channel.queues)
  233. migrate_tasks(x, y, accept=['text/plain'], app=self.app)
  234. yq = q(y.default_channel)
  235. self.assertEqual(yq.get().body, ensure_bytes('foo'))
  236. self.assertEqual(yq.get().body, ensure_bytes('bar'))
  237. self.assertEqual(yq.get().body, ensure_bytes('baz'))
  238. Producer(x).publish('foo', exchange=name, routing_key=name)
  239. callback = Mock()
  240. migrate_tasks(x, y,
  241. callback=callback, accept=['text/plain'], app=self.app)
  242. self.assertTrue(callback.called)
  243. migrate = Mock()
  244. Producer(x).publish('baz', exchange=name, routing_key=name)
  245. migrate_tasks(x, y, callback=callback,
  246. migrate=migrate, accept=['text/plain'], app=self.app)
  247. self.assertTrue(migrate.called)
  248. with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
  249. def effect(*args, **kwargs):
  250. if kwargs.get('passive'):
  251. raise ChannelError('some channel error')
  252. return 0, 3, 0
  253. qd.side_effect = effect
  254. migrate_tasks(x, y, app=self.app)
  255. x = Connection('memory://')
  256. x.default_channel.queues = {}
  257. y.default_channel.queues = {}
  258. callback = Mock()
  259. migrate_tasks(x, y,
  260. callback=callback, accept=['text/plain'], app=self.app)
  261. self.assertFalse(callback.called)