migrate.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.contrib.migrate
  4. ~~~~~~~~~~~~~~~~~~~~~~
  5. Migration tools.
  6. """
  7. from __future__ import absolute_import, print_function, unicode_literals
  8. import socket
  9. from functools import partial
  10. from itertools import cycle, islice
  11. from kombu import eventloop, Queue
  12. from kombu.common import maybe_declare
  13. from kombu.exceptions import StdChannelError
  14. from kombu.utils.encoding import ensure_bytes
  15. from celery.app import app_or_default
  16. from celery.five import string, string_t
  17. from celery.utils import worker_direct
  18. MOVING_PROGRESS_FMT = """\
  19. Moving task {state.filtered}/{state.strtotal}: \
  20. {body[task]}[{body[id]}]\
  21. """
  22. class StopFiltering(Exception):
  23. pass
  24. class State(object):
  25. count = 0
  26. filtered = 0
  27. total_apx = 0
  28. @property
  29. def strtotal(self):
  30. if not self.total_apx:
  31. return '?'
  32. return string(self.total_apx)
  33. def __repr__(self):
  34. if self.filtered:
  35. return '^{0.filtered}'.format(self)
  36. return '{0.count}/{0.strtotal}'.format(self)
  37. def republish(producer, message, exchange=None, routing_key=None,
  38. remove_props=['application_headers',
  39. 'content_type',
  40. 'content_encoding',
  41. 'headers']):
  42. body = ensure_bytes(message.body) # use raw message body.
  43. info, headers, props = (message.delivery_info,
  44. message.headers, message.properties)
  45. exchange = info['exchange'] if exchange is None else exchange
  46. routing_key = info['routing_key'] if routing_key is None else routing_key
  47. ctype, enc = message.content_type, message.content_encoding
  48. # remove compression header, as this will be inserted again
  49. # when the message is recompressed.
  50. compression = headers.pop('compression', None)
  51. for key in remove_props:
  52. props.pop(key, None)
  53. producer.publish(ensure_bytes(body), exchange=exchange,
  54. routing_key=routing_key, compression=compression,
  55. headers=headers, content_type=ctype,
  56. content_encoding=enc, **props)
  57. def migrate_task(producer, body_, message, queues=None):
  58. info = message.delivery_info
  59. queues = {} if queues is None else queues
  60. republish(producer, message,
  61. exchange=queues.get(info['exchange']),
  62. routing_key=queues.get(info['routing_key']))
  63. def filter_callback(callback, tasks):
  64. def filtered(body, message):
  65. if tasks and body['task'] not in tasks:
  66. return
  67. return callback(body, message)
  68. return filtered
  69. def migrate_tasks(source, dest, migrate=migrate_task, app=None,
  70. queues=None, **kwargs):
  71. app = app_or_default(app)
  72. queues = prepare_queues(queues)
  73. producer = app.amqp.TaskProducer(dest)
  74. migrate = partial(migrate, producer, queues=queues)
  75. def on_declare_queue(queue):
  76. new_queue = queue(producer.channel)
  77. new_queue.name = queues.get(queue.name, queue.name)
  78. if new_queue.routing_key == queue.name:
  79. new_queue.routing_key = queues.get(queue.name,
  80. new_queue.routing_key)
  81. if new_queue.exchange.name == queue.name:
  82. new_queue.exchange.name = queues.get(queue.name, queue.name)
  83. new_queue.declare()
  84. return start_filter(app, source, migrate, queues=queues,
  85. on_declare_queue=on_declare_queue, **kwargs)
  86. def _maybe_queue(app, q):
  87. if isinstance(q, string_t):
  88. return app.amqp.queues[q]
  89. return q
  90. def move(predicate, connection=None, exchange=None, routing_key=None,
  91. source=None, app=None, callback=None, limit=None, transform=None,
  92. **kwargs):
  93. """Find tasks by filtering them and move the tasks to a new queue.
  94. :param predicate: Filter function used to decide which messages
  95. to move. Must accept the standard signature of ``(body, message)``
  96. used by Kombu consumer callbacks. If the predicate wants the message
  97. to be moved it must return either:
  98. 1) a tuple of ``(exchange, routing_key)``, or
  99. 2) a :class:`~kombu.entity.Queue` instance, or
  100. 3) any other true value which means the specified
  101. ``exchange`` and ``routing_key`` arguments will be used.
  102. :keyword connection: Custom connection to use.
  103. :keyword source: Optional list of source queues to use instead of the
  104. default (which is the queues in :setting:`CELERY_QUEUES`).
  105. This list can also contain new :class:`~kombu.entity.Queue` instances.
  106. :keyword exchange: Default destination exchange.
  107. :keyword routing_key: Default destination routing key.
  108. :keyword limit: Limit number of messages to filter.
  109. :keyword callback: Callback called after message moved,
  110. with signature ``(state, body, message)``.
  111. :keyword transform: Optional function to transform the return
  112. value (destination) of the filter function.
  113. Also supports the same keyword arguments as :func:`start_filter`.
  114. To demonstrate, the :func:`move_task_by_id` operation can be implemented
  115. like this:
  116. .. code-block:: python
  117. def is_wanted_task(body, message):
  118. if body['id'] == wanted_id:
  119. return Queue('foo', exchange=Exchange('foo'),
  120. routing_key='foo')
  121. move(is_wanted_task)
  122. or with a transform:
  123. .. code-block:: python
  124. def transform(value):
  125. if isinstance(value, string_t):
  126. return Queue(value, Exchange(value), value)
  127. return value
  128. move(is_wanted_task, transform=transform)
  129. The predicate may also return a tuple of ``(exchange, routing_key)``
  130. to specify the destination to where the task should be moved,
  131. or a :class:`~kombu.entitiy.Queue` instance.
  132. Any other true value means that the task will be moved to the
  133. default exchange/routing_key.
  134. """
  135. app = app_or_default(app)
  136. queues = [_maybe_queue(app, queue) for queue in source or []] or None
  137. with app.connection_or_acquire(connection, pool=False) as conn:
  138. producer = app.amqp.TaskProducer(conn)
  139. state = State()
  140. def on_task(body, message):
  141. ret = predicate(body, message)
  142. if ret:
  143. if transform:
  144. ret = transform(ret)
  145. if isinstance(ret, Queue):
  146. maybe_declare(ret, conn.default_channel)
  147. ex, rk = ret.exchange.name, ret.routing_key
  148. else:
  149. ex, rk = expand_dest(ret, exchange, routing_key)
  150. republish(producer, message,
  151. exchange=ex, routing_key=rk)
  152. message.ack()
  153. state.filtered += 1
  154. if callback:
  155. callback(state, body, message)
  156. if limit and state.filtered >= limit:
  157. raise StopFiltering()
  158. return start_filter(app, conn, on_task, consume_from=queues, **kwargs)
  159. def expand_dest(ret, exchange, routing_key):
  160. try:
  161. ex, rk = ret
  162. except (TypeError, ValueError):
  163. ex, rk = exchange, routing_key
  164. return ex, rk
  165. def task_id_eq(task_id, body, message):
  166. return body['id'] == task_id
  167. def task_id_in(ids, body, message):
  168. return body['id'] in ids
  169. def prepare_queues(queues):
  170. if isinstance(queues, string_t):
  171. queues = queues.split(',')
  172. if isinstance(queues, list):
  173. queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
  174. for q in queues)
  175. if queues is None:
  176. queues = {}
  177. return queues
  178. def start_filter(app, conn, filter, limit=None, timeout=1.0,
  179. ack_messages=False, tasks=None, queues=None,
  180. callback=None, forever=False, on_declare_queue=None,
  181. consume_from=None, state=None, accept=None, **kwargs):
  182. state = state or State()
  183. queues = prepare_queues(queues)
  184. consume_from = [_maybe_queue(app, q)
  185. for q in consume_from or queues.keys()]
  186. if isinstance(tasks, string_t):
  187. tasks = set(tasks.split(','))
  188. if tasks is None:
  189. tasks = set([])
  190. def update_state(body, message):
  191. state.count += 1
  192. if limit and state.count >= limit:
  193. raise StopFiltering()
  194. def ack_message(body, message):
  195. message.ack()
  196. consumer = app.amqp.TaskConsumer(conn, queues=consume_from, accept=accept)
  197. if tasks:
  198. filter = filter_callback(filter, tasks)
  199. update_state = filter_callback(update_state, tasks)
  200. ack_message = filter_callback(ack_message, tasks)
  201. consumer.register_callback(filter)
  202. consumer.register_callback(update_state)
  203. if ack_messages:
  204. consumer.register_callback(ack_message)
  205. if callback is not None:
  206. callback = partial(callback, state)
  207. if tasks:
  208. callback = filter_callback(callback, tasks)
  209. consumer.register_callback(callback)
  210. # declare all queues on the new broker.
  211. for queue in consumer.queues:
  212. if queues and queue.name not in queues:
  213. continue
  214. if on_declare_queue is not None:
  215. on_declare_queue(queue)
  216. try:
  217. _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
  218. if mcount:
  219. state.total_apx += mcount
  220. except conn.channel_errors + (StdChannelError, ):
  221. pass
  222. # start migrating messages.
  223. with consumer:
  224. try:
  225. for _ in eventloop(conn, # pragma: no cover
  226. timeout=timeout, ignore_timeouts=forever):
  227. pass
  228. except socket.timeout:
  229. pass
  230. except StopFiltering:
  231. pass
  232. return state
  233. def move_task_by_id(task_id, dest, **kwargs):
  234. """Find a task by id and move it to another queue.
  235. :param task_id: Id of task to move.
  236. :param dest: Destination queue.
  237. Also supports the same keyword arguments as :func:`move`.
  238. """
  239. return move_by_idmap({task_id: dest}, **kwargs)
  240. def move_by_idmap(map, **kwargs):
  241. """Moves tasks by matching from a ``task_id: queue`` mapping,
  242. where ``queue`` is a queue to move the task to.
  243. Example::
  244. >>> reroute_idmap({
  245. ... '5bee6e82-f4ac-468e-bd3d-13e8600250bc': Queue(...),
  246. ... 'ada8652d-aef3-466b-abd2-becdaf1b82b3': Queue(...),
  247. ... '3a2b140d-7db1-41ba-ac90-c36a0ef4ab1f': Queue(...)},
  248. ... queues=['hipri'])
  249. """
  250. def task_id_in_map(body, message):
  251. return map.get(body['id'])
  252. # adding the limit means that we don't have to consume any more
  253. # when we've found everything.
  254. return move(task_id_in_map, limit=len(map), **kwargs)
  255. def move_by_taskmap(map, **kwargs):
  256. """Moves tasks by matching from a ``task_name: queue`` mapping,
  257. where ``queue`` is the queue to move the task to.
  258. Example::
  259. >>> reroute_idmap({
  260. ... 'tasks.add': Queue(...),
  261. ... 'tasks.mul': Queue(...),
  262. ... })
  263. """
  264. def task_name_in_map(body, message):
  265. return map.get(body['task']) # <- name of task
  266. return move(task_name_in_map, **kwargs)
  267. def filter_status(state, body, message, **kwargs):
  268. print(MOVING_PROGRESS_FMT.format(state=state, body=body, **kwargs))
  269. move_direct = partial(move, transform=worker_direct)
  270. move_direct_by_id = partial(move_task_by_id, transform=worker_direct)
  271. move_direct_by_idmap = partial(move_by_idmap, transform=worker_direct)
  272. move_direct_by_taskmap = partial(move_by_taskmap, transform=worker_direct)