migrate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # -*- coding: utf-8 -*-
  2. """Message migration tools (Broker <-> Broker)."""
  3. from __future__ import absolute_import, print_function, unicode_literals
  4. import socket
  5. from functools import partial
  6. from itertools import cycle, islice
  7. from kombu import eventloop, Queue
  8. from kombu.common import maybe_declare
  9. from kombu.utils.encoding import ensure_bytes
  10. from celery.app import app_or_default
  11. from celery.five import python_2_unicode_compatible, string, string_t
  12. from celery.utils.nodenames import worker_direct
  13. __all__ = [
  14. 'StopFiltering', 'State', 'republish', 'migrate_task',
  15. 'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
  16. 'start_filter', 'move_task_by_id', 'move_by_idmap',
  17. 'move_by_taskmap', 'move_direct', 'move_direct_by_id',
  18. ]
  19. MOVING_PROGRESS_FMT = """\
  20. Moving task {state.filtered}/{state.strtotal}: \
  21. {body[task]}[{body[id]}]\
  22. """
  23. class StopFiltering(Exception):
  24. pass
  25. @python_2_unicode_compatible
  26. class State(object):
  27. count = 0
  28. filtered = 0
  29. total_apx = 0
  30. @property
  31. def strtotal(self):
  32. if not self.total_apx:
  33. return '?'
  34. return string(self.total_apx)
  35. def __repr__(self):
  36. if self.filtered:
  37. return '^{0.filtered}'.format(self)
  38. return '{0.count}/{0.strtotal}'.format(self)
  39. def republish(producer, message, exchange=None, routing_key=None,
  40. remove_props=['application_headers',
  41. 'content_type',
  42. 'content_encoding',
  43. 'headers']):
  44. body = ensure_bytes(message.body) # use raw message body.
  45. info, headers, props = (message.delivery_info,
  46. message.headers, message.properties)
  47. exchange = info['exchange'] if exchange is None else exchange
  48. routing_key = info['routing_key'] if routing_key is None else routing_key
  49. ctype, enc = message.content_type, message.content_encoding
  50. # remove compression header, as this will be inserted again
  51. # when the message is recompressed.
  52. compression = headers.pop('compression', None)
  53. for key in remove_props:
  54. props.pop(key, None)
  55. producer.publish(ensure_bytes(body), exchange=exchange,
  56. routing_key=routing_key, compression=compression,
  57. headers=headers, content_type=ctype,
  58. content_encoding=enc, **props)
  59. def migrate_task(producer, body_, message, queues=None):
  60. info = message.delivery_info
  61. queues = {} if queues is None else queues
  62. republish(producer, message,
  63. exchange=queues.get(info['exchange']),
  64. routing_key=queues.get(info['routing_key']))
  65. def filter_callback(callback, tasks):
  66. def filtered(body, message):
  67. if tasks and body['task'] not in tasks:
  68. return
  69. return callback(body, message)
  70. return filtered
  71. def migrate_tasks(source, dest, migrate=migrate_task, app=None,
  72. queues=None, **kwargs):
  73. app = app_or_default(app)
  74. queues = prepare_queues(queues)
  75. producer = app.amqp.Producer(dest)
  76. migrate = partial(migrate, producer, queues=queues)
  77. def on_declare_queue(queue):
  78. new_queue = queue(producer.channel)
  79. new_queue.name = queues.get(queue.name, queue.name)
  80. if new_queue.routing_key == queue.name:
  81. new_queue.routing_key = queues.get(queue.name,
  82. new_queue.routing_key)
  83. if new_queue.exchange.name == queue.name:
  84. new_queue.exchange.name = queues.get(queue.name, queue.name)
  85. new_queue.declare()
  86. return start_filter(app, source, migrate, queues=queues,
  87. on_declare_queue=on_declare_queue, **kwargs)
  88. def _maybe_queue(app, q):
  89. if isinstance(q, string_t):
  90. return app.amqp.queues[q]
  91. return q
  92. def move(predicate, connection=None, exchange=None, routing_key=None,
  93. source=None, app=None, callback=None, limit=None, transform=None,
  94. **kwargs):
  95. """Find tasks by filtering them and move the tasks to a new queue.
  96. Arguments:
  97. predicate (Callable): Filter function used to decide which messages
  98. to move. Must accept the standard signature of ``(body, message)``
  99. used by Kombu consumer callbacks. If the predicate wants the
  100. message to be moved it must return either:
  101. 1) a tuple of ``(exchange, routing_key)``, or
  102. 2) a :class:`~kombu.entity.Queue` instance, or
  103. 3) any other true value which means the specified
  104. ``exchange`` and ``routing_key`` arguments will be used.
  105. connection (kombu.Connection): Custom connection to use.
  106. source: List[Union[str, kombu.Queue]]: Optional list of source
  107. queues to use instead of the default (which is the queues
  108. in :setting:`task_queues`). This list can also contain
  109. :class:`~kombu.entity.Queue` instances.
  110. exchange (str, kombu.Exchange): Default destination exchange.
  111. routing_key (str): Default destination routing key.
  112. limit (int): Limit number of messages to filter.
  113. callback (Callable): Callback called after message moved,
  114. with signature ``(state, body, message)``.
  115. transform (Callable): Optional function to transform the return
  116. value (destination) of the filter function.
  117. Also supports the same keyword arguments as :func:`start_filter`.
  118. To demonstrate, the :func:`move_task_by_id` operation can be implemented
  119. like this:
  120. .. code-block:: python
  121. def is_wanted_task(body, message):
  122. if body['id'] == wanted_id:
  123. return Queue('foo', exchange=Exchange('foo'),
  124. routing_key='foo')
  125. move(is_wanted_task)
  126. or with a transform:
  127. .. code-block:: python
  128. def transform(value):
  129. if isinstance(value, string_t):
  130. return Queue(value, Exchange(value), value)
  131. return value
  132. move(is_wanted_task, transform=transform)
  133. Note:
  134. The predicate may also return a tuple of ``(exchange, routing_key)``
  135. to specify the destination to where the task should be moved,
  136. or a :class:`~kombu.entitiy.Queue` instance.
  137. Any other true value means that the task will be moved to the
  138. default exchange/routing_key.
  139. """
  140. app = app_or_default(app)
  141. queues = [_maybe_queue(app, queue) for queue in source or []] or None
  142. with app.connection_or_acquire(connection, pool=False) as conn:
  143. producer = app.amqp.Producer(conn)
  144. state = State()
  145. def on_task(body, message):
  146. ret = predicate(body, message)
  147. if ret:
  148. if transform:
  149. ret = transform(ret)
  150. if isinstance(ret, Queue):
  151. maybe_declare(ret, conn.default_channel)
  152. ex, rk = ret.exchange.name, ret.routing_key
  153. else:
  154. ex, rk = expand_dest(ret, exchange, routing_key)
  155. republish(producer, message,
  156. exchange=ex, routing_key=rk)
  157. message.ack()
  158. state.filtered += 1
  159. if callback:
  160. callback(state, body, message)
  161. if limit and state.filtered >= limit:
  162. raise StopFiltering()
  163. return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
  164. def expand_dest(ret, exchange, routing_key):
  165. try:
  166. ex, rk = ret
  167. except (TypeError, ValueError):
  168. ex, rk = exchange, routing_key
  169. return ex, rk
  170. def task_id_eq(task_id, body, message):
  171. return body['id'] == task_id
  172. def task_id_in(ids, body, message):
  173. return body['id'] in ids
  174. def prepare_queues(queues):
  175. if isinstance(queues, string_t):
  176. queues = queues.split(',')
  177. if isinstance(queues, list):
  178. queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
  179. for q in queues)
  180. if queues is None:
  181. queues = {}
  182. return queues
  183. def start_filter(app, conn, filter, limit=None, timeout=1.0,
  184. ack_messages=False, tasks=None, queues=None,
  185. callback=None, forever=False, on_declare_queue=None,
  186. consume_from=None, state=None, accept=None, **kwargs):
  187. state = state or State()
  188. queues = prepare_queues(queues)
  189. consume_from = [_maybe_queue(app, q)
  190. for q in consume_from or list(queues)]
  191. if isinstance(tasks, string_t):
  192. tasks = set(tasks.split(','))
  193. if tasks is None:
  194. tasks = set()
  195. def update_state(body, message):
  196. state.count += 1
  197. if limit and state.count >= limit:
  198. raise StopFiltering()
  199. def ack_message(body, message):
  200. message.ack()
  201. consumer = app.amqp.TaskConsumer(conn, queues=consume_from, accept=accept)
  202. if tasks:
  203. filter = filter_callback(filter, tasks)
  204. update_state = filter_callback(update_state, tasks)
  205. ack_message = filter_callback(ack_message, tasks)
  206. consumer.register_callback(filter)
  207. consumer.register_callback(update_state)
  208. if ack_messages:
  209. consumer.register_callback(ack_message)
  210. if callback is not None:
  211. callback = partial(callback, state)
  212. if tasks:
  213. callback = filter_callback(callback, tasks)
  214. consumer.register_callback(callback)
  215. # declare all queues on the new broker.
  216. for queue in consumer.queues:
  217. if queues and queue.name not in queues:
  218. continue
  219. if on_declare_queue is not None:
  220. on_declare_queue(queue)
  221. try:
  222. _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
  223. if mcount:
  224. state.total_apx += mcount
  225. except conn.channel_errors:
  226. pass
  227. # start migrating messages.
  228. with consumer:
  229. try:
  230. for _ in eventloop(conn, # pragma: no cover
  231. timeout=timeout, ignore_timeouts=forever):
  232. pass
  233. except socket.timeout:
  234. pass
  235. except StopFiltering:
  236. pass
  237. return state
  238. def move_task_by_id(task_id, dest, **kwargs):
  239. """Find a task by id and move it to another queue.
  240. Arguments:
  241. task_id (str): Id of task to find and move.
  242. dest: (str, kombu.Queue): Destination queue.
  243. **kwargs (Any): Also supports the same keyword
  244. arguments as :func:`move`.
  245. """
  246. return move_by_idmap({task_id: dest}, **kwargs)
  247. def move_by_idmap(map, **kwargs):
  248. """Moves tasks by matching from a ``task_id: queue`` mapping,
  249. where ``queue`` is a queue to move the task to.
  250. Example:
  251. >>> move_by_idmap({
  252. ... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
  253. ... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
  254. ... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
  255. ... queues=['hipri'])
  256. """
  257. def task_id_in_map(body, message):
  258. return map.get(body['id'])
  259. # adding the limit means that we don't have to consume any more
  260. # when we've found everything.
  261. return move(task_id_in_map, limit=len(map), **kwargs)
  262. def move_by_taskmap(map, **kwargs):
  263. """Moves tasks by matching from a ``task_name: queue`` mapping,
  264. where ``queue`` is the queue to move the task to.
  265. Example:
  266. >>> move_by_taskmap({
  267. ... 'tasks.add': Queue('name'),
  268. ... 'tasks.mul': Queue('name'),
  269. ... })
  270. """
  271. def task_name_in_map(body, message):
  272. return map.get(body['task']) # <- name of task
  273. return move(task_name_in_map, **kwargs)
  274. def filter_status(state, body, message, **kwargs):
  275. print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
  276. move_direct = partial(move, transform=worker_direct)
  277. move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
  278. move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
  279. move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)