amqp.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from datetime import datetime, timedelta
  2. from UserDict import UserDict
  3. from carrot.connection import BrokerConnection
  4. from carrot import messaging
  5. from celery import routes
  6. from celery import signals
  7. from celery.utils import gen_unique_id, mitemgetter, textindent
  8. MSG_OPTIONS = ("mandatory", "priority", "immediate",
  9. "routing_key", "serializer", "delivery_mode")
  10. QUEUE_FORMAT = """
  11. . %(name)s -> exchange:%(exchange)s (%(exchange_type)s) \
  12. binding:%(binding_key)s
  13. """
  14. BROKER_FORMAT = "%(carrot_backend)s://%(userid)s@%(host)s%(port)s%(vhost)s"
  15. get_msg_options = mitemgetter(*MSG_OPTIONS)
  16. extract_msg_options = lambda d: dict(zip(MSG_OPTIONS, get_msg_options(d)))
  17. _queues_declared = False
  18. _exchanges_declared = set()
  19. class Queues(UserDict):
  20. def __init__(self, queues):
  21. self.data = {}
  22. for queue_name, options in (queues or {}).items():
  23. self.add(queue_name, **options)
  24. def add(self, queue, exchange=None, routing_key=None,
  25. exchange_type="direct", **options):
  26. q = self[queue] = self.options(exchange, routing_key,
  27. exchange_type, **options)
  28. return q
  29. def options(self, exchange, routing_key,
  30. exchange_type="direct", **options):
  31. return dict(options, routing_key=routing_key,
  32. binding_key=routing_key,
  33. exchange=exchange,
  34. exchange_type=exchange_type)
  35. def format(self, indent=0):
  36. """Format routing table into string for log dumps."""
  37. format = lambda **queue: QUEUE_FORMAT.strip() % queue
  38. info = "\n".join(format(name=name, **config)
  39. for name, config in self.items())
  40. return textindent(info, indent=indent)
  41. def select_subset(self, wanted, create_missing=True):
  42. acc = {}
  43. for queue in wanted:
  44. try:
  45. options = self[queue]
  46. except KeyError:
  47. if not create_missing:
  48. raise
  49. options = self.options(queue, queue)
  50. acc[queue] = options
  51. self.data.clear()
  52. self.data.update(acc)
  53. @classmethod
  54. def with_defaults(cls, queues, default_exchange, default_exchange_type):
  55. def _defaults(opts):
  56. opts.setdefault("exchange", default_exchange),
  57. opts.setdefault("exchange_type", default_exchange_type)
  58. opts.setdefault("binding_key", default_exchange)
  59. opts.setdefault("routing_key", opts.get("binding_key"))
  60. return opts
  61. map(_defaults, queues.values())
  62. return cls(queues)
  63. class TaskPublisher(messaging.Publisher):
  64. auto_declare = False
  65. def declare(self):
  66. if self.exchange not in _exchanges_declared:
  67. super(TaskPublisher, self).declare()
  68. _exchanges_declared.add(self.exchange)
  69. def delay_task(self, task_name, task_args=None, task_kwargs=None,
  70. countdown=None, eta=None, task_id=None, taskset_id=None,
  71. expires=None, exchange=None, exchange_type=None, **kwargs):
  72. """Delay task for execution by the celery nodes."""
  73. task_id = task_id or gen_unique_id()
  74. task_args = task_args or []
  75. task_kwargs = task_kwargs or {}
  76. now = None
  77. if countdown: # Convert countdown to ETA.
  78. now = datetime.now()
  79. eta = now + timedelta(seconds=countdown)
  80. if not isinstance(task_args, (list, tuple)):
  81. raise ValueError("task args must be a list or tuple")
  82. if not isinstance(task_kwargs, dict):
  83. raise ValueError("task kwargs must be a dictionary")
  84. if isinstance(expires, int):
  85. now = now or datetime.now()
  86. expires = now + timedelta(seconds=expires)
  87. message_data = {
  88. "task": task_name,
  89. "id": task_id,
  90. "args": task_args or [],
  91. "kwargs": task_kwargs or {},
  92. "retries": kwargs.get("retries", 0),
  93. "eta": eta and eta.isoformat(),
  94. "expires": expires and expires.isoformat(),
  95. }
  96. if taskset_id:
  97. message_data["taskset"] = taskset_id
  98. # custom exchange passed, need to declare it.
  99. if exchange and exchange not in _exchanges_declared:
  100. exchange_type = exchange_type or self.exchange_type
  101. self.backend.exchange_declare(exchange=exchange,
  102. exchange_type=exchange_type,
  103. durable=self.durable,
  104. auto_delete=self.auto_delete)
  105. self.send(message_data, exchange=exchange,
  106. **extract_msg_options(kwargs))
  107. signals.task_sent.send(sender=task_name, **message_data)
  108. return task_id
  109. class ConsumerSet(messaging.ConsumerSet):
  110. """ConsumerSet with an optional decode error callback.
  111. For more information see :class:`carrot.messaging.ConsumerSet`.
  112. .. attribute:: on_decode_error
  113. Callback called if a message had decoding errors.
  114. The callback is called with the signature::
  115. callback(message, exception)
  116. """
  117. on_decode_error = None
  118. def _receive_callback(self, raw_message):
  119. message = self.backend.message_to_python(raw_message)
  120. if self.auto_ack and not message.acknowledged:
  121. message.ack()
  122. try:
  123. decoded = message.decode()
  124. except Exception, exc:
  125. if self.on_decode_error:
  126. return self.on_decode_error(message, exc)
  127. else:
  128. raise
  129. self.receive(decoded, message)
  130. class AMQP(object):
  131. BrokerConnection = BrokerConnection
  132. Publisher = messaging.Publisher
  133. Consumer = messaging.Consumer
  134. ConsumerSet = ConsumerSet
  135. _queues = None
  136. def __init__(self, app):
  137. self.app = app
  138. def Queues(self, queues):
  139. return Queues.with_defaults(queues,
  140. self.app.conf.CELERY_DEFAULT_EXCHANGE,
  141. self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
  142. def Router(self, queues=None, create_missing=None):
  143. return routes.Router(self.app.conf.CELERY_ROUTES,
  144. queues or self.app.conf.CELERY_QUEUES,
  145. self.app.either("CELERY_CREATE_MISSING_QUEUES",
  146. create_missing),
  147. app=self.app)
  148. def TaskConsumer(self, *args, **kwargs):
  149. default_queue_name, default_queue = self.get_default_queue()
  150. defaults = dict({"queue": default_queue_name}, **default_queue)
  151. defaults["routing_key"] = defaults.pop("binding_key", None)
  152. return self.Consumer(*args,
  153. **self.app.merge(defaults, kwargs))
  154. def TaskPublisher(self, *args, **kwargs):
  155. _, default_queue = self.get_default_queue()
  156. defaults = {"exchange": default_queue["exchange"],
  157. "exchange_type": default_queue["exchange_type"],
  158. "routing_key": self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
  159. "serializer": self.app.conf.CELERY_TASK_SERIALIZER}
  160. publisher = TaskPublisher(*args,
  161. **self.app.merge(defaults, kwargs))
  162. # Make sure all queues are declared.
  163. global _queues_declared
  164. if not _queues_declared:
  165. consumers = self.get_consumer_set(publisher.connection)
  166. consumers.close()
  167. _queues_declared = True
  168. publisher.declare()
  169. return publisher
  170. def get_consumer_set(self, connection, queues=None, **options):
  171. queues = queues or self.queues
  172. cset = self.ConsumerSet(connection)
  173. for queue_name, queue_options in queues.items():
  174. queue_options = dict(queue_options)
  175. queue_options["routing_key"] = queue_options.pop("binding_key",
  176. None)
  177. consumer = self.Consumer(connection, queue=queue_name,
  178. backend=cset.backend, **queue_options)
  179. cset.consumers.append(consumer)
  180. return cset
  181. def get_default_queue(self):
  182. q = self.app.conf.CELERY_DEFAULT_QUEUE
  183. return q, self.queues[q]
  184. def get_broker_info(self):
  185. broker_connection = self.app.broker_connection()
  186. carrot_backend = broker_connection.backend_cls
  187. if carrot_backend and not isinstance(carrot_backend, str):
  188. carrot_backend = carrot_backend.__name__
  189. carrot_backend = carrot_backend or "amqp"
  190. port = broker_connection.port or \
  191. broker_connection.get_backend_cls().default_port
  192. port = port and ":%s" % port or ""
  193. vhost = broker_connection.virtual_host
  194. if not vhost.startswith("/"):
  195. vhost = "/" + vhost
  196. return {"carrot_backend": carrot_backend,
  197. "userid": broker_connection.userid,
  198. "host": broker_connection.hostname,
  199. "port": port,
  200. "vhost": vhost}
  201. def format_broker_info(self, info=None):
  202. """Get message broker connection info string for log dumps."""
  203. return BROKER_FORMAT % self.get_broker_info()
  204. def _get_queues(self):
  205. if self._queues is None:
  206. c = self.app.conf
  207. self._queues = self.Queues(c.CELERY_QUEUES)
  208. return self._queues
  209. def _set_queues(self, queues):
  210. self._queues = self.Queues(queues)
  211. queues = property(_get_queues, _set_queues)