migrate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # -*- coding: utf-8 -*-
  2. """Message migration tools (Broker <-> Broker)."""
  3. from __future__ import absolute_import, print_function, unicode_literals
  4. import socket
  5. from functools import partial
  6. from itertools import cycle, islice
  7. from kombu import Queue, eventloop
  8. from kombu.common import maybe_declare
  9. from kombu.utils.encoding import ensure_bytes
  10. from celery.app import app_or_default
  11. from celery.five import python_2_unicode_compatible, string, string_t
  12. from celery.utils.nodenames import worker_direct
  13. from celery.utils.text import str_to_list
  14. __all__ = (
  15. 'StopFiltering', 'State', 'republish', 'migrate_task',
  16. 'migrate_tasks', 'move', 'task_id_eq', 'task_id_in',
  17. 'start_filter', 'move_task_by_id', 'move_by_idmap',
  18. 'move_by_taskmap', 'move_direct', 'move_direct_by_id',
  19. )
  20. MOVING_PROGRESS_FMT = """\
  21. Moving task {state.filtered}/{state.strtotal}: \
  22. {body[task]}[{body[id]}]\
  23. """
  24. class StopFiltering(Exception):
  25. """Semi-predicate used to signal filter stop."""
  26. @python_2_unicode_compatible
  27. class State(object):
  28. """Migration progress state."""
  29. count = 0
  30. filtered = 0
  31. total_apx = 0
  32. @property
  33. def strtotal(self):
  34. if not self.total_apx:
  35. return '?'
  36. return string(self.total_apx)
  37. def __repr__(self):
  38. if self.filtered:
  39. return '^{0.filtered}'.format(self)
  40. return '{0.count}/{0.strtotal}'.format(self)
  41. def republish(producer, message, exchange=None, routing_key=None,
  42. remove_props=['application_headers',
  43. 'content_type',
  44. 'content_encoding',
  45. 'headers']):
  46. """Republish message."""
  47. body = ensure_bytes(message.body) # use raw message body.
  48. info, headers, props = (message.delivery_info,
  49. message.headers, message.properties)
  50. exchange = info['exchange'] if exchange is None else exchange
  51. routing_key = info['routing_key'] if routing_key is None else routing_key
  52. ctype, enc = message.content_type, message.content_encoding
  53. # remove compression header, as this will be inserted again
  54. # when the message is recompressed.
  55. compression = headers.pop('compression', None)
  56. for key in remove_props:
  57. props.pop(key, None)
  58. producer.publish(ensure_bytes(body), exchange=exchange,
  59. routing_key=routing_key, compression=compression,
  60. headers=headers, content_type=ctype,
  61. content_encoding=enc, **props)
  62. def migrate_task(producer, body_, message, queues=None):
  63. """Migrate single task message."""
  64. info = message.delivery_info
  65. queues = {} if queues is None else queues
  66. republish(producer, message,
  67. exchange=queues.get(info['exchange']),
  68. routing_key=queues.get(info['routing_key']))
  69. def filter_callback(callback, tasks):
  70. def filtered(body, message):
  71. if tasks and body['task'] not in tasks:
  72. return
  73. return callback(body, message)
  74. return filtered
  75. def migrate_tasks(source, dest, migrate=migrate_task, app=None,
  76. queues=None, **kwargs):
  77. """Migrate tasks from one broker to another."""
  78. app = app_or_default(app)
  79. queues = prepare_queues(queues)
  80. producer = app.amqp.Producer(dest, auto_declare=False)
  81. migrate = partial(migrate, producer, queues=queues)
  82. def on_declare_queue(queue):
  83. new_queue = queue(producer.channel)
  84. new_queue.name = queues.get(queue.name, queue.name)
  85. if new_queue.routing_key == queue.name:
  86. new_queue.routing_key = queues.get(queue.name,
  87. new_queue.routing_key)
  88. if new_queue.exchange.name == queue.name:
  89. new_queue.exchange.name = queues.get(queue.name, queue.name)
  90. new_queue.declare()
  91. return start_filter(app, source, migrate, queues=queues,
  92. on_declare_queue=on_declare_queue, **kwargs)
  93. def _maybe_queue(app, q):
  94. if isinstance(q, string_t):
  95. return app.amqp.queues[q]
  96. return q
  97. def move(predicate, connection=None, exchange=None, routing_key=None,
  98. source=None, app=None, callback=None, limit=None, transform=None,
  99. **kwargs):
  100. """Find tasks by filtering them and move the tasks to a new queue.
  101. Arguments:
  102. predicate (Callable): Filter function used to decide the messages
  103. to move. Must accept the standard signature of ``(body, message)``
  104. used by Kombu consumer callbacks. If the predicate wants the
  105. message to be moved it must return either:
  106. 1) a tuple of ``(exchange, routing_key)``, or
  107. 2) a :class:`~kombu.entity.Queue` instance, or
  108. 3) any other true value means the specified
  109. ``exchange`` and ``routing_key`` arguments will be used.
  110. connection (kombu.Connection): Custom connection to use.
  111. source: List[Union[str, kombu.Queue]]: Optional list of source
  112. queues to use instead of the default (queues
  113. in :setting:`task_queues`). This list can also contain
  114. :class:`~kombu.entity.Queue` instances.
  115. exchange (str, kombu.Exchange): Default destination exchange.
  116. routing_key (str): Default destination routing key.
  117. limit (int): Limit number of messages to filter.
  118. callback (Callable): Callback called after message moved,
  119. with signature ``(state, body, message)``.
  120. transform (Callable): Optional function to transform the return
  121. value (destination) of the filter function.
  122. Also supports the same keyword arguments as :func:`start_filter`.
  123. To demonstrate, the :func:`move_task_by_id` operation can be implemented
  124. like this:
  125. .. code-block:: python
  126. def is_wanted_task(body, message):
  127. if body['id'] == wanted_id:
  128. return Queue('foo', exchange=Exchange('foo'),
  129. routing_key='foo')
  130. move(is_wanted_task)
  131. or with a transform:
  132. .. code-block:: python
  133. def transform(value):
  134. if isinstance(value, string_t):
  135. return Queue(value, Exchange(value), value)
  136. return value
  137. move(is_wanted_task, transform=transform)
  138. Note:
  139. The predicate may also return a tuple of ``(exchange, routing_key)``
  140. to specify the destination to where the task should be moved,
  141. or a :class:`~kombu.entitiy.Queue` instance.
  142. Any other true value means that the task will be moved to the
  143. default exchange/routing_key.
  144. """
  145. app = app_or_default(app)
  146. queues = [_maybe_queue(app, queue) for queue in source or []] or None
  147. with app.connection_or_acquire(connection, pool=False) as conn:
  148. producer = app.amqp.Producer(conn)
  149. state = State()
  150. def on_task(body, message):
  151. ret = predicate(body, message)
  152. if ret:
  153. if transform:
  154. ret = transform(ret)
  155. if isinstance(ret, Queue):
  156. maybe_declare(ret, conn.default_channel)
  157. ex, rk = ret.exchange.name, ret.routing_key
  158. else:
  159. ex, rk = expand_dest(ret, exchange, routing_key)
  160. republish(producer, message,
  161. exchange=ex, routing_key=rk)
  162. message.ack()
  163. state.filtered += 1
  164. if callback:
  165. callback(state, body, message)
  166. if limit and state.filtered >= limit:
  167. raise StopFiltering()
  168. return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
  169. def expand_dest(ret, exchange, routing_key):
  170. try:
  171. ex, rk = ret
  172. except (TypeError, ValueError):
  173. ex, rk = exchange, routing_key
  174. return ex, rk
  175. def task_id_eq(task_id, body, message):
  176. """Return true if task id equals task_id'."""
  177. return body['id'] == task_id
  178. def task_id_in(ids, body, message):
  179. """Return true if task id is member of set ids'."""
  180. return body['id'] in ids
  181. def prepare_queues(queues):
  182. if isinstance(queues, string_t):
  183. queues = queues.split(',')
  184. if isinstance(queues, list):
  185. queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
  186. for q in queues)
  187. if queues is None:
  188. queues = {}
  189. return queues
  190. class Filterer(object):
  191. def __init__(self, app, conn, filter,
  192. limit=None, timeout=1.0,
  193. ack_messages=False, tasks=None, queues=None,
  194. callback=None, forever=False, on_declare_queue=None,
  195. consume_from=None, state=None, accept=None, **kwargs):
  196. self.app = app
  197. self.conn = conn
  198. self.filter = filter
  199. self.limit = limit
  200. self.timeout = timeout
  201. self.ack_messages = ack_messages
  202. self.tasks = set(str_to_list(tasks) or [])
  203. self.queues = prepare_queues(queues)
  204. self.callback = callback
  205. self.forever = forever
  206. self.on_declare_queue = on_declare_queue
  207. self.consume_from = [
  208. _maybe_queue(self.app, q)
  209. for q in consume_from or list(self.queues)
  210. ]
  211. self.state = state or State()
  212. self.accept = accept
  213. def start(self):
  214. # start migrating messages.
  215. with self.prepare_consumer(self.create_consumer()):
  216. try:
  217. for _ in eventloop(self.conn, # pragma: no cover
  218. timeout=self.timeout,
  219. ignore_timeouts=self.forever):
  220. pass
  221. except socket.timeout:
  222. pass
  223. except StopFiltering:
  224. pass
  225. return self.state
  226. def update_state(self, body, message):
  227. self.state.count += 1
  228. if self.limit and self.state.count >= self.limit:
  229. raise StopFiltering()
  230. def ack_message(self, body, message):
  231. message.ack()
  232. def create_consumer(self):
  233. return self.app.amqp.TaskConsumer(
  234. self.conn,
  235. queues=self.consume_from,
  236. accept=self.accept,
  237. )
  238. def prepare_consumer(self, consumer):
  239. filter = self.filter
  240. update_state = self.update_state
  241. ack_message = self.ack_message
  242. if self.tasks:
  243. filter = filter_callback(filter, self.tasks)
  244. update_state = filter_callback(update_state, self.tasks)
  245. ack_message = filter_callback(ack_message, self.tasks)
  246. consumer.register_callback(filter)
  247. consumer.register_callback(update_state)
  248. if self.ack_messages:
  249. consumer.register_callback(self.ack_message)
  250. if self.callback is not None:
  251. callback = partial(self.callback, self.state)
  252. if self.tasks:
  253. callback = filter_callback(callback, self.tasks)
  254. consumer.register_callback(callback)
  255. self.declare_queues(consumer)
  256. return consumer
  257. def declare_queues(self, consumer):
  258. # declare all queues on the new broker.
  259. for queue in consumer.queues:
  260. if self.queues and queue.name not in self.queues:
  261. continue
  262. if self.on_declare_queue is not None:
  263. self.on_declare_queue(queue)
  264. try:
  265. _, mcount, _ = queue(
  266. consumer.channel).queue_declare(passive=True)
  267. if mcount:
  268. self.state.total_apx += mcount
  269. except self.conn.channel_errors:
  270. pass
  271. def start_filter(app, conn, filter, limit=None, timeout=1.0,
  272. ack_messages=False, tasks=None, queues=None,
  273. callback=None, forever=False, on_declare_queue=None,
  274. consume_from=None, state=None, accept=None, **kwargs):
  275. """Filter tasks."""
  276. return Filterer(
  277. app, conn, filter,
  278. limit=limit,
  279. timeout=timeout,
  280. ack_messages=ack_messages,
  281. tasks=tasks,
  282. queues=queues,
  283. callback=callback,
  284. forever=forever,
  285. on_declare_queue=on_declare_queue,
  286. consume_from=consume_from,
  287. state=state,
  288. accept=accept,
  289. **kwargs).start()
  290. def move_task_by_id(task_id, dest, **kwargs):
  291. """Find a task by id and move it to another queue.
  292. Arguments:
  293. task_id (str): Id of task to find and move.
  294. dest: (str, kombu.Queue): Destination queue.
  295. **kwargs (Any): Also supports the same keyword
  296. arguments as :func:`move`.
  297. """
  298. return move_by_idmap({task_id: dest}, **kwargs)
  299. def move_by_idmap(map, **kwargs):
  300. """Move tasks by matching from a ``task_id: queue`` mapping.
  301. Where ``queue`` is a queue to move the task to.
  302. Example:
  303. >>> move_by_idmap({
  304. ... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue('name'),
  305. ... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue('name'),
  306. ... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue('name')},
  307. ... queues=['hipri'])
  308. """
  309. def task_id_in_map(body, message):
  310. return map.get(body['id'])
  311. # adding the limit means that we don't have to consume any more
  312. # when we've found everything.
  313. return move(task_id_in_map, limit=len(map), **kwargs)
  314. def move_by_taskmap(map, **kwargs):
  315. """Move tasks by matching from a ``task_name: queue`` mapping.
  316. ``queue`` is the queue to move the task to.
  317. Example:
  318. >>> move_by_taskmap({
  319. ... 'tasks.add': Queue('name'),
  320. ... 'tasks.mul': Queue('name'),
  321. ... })
  322. """
  323. def task_name_in_map(body, message):
  324. return map.get(body['task']) # <- name of task
  325. return move(task_name_in_map, **kwargs)
  326. def filter_status(state, body, message, **kwargs):
  327. print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
  328. move_direct = partial(move, transform=worker_direct)
  329. move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
  330. move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
  331. move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)