|
@@ -64,10 +64,9 @@ class AMQPBackend(BaseDictBackend):
|
|
|
self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
|
|
|
|
|
|
def _create_binding(self, task_id):
|
|
|
- name = task_id.replace("-", "")
|
|
|
- return Queue(name=name,
|
|
|
+ return Queue(name=task_id,
|
|
|
exchange=self.exchange,
|
|
|
- routing_key=name,
|
|
|
+ routing_key=task_id,
|
|
|
durable=self.persistent,
|
|
|
auto_delete=self.auto_delete)
|
|
|
|
|
@@ -76,12 +75,11 @@ class AMQPBackend(BaseDictBackend):
|
|
|
binding(self.channel).declare()
|
|
|
|
|
|
return Producer(self.channel, exchange=self.exchange,
|
|
|
- routing_key=task_id.replace("-", ""),
|
|
|
+ routing_key=task_id,
|
|
|
serializer=self.serializer)
|
|
|
|
|
|
- def _create_consumer(self, task_id):
|
|
|
- binding = self._create_binding(task_id)
|
|
|
- return Consumer(self.channel, [binding], no_ack=True)
|
|
|
+ def _create_consumer(self, bindings):
|
|
|
+ return Consumer(self.channel, bindings, no_ack=True)
|
|
|
|
|
|
def store_result(self, task_id, result, status, traceback=None,
|
|
|
max_retries=20, retry_delay=0.2):
|
|
@@ -143,33 +141,64 @@ class AMQPBackend(BaseDictBackend):
|
|
|
return self._cache[task_id] # use previously received state.
|
|
|
return {"status": states.PENDING, "result": None}
|
|
|
|
|
|
- def consume(self, task_id, timeout=None):
|
|
|
- results = []
|
|
|
+ def drain_events(self, consumer, timeout=None):
|
|
|
+ wait = self.connection.drain_events
|
|
|
+ results = {}
|
|
|
|
|
|
def callback(meta, message):
|
|
|
if meta["status"] in states.READY_STATES:
|
|
|
- results.append(meta)
|
|
|
+ results[message.delivery_info["routing_key"]] = meta
|
|
|
|
|
|
- wait = self.connection.drain_events
|
|
|
- consumer = self._create_consumer(task_id)
|
|
|
consumer.register_callback(callback)
|
|
|
|
|
|
+ time_start = time.time()
|
|
|
+ while 1:
|
|
|
+ # Total time spent may exceed a single call to wait()
|
|
|
+ if timeout and time.time() - time_start >= timeout:
|
|
|
+ raise socket.timeout()
|
|
|
+ wait(timeout=timeout)
|
|
|
+ if results:
|
|
|
+ # Got event on the wanted channel.
|
|
|
+ break
|
|
|
+
|
|
|
+ self._cache.update(results)
|
|
|
+ return results
|
|
|
+
|
|
|
+ def consume(self, task_id, timeout=None):
|
|
|
+ binding = self._create_binding(task_id)
|
|
|
+ consumer = self._create_consumer(binding)
|
|
|
+ consumer.consume()
|
|
|
+ try:
|
|
|
+ return self.drain_events(consumer, timeout=timeout).values()[0]
|
|
|
+ finally:
|
|
|
+ consumer.cancel()
|
|
|
+
|
|
|
+ def get_many(self, task_ids, timeout=None):
|
|
|
+ bindings = [self._create_binding(task_id) for task_id in task_ids]
|
|
|
+ consumer = self._create_consumer(bindings)
|
|
|
consumer.consume()
|
|
|
+ ids = set(task_ids)
|
|
|
+ results = {}
|
|
|
+ cached_ids = set()
|
|
|
+ for task_id in ids:
|
|
|
+ try:
|
|
|
+ cached = self._cache[task_id]
|
|
|
+ except KeyError:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ if cached["status"] in states.READY_STATES:
|
|
|
+ results[task_id] = cached
|
|
|
+ cached_ids.add(task_id)
|
|
|
+ ids ^= cached_ids
|
|
|
try:
|
|
|
- time_start = time.time()
|
|
|
- while True:
|
|
|
- # Total time spent may exceed a single call to wait()
|
|
|
- if timeout and time.time() - time_start >= timeout:
|
|
|
- raise socket.timeout()
|
|
|
- wait(timeout=timeout)
|
|
|
- if results:
|
|
|
- # Got event on the wanted channel.
|
|
|
- break
|
|
|
+ while ids:
|
|
|
+ r = self.drain_events(consumer, timeout=timeout)
|
|
|
+ results.update(r)
|
|
|
+ ids ^= set(r.keys())
|
|
|
finally:
|
|
|
consumer.cancel()
|
|
|
|
|
|
- self._cache[task_id] = results[0]
|
|
|
- return results[0]
|
|
|
+ return results
|
|
|
|
|
|
def close(self):
|
|
|
if self._channel is not None:
|