test_migrate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from __future__ import absolute_import, unicode_literals
  2. from contextlib import contextmanager
  3. import pytest
  4. from amqp import ChannelError
  5. from case import Mock, mock, patch
  6. from celery.contrib.migrate import (State, StopFiltering, _maybe_queue,
  7. expand_dest, filter_callback,
  8. filter_status, migrate_task,
  9. migrate_tasks, move, move_by_idmap,
  10. move_by_taskmap, move_task_by_id,
  11. start_filter, task_id_eq, task_id_in)
  12. from celery.utils.encoding import bytes_t, ensure_bytes
  13. from kombu import Connection, Exchange, Producer, Queue
  14. from kombu.transport.virtual import QoS
  15. # hack to ignore error at shutdown
  16. QoS.restore_at_shutdown = False
  17. def Message(body, exchange='exchange', routing_key='rkey',
  18. compression=None, content_type='application/json',
  19. content_encoding='utf-8'):
  20. return Mock(
  21. attrs={
  22. 'body': body,
  23. 'delivery_info': {
  24. 'exchange': exchange,
  25. 'routing_key': routing_key,
  26. },
  27. 'headers': {
  28. 'compression': compression,
  29. },
  30. 'content_type': content_type,
  31. 'content_encoding': content_encoding,
  32. 'properties': {}
  33. },
  34. )
  35. class test_State:
  36. def test_strtotal(self):
  37. x = State()
  38. assert x.strtotal == '?'
  39. x.total_apx = 100
  40. assert x.strtotal == '100'
  41. def test_repr(self):
  42. x = State()
  43. assert repr(x)
  44. x.filtered = 'foo'
  45. assert repr(x)
  46. class test_move:
  47. @contextmanager
  48. def move_context(self, **kwargs):
  49. with patch('celery.contrib.migrate.start_filter') as start:
  50. with patch('celery.contrib.migrate.republish') as republish:
  51. pred = Mock(name='predicate')
  52. move(pred, app=self.app,
  53. connection=self.app.connection(), **kwargs)
  54. start.assert_called()
  55. callback = start.call_args[0][2]
  56. yield callback, pred, republish
  57. def msgpair(self, **kwargs):
  58. body = dict({'task': 'add', 'id': 'id'}, **kwargs)
  59. return body, Message(body)
  60. def test_move(self):
  61. with self.move_context() as (callback, pred, republish):
  62. pred.return_value = None
  63. body, message = self.msgpair()
  64. callback(body, message)
  65. message.ack.assert_not_called()
  66. republish.assert_not_called()
  67. pred.return_value = 'foo'
  68. callback(body, message)
  69. message.ack.assert_called_with()
  70. republish.assert_called()
  71. def test_move_transform(self):
  72. trans = Mock(name='transform')
  73. trans.return_value = Queue('bar')
  74. with self.move_context(transform=trans) as (callback, pred, republish):
  75. pred.return_value = 'foo'
  76. body, message = self.msgpair()
  77. with patch('celery.contrib.migrate.maybe_declare') as maybed:
  78. callback(body, message)
  79. trans.assert_called_with('foo')
  80. maybed.assert_called()
  81. republish.assert_called()
  82. def test_limit(self):
  83. with self.move_context(limit=1) as (callback, pred, republish):
  84. pred.return_value = 'foo'
  85. body, message = self.msgpair()
  86. with pytest.raises(StopFiltering):
  87. callback(body, message)
  88. republish.assert_called()
  89. def test_callback(self):
  90. cb = Mock()
  91. with self.move_context(callback=cb) as (callback, pred, republish):
  92. pred.return_value = 'foo'
  93. body, message = self.msgpair()
  94. callback(body, message)
  95. republish.assert_called()
  96. cb.assert_called()
  97. class test_start_filter:
  98. def test_start(self):
  99. with patch('celery.contrib.migrate.eventloop') as evloop:
  100. app = Mock()
  101. filt = Mock(name='filter')
  102. conn = Connection('memory://')
  103. evloop.side_effect = StopFiltering()
  104. app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
  105. consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
  106. consumer.queues = list(app.amqp.queues.values())
  107. consumer.channel = conn.default_channel
  108. consumer.__enter__ = Mock(name='consumer.__enter__')
  109. consumer.__exit__ = Mock(name='consumer.__exit__')
  110. consumer.callbacks = []
  111. def register_callback(x):
  112. consumer.callbacks.append(x)
  113. consumer.register_callback = register_callback
  114. start_filter(app, conn, filt,
  115. queues='foo,bar', ack_messages=True)
  116. body = {'task': 'add', 'id': 'id'}
  117. for callback in consumer.callbacks:
  118. callback(body, Message(body))
  119. consumer.callbacks[:] = []
  120. cb = Mock(name='callback=')
  121. start_filter(app, conn, filt, tasks='add,mul', callback=cb)
  122. for callback in consumer.callbacks:
  123. callback(body, Message(body))
  124. cb.assert_called()
  125. on_declare_queue = Mock()
  126. start_filter(app, conn, filt, tasks='add,mul', queues='foo',
  127. on_declare_queue=on_declare_queue)
  128. on_declare_queue.assert_called()
  129. start_filter(app, conn, filt, queues=['foo', 'bar'])
  130. consumer.callbacks[:] = []
  131. state = State()
  132. start_filter(app, conn, filt,
  133. tasks='add,mul', callback=cb, state=state, limit=1)
  134. stop_filtering_raised = False
  135. for callback in consumer.callbacks:
  136. try:
  137. callback(body, Message(body))
  138. except StopFiltering:
  139. stop_filtering_raised = True
  140. assert state.count
  141. assert stop_filtering_raised
  142. class test_filter_callback:
  143. def test_filter(self):
  144. callback = Mock()
  145. filt = filter_callback(callback, ['add', 'mul'])
  146. t1 = {'task': 'add'}
  147. t2 = {'task': 'div'}
  148. message = Mock()
  149. filt(t2, message)
  150. callback.assert_not_called()
  151. filt(t1, message)
  152. callback.assert_called_with(t1, message)
  153. def test_task_id_in():
  154. assert task_id_in(['A'], {'id': 'A'}, Mock())
  155. assert not task_id_in(['A'], {'id': 'B'}, Mock())
  156. def test_task_id_eq():
  157. assert task_id_eq('A', {'id': 'A'}, Mock())
  158. assert not task_id_eq('A', {'id': 'B'}, Mock())
  159. def test_expand_dest():
  160. assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')
  161. assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')
  162. def test_maybe_queue():
  163. app = Mock()
  164. app.amqp.queues = {'foo': 313}
  165. assert _maybe_queue(app, 'foo') == 313
  166. assert _maybe_queue(app, Queue('foo')) == Queue('foo')
  167. def test_filter_status():
  168. with mock.stdouts() as (stdout, stderr):
  169. filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
  170. assert stdout.getvalue()
  171. def test_move_by_taskmap():
  172. with patch('celery.contrib.migrate.move') as move:
  173. move_by_taskmap({'add': Queue('foo')})
  174. move.assert_called()
  175. cb = move.call_args[0][0]
  176. assert cb({'task': 'add'}, Mock())
  177. def test_move_by_idmap():
  178. with patch('celery.contrib.migrate.move') as move:
  179. move_by_idmap({'123f': Queue('foo')})
  180. move.assert_called()
  181. cb = move.call_args[0][0]
  182. assert cb({'id': '123f'}, Mock())
  183. def test_move_task_by_id():
  184. with patch('celery.contrib.migrate.move') as move:
  185. move_task_by_id('123f', Queue('foo'))
  186. move.assert_called()
  187. cb = move.call_args[0][0]
  188. assert cb({'id': '123f'}, Mock()) == Queue('foo')
  189. class test_migrate_task:
  190. def test_removes_compression_header(self):
  191. x = Message('foo', compression='zlib')
  192. producer = Mock()
  193. migrate_task(producer, x.body, x)
  194. producer.publish.assert_called()
  195. args, kwargs = producer.publish.call_args
  196. assert isinstance(args[0], bytes_t)
  197. assert 'compression' not in kwargs['headers']
  198. assert kwargs['compression'] == 'zlib'
  199. assert kwargs['content_type'] == 'application/json'
  200. assert kwargs['content_encoding'] == 'utf-8'
  201. assert kwargs['exchange'] == 'exchange'
  202. assert kwargs['routing_key'] == 'rkey'
  203. class test_migrate_tasks:
  204. def test_migrate(self, app, name='testcelery'):
  205. connection_kwargs = {
  206. 'transport_options': {'polling_interval': 0.01}
  207. }
  208. x = Connection('memory://foo', **connection_kwargs)
  209. y = Connection('memory://foo', **connection_kwargs)
  210. # use separate state
  211. x.default_channel.queues = {}
  212. y.default_channel.queues = {}
  213. ex = Exchange(name, 'direct')
  214. q = Queue(name, exchange=ex, routing_key=name)
  215. q(x.default_channel).declare()
  216. Producer(x).publish('foo', exchange=name, routing_key=name)
  217. Producer(x).publish('bar', exchange=name, routing_key=name)
  218. Producer(x).publish('baz', exchange=name, routing_key=name)
  219. assert x.default_channel.queues
  220. assert not y.default_channel.queues
  221. migrate_tasks(x, y, accept=['text/plain'], app=app)
  222. yq = q(y.default_channel)
  223. assert yq.get().body == ensure_bytes('foo')
  224. assert yq.get().body == ensure_bytes('bar')
  225. assert yq.get().body == ensure_bytes('baz')
  226. Producer(x).publish('foo', exchange=name, routing_key=name)
  227. callback = Mock()
  228. migrate_tasks(x, y,
  229. callback=callback, accept=['text/plain'], app=app)
  230. callback.assert_called()
  231. migrate = Mock()
  232. Producer(x).publish('baz', exchange=name, routing_key=name)
  233. migrate_tasks(x, y, callback=callback,
  234. migrate=migrate, accept=['text/plain'], app=app)
  235. migrate.assert_called()
  236. with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
  237. def effect(*args, **kwargs):
  238. if kwargs.get('passive'):
  239. raise ChannelError('some channel error')
  240. return 0, 3, 0
  241. qd.side_effect = effect
  242. migrate_tasks(x, y, app=app)
  243. x = Connection('memory://', **connection_kwargs)
  244. x.default_channel.queues = {}
  245. y.default_channel.queues = {}
  246. callback = Mock()
  247. migrate_tasks(x, y,
  248. callback=callback, accept=['text/plain'], app=app)
  249. callback.assert_not_called()