|
@@ -46,12 +46,122 @@ class NoCacheQueue(Queue):
|
|
|
can_cache_declaration = False
|
|
|
|
|
|
|
|
|
+class ResultConsumer(object):
|
|
|
+ Consumer = Consumer
|
|
|
+
|
|
|
+ def __init__(self, backend, app, accept, pending_results):
|
|
|
+ self.backend = backend
|
|
|
+ self.app = app
|
|
|
+ self.accept = accept
|
|
|
+ self._pending_results = pending_results
|
|
|
+ self._consumer = None
|
|
|
+ self._conn = None
|
|
|
+ self.on_message = None
|
|
|
+ self.bucket = None
|
|
|
+
|
|
|
+ def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
|
|
|
+ wait = self.drain_events
|
|
|
+ with self.app.pool.acquire_channel(block=True) as (conn, channel):
|
|
|
+ binding = self.backend._create_binding(task_id)
|
|
|
+ with self.Consumer(channel, binding,
|
|
|
+ no_ack=no_ack, accept=self.accept) as consumer:
|
|
|
+ while 1:
|
|
|
+ try:
|
|
|
+ return wait(
|
|
|
+ conn, consumer, timeout, on_interval)[task_id]
|
|
|
+ except KeyError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ def wait_for_pending(self, result,
|
|
|
+ callback=None, propagate=True, **kwargs):
|
|
|
+ for _ in self._wait_for_pending(result, **kwargs):
|
|
|
+ pass
|
|
|
+ return result.maybe_throw(callback=callback, propagate=propagate)
|
|
|
+
|
|
|
+ def _wait_for_pending(self, result, timeout=None, interval=0.5,
|
|
|
+ no_ack=True, on_interval=None, callback=None,
|
|
|
+ on_message=None, propagate=True):
|
|
|
+ prev_on_m, self.on_message = self.on_message, on_message
|
|
|
+ try:
|
|
|
+ for _ in self.drain_events_until(
|
|
|
+ result.on_ready, timeout=timeout,
|
|
|
+ on_interval=on_interval):
|
|
|
+ yield
|
|
|
+ except socket.timeout:
|
|
|
+ raise TimeoutError('The operation timed out.')
|
|
|
+ finally:
|
|
|
+ self.on_message = prev_on_m
|
|
|
+
|
|
|
+ def collect_for_pending(self, result, bucket=None, **kwargs):
|
|
|
+ prev_bucket, self.bucket = self.bucket, bucket
|
|
|
+ try:
|
|
|
+ for _ in self._wait_for_pending(result, **kwargs):
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ self.bucket = prev_bucket
|
|
|
+
|
|
|
+ def start(self, initial_queue, no_ack=True):
|
|
|
+ self._conn = self.app.connection()
|
|
|
+ self._consumer = self.Consumer(
|
|
|
+ self._conn.default_channel, [initial_queue],
|
|
|
+ callbacks=[self.on_state_change], no_ack=no_ack,
|
|
|
+ accept=self.accept)
|
|
|
+ self._consumer.consume()
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ try:
|
|
|
+ self._consumer.cancel()
|
|
|
+ finally:
|
|
|
+ self._connection.close()
|
|
|
+
|
|
|
+ def consume_from(self, queue):
|
|
|
+ if self._consumer is None:
|
|
|
+ return self.start(queue)
|
|
|
+ if not self._consumer.consuming_from(queue):
|
|
|
+ self._consumer.add_queue(queue)
|
|
|
+ self._consumer.consume()
|
|
|
+
|
|
|
+ def cancel_for(self, queue):
|
|
|
+ self._consumer.cancel_by_queue(queue)
|
|
|
+
|
|
|
+ def on_state_change(self, meta, message):
|
|
|
+ if self.on_message:
|
|
|
+ self.on_message(meta)
|
|
|
+ if meta['status'] in states.READY_STATES:
|
|
|
+ try:
|
|
|
+ result = self._pending_results[meta['task_id']]
|
|
|
+ except KeyError:
|
|
|
+ return
|
|
|
+ result._maybe_set_cache(meta)
|
|
|
+ if self.bucket is not None:
|
|
|
+ self.bucket.append(result)
|
|
|
+
|
|
|
+ def drain_events_until(self, p, timeout=None, on_interval=None,
|
|
|
+ monotonic=monotonic, wait=None):
|
|
|
+ wait = wait or self._conn.drain_events
|
|
|
+ time_start = monotonic()
|
|
|
+
|
|
|
+ while 1:
|
|
|
+ # Total time spent may exceed a single call to wait()
|
|
|
+ if timeout and monotonic() - time_start >= timeout:
|
|
|
+ raise socket.timeout()
|
|
|
+ try:
|
|
|
+ yield wait(timeout=1)
|
|
|
+ except socket.timeout:
|
|
|
+ pass
|
|
|
+ if on_interval:
|
|
|
+ on_interval()
|
|
|
+ if p.ready: # got event on the wanted channel.
|
|
|
+ break
|
|
|
+
|
|
|
+
|
|
|
class AMQPBackend(BaseBackend):
|
|
|
"""Publishes results by sending messages."""
|
|
|
Exchange = Exchange
|
|
|
Queue = NoCacheQueue
|
|
|
Consumer = Consumer
|
|
|
Producer = Producer
|
|
|
+ ResultConsumer = ResultConsumer
|
|
|
|
|
|
BacklogLimitExceeded = BacklogLimitExceeded
|
|
|
|
|
@@ -83,6 +193,8 @@ class AMQPBackend(BaseBackend):
|
|
|
self.queue_arguments = dictfilter({
|
|
|
'x-expires': maybe_s_to_ms(self.expires),
|
|
|
})
|
|
|
+ self.result_consumer = self.ResultConsumer(
|
|
|
+ self, self.app, self.accept, self._pending_results)
|
|
|
|
|
|
def _create_exchange(self, name, type='direct', delivery_mode=2):
|
|
|
return self.Exchange(name=name,
|
|
@@ -136,22 +248,6 @@ class AMQPBackend(BaseBackend):
|
|
|
def on_reply_declare(self, task_id):
|
|
|
return [self._create_binding(task_id)]
|
|
|
|
|
|
- def wait_for(self, task_id, timeout=None, cache=True,
|
|
|
- no_ack=True, on_interval=None,
|
|
|
- READY_STATES=states.READY_STATES,
|
|
|
- PROPAGATE_STATES=states.PROPAGATE_STATES,
|
|
|
- **kwargs):
|
|
|
- cached_meta = self._cache.get(task_id)
|
|
|
- if cache and cached_meta and \
|
|
|
- cached_meta['status'] in READY_STATES:
|
|
|
- return cached_meta
|
|
|
- else:
|
|
|
- try:
|
|
|
- return self.consume(task_id, timeout=timeout, no_ack=no_ack,
|
|
|
- on_interval=on_interval)
|
|
|
- except socket.timeout:
|
|
|
- raise TimeoutError('The operation timed out.')
|
|
|
-
|
|
|
def get_task_meta(self, task_id, backlog_limit=1000):
|
|
|
# Polling and using basic_get
|
|
|
with self.app.pool.acquire_channel(block=True) as (_, channel):
|
|
@@ -189,50 +285,37 @@ class AMQPBackend(BaseBackend):
|
|
|
return {'status': states.PENDING, 'result': None}
|
|
|
poll = get_task_meta # XXX compat
|
|
|
|
|
|
- def drain_events(self, connection, consumer,
|
|
|
- timeout=None, on_interval=None, now=monotonic, wait=None):
|
|
|
- wait = wait or connection.drain_events
|
|
|
- results = {}
|
|
|
+ def wait_for_pending(self, result, timeout=None, interval=0.5,
|
|
|
+ no_ack=True, on_interval=None, on_message=None,
|
|
|
+ callback=None, propagate=True):
|
|
|
+ return self.result_consumer.wait_for_pending(
|
|
|
+ result, timeout=timeout, interval=interval,
|
|
|
+ no_ack=no_ack, on_interval=on_interval,
|
|
|
+ callback=callback, on_message=on_message, propagate=propagate,
|
|
|
+ )
|
|
|
|
|
|
- def callback(meta, message):
|
|
|
- if meta['status'] in states.READY_STATES:
|
|
|
- results[meta['task_id']] = self.meta_from_decoded(meta)
|
|
|
+ def collect_for_pending(self, result, bucket=None, timeout=None,
|
|
|
+ interval=0.5, no_ack=True, on_interval=None,
|
|
|
+ on_message=None, callback=None, propagate=True):
|
|
|
+ return self.result_consumer.collect_for_pending(
|
|
|
+ result, bucket=bucket, timeout=timeout, interval=interval,
|
|
|
+ no_ack=no_ack, on_interval=on_interval,
|
|
|
+ callback=callback, on_message=on_message, propagate=propagate,
|
|
|
+ )
|
|
|
|
|
|
- consumer.callbacks[:] = [callback]
|
|
|
- time_start = now()
|
|
|
+ def add_pending_result(self, result):
|
|
|
+ if result.id not in self._pending_results:
|
|
|
+ self._pending_results[result.id] = result
|
|
|
+ self.result_consumer.consume_from(self._create_binding(result.id))
|
|
|
|
|
|
- while 1:
|
|
|
- # Total time spent may exceed a single call to wait()
|
|
|
- if timeout and now() - time_start >= timeout:
|
|
|
- raise socket.timeout()
|
|
|
- try:
|
|
|
- wait(timeout=1)
|
|
|
- except socket.timeout:
|
|
|
- pass
|
|
|
- if on_interval:
|
|
|
- on_interval()
|
|
|
- if results: # got event on the wanted channel.
|
|
|
- break
|
|
|
- self._cache.update(results)
|
|
|
- return results
|
|
|
-
|
|
|
- def consume(self, task_id, timeout=None, no_ack=True, on_interval=None):
|
|
|
- wait = self.drain_events
|
|
|
- with self.app.pool.acquire_channel(block=True) as (conn, channel):
|
|
|
- binding = self._create_binding(task_id)
|
|
|
- with self.Consumer(channel, binding,
|
|
|
- no_ack=no_ack, accept=self.accept) as consumer:
|
|
|
- while 1:
|
|
|
- try:
|
|
|
- return wait(
|
|
|
- conn, consumer, timeout, on_interval)[task_id]
|
|
|
- except KeyError:
|
|
|
- continue
|
|
|
+ def remove_pending_result(self, result):
|
|
|
+ self._pending_results.pop(result.id, None)
|
|
|
+ # XXX cancel queue after result consumed
|
|
|
|
|
|
def _many_bindings(self, ids):
|
|
|
return [self._create_binding(task_id) for task_id in ids]
|
|
|
|
|
|
- def get_many(self, task_ids, timeout=None, no_ack=True,
|
|
|
+ def xxx_get_many(self, task_ids, timeout=None, no_ack=True,
|
|
|
on_message=None, on_interval=None,
|
|
|
now=monotonic, getfields=itemgetter('status', 'task_id'),
|
|
|
READY_STATES=states.READY_STATES,
|