Ver Fonte

Events: State: Fixes pickleability (cc @mher)

Ask Solem há 8 anos atrás
pai
commit
f4ba39891f
2 ficheiros alterados com 40 adições e 6 exclusões
  1. 36 5
      celery/events/state.py
  2. 4 1
      celery/tests/events/test_state.py

+ 36 - 5
celery/events/state.py

@@ -35,7 +35,7 @@ from kombu.utils import cached_property
 
 from celery import states
 from celery.five import items, python_2_unicode_compatible, values
-from celery.utils.functional import LRUCache, memoize
+from celery.utils.functional import LRUCache, memoize, pass1
 from celery.utils.log import get_logger
 
 __all__ = ['Worker', 'Task', 'State', 'heartbeat_expires']
@@ -285,10 +285,17 @@ class Task(object):
         'root_id', 'parent_id',
     )
 
-    def __init__(self, uuid=None, cluster_state=None, **kwargs):
+    def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
         self.uuid = uuid
-        self.children = WeakSet()
         self.cluster_state = cluster_state
+        self.children = WeakSet(
+            self.cluster_state.tasks.get(task_id)
+            for task_id in children or ()
+            if task_id in self.cluster_state.tasks
+        )
+        self._serializer_handlers = {
+            'children': self._serializable_children,
+        }
         if kwargs:
             self.__dict__.update(kwargs)
 
@@ -339,10 +346,14 @@ class Task(object):
 
     def as_dict(self):
         get = object.__getattribute__
+        handler = self._serializer_handlers.get
         return {
-            k: get(self, k) for k in self._fields
+            k: handler(k, pass1)(get(self, k)) for k in self._fields
         }
 
+    def _serializable_children(self, value):
+        return [task.id for task in self.children]
+
     def __reduce__(self):
         return _depickle_task, (self.__class__, self.as_dict())
 
@@ -378,7 +389,8 @@ class State(object):
     def __init__(self, callback=None,
                  workers=None, tasks=None, taskheap=None,
                  max_workers_in_memory=5000, max_tasks_in_memory=10000,
-                 on_node_join=None, on_node_leave=None):
+                 on_node_join=None, on_node_leave=None,
+                 tasks_by_type=None, tasks_by_worker=None):
         self.event_callback = callback
         self.workers = (LRUCache(max_workers_in_memory)
                         if workers is None else workers)
@@ -394,10 +406,18 @@ class State(object):
         self._seen_types = set()
         self._tasks_to_resolve = {}
         self.rebuild_taskheap()
+
+        # type: Mapping[TaskName, WeakSet[Task]]
         self.tasks_by_type = CallableDefaultdict(
             self._tasks_by_type, WeakSet)
+        self.tasks_by_type.update(
+            _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
+
+        # type: Mapping[Hostname, WeakSet[Task]]
         self.tasks_by_worker = CallableDefaultdict(
             self._tasks_by_worker, WeakSet)
+        self.tasks_by_worker.update(
+            _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
 
     @cached_property
     def _event(self):
@@ -674,4 +694,15 @@ class State(object):
             self.event_callback, self.workers, self.tasks, None,
             self.max_workers_in_memory, self.max_tasks_in_memory,
             self.on_node_join, self.on_node_leave,
+            _serialize_Task_WeakSet_Mapping(self.tasks_by_type),
+            _serialize_Task_WeakSet_Mapping(self.tasks_by_worker),
         )
+
+
+def _serialize_Task_WeakSet_Mapping(mapping):
+    return {name: [t.id for t in tasks] for name, tasks in items(mapping)}
+
+
+def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
+    return {name: WeakSet(tasks[i] for i in ids if i in tasks)
+            for name, ids in items(mapping or {})}

+ 4 - 1
celery/tests/events/test_state.py

@@ -353,7 +353,10 @@ class test_State(AppCase):
         self.assertTrue(repr(State()))
 
     def test_pickleable(self):
-        self.assertTrue(pickle.loads(pickle.dumps(State())))
+        state = State()
+        r = ev_logical_clock_ordering(state)
+        r.play()
+        self.assertTrue(pickle.loads(pickle.dumps(state)))
 
     def test_task_logical_clock_ordering(self):
         state = State()