Browse Source

[events][state] Adds Task.parent, Task.root and Task.children links to related tasks

Ask Solem 9 years ago
parent
commit
d81f4406e8
2 changed files with 68 additions and 11 deletions
  1. 45 11
      celery/events/state.py
  2. 23 0
      celery/tests/events/test_state.py

+ 45 - 11
celery/events/state.py

@@ -27,7 +27,7 @@ from decimal import Decimal
 from itertools import islice
 from operator import itemgetter
 from time import time
-from weakref import ref
+from weakref import WeakSet, ref
 
 from kombu.clocks import timetuple
 from kombu.utils import cached_property
@@ -216,7 +216,8 @@ class Task(object):
         'succeeded', 'failed', 'retried', 'revoked', 'args', 'kwargs',
         'eta', 'expires', 'retries', 'worker', 'result', 'exception',
         'timestamp', 'runtime', 'traceback', 'exchange', 'routing_key',
-        'clock', 'client', 'root_id', 'parent_id',
+        'clock', 'client', 'root', 'root_id', 'parent', 'parent_id',
+        'children',
     )
     if not PYPY:  # pragma: no cover
         __slots__ = ('__dict__', '__weakref__')
@@ -243,11 +244,12 @@ class Task(object):
         'root_id', 'parent_id',
     )
 
-    def __init__(self, uuid=None, **kwargs):
+    def __init__(self, uuid=None, cluster_state=None, **kwargs):
         self.uuid = uuid
+        self.children = WeakSet()
+        self.cluster_state = cluster_state
         if kwargs:
-            for k, v in items(kwargs):
-                setattr(self, k, v)
+            self.__dict__.update(kwargs)
 
     def event(self, type_, timestamp=None, local_received=None, fields=None,
               precedence=states.precedence, items=items, dict=dict,
@@ -285,13 +287,11 @@ class Task(object):
                 fields = {
                     k: v for k, v in items(fields) if k in keep
                 }
-            for key, value in items(fields):
-                setattr(self, key, value)
+            self.__dict__.update(fields)
         else:
             self.state = state
             self.timestamp = timestamp
-            for key, value in items(fields):
-                setattr(self, key, value)
+            self.__dict__.update(fields)
 
     def info(self, fields=None, extra=[]):
         """Information about this task suitable for on-screen display."""
@@ -317,6 +317,10 @@ class Task(object):
     def __reduce__(self):
         return _depickle_task, (self.__class__, self.as_dict())
 
+    @property
+    def id(self):
+        return self.uuid
+
     @property
     def origin(self):
         return self.client if self.worker is None else self.worker.id
@@ -325,6 +329,14 @@ class Task(object):
     def ready(self):
         return self.state in states.READY_STATES
 
+    @cached_property
+    def parent(self):
+        return self.parent_id and self.cluster_state.tasks[self.parent_id]
+
+    @cached_property
+    def root(self):
+        return self.root_id and self.cluster_state.tasks[self.root_id]
+
 
 class State(object):
     """Records clusters state."""
@@ -351,6 +363,7 @@ class State(object):
         self._mutex = threading.Lock()
         self.handlers = {}
         self._seen_types = set()
+        self._tasks_to_resolve = {}
         self.rebuild_taskheap()
 
     @cached_property
@@ -412,7 +425,7 @@ class State(object):
         try:
             return self.tasks[uuid], False
         except KeyError:
-            task = self.tasks[uuid] = self.Task(uuid)
+            task = self.tasks[uuid] = self.Task(uuid, cluster_state=self)
             return task, True
 
     def event(self, event):
@@ -491,7 +504,7 @@ class State(object):
                 try:
                     task, created = get_task(uuid), False
                 except KeyError:
-                    task = tasks[uuid] = Task(uuid)
+                    task = tasks[uuid] = Task(uuid, cluster_state=self)
                 if is_client_event:
                     task.client = hostname
                 else:
@@ -523,9 +536,30 @@ class State(object):
                 task_name = task.name
                 if task_name is not None:
                     add_type(task_name)
+                if task.parent_id:
+                    try:
+                        parent_task = self.tasks[task.parent_id]
+                    except KeyError:
+                        self._add_pending_task_child(task)
+                    else:
+                        parent_task.children.add(task)
+                try:
+                    _children = self._tasks_to_resolve.pop(uuid)
+                except KeyError:
+                    pass
+                else:
+                    task.children.update(_children)
+
                 return (task, created), subject
         return _event
 
+    def _add_pending_task_child(self, task):
+        try:
+            ch = self._tasks_to_resolve[task.parent_id]
+        except KeyError:
+            ch = self._tasks_to_resolve[task.parent_id] = WeakSet()
+        ch.add(task)
+
     def rebuild_taskheap(self, timetuple=timetuple):
         heap = self._taskheap[:] = [
             timetuple(t.clock, t.timestamp, t.origin, ref(t))

+ 23 - 0
celery/tests/events/test_state.py

@@ -93,6 +93,7 @@ class ev_task_states(replay):
 
     def setup(self):
         tid = self.tid = uuid()
+        tid2 = self.tid2 = uuid()
         self.events = [
             Event('task-received', uuid=tid, name='task1',
                   args='(2, 2)', kwargs="{'foo': 'bar'}",
@@ -106,6 +107,11 @@ class ev_task_states(replay):
             Event('task-succeeded', uuid=tid, result='4',
                   runtime=0.1234, hostname='utest1'),
             Event('foo-bar'),
+
+            Event('task-received', uuid=tid2, name='task2',
+                  args='(4, 4)', kwargs="{'foo': 'bar'}",
+                  retries=0, eta=None, parent_id=tid, root_id=tid,
+                  hostname='utest1'),
         ]
 
 
@@ -499,6 +505,23 @@ class test_State(AppCase):
         self.assertEqual(task.result, '4')
         self.assertEqual(task.runtime, 0.1234)
 
+        # children, parent, root
+        r.play()
+        self.assertIn(r.tid2, r.state.tasks)
+        task2 = r.state.tasks[r.tid2]
+
+        self.assertIs(task2.parent, task)
+        self.assertIs(task2.root, task)
+        self.assertIn(task2, task.children)
+
+    def test_task_children_set_if_received_in_wrong_order(self):
+        r = ev_task_states(State())
+        r.events.insert(0, r.events.pop())
+        r.play()
+        self.assertIn(r.state.tasks[r.tid2], r.state.tasks[r.tid].children)
+        self.assertIs(r.state.tasks[r.tid2].root, r.state.tasks[r.tid])
+        self.assertIs(r.state.tasks[r.tid2].parent, r.state.tasks[r.tid])
+
     def assertStateEmpty(self, state):
         self.assertFalse(state.tasks)
         self.assertFalse(state.workers)