Parcourir la source

Chords now works when result serializer is not pickle (e.g. JSON)

Ask Solem il y a 13 ans
Parent
commit
c2aeb14536

+ 1 - 1
celery/app/base.py

@@ -22,7 +22,6 @@ from functools import wraps
 from kombu.clocks import LamportClock
 
 from celery import platforms
-from celery.backends import get_backend_by_url
 from celery.exceptions import AlwaysEagerIgnored
 from celery.loaders import get_loader_cls
 from celery.local import PromiseProxy, maybe_evaluate
@@ -457,6 +456,7 @@ class Celery(object):
         return bugreport(self)
 
     def _get_backend(self):
+        from celery.backends import get_backend_by_url
         backend, url = get_backend_by_url(
                 self.backend_cls or self.conf.CELERY_RESULT_BACKEND,
                 self.loader)

+ 10 - 2
celery/backends/base.py

@@ -14,6 +14,7 @@ from celery import states
 from celery.app import current_task
 from celery.datastructures import LRUCache
 from celery.exceptions import TimeoutError, TaskRevokedError
+from celery.result import from_serializable
 from celery.utils import timeutils
 from celery.utils.serialization import (
         get_pickled_exception,
@@ -403,7 +404,7 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def _save_taskset(self, taskset_id, result):
         self.set(self.get_key_for_taskset(taskset_id),
-                 self.encode({"result": result}))
+                 self.encode({"result": result.serializable()}))
         return result
 
     def _delete_taskset(self, taskset_id):
@@ -419,8 +420,15 @@ class KeyValueStoreBackend(BaseDictBackend):
     def _restore_taskset(self, taskset_id):
         """Get task metadata for a task by id."""
         meta = self.get(self.get_key_for_taskset(taskset_id))
+        # previously this was always pickled, but later this
+        # was extended to support other serializers, so the
+        # structure is kind of weird.
         if meta:
-            return self.decode(meta)
+            meta = self.decode(meta)
+            result = meta["result"]
+            if isinstance(result, (list, tuple)):
+                return {"result": from_serializable(result)}
+            return meta
 
 
 class DisabledBackend(BaseBackend):

+ 35 - 9
celery/result.py

@@ -33,13 +33,21 @@ def _unpickle_result(task_id, task_name):
 
 
 def from_serializable(r):
-    id, nodes = r
-    if nodes:
-        return TaskSetResult(id, map(AsyncResult(nodes)))
-    return AsyncResult(id)
+    # earlier backends may just pickle, so check if
+    # result is already prepared.
+    if not isinstance(r, ResultBase):
+        id, nodes = r
+        if nodes:
+            return TaskSetResult(id, [AsyncResult(id) for id, _ in nodes])
+        return AsyncResult(id)
+    return r
 
 
-class AsyncResult(object):
+class ResultBase(object):
+    """Base class for all results"""
+
+
+class AsyncResult(ResultBase):
     """Query task state.
 
     :param id: see :attr:`id`.
@@ -190,12 +198,12 @@ class AsyncResult(object):
         return hash(self.id)
 
     def __repr__(self):
-        return "<AsyncResult: %s>" % self.id
+        return "<%s: %s>" % (self.__class__.__name__, self.id)
 
     def __eq__(self, other):
-        if isinstance(other, self.__class__):
+        if isinstance(other, AsyncResult):
             return self.id == other.id
-        return other == self.id
+        return NotImplemented
 
     def __copy__(self):
         return self.__class__(self.id, backend=self.backend)
@@ -288,7 +296,7 @@ class AsyncResult(object):
 BaseAsyncResult = AsyncResult  # for backwards compatibility.
 
 
-class ResultSet(object):
+class ResultSet(ResultBase):
     """Working with more than one result.
 
     :param results: List of result instances.
@@ -531,6 +539,15 @@ class ResultSet(object):
     def __len__(self):
         return len(self.results)
 
+    def __eq__(self, other):
+        if isinstance(other, ResultSet):
+            return other.results == self.results
+        return NotImplemented
+
+    def __repr__(self):
+        return "<%s: %r>" % (self.__class__.__name__,
+                             [r.id for r in self.results])
+
     @property
     def total(self):
         """Deprecated: Use ``len(r)``."""
@@ -594,6 +611,15 @@ class TaskSetResult(ResultSet):
     def __reduce__(self):
         return (TaskSetResult, (self.id, self.results))
 
+    def __eq__(self, other):
+        if isinstance(other, TaskSetResult):
+            return other.id == self.id and other.results == self.results
+        return NotImplemented
+
+    def __repr__(self):
+        return "<%s: %s %r>" % (self.__class__.__name__, self.id,
+                                [r.id for r in self.results])
+
     def serializable(self):
         return self.id, [r.serializable() for r in self.results]
 

+ 3 - 4
celery/task/sets.py

@@ -28,7 +28,7 @@ class subtask(AttributeDict):
     """Class that wraps the arguments and execution options
     for a single task invocation.
 
-    Used as the parts in a :class:`TaskSet` or to safely
+    Used as the parts in a :class:`group` or to safely
     pass tasks around as callbacks.
 
     :param task: Either a task class/instance, or the name of a task.
@@ -127,7 +127,6 @@ class subtask(AttributeDict):
 
 
 def maybe_subtask(t):
-    print("SUBTASK: %r" % (subtask, ))
     if not isinstance(t, subtask):
         return subtask(t)
     return t
@@ -172,7 +171,7 @@ class group(UserList):
 
     def apply_async(self, connection=None, connect_timeout=None,
             publisher=None, taskset_id=None):
-        """Apply taskset."""
+        """Apply group."""
         app = self.app
 
         if app.conf.CELERY_ALWAYS_EAGER:
@@ -198,7 +197,7 @@ class group(UserList):
                 for task in self.tasks]
 
     def apply(self, taskset_id=None):
-        """Applies the taskset locally by blocking until all tasks return."""
+        """Applies the group locally by blocking until all tasks return."""
         setid = taskset_id or uuid()
         return self.app.TaskSetResult(setid, self._sync_results(setid))
 

+ 6 - 3
celery/tests/test_backends/test_base.py

@@ -8,7 +8,7 @@ from mock import Mock
 from nose import SkipTest
 
 from celery import current_app
-from celery.result import AsyncResult
+from celery.result import AsyncResult, TaskSetResult
 from celery.utils import serialization
 from celery.utils.serialization import subclass_exception
 from celery.utils.serialization import \
@@ -268,8 +268,11 @@ class test_KeyValueStoreBackend(Case):
 
     def test_save_restore_delete_taskset(self):
         tid = uuid()
-        self.b.save_taskset(tid, "Hello world")
-        self.assertEqual(self.b.restore_taskset(tid), "Hello world")
+        tsr = TaskSetResult(tid, [AsyncResult(uuid()) for _ in range(10)])
+        self.b.save_taskset(tid, tsr)
+        stored = self.b.restore_taskset(tid)
+        print(stored)
+        self.assertEqual(self.b.restore_taskset(tid), tsr)
         self.b.delete_taskset(tid)
         self.assertIsNone(self.b.restore_taskset(tid))