test_migrate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from __future__ import absolute_import, unicode_literals
  2. from contextlib import contextmanager
  3. import pytest
  4. from amqp import ChannelError
  5. from case import Mock, mock, patch
  6. from kombu import Connection, Exchange, Producer, Queue
  7. from kombu.transport.virtual import QoS
  8. from celery.contrib.migrate import (State, StopFiltering, _maybe_queue,
  9. expand_dest, filter_callback,
  10. filter_status, migrate_task,
  11. migrate_tasks, move, move_by_idmap,
  12. move_by_taskmap, move_task_by_id,
  13. start_filter, task_id_eq, task_id_in)
  14. from celery.utils.encoding import bytes_t, ensure_bytes
  15. # hack to ignore error at shutdown
  16. QoS.restore_at_shutdown = False
  17. def Message(body, exchange='exchange', routing_key='rkey',
  18. compression=None, content_type='application/json',
  19. content_encoding='utf-8'):
  20. return Mock(
  21. attrs={
  22. 'body': body,
  23. 'delivery_info': {
  24. 'exchange': exchange,
  25. 'routing_key': routing_key,
  26. },
  27. 'headers': {
  28. 'compression': compression,
  29. },
  30. 'content_type': content_type,
  31. 'content_encoding': content_encoding,
  32. 'properties': {}
  33. },
  34. )
  35. class test_State:
  36. def test_strtotal(self):
  37. x = State()
  38. assert x.strtotal == '?'
  39. x.total_apx = 100
  40. assert x.strtotal == '100'
  41. def test_repr(self):
  42. x = State()
  43. assert repr(x)
  44. x.filtered = 'foo'
  45. assert repr(x)
  46. class test_move:
  47. @contextmanager
  48. def move_context(self, **kwargs):
  49. with patch('celery.contrib.migrate.start_filter') as start:
  50. with patch('celery.contrib.migrate.republish') as republish:
  51. pred = Mock(name='predicate')
  52. move(pred, app=self.app,
  53. connection=self.app.connection(), **kwargs)
  54. start.assert_called()
  55. callback = start.call_args[0][2]
  56. yield callback, pred, republish
  57. def msgpair(self, **kwargs):
  58. body = dict({'task': 'add', 'id': 'id'}, **kwargs)
  59. return body, Message(body)
  60. def test_move(self):
  61. with self.move_context() as (callback, pred, republish):
  62. pred.return_value = None
  63. body, message = self.msgpair()
  64. callback(body, message)
  65. message.ack.assert_not_called()
  66. republish.assert_not_called()
  67. pred.return_value = 'foo'
  68. callback(body, message)
  69. message.ack.assert_called_with()
  70. republish.assert_called()
  71. def test_move_transform(self):
  72. trans = Mock(name='transform')
  73. trans.return_value = Queue('bar')
  74. with self.move_context(transform=trans) as (callback, pred, republish):
  75. pred.return_value = 'foo'
  76. body, message = self.msgpair()
  77. with patch('celery.contrib.migrate.maybe_declare') as maybed:
  78. callback(body, message)
  79. trans.assert_called_with('foo')
  80. maybed.assert_called()
  81. republish.assert_called()
  82. def test_limit(self):
  83. with self.move_context(limit=1) as (callback, pred, republish):
  84. pred.return_value = 'foo'
  85. body, message = self.msgpair()
  86. with pytest.raises(StopFiltering):
  87. callback(body, message)
  88. republish.assert_called()
  89. def test_callback(self):
  90. cb = Mock()
  91. with self.move_context(callback=cb) as (callback, pred, republish):
  92. pred.return_value = 'foo'
  93. body, message = self.msgpair()
  94. callback(body, message)
  95. republish.assert_called()
  96. cb.assert_called()
  97. class test_start_filter:
  98. def test_start(self):
  99. with patch('celery.contrib.migrate.eventloop') as evloop:
  100. app = Mock()
  101. filt = Mock(name='filter')
  102. conn = Connection('memory://')
  103. evloop.side_effect = StopFiltering()
  104. app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
  105. consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
  106. consumer.queues = list(app.amqp.queues.values())
  107. consumer.channel = conn.default_channel
  108. consumer.__enter__ = Mock(name='consumer.__enter__')
  109. consumer.__exit__ = Mock(name='consumer.__exit__')
  110. consumer.callbacks = []
  111. def register_callback(x):
  112. consumer.callbacks.append(x)
  113. consumer.register_callback = register_callback
  114. start_filter(app, conn, filt,
  115. queues='foo,bar', ack_messages=True)
  116. body = {'task': 'add', 'id': 'id'}
  117. for callback in consumer.callbacks:
  118. callback(body, Message(body))
  119. consumer.callbacks[:] = []
  120. cb = Mock(name='callback=')
  121. start_filter(app, conn, filt, tasks='add,mul', callback=cb)
  122. for callback in consumer.callbacks:
  123. callback(body, Message(body))
  124. cb.assert_called()
  125. on_declare_queue = Mock()
  126. start_filter(app, conn, filt, tasks='add,mul', queues='foo',
  127. on_declare_queue=on_declare_queue)
  128. on_declare_queue.assert_called()
  129. start_filter(app, conn, filt, queues=['foo', 'bar'])
  130. consumer.callbacks[:] = []
  131. state = State()
  132. start_filter(app, conn, filt,
  133. tasks='add,mul', callback=cb, state=state, limit=1)
  134. stop_filtering_raised = False
  135. for callback in consumer.callbacks:
  136. try:
  137. callback(body, Message(body))
  138. except StopFiltering:
  139. stop_filtering_raised = True
  140. assert state.count
  141. assert stop_filtering_raised
  142. class test_filter_callback:
  143. def test_filter(self):
  144. callback = Mock()
  145. filt = filter_callback(callback, ['add', 'mul'])
  146. t1 = {'task': 'add'}
  147. t2 = {'task': 'div'}
  148. message = Mock()
  149. filt(t2, message)
  150. callback.assert_not_called()
  151. filt(t1, message)
  152. callback.assert_called_with(t1, message)
  153. def test_task_id_in():
  154. assert task_id_in(['A'], {'id': 'A'}, Mock())
  155. assert not task_id_in(['A'], {'id': 'B'}, Mock())
  156. def test_task_id_eq():
  157. assert task_id_eq('A', {'id': 'A'}, Mock())
  158. assert not task_id_eq('A', {'id': 'B'}, Mock())
  159. def test_expand_dest():
  160. assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')
  161. assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')
  162. def test_maybe_queue():
  163. app = Mock()
  164. app.amqp.queues = {'foo': 313}
  165. assert _maybe_queue(app, 'foo') == 313
  166. assert _maybe_queue(app, Queue('foo')) == Queue('foo')
  167. def test_filter_status():
  168. with mock.stdouts() as (stdout, stderr):
  169. filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
  170. assert stdout.getvalue()
  171. def test_move_by_taskmap():
  172. with patch('celery.contrib.migrate.move') as move:
  173. move_by_taskmap({'add': Queue('foo')})
  174. move.assert_called()
  175. cb = move.call_args[0][0]
  176. assert cb({'task': 'add'}, Mock())
  177. def test_move_by_idmap():
  178. with patch('celery.contrib.migrate.move') as move:
  179. move_by_idmap({'123f': Queue('foo')})
  180. move.assert_called()
  181. cb = move.call_args[0][0]
  182. assert cb({'id': '123f'}, Mock())
  183. def test_move_task_by_id():
  184. with patch('celery.contrib.migrate.move') as move:
  185. move_task_by_id('123f', Queue('foo'))
  186. move.assert_called()
  187. cb = move.call_args[0][0]
  188. assert cb({'id': '123f'}, Mock()) == Queue('foo')
  189. class test_migrate_task:
  190. def test_removes_compression_header(self):
  191. x = Message('foo', compression='zlib')
  192. producer = Mock()
  193. migrate_task(producer, x.body, x)
  194. producer.publish.assert_called()
  195. args, kwargs = producer.publish.call_args
  196. assert isinstance(args[0], bytes_t)
  197. assert 'compression' not in kwargs['headers']
  198. assert kwargs['compression'] == 'zlib'
  199. assert kwargs['content_type'] == 'application/json'
  200. assert kwargs['content_encoding'] == 'utf-8'
  201. assert kwargs['exchange'] == 'exchange'
  202. assert kwargs['routing_key'] == 'rkey'
  203. class test_migrate_tasks:
  204. def test_migrate(self, app, name='testcelery'):
  205. connection_kwargs = {
  206. 'transport_options': {'polling_interval': 0.01}
  207. }
  208. x = Connection('memory://foo', **connection_kwargs)
  209. y = Connection('memory://foo', **connection_kwargs)
  210. # use separate state
  211. x.default_channel.queues = {}
  212. y.default_channel.queues = {}
  213. ex = Exchange(name, 'direct')
  214. q = Queue(name, exchange=ex, routing_key=name)
  215. q(x.default_channel).declare()
  216. Producer(x).publish('foo', exchange=name, routing_key=name)
  217. Producer(x).publish('bar', exchange=name, routing_key=name)
  218. Producer(x).publish('baz', exchange=name, routing_key=name)
  219. assert x.default_channel.queues
  220. assert not y.default_channel.queues
  221. migrate_tasks(x, y, accept=['text/plain'], app=app)
  222. yq = q(y.default_channel)
  223. assert yq.get().body == ensure_bytes('foo')
  224. assert yq.get().body == ensure_bytes('bar')
  225. assert yq.get().body == ensure_bytes('baz')
  226. Producer(x).publish('foo', exchange=name, routing_key=name)
  227. callback = Mock()
  228. migrate_tasks(x, y,
  229. callback=callback, accept=['text/plain'], app=app)
  230. callback.assert_called()
  231. migrate = Mock()
  232. Producer(x).publish('baz', exchange=name, routing_key=name)
  233. migrate_tasks(x, y, callback=callback,
  234. migrate=migrate, accept=['text/plain'], app=app)
  235. migrate.assert_called()
  236. with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
  237. def effect(*args, **kwargs):
  238. if kwargs.get('passive'):
  239. raise ChannelError('some channel error')
  240. return 0, 3, 0
  241. qd.side_effect = effect
  242. migrate_tasks(x, y, app=app)
  243. x = Connection('memory://', **connection_kwargs)
  244. x.default_channel.queues = {}
  245. y.default_channel.queues = {}
  246. callback = Mock()
  247. migrate_tasks(x, y,
  248. callback=callback, accept=['text/plain'], app=app)
  249. callback.assert_not_called()