|
@@ -35,7 +35,7 @@ from kombu.utils import cached_property
|
|
|
|
|
|
from celery import states
|
|
|
from celery.five import items, python_2_unicode_compatible, values
|
|
|
-from celery.utils.functional import LRUCache, memoize
|
|
|
+from celery.utils.functional import LRUCache, memoize, pass1
|
|
|
from celery.utils.log import get_logger
|
|
|
|
|
|
__all__ = ['Worker', 'Task', 'State', 'heartbeat_expires']
|
|
@@ -285,10 +285,17 @@ class Task(object):
|
|
|
'root_id', 'parent_id',
|
|
|
)
|
|
|
|
|
|
- def __init__(self, uuid=None, cluster_state=None, **kwargs):
|
|
|
+ def __init__(self, uuid=None, cluster_state=None, children=None, **kwargs):
|
|
|
self.uuid = uuid
|
|
|
- self.children = WeakSet()
|
|
|
self.cluster_state = cluster_state
|
|
|
+ self.children = WeakSet(
|
|
|
+ self.cluster_state.tasks.get(task_id)
|
|
|
+ for task_id in children or ()
|
|
|
+ if task_id in self.cluster_state.tasks
|
|
|
+ )
|
|
|
+ self._serializer_handlers = {
|
|
|
+ 'children': self._serializable_children,
|
|
|
+ }
|
|
|
if kwargs:
|
|
|
self.__dict__.update(kwargs)
|
|
|
|
|
@@ -339,10 +346,14 @@ class Task(object):
|
|
|
|
|
|
def as_dict(self):
|
|
|
get = object.__getattribute__
|
|
|
+ handler = self._serializer_handlers.get
|
|
|
return {
|
|
|
- k: get(self, k) for k in self._fields
|
|
|
+ k: handler(k, pass1)(get(self, k)) for k in self._fields
|
|
|
}
|
|
|
|
|
|
+ def _serializable_children(self, value):
|
|
|
+ return [task.id for task in self.children]
|
|
|
+
|
|
|
def __reduce__(self):
|
|
|
return _depickle_task, (self.__class__, self.as_dict())
|
|
|
|
|
@@ -378,7 +389,8 @@ class State(object):
|
|
|
def __init__(self, callback=None,
|
|
|
workers=None, tasks=None, taskheap=None,
|
|
|
max_workers_in_memory=5000, max_tasks_in_memory=10000,
|
|
|
- on_node_join=None, on_node_leave=None):
|
|
|
+ on_node_join=None, on_node_leave=None,
|
|
|
+ tasks_by_type=None, tasks_by_worker=None):
|
|
|
self.event_callback = callback
|
|
|
self.workers = (LRUCache(max_workers_in_memory)
|
|
|
if workers is None else workers)
|
|
@@ -394,10 +406,18 @@ class State(object):
|
|
|
self._seen_types = set()
|
|
|
self._tasks_to_resolve = {}
|
|
|
self.rebuild_taskheap()
|
|
|
+
|
|
|
+ # type: Mapping[TaskName, WeakSet[Task]]
|
|
|
self.tasks_by_type = CallableDefaultdict(
|
|
|
self._tasks_by_type, WeakSet)
|
|
|
+ self.tasks_by_type.update(
|
|
|
+ _deserialize_Task_WeakSet_Mapping(tasks_by_type, self.tasks))
|
|
|
+
|
|
|
+ # type: Mapping[Hostname, WeakSet[Task]]
|
|
|
self.tasks_by_worker = CallableDefaultdict(
|
|
|
self._tasks_by_worker, WeakSet)
|
|
|
+ self.tasks_by_worker.update(
|
|
|
+ _deserialize_Task_WeakSet_Mapping(tasks_by_worker, self.tasks))
|
|
|
|
|
|
@cached_property
|
|
|
def _event(self):
|
|
@@ -674,4 +694,15 @@ class State(object):
|
|
|
self.event_callback, self.workers, self.tasks, None,
|
|
|
self.max_workers_in_memory, self.max_tasks_in_memory,
|
|
|
self.on_node_join, self.on_node_leave,
|
|
|
+ _serialize_Task_WeakSet_Mapping(self.tasks_by_type),
|
|
|
+ _serialize_Task_WeakSet_Mapping(self.tasks_by_worker),
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+def _serialize_Task_WeakSet_Mapping(mapping):
|
|
|
+ return {name: [t.id for t in tasks] for name, tasks in items(mapping)}
|
|
|
+
|
|
|
+
|
|
|
+def _deserialize_Task_WeakSet_Mapping(mapping, tasks):
|
|
|
+ return {name: WeakSet(tasks[i] for i in ids if i in tasks)
|
|
|
+ for name, ids in items(mapping or {})}
|