migrate.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.contrib.migrate
  4. ~~~~~~~~~~~~~~~~~~~~~~
  5. Migration tools.
  6. """
  7. from __future__ import absolute_import
  8. from __future__ import with_statement
  9. import socket
  10. from functools import partial
  11. from itertools import cycle, islice
  12. from kombu import eventloop
  13. from kombu.exceptions import StdChannelError
  14. from kombu.utils.encoding import ensure_bytes
  15. from celery.app import app_or_default
  16. class State(object):
  17. count = 0
  18. total_apx = 0
  19. @property
  20. def strtotal(self):
  21. if not self.total_apx:
  22. return u'?'
  23. return unicode(self.total_apx)
  24. def __repr__(self):
  25. return '%s/%s' % (self.count, self.strtotal)
  26. def republish(producer, message, exchange=None, routing_key=None,
  27. remove_props=['application_headers',
  28. 'content_type',
  29. 'content_encoding',
  30. 'headers']):
  31. body = ensure_bytes(message.body) # use raw message body.
  32. info, headers, props = (message.delivery_info,
  33. message.headers, message.properties)
  34. exchange = info['exchange'] if exchange is None else exchange
  35. routing_key = info['routing_key'] if routing_key is None else routing_key
  36. ctype, enc = message.content_type, message.content_encoding
  37. # remove compression header, as this will be inserted again
  38. # when the message is recompressed.
  39. compression = headers.pop('compression', None)
  40. for key in remove_props:
  41. props.pop(key, None)
  42. producer.publish(ensure_bytes(body), exchange=exchange,
  43. routing_key=routing_key, compression=compression,
  44. headers=headers, content_type=ctype,
  45. content_encoding=enc, **props)
  46. def migrate_task(producer, body_, message, queues=None):
  47. info = message.delivery_info
  48. queues = {} if queues is None else queues
  49. republish(producer, message,
  50. exchange=queues.get(info['exchange']),
  51. routing_key=queues.get(info['routing_key']))
  52. def filter_callback(callback, tasks):
  53. def filtered(body, message):
  54. if tasks and message.payload['task'] not in tasks:
  55. return
  56. return callback(body, message)
  57. return filtered
  58. def migrate_tasks(source, dest, migrate=migrate_task, app=None,
  59. queues=None, **kwargs):
  60. app = app_or_default(app)
  61. queues = prepare_queues(queues)
  62. producer = app.amqp.TaskProducer(dest)
  63. migrate = partial(migrate, producer, queues=queues)
  64. def on_declare_queue(queue):
  65. new_queue = queue(producer.channel)
  66. new_queue.name = queues.get(queue.name, queue.name)
  67. if new_queue.routing_key == queue.name:
  68. new_queue.routing_key = queues.get(queue.name,
  69. new_queue.routing_key)
  70. if new_queue.exchange.name == queue.name:
  71. new_queue.exchange.name = queues.get(queue.name, queue.name)
  72. new_queue.declare()
  73. return start_filter(app, source, migrate, queues=queues,
  74. on_declare_queue=on_declare_queue, **kwargs)
  75. def move_tasks(conn, predicate, exchange, routing_key, app=None, **kwargs):
  76. app = app_or_default(app)
  77. producer = app.amqp.TaskProducer(conn)
  78. def on_task(body, message):
  79. if predicate(body, message):
  80. republish(producer, message,
  81. exchange=exchange, routing_key=routing_key)
  82. message.ack()
  83. return start_filter(app, conn, on_task, **kwargs)
  84. def move_task_by_id(conn, task_id, exchange, routing_key, **kwargs):
  85. def predicate(body, message):
  86. if body['id'] == task_id:
  87. return True
  88. return move_tasks(conn, predicate, exchange, routing_key, **kwargs)
  89. def prepare_queues(queues):
  90. if isinstance(queues, basestring):
  91. queues = queues.split(',')
  92. if isinstance(queues, list):
  93. queues = dict(tuple(islice(cycle(q.split(':')), None, 2))
  94. for q in queues)
  95. if queues is None:
  96. queues = {}
  97. def start_filter(app, conn, filter, limit=None, timeout=1.0,
  98. ack_messages=False, migrate=migrate_task, tasks=None, queues=None,
  99. callback=None, forever=False, on_declare_queue=None, **kwargs):
  100. state = State()
  101. queues = prepare_queues(queues)
  102. if isinstance(tasks, basestring):
  103. tasks = set(tasks.split(','))
  104. if tasks is None:
  105. tasks = set([])
  106. def update_state(body, message):
  107. state.count += 1
  108. def ack_message(body, message):
  109. message.ack()
  110. consumer = app.amqp.TaskConsumer(conn)
  111. if tasks:
  112. filter = filter_callback(filter, tasks)
  113. update_state = filter_callback(update_state, tasks)
  114. ack_message = filter_callback(ack_message, tasks)
  115. consumer.register_callback(filter)
  116. consumer.register_callback(update_state)
  117. if ack_messages:
  118. consumer.register_callback(ack_message)
  119. if callback is not None:
  120. callback = partial(callback, state)
  121. if tasks:
  122. callback = filter_callback(callback, tasks)
  123. consumer.register_callback(callback)
  124. # declare all queues on the new broker.
  125. for queue in consumer.queues:
  126. if queues and queue.name not in queues:
  127. continue
  128. if on_declare_queue is not None:
  129. on_declare_queue(queue)
  130. try:
  131. _, mcount, _ = queue(consumer.channel).queue_declare(passive=True)
  132. if mcount:
  133. state.total_apx += mcount
  134. except conn.channel_errors + (StdChannelError, ):
  135. pass
  136. # start migrating messages.
  137. with consumer:
  138. try:
  139. for _ in eventloop(conn, limit=limit, # pragma: no cover
  140. timeout=timeout, ignore_timeouts=forever):
  141. pass
  142. except socket.timeout:
  143. pass
  144. return state