amqp.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import os
  2. from datetime import datetime, timedelta
  3. from UserDict import UserDict
  4. from celery import routes
  5. from celery import signals
  6. from celery.utils import gen_unique_id, mitemgetter, textindent
  7. from kombu import compat as messaging
  8. from kombu import BrokerConnection
  9. MSG_OPTIONS = ("mandatory", "priority", "immediate",
  10. "routing_key", "serializer", "delivery_mode",
  11. "compression")
  12. QUEUE_FORMAT = """
  13. . %(name)s -> exchange:%(exchange)s (%(exchange_type)s) \
  14. binding:%(binding_key)s
  15. """
  16. BROKER_FORMAT = """\
  17. %(transport)s://%(userid)s@%(hostname)s%(port)s%(virtual_host)s\
  18. """
  19. get_msg_options = mitemgetter(*MSG_OPTIONS)
  20. extract_msg_options = lambda d: dict(zip(MSG_OPTIONS, get_msg_options(d)))
  21. _queues_declared = False
  22. _exchanges_declared = set()
  23. class Queues(UserDict):
  24. def __init__(self, queues):
  25. self.data = {}
  26. for queue_name, options in (queues or {}).items():
  27. self.add(queue_name, **options)
  28. def add(self, queue, exchange=None, routing_key=None,
  29. exchange_type="direct", **options):
  30. q = self[queue] = self.options(exchange, routing_key,
  31. exchange_type, **options)
  32. return q
  33. def options(self, exchange, routing_key,
  34. exchange_type="direct", **options):
  35. return dict(options, routing_key=routing_key,
  36. binding_key=routing_key,
  37. exchange=exchange,
  38. exchange_type=exchange_type)
  39. def format(self, indent=0):
  40. """Format routing table into string for log dumps."""
  41. format = lambda **queue: QUEUE_FORMAT.strip() % queue
  42. info = "\n".join(format(name=name, **config)
  43. for name, config in self.items())
  44. return textindent(info, indent=indent)
  45. def select_subset(self, wanted, create_missing=True):
  46. acc = {}
  47. for queue in wanted:
  48. try:
  49. options = self[queue]
  50. except KeyError:
  51. if not create_missing:
  52. raise
  53. options = self.options(queue, queue)
  54. acc[queue] = options
  55. self.data.clear()
  56. self.data.update(acc)
  57. @classmethod
  58. def with_defaults(cls, queues, default_exchange, default_exchange_type):
  59. for opts in queues.values():
  60. opts.setdefault("exchange", default_exchange),
  61. opts.setdefault("exchange_type", default_exchange_type)
  62. opts.setdefault("binding_key", default_exchange)
  63. opts.setdefault("routing_key", opts.get("binding_key"))
  64. return cls(queues)
  65. class TaskPublisher(messaging.Publisher):
  66. auto_declare = False
  67. def declare(self):
  68. if self.exchange not in _exchanges_declared:
  69. super(TaskPublisher, self).declare()
  70. _exchanges_declared.add(self.exchange)
  71. def delay_task(self, task_name, task_args=None, task_kwargs=None,
  72. countdown=None, eta=None, task_id=None, taskset_id=None,
  73. expires=None, exchange=None, exchange_type=None, **kwargs):
  74. """Delay task for execution by the celery nodes."""
  75. task_id = task_id or gen_unique_id()
  76. task_args = task_args or []
  77. task_kwargs = task_kwargs or {}
  78. now = None
  79. if countdown: # Convert countdown to ETA.
  80. now = datetime.now()
  81. eta = now + timedelta(seconds=countdown)
  82. if not isinstance(task_args, (list, tuple)):
  83. raise ValueError("task args must be a list or tuple")
  84. if not isinstance(task_kwargs, dict):
  85. raise ValueError("task kwargs must be a dictionary")
  86. if isinstance(expires, int):
  87. now = now or datetime.now()
  88. expires = now + timedelta(seconds=expires)
  89. message_data = {
  90. "task": task_name,
  91. "id": task_id,
  92. "args": task_args or [],
  93. "kwargs": task_kwargs or {},
  94. "retries": kwargs.get("retries", 0),
  95. "eta": eta and eta.isoformat(),
  96. "expires": expires and expires.isoformat(),
  97. }
  98. if taskset_id:
  99. message_data["taskset"] = taskset_id
  100. # custom exchange passed, need to declare it.
  101. if exchange and exchange not in _exchanges_declared:
  102. exchange_type = exchange_type or self.exchange_type
  103. self.backend.exchange_declare(exchange=exchange,
  104. type=exchange_type,
  105. durable=self.durable,
  106. auto_delete=self.auto_delete)
  107. self.send(message_data, exchange=exchange,
  108. **extract_msg_options(kwargs))
  109. signals.task_sent.send(sender=task_name, **message_data)
  110. return task_id
  111. class AMQP(object):
  112. BrokerConnection = BrokerConnection
  113. Publisher = messaging.Publisher
  114. Consumer = messaging.Consumer
  115. _queues = None
  116. def __init__(self, app):
  117. self.app = app
  118. def ConsumerSet(self, *args, **kwargs):
  119. return messaging.ConsumerSet(*args, **kwargs)
  120. def Queues(self, queues):
  121. return Queues.with_defaults(queues,
  122. self.app.conf.CELERY_DEFAULT_EXCHANGE,
  123. self.app.conf.CELERY_DEFAULT_EXCHANGE_TYPE)
  124. def Router(self, queues=None, create_missing=None):
  125. return routes.Router(self.app.conf.CELERY_ROUTES,
  126. queues or self.app.conf.CELERY_QUEUES,
  127. self.app.either("CELERY_CREATE_MISSING_QUEUES",
  128. create_missing),
  129. app=self.app)
  130. def TaskConsumer(self, *args, **kwargs):
  131. default_queue_name, default_queue = self.get_default_queue()
  132. defaults = dict({"queue": default_queue_name}, **default_queue)
  133. defaults["routing_key"] = defaults.pop("binding_key", None)
  134. return self.Consumer(*args,
  135. **self.app.merge(defaults, kwargs))
  136. def TaskPublisher(self, *args, **kwargs):
  137. _, default_queue = self.get_default_queue()
  138. defaults = {"exchange": default_queue["exchange"],
  139. "exchange_type": default_queue["exchange_type"],
  140. "routing_key": self.app.conf.CELERY_DEFAULT_ROUTING_KEY,
  141. "serializer": self.app.conf.CELERY_TASK_SERIALIZER}
  142. publisher = TaskPublisher(*args,
  143. **self.app.merge(defaults, kwargs))
  144. # Make sure all queues are declared.
  145. global _queues_declared
  146. if not _queues_declared:
  147. self.get_task_consumer(publisher.connection).close()
  148. _queues_declared = True
  149. publisher.declare()
  150. return publisher
  151. def get_task_consumer(self, connection, queues=None, **kwargs):
  152. return self.ConsumerSet(connection, from_dict=queues or self.queues,
  153. **kwargs)
  154. def get_default_queue(self):
  155. q = self.app.conf.CELERY_DEFAULT_QUEUE
  156. return q, self.queues[q]
  157. def get_broker_info(self, broker_connection=None):
  158. if broker_connection is None:
  159. broker_connection = self.app.broker_connection()
  160. info = broker_connection.info()
  161. port = info["port"]
  162. if port:
  163. info["port"] = ":%s" % (port, )
  164. vhost = info["virtual_host"]
  165. if not vhost.startswith("/"):
  166. info["virtual_host"] = "/" + vhost
  167. return info
  168. def format_broker_info(self, info=None):
  169. """Get message broker connection info string for log dumps."""
  170. return BROKER_FORMAT % self.get_broker_info()
  171. def _get_queues(self):
  172. if self._queues is None:
  173. c = self.app.conf
  174. self._queues = self.Queues(c.CELERY_QUEUES)
  175. return self._queues
  176. def _set_queues(self, queues):
  177. self._queues = self.Queues(queues)
  178. queues = property(_get_queues, _set_queues)