test_migrate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import pytest
  2. from contextlib import contextmanager
  3. from amqp import ChannelError
  4. from case import Mock, mock, patch
  5. from kombu import Connection, Producer, Queue, Exchange
  6. from kombu.transport.virtual import QoS
  7. from celery.contrib.migrate import (
  8. StopFiltering,
  9. State,
  10. migrate_task,
  11. migrate_tasks,
  12. filter_callback,
  13. _maybe_queue,
  14. filter_status,
  15. move_by_taskmap,
  16. move_by_idmap,
  17. move_task_by_id,
  18. start_filter,
  19. task_id_in,
  20. task_id_eq,
  21. expand_dest,
  22. move,
  23. )
  24. from celery.utils.encoding import bytes_t, ensure_bytes
  25. # hack to ignore error at shutdown
  26. QoS.restore_at_shutdown = False
  27. def Message(body, exchange='exchange', routing_key='rkey',
  28. compression=None, content_type='application/json',
  29. content_encoding='utf-8'):
  30. return Mock(
  31. attrs={
  32. 'body': body,
  33. 'delivery_info': {
  34. 'exchange': exchange,
  35. 'routing_key': routing_key,
  36. },
  37. 'headers': {
  38. 'compression': compression,
  39. },
  40. 'content_type': content_type,
  41. 'content_encoding': content_encoding,
  42. 'properties': {}
  43. },
  44. )
  45. class test_State:
  46. def test_strtotal(self):
  47. x = State()
  48. assert x.strtotal == '?'
  49. x.total_apx = 100
  50. assert x.strtotal == '100'
  51. def test_repr(self):
  52. x = State()
  53. assert repr(x)
  54. x.filtered = 'foo'
  55. assert repr(x)
  56. class test_move:
  57. @contextmanager
  58. def move_context(self, **kwargs):
  59. with patch('celery.contrib.migrate.start_filter') as start:
  60. with patch('celery.contrib.migrate.republish') as republish:
  61. pred = Mock(name='predicate')
  62. move(pred, app=self.app,
  63. connection=self.app.connection(), **kwargs)
  64. start.assert_called()
  65. callback = start.call_args[0][2]
  66. yield callback, pred, republish
  67. def msgpair(self, **kwargs):
  68. body = dict({'task': 'add', 'id': 'id'}, **kwargs)
  69. return body, Message(body)
  70. def test_move(self):
  71. with self.move_context() as (callback, pred, republish):
  72. pred.return_value = None
  73. body, message = self.msgpair()
  74. callback(body, message)
  75. message.ack.assert_not_called()
  76. republish.assert_not_called()
  77. pred.return_value = 'foo'
  78. callback(body, message)
  79. message.ack.assert_called_with()
  80. republish.assert_called()
  81. def test_move_transform(self):
  82. trans = Mock(name='transform')
  83. trans.return_value = Queue('bar')
  84. with self.move_context(transform=trans) as (callback, pred, republish):
  85. pred.return_value = 'foo'
  86. body, message = self.msgpair()
  87. with patch('celery.contrib.migrate.maybe_declare') as maybed:
  88. callback(body, message)
  89. trans.assert_called_with('foo')
  90. maybed.assert_called()
  91. republish.assert_called()
  92. def test_limit(self):
  93. with self.move_context(limit=1) as (callback, pred, republish):
  94. pred.return_value = 'foo'
  95. body, message = self.msgpair()
  96. with pytest.raises(StopFiltering):
  97. callback(body, message)
  98. republish.assert_called()
  99. def test_callback(self):
  100. cb = Mock()
  101. with self.move_context(callback=cb) as (callback, pred, republish):
  102. pred.return_value = 'foo'
  103. body, message = self.msgpair()
  104. callback(body, message)
  105. republish.assert_called()
  106. cb.assert_called()
  107. class test_start_filter:
  108. def test_start(self):
  109. with patch('celery.contrib.migrate.eventloop') as evloop:
  110. app = Mock()
  111. filt = Mock(name='filter')
  112. conn = Connection('memory://')
  113. evloop.side_effect = StopFiltering()
  114. app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
  115. consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
  116. consumer.queues = list(app.amqp.queues.values())
  117. consumer.channel = conn.default_channel
  118. consumer.__enter__ = Mock(name='consumer.__enter__')
  119. consumer.__exit__ = Mock(name='consumer.__exit__')
  120. consumer.callbacks = []
  121. def register_callback(x):
  122. consumer.callbacks.append(x)
  123. consumer.register_callback = register_callback
  124. start_filter(app, conn, filt,
  125. queues='foo,bar', ack_messages=True)
  126. body = {'task': 'add', 'id': 'id'}
  127. for callback in consumer.callbacks:
  128. callback(body, Message(body))
  129. consumer.callbacks[:] = []
  130. cb = Mock(name='callback=')
  131. start_filter(app, conn, filt, tasks='add,mul', callback=cb)
  132. for callback in consumer.callbacks:
  133. callback(body, Message(body))
  134. cb.assert_called()
  135. on_declare_queue = Mock()
  136. start_filter(app, conn, filt, tasks='add,mul', queues='foo',
  137. on_declare_queue=on_declare_queue)
  138. on_declare_queue.assert_called()
  139. start_filter(app, conn, filt, queues=['foo', 'bar'])
  140. consumer.callbacks[:] = []
  141. state = State()
  142. start_filter(app, conn, filt,
  143. tasks='add,mul', callback=cb, state=state, limit=1)
  144. stop_filtering_raised = False
  145. for callback in consumer.callbacks:
  146. try:
  147. callback(body, Message(body))
  148. except StopFiltering:
  149. stop_filtering_raised = True
  150. assert state.count
  151. assert stop_filtering_raised
  152. class test_filter_callback:
  153. def test_filter(self):
  154. callback = Mock()
  155. filt = filter_callback(callback, ['add', 'mul'])
  156. t1 = {'task': 'add'}
  157. t2 = {'task': 'div'}
  158. message = Mock()
  159. filt(t2, message)
  160. callback.assert_not_called()
  161. filt(t1, message)
  162. callback.assert_called_with(t1, message)
  163. def test_task_id_in():
  164. assert task_id_in(['A'], {'id': 'A'}, Mock())
  165. assert not task_id_in(['A'], {'id': 'B'}, Mock())
  166. def test_task_id_eq():
  167. assert task_id_eq('A', {'id': 'A'}, Mock())
  168. assert not task_id_eq('A', {'id': 'B'}, Mock())
  169. def test_expand_dest():
  170. assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')
  171. assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')
  172. def test_maybe_queue():
  173. app = Mock()
  174. app.amqp.queues = {'foo': 313}
  175. assert _maybe_queue(app, 'foo') == 313
  176. assert _maybe_queue(app, Queue('foo')) == Queue('foo')
  177. def test_filter_status():
  178. with mock.stdouts() as (stdout, stderr):
  179. filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
  180. assert stdout.getvalue()
  181. def test_move_by_taskmap():
  182. with patch('celery.contrib.migrate.move') as move:
  183. move_by_taskmap({'add': Queue('foo')})
  184. move.assert_called()
  185. cb = move.call_args[0][0]
  186. assert cb({'task': 'add'}, Mock())
  187. def test_move_by_idmap():
  188. with patch('celery.contrib.migrate.move') as move:
  189. move_by_idmap({'123f': Queue('foo')})
  190. move.assert_called()
  191. cb = move.call_args[0][0]
  192. assert cb({'id': '123f'}, Mock())
  193. def test_move_task_by_id():
  194. with patch('celery.contrib.migrate.move') as move:
  195. move_task_by_id('123f', Queue('foo'))
  196. move.assert_called()
  197. cb = move.call_args[0][0]
  198. assert cb({'id': '123f'}, Mock()) == Queue('foo')
  199. class test_migrate_task:
  200. def test_removes_compression_header(self):
  201. x = Message('foo', compression='zlib')
  202. producer = Mock()
  203. migrate_task(producer, x.body, x)
  204. producer.publish.assert_called()
  205. args, kwargs = producer.publish.call_args
  206. assert isinstance(args[0], bytes_t)
  207. assert 'compression' not in kwargs['headers']
  208. assert kwargs['compression'] == 'zlib'
  209. assert kwargs['content_type'] == 'application/json'
  210. assert kwargs['content_encoding'] == 'utf-8'
  211. assert kwargs['exchange'] == 'exchange'
  212. assert kwargs['routing_key'] == 'rkey'
  213. class test_migrate_tasks:
  214. def test_migrate(self, app, name='testcelery'):
  215. x = Connection('memory://foo')
  216. y = Connection('memory://foo')
  217. # use separate state
  218. x.default_channel.queues = {}
  219. y.default_channel.queues = {}
  220. ex = Exchange(name, 'direct')
  221. q = Queue(name, exchange=ex, routing_key=name)
  222. q(x.default_channel).declare()
  223. Producer(x).publish('foo', exchange=name, routing_key=name)
  224. Producer(x).publish('bar', exchange=name, routing_key=name)
  225. Producer(x).publish('baz', exchange=name, routing_key=name)
  226. assert x.default_channel.queues
  227. assert not y.default_channel.queues
  228. migrate_tasks(x, y, accept=['text/plain'], app=app)
  229. yq = q(y.default_channel)
  230. assert yq.get().body == ensure_bytes('foo')
  231. assert yq.get().body == ensure_bytes('bar')
  232. assert yq.get().body == ensure_bytes('baz')
  233. Producer(x).publish('foo', exchange=name, routing_key=name)
  234. callback = Mock()
  235. migrate_tasks(x, y,
  236. callback=callback, accept=['text/plain'], app=app)
  237. callback.assert_called()
  238. migrate = Mock()
  239. Producer(x).publish('baz', exchange=name, routing_key=name)
  240. migrate_tasks(x, y, callback=callback,
  241. migrate=migrate, accept=['text/plain'], app=app)
  242. migrate.assert_called()
  243. with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
  244. def effect(*args, **kwargs):
  245. if kwargs.get('passive'):
  246. raise ChannelError('some channel error')
  247. return 0, 3, 0
  248. qd.side_effect = effect
  249. migrate_tasks(x, y, app=app)
  250. x = Connection('memory://')
  251. x.default_channel.queues = {}
  252. y.default_channel.queues = {}
  253. callback = Mock()
  254. migrate_tasks(x, y,
  255. callback=callback, accept=['text/plain'], app=app)
  256. callback.assert_not_called()