Browse Source

Merged branch join_native: TaskSetResult.join_native, uses backend-optimized solution to retrieve results for more than one task

Ask Solem 14 years ago
parent
commit
da6dc4bb4c
2 changed files with 65 additions and 23 deletions
  1. 52 23
      celery/backends/amqp.py
  2. 13 0
      celery/result.py

+ 52 - 23
celery/backends/amqp.py

@@ -64,10 +64,9 @@ class AMQPBackend(BaseDictBackend):
             self.queue_arguments["x-expires"] = int(self.expires * 1000.0)
 
     def _create_binding(self, task_id):
-        name = task_id.replace("-", "")
-        return Queue(name=name,
+        return Queue(name=task_id,
                      exchange=self.exchange,
-                     routing_key=name,
+                     routing_key=task_id,
                      durable=self.persistent,
                      auto_delete=self.auto_delete)
 
@@ -76,12 +75,11 @@ class AMQPBackend(BaseDictBackend):
         binding(self.channel).declare()
 
         return Producer(self.channel, exchange=self.exchange,
-                        routing_key=task_id.replace("-", ""),
+                        routing_key=task_id,
                         serializer=self.serializer)
 
-    def _create_consumer(self, task_id):
-        binding = self._create_binding(task_id)
-        return Consumer(self.channel, [binding], no_ack=True)
+    def _create_consumer(self, bindings):
+        return Consumer(self.channel, bindings, no_ack=True)
 
     def store_result(self, task_id, result, status, traceback=None,
             max_retries=20, retry_delay=0.2):
@@ -143,33 +141,64 @@ class AMQPBackend(BaseDictBackend):
             return self._cache[task_id]     # use previously received state.
         return {"status": states.PENDING, "result": None}
 
-    def consume(self, task_id, timeout=None):
-        results = []
+    def drain_events(self, consumer, timeout=None):
+        wait = self.connection.drain_events
+        results = {}
 
         def callback(meta, message):
             if meta["status"] in states.READY_STATES:
-                results.append(meta)
+                results[message.delivery_info["routing_key"]] = meta
 
-        wait = self.connection.drain_events
-        consumer = self._create_consumer(task_id)
         consumer.register_callback(callback)
 
+        time_start = time.time()
+        while 1:
+            # Total time spent may exceed a single call to wait()
+            if timeout and time.time() - time_start >= timeout:
+                raise socket.timeout()
+            wait(timeout=timeout)
+            if results:
+                # Got event on the wanted channel.
+                break
+
+        self._cache.update(results)
+        return results
+
+    def consume(self, task_id, timeout=None):
+        binding = self._create_binding(task_id)
+        consumer = self._create_consumer(binding)
+        consumer.consume()
+        try:
+            return self.drain_events(consumer, timeout=timeout).values()[0]
+        finally:
+            consumer.cancel()
+
+    def get_many(self, task_ids, timeout=None):
+        bindings = [self._create_binding(task_id) for task_id in task_ids]
+        consumer = self._create_consumer(bindings)
         consumer.consume()
+        ids = set(task_ids)
+        results = {}
+        cached_ids = set()
+        for task_id in ids:
+            try:
+                cached = self._cache[task_id]
+            except KeyError:
+                pass
+            else:
+                if cached["status"] in states.READY_STATES:
+                    results[task_id] = cached
+                    cached_ids.add(task_id)
+        ids ^= cached_ids
         try:
-            time_start = time.time()
-            while True:
-                # Total time spent may exceed a single call to wait()
-                if timeout and time.time() - time_start >= timeout:
-                    raise socket.timeout()
-                wait(timeout=timeout)
-                if results:
-                    # Got event on the wanted channel.
-                    break
+            while ids:
+                r = self.drain_events(consumer, timeout=timeout)
+                results.update(r)
+                ids ^= set(r.keys())
         finally:
             consumer.cancel()
 
-        self._cache[task_id] = results[0]
-        return results[0]
+        return results
 
     def close(self):
         if self._channel is not None:

+ 13 - 0
celery/result.py

@@ -320,6 +320,19 @@ class TaskSetResult(object):
                 elif result.status in states.PROPAGATE_STATES:
                     raise result.result
 
+    def join_native(self, timeout=None, propagate=True):
+        backend = self.subtasks[0].backend
+        results = PositionQueue(length=self.total)
+        ids = [subtask.task_id for subtask in self.subtasks]
+
+        states = backend.get_many(ids, timeout=timeout)
+
+        for task_id, meta in states.items():
+            index = self.subtasks.index(task_id)
+            results[index] = meta["result"]
+
+        return list(results)
+
     def join(self, timeout=None, propagate=True):
         """Gather the results of all tasks in the taskset,
         and returns a list ordered by the order of the set.