amqp.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import os
  2. from datetime import datetime, timedelta
  3. from UserDict import UserDict
  4. from celery import routes
  5. from celery import signals
  6. from celery.utils import gen_unique_id, mitemgetter, textindent
  7. from kombu import compat as messaging
  8. from kombu import BrokerConnection
  9. MSG_OPTIONS = ("mandatory", "priority", "immediate",
  10. "routing_key", "serializer", "delivery_mode",
  11. "compression")
  12. QUEUE_FORMAT = """
  13. . %(name)s -> exchange:%(exchange)s (%(exchange_type)s) \
  14. binding:%(binding_key)s
  15. """
  16. BROKER_FORMAT = """\
  17. %(transport_cls)s://%(userid)s@%(hostname)s%(port)s%(virtual_host)s\
  18. """
  19. get_msg_options = mitemgetter(*MSG_OPTIONS)
  20. extract_msg_options = lambda d: dict(zip(MSG_OPTIONS, get_msg_options(d)))
  21. _queues_declared = False
  22. _exchanges_declared = set()
  23. class Queues(UserDict):
  24. def __init__(self, queues):
  25. self.data = {}
  26. for queue_name, options in (queues or {}).items():
  27. self.add(queue_name, **options)
  28. def add(self, queue, exchange=None, routing_key=None,
  29. exchange_type="direct", **options):
  30. q = self[queue] = self.options(exchange, routing_key,
  31. exchange_type, **options)
  32. return q
  33. def options(self, exchange, routing_key,
  34. exchange_type="direct", **options):
  35. return dict(options, routing_key=routing_key,
  36. binding_key=routing_key,
  37. exchange=exchange,
  38. exchange_type=exchange_type)
  39. def format(self, indent=0):
  40. """Format routing table into string for log dumps."""
  41. format = lambda **queue: QUEUE_FORMAT.strip() % queue
  42. info = "\n".join(format(name=name, **config)
  43. for name, config in self.items())
  44. return textindent(info, indent=indent)
  45. def select_subset(self, wanted, create_missing=True):
  46. acc = {}
  47. for queue in wanted:
  48. try:
  49. options = self[queue]
  50. except KeyError:
  51. if not create_missing:
  52. raise
  53. options = self.options(queue, queue)
  54. acc[queue] = options
  55. self.data.clear()
  56. self.data.update(acc)
  57. @classmethod
  58. def with_defaults(cls, queues, default_exchange, default_exchange_type):
  59. def _defaults(opts):
  60. opts.setdefault("exchange", default_exchange),
  61. opts.setdefault("exchange_type", default_exchange_type)
  62. opts.setdefault("binding_key", default_exchange)
  63. opts.setdefault("routing_key", opts.get("binding_key"))
  64. return opts
  65. map(_defaults, queues.values())
  66. return cls(queues)
  67. class TaskPublisher(messaging.Publisher):
  68. auto_declare = False
  69. def declare(self):
  70. if self.exchange not in _exchanges_declared:
  71. super(TaskPublisher, self).declare()
  72. _exchanges_declared.add(self.exchange)
  73. def delay_task(self, task_name, task_args=None, task_kwargs=None,
  74. countdown=None, eta=None, task_id=None, taskset_id=None,
  75. expires=None, exchange=None, exchange_type=None, **kwargs):
  76. """Delay task for execution by the celery nodes."""
  77. task_id = task_id or gen_unique_id()
  78. task_args = task_args or []
  79. task_kwargs = task_kwargs or {}
  80. now = None
  81. if countdown: # Convert countdown to ETA.
  82. now = datetime.now()
  83. eta = now + timedelta(seconds=countdown)
  84. if not isinstance(task_args, (list, tuple)):
  85. raise ValueError("task args must be a list or tuple")
  86. if not isinstance(task_kwargs, dict):
  87. raise ValueError("task kwargs must be a dictionary")
  88. if isinstance(expires, int):
  89. now = now or datetime.now()
  90. expires = now + timedelta(seconds=expires)
  91. message_data = {
  92. "task": task_name,
  93. "id": task_id,
  94. "args": task_args or [],
  95. "kwargs": task_kwargs or {},
  96. "retries": kwargs.get("retries", 0),
  97. "eta": eta and eta.isoformat(),
  98. "expires": expires and expires.isoformat(),
  99. }
  100. if taskset_id:
  101. message_data["taskset"] = taskset_id
  102. # custom exchange passed, need to declare it.
  103. if exchange and exchange not in _exchanges_declared:
  104. exchange_type = exchange_type or self.exchange_type
  105. self.backend.exchange_declare(exchange=exchange,
  106. exchange_type=exchange_type,
  107. durable=self.durable,
  108. auto_delete=self.auto_delete)
  109. self.send(message_data, exchange=exchange,
  110. **extract_msg_options(kwargs))
  111. signals.task_sent.send(sender=task_name, **message_data)
  112. return task_id
  113. class AMQP(object):
  114. BrokerConnection = BrokerConnection
  115. Publisher = messaging.Publisher
  116. Consumer = messaging.Consumer
  117. _queues = None
  118. def __init__(self, app):
  119. self.app = app
  120. def ConsumerSet(self, *args, **kwargs):
  121. return messaging.ConsumerSet(*args, **kwargs)
  122. def Queues(self, queues):
  123. return Queues.with_defaults(queues,
  124. self.app.conf.CELERY_DEFAULT_EXCHANGE,
  125. self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
  126. def Router(self, queues=None, create_missing=None):
  127. return routes.Router(self.app.conf.CELERY_ROUTES,
  128. queues or self.app.conf.CELERY_QUEUES,
  129. self.app.either("CELERY_CREATE_MISSING_QUEUES",
  130. create_missing),
  131. app=self.app)
  132. def TaskConsumer(self, *args, **kwargs):
  133. default_queue_name, default_queue = self.get_default_queue()
  134. defaults = dict({"queue": default_queue_name}, **default_queue)
  135. defaults["routing_key"] = defaults.pop("binding_key", None)
  136. return self.Consumer(*args,
  137. **self.app.merge(defaults, kwargs))
  138. def TaskPublisher(self, *args, **kwargs):
  139. _, default_queue = self.get_default_queue()
  140. defaults = {"exchange": default_queue["exchange"],
  141. "exchange_type": default_queue["exchange_type"],
  142. "routing_key": self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
  143. "serializer": self.app.conf.CELERY_TASK_SERIALIZER}
  144. publisher = TaskPublisher(*args,
  145. **self.app.merge(defaults, kwargs))
  146. # Make sure all queues are declared.
  147. global _queues_declared
  148. if not _queues_declared:
  149. self.get_task_consumer(publisher.connection).close()
  150. _queues_declared = True
  151. publisher.declare()
  152. return publisher
  153. def get_task_consumer(self, connection, queues=None, **kwargs):
  154. return self.ConsumerSet(connection, from_dict=queues or self.queues,
  155. **kwargs)
  156. def get_default_queue(self):
  157. q = self.app.conf.CELERY_DEFAULT_QUEUE
  158. return q, self.queues[q]
  159. def get_broker_info(self, broker_connection=None):
  160. if broker_connection is None:
  161. broker_connection = self.app.broker_connection()
  162. info = broker_connection.info()
  163. port = info["port"]
  164. if port:
  165. info["port"] = ":%s" % (port, )
  166. vhost = info["virtual_host"]
  167. if not vhost.startswith("/"):
  168. info["virtual_host"] = "/" + vhost
  169. return info
  170. def format_broker_info(self, info=None):
  171. """Get message broker connection info string for log dumps."""
  172. return BROKER_FORMAT % self.get_broker_info()
  173. def _get_queues(self):
  174. if self._queues is None:
  175. c = self.app.conf
  176. self._queues = self.Queues(c.CELERY_QUEUES)
  177. return self._queues
  178. def _set_queues(self, queues):
  179. self._queues = self.Queues(queues)
  180. queues = property(_get_queues, _set_queues)