Browse Source

Smarter pickling of AsyncResult's

Ask Solem 14 years ago
parent
commit
40ba910ea0
4 changed files with 35 additions and 11 deletions
  1. 2 1
      celery/app/base.py
  2. 4 0
      celery/registry.py
  3. 26 4
      celery/result.py
  4. 3 6
      celery/task/base.py

+ 2 - 1
celery/app/base.py

@@ -116,10 +116,11 @@ class BaseApp(object):
         return self.with_default_connection(_do_publish)(
                 connection=connection, connect_timeout=connect_timeout)
 
-    def AsyncResult(self, task_id, backend=None):
+    def AsyncResult(self, task_id, backend=None, task_name=None):
         """Create :class:`celery.result.BaseAsyncResult` instance."""
         from celery.result import BaseAsyncResult
         return BaseAsyncResult(task_id, app=self,
+                               task_name=task_name,
                                backend=backend or self.backend)
 
     def TaskSetResult(self, taskset_id, results, **kwargs):

+ 4 - 0
celery/registry.py

@@ -76,3 +76,7 @@ class TaskRegistry(UserDict):
 
 """
 tasks = TaskRegistry()
+
+
+def _unpickle_task(name):
+    return tasks[name]

+ 26 - 4
celery/result.py

@@ -9,9 +9,14 @@ from celery import states
 from celery.app import app_or_default
 from celery.datastructures import PositionQueue
 from celery.exceptions import TimeoutError
+from celery.registry import _unpickle_task
 from celery.utils.compat import any, all
 
 
+def _unpickle_result(task_id, task_name):
+    return _unpickle_task(task_name).AsyncResult(task_id)
+
+
 class BaseAsyncResult(object):
     """Base class for pending result, supports custom task result backend.
 
@@ -29,11 +34,19 @@ class BaseAsyncResult(object):
     #: The task result backend to use.
     backend = None
 
-    def __init__(self, task_id, backend, app=None):
+    def __init__(self, task_id, backend, task_name=None, app=None):
         self.task_id = task_id
         self.backend = backend
+        self.task_name = task_name
         self.app = app_or_default(app)
 
+    def __reduce__(self):
+        if self.task_name:
+            return (_unpickle_result, (self.task_id, self.task_name))
+        else:
+            return (self.__class__, (self.task_id, self.backend,
+                                     None, self.app))
+
     def forget(self):
         """Forget about (and possibly remove the result of) this task."""
         self.backend.forget(self.task_id)
@@ -172,10 +185,12 @@ class AsyncResult(BaseAsyncResult):
     #: Task result store backend to use.
     backend = None
 
-    def __init__(self, task_id, backend=None, app=None):
+    def __init__(self, task_id, backend=None, task_name=None, app=None):
         app = app_or_default(app)
         backend = backend or app.backend
-        super(AsyncResult, self).__init__(task_id, backend, app=app)
+        super(AsyncResult, self).__init__(task_id, backend,
+                                          task_name=task_name, app=app)
+
 
 
 class TaskSetResult(object):
@@ -201,6 +216,13 @@ class TaskSetResult(object):
         self.subtasks = subtasks
         self.app = app_or_default(app)
 
+    def __reduce__(self):
+        return (self.__class__, (self.taskset_id, self.subtasks))
+
+    @classmethod
+    def _unpickle(cls, taskset_id, subtasks):
+        return cls(taskset_id, subtasks)
+
     def itersubtasks(self):
         """Taskset subtask iterator.
 
@@ -357,7 +379,7 @@ class TaskSetResult(object):
     def restore(self, taskset_id, backend=None):
         """Restore previously saved taskset result."""
         if backend is None:
-            backend = self.app.backend
+            backend = app_or_default().backend
         return backend.restore_taskset(taskset_id)
 
     @property

+ 3 - 6
celery/task/base.py

@@ -6,7 +6,7 @@ from celery.app import app_or_default
 from celery.datastructures import ExceptionInfo
 from celery.exceptions import MaxRetriesExceededError, RetryTaskError
 from celery.execute.trace import TaskTrace
-from celery.registry import tasks
+from celery.registry import tasks, _unpickle_task
 from celery.result import EagerResult
 from celery.schedules import maybe_schedule
 from celery.utils import mattrgetter, gen_unique_id, fun_takes_kwargs
@@ -44,10 +44,6 @@ _default_context = {"logfile": None,
                     "delivery_info": None}
 
 
-def _unpickle_task(name):
-    return tasks[name]
-
-
 class Context(threading.local):
 
     def update(self, d, **kwargs):
@@ -609,7 +605,8 @@ class BaseTask(object):
         :param task_id: Task id to get result for.
 
         """
-        return self.app.AsyncResult(task_id, backend=self.backend)
+        return self.app.AsyncResult(task_id, backend=self.backend,
+                                             task_name=self.name)
 
     def update_state(self, task_id=None, state=None, meta=None):
         """Update task state.