Explorar el Código

AMQP Backend: Get many must prepare exceptions

Ask Solem hace 11 años
padre
commit
5dbd4694a2
Se han modificado 2 ficheros con 15 adiciones y 9 borrados
  1. 7 3
      celery/backends/amqp.py
  2. 8 6
      celery/backends/base.py

+ 7 - 3
celery/backends/amqp.py

@@ -114,8 +114,8 @@ class AMQPBackend(BaseBackend):
             return self.rkey(task_id), request.correlation_id or task_id
         return self.rkey(task_id), task_id
 
-    def _store_result(self, task_id, result, status,
-                      traceback=None, request=None, **kwargs):
+    def store_result(self, task_id, result, status,
+                     traceback=None, request=None, **kwargs):
         """Send task return value and status."""
         routing_key, correlation_id = self.destination_for(task_id, request)
         if not routing_key:
@@ -230,7 +230,8 @@ class AMQPBackend(BaseBackend):
 
     def get_many(self, task_ids, timeout=None,
                  now=monotonic, getfields=itemgetter('status', 'task_id'),
-                 READY_STATES=states.READY_STATES, **kwargs):
+                 READY_STATES=states.READY_STATES,
+                 PROPAGATE_STATES=states.PROPAGATE_STATES, **kwargs):
         with self.app.pool.acquire_channel(block=True) as (conn, channel):
             ids = set(task_ids)
             cached_ids = set()
@@ -248,11 +249,14 @@ class AMQPBackend(BaseBackend):
             results = deque()
             push_result = results.append
             push_cache = self._cache.__setitem__
+            to_exception = self.exception_to_python
 
             def on_message(message):
                 body = message.decode()
                 state, uid = getfields(body)
                 if state in READY_STATES:
+                    if state in PROPAGATE_STATES:
+                        body['result'] = to_exception(body['result'])
                     push_result(body) \
                         if uid in task_ids else push_cache(uid, body)
 

+ 8 - 6
celery/backends/base.py

@@ -133,8 +133,8 @@ class BaseBackend(object):
         """Convert serialized exception to Python exception."""
         if self.serializer in EXCEPTION_ABLE_CODECS:
             return get_pickled_exception(exc)
-        return create_exception_cls(from_utf8(exc['exc_type']),
-                                    sys.modules[__name__])(exc['exc_message'])
+        return create_exception_cls(
+            from_utf8(exc['exc_type']), __name__)(exc['exc_message'])
 
     def prepare_value(self, result):
         """Prepare value for storage."""
@@ -379,17 +379,19 @@ class KeyValueStoreBackend(BaseBackend):
                         for i, value in enumerate(values)
                         if value is not None)
 
-    def get_many(self, task_ids, timeout=None, interval=0.5):
+    def get_many(self, task_ids, timeout=None, interval=0.5,
+                 READY_STATES=states.READY_STATES):
         interval = 0.5 if interval is None else interval
         ids = task_ids if isinstance(task_ids, set) else set(task_ids)
         cached_ids = set()
+        cache = self._cache
         for task_id in ids:
             try:
-                cached = self._cache[task_id]
+                cached = cache[task_id]
             except KeyError:
                 pass
             else:
-                if cached['status'] in states.READY_STATES:
+                if cached['status'] in READY_STATES:
                     yield bytes_to_str(task_id), cached
                     cached_ids.add(task_id)
 
@@ -399,7 +401,7 @@ class KeyValueStoreBackend(BaseBackend):
             keys = list(ids)
             r = self._mget_to_results(self.mget([self.get_key_for_task(k)
                                                  for k in keys]), keys)
-            self._cache.update(r)
+            cache.update(r)
             ids.difference_update(set(bytes_to_str(v) for v in r))
             for key, value in items(r):
                 yield bytes_to_str(key), value