|
@@ -34,12 +34,11 @@ class AMQPBackend(BaseDictBackend):
|
|
|
|
|
|
"""
|
|
|
|
|
|
- _connection = None
|
|
|
- _channel = None
|
|
|
+ _pool = None
|
|
|
|
|
|
def __init__(self, connection=None, exchange=None, exchange_type=None,
|
|
|
persistent=None, serializer=None, auto_delete=True,
|
|
|
- expires=None, **kwargs):
|
|
|
+ expires=None, connection_max=None, **kwargs):
|
|
|
super(AMQPBackend, self).__init__(**kwargs)
|
|
|
conf = self.app.conf
|
|
|
self._connection = connection
|
|
@@ -68,6 +67,8 @@ class AMQPBackend(BaseDictBackend):
|
|
|
# x-expires must be a signed-int, or long describing
|
|
|
# the expiry time in milliseconds.
|
|
|
self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
|
|
|
+ self.connection_max = (connection_max or
|
|
|
+ conf.CELERY_AMQP_TASK_RESULT_CONNECTION_MAX)
|
|
|
|
|
|
def _create_binding(self, task_id):
|
|
|
name = task_id.replace("-", "")
|
|
@@ -77,16 +78,16 @@ class AMQPBackend(BaseDictBackend):
|
|
|
durable=self.persistent,
|
|
|
auto_delete=self.auto_delete)
|
|
|
|
|
|
- def _create_producer(self, task_id):
|
|
|
+ def _create_producer(self, task_id, channel):
|
|
|
binding = self._create_binding(task_id)
|
|
|
- binding(self.channel).declare()
|
|
|
+ binding(channel).declare()
|
|
|
|
|
|
- return Producer(self.channel, exchange=self.exchange,
|
|
|
+ return Producer(channel, exchange=self.exchange,
|
|
|
routing_key=task_id.replace("-", ""),
|
|
|
serializer=self.serializer)
|
|
|
|
|
|
- def _create_consumer(self, bindings):
|
|
|
- return Consumer(self.channel, bindings, no_ack=True)
|
|
|
+ def _create_consumer(self, bindings, channel):
|
|
|
+ return Consumer(channel, bindings, no_ack=True)
|
|
|
|
|
|
def store_result(self, task_id, result, status, traceback=None,
|
|
|
max_retries=20, retry_delay=0.2):
|
|
@@ -99,17 +100,22 @@ class AMQPBackend(BaseDictBackend):
|
|
|
"traceback": traceback}
|
|
|
|
|
|
for i in range(max_retries + 1):
|
|
|
+ conn = self.pool.acquire(block=True)
|
|
|
+ channel = conn.channel()
|
|
|
try:
|
|
|
- self._create_producer(task_id).publish(meta)
|
|
|
- except Exception, exc:
|
|
|
- if not max_retries:
|
|
|
- raise
|
|
|
- self._channel = None
|
|
|
- self._connection = None
|
|
|
- warnings.warn(AMQResultWarning(
|
|
|
- "Error sending result %s: %r" % (task_id, exc)))
|
|
|
- time.sleep(retry_delay)
|
|
|
- break
|
|
|
+ try:
|
|
|
+ self._create_producer(task_id, channel).publish(meta)
|
|
|
+ except Exception, exc:
|
|
|
+ if not max_retries:
|
|
|
+ raise
|
|
|
+ warnings.warn(AMQResultWarning(
|
|
|
+ "Error sending result %s: %r" % (task_id, exc)))
|
|
|
+ time.sleep(retry_delay)
|
|
|
+ else:
|
|
|
+ break
|
|
|
+ finally:
|
|
|
+ channel.close()
|
|
|
+ conn.release()
|
|
|
|
|
|
return result
|
|
|
|
|
@@ -138,18 +144,28 @@ class AMQPBackend(BaseDictBackend):
|
|
|
return self.wait_for(task_id, timeout, cache)
|
|
|
|
|
|
def poll(self, task_id):
|
|
|
- binding = self._create_binding(task_id)(self.channel)
|
|
|
- result = binding.get()
|
|
|
- if result:
|
|
|
- binding.delete(if_unused=True, if_empty=True, nowait=True)
|
|
|
- payload = self._cache[task_id] = result.payload
|
|
|
- return payload
|
|
|
- elif task_id in self._cache:
|
|
|
- return self._cache[task_id] # use previously received state.
|
|
|
- return {"status": states.PENDING, "result": None}
|
|
|
+ conn = self.pool.acquire(block=True)
|
|
|
+ channel = conn.channel()
|
|
|
+ try:
|
|
|
+ binding = self._create_binding(task_id)(channel)
|
|
|
+ result = binding.get()
|
|
|
+ if result:
|
|
|
+ try:
|
|
|
+ binding.delete(if_unused=True, if_empty=True, nowait=True)
|
|
|
+ except conn.channel_errors:
|
|
|
+ pass
|
|
|
+ payload = self._cache[task_id] = result.payload
|
|
|
+ return payload
|
|
|
+ elif task_id in self._cache:
|
|
|
+ # use previously received state.
|
|
|
+ return self._cache[task_id]
|
|
|
+ return {"status": states.PENDING, "result": None}
|
|
|
+ finally:
|
|
|
+ channel.close()
|
|
|
+ conn.release()
|
|
|
|
|
|
def drain_events(self, consumer, timeout=None):
|
|
|
- wait = self.connection.drain_events
|
|
|
+ wait = consumer.channel.connection.drain_events
|
|
|
results = {}
|
|
|
|
|
|
def callback(meta, message):
|
|
@@ -173,60 +189,61 @@ class AMQPBackend(BaseDictBackend):
|
|
|
return results
|
|
|
|
|
|
def consume(self, task_id, timeout=None):
|
|
|
- binding = self._create_binding(task_id)
|
|
|
- consumer = self._create_consumer(binding)
|
|
|
- consumer.consume()
|
|
|
+ conn = self.pool.acquire(block=True)
|
|
|
+ channel = conn.channel()
|
|
|
try:
|
|
|
- return self.drain_events(consumer, timeout=timeout).values()[0]
|
|
|
+ binding = self._create_binding(task_id)
|
|
|
+ consumer = self._create_consumer(binding, channel)
|
|
|
+ consumer.consume()
|
|
|
+ try:
|
|
|
+ return self.drain_events(consumer, timeout=timeout).values()[0]
|
|
|
+ finally:
|
|
|
+ consumer.cancel()
|
|
|
finally:
|
|
|
- consumer.cancel()
|
|
|
+ channel.release()
|
|
|
+ conn.release()
|
|
|
|
|
|
def get_many(self, task_ids, timeout=None):
|
|
|
- bindings = [self._create_binding(task_id) for task_id in task_ids]
|
|
|
- consumer = self._create_consumer(bindings)
|
|
|
- consumer.consume()
|
|
|
- ids = set(task_ids)
|
|
|
- results = {}
|
|
|
- cached_ids = set()
|
|
|
- for task_id in ids:
|
|
|
- try:
|
|
|
- cached = self._cache[task_id]
|
|
|
- except KeyError:
|
|
|
- pass
|
|
|
- else:
|
|
|
- if cached["status"] in states.READY_STATES:
|
|
|
- results[task_id] = cached
|
|
|
- cached_ids.add(task_id)
|
|
|
- ids ^= cached_ids
|
|
|
+ conn = self.pool.acquire(block=True)
|
|
|
+ channel = conn.channel()
|
|
|
try:
|
|
|
- while ids:
|
|
|
- r = self.drain_events(consumer, timeout=timeout)
|
|
|
- results.update(r)
|
|
|
- ids ^= set(r.keys())
|
|
|
+ bindings = [self._create_binding(task_id) for task_id in task_ids]
|
|
|
+ consumer = self._create_consumer(bindings, channel)
|
|
|
+ consumer.consume()
|
|
|
+ ids = set(task_ids)
|
|
|
+ cached_ids = set()
|
|
|
+ for task_id in ids:
|
|
|
+ try:
|
|
|
+ cached = self._cache[task_id]
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ if cached["status"] in states.READY_STATES:
|
|
|
+ yield task_id, cached
|
|
|
+ cached_ids.add(task_id)
|
|
|
+ ids ^= cached_ids
|
|
|
+ try:
|
|
|
+ while ids:
|
|
|
+ r = self.drain_events(consumer, timeout=timeout)
|
|
|
+ ids ^= set(r.keys())
|
|
|
+ for ready_id, ready_meta in r.items():
|
|
|
+ yield ready_id, ready_meta
|
|
|
+ finally:
|
|
|
+ consumer.cancel()
|
|
|
finally:
|
|
|
- consumer.cancel()
|
|
|
-
|
|
|
- return results
|
|
|
+ channel.close()
|
|
|
+ conn.release()
|
|
|
|
|
|
def close(self):
|
|
|
- if self._channel is not None:
|
|
|
- self._channel.close()
|
|
|
- self._channel = None
|
|
|
- if self._connection is not None:
|
|
|
- self._connection.close()
|
|
|
- self._connection = None
|
|
|
-
|
|
|
- @property
|
|
|
- def connection(self):
|
|
|
- if not self._connection:
|
|
|
- self._connection = self.app.broker_connection()
|
|
|
- return self._connection
|
|
|
+ if self._pool is not None:
|
|
|
+ self._pool.close()
|
|
|
+ self._pool = None
|
|
|
|
|
|
@property
|
|
|
- def channel(self):
|
|
|
- if not self._channel:
|
|
|
- self._channel = self.connection.channel()
|
|
|
- return self._channel
|
|
|
+ def pool(self):
|
|
|
+ if not self._pool:
|
|
|
+ self._pool = self.app.broker_connection().Pool(self.connection_max)
|
|
|
+ return self._pool
|
|
|
|
|
|
def reload_task_result(self, task_id):
|
|
|
raise NotImplementedError(
|