Procházet zdrojové kódy

Events: State: tasks_by_type, tasks_by_worker can now be used as dict index

Ask Solem před 9 roky
rodič
revize
48871fe571
2 změnil soubory, kde provedl 56 přidání a 6 odebrání
  1. 50 6
      celery/events/state.py
  2. 6 0
      celery/tests/events/test_state.py

+ 50 - 6
celery/events/state.py

@@ -22,6 +22,7 @@ import bisect
 import sys
 import threading
 
+from collections import Callable, defaultdict
 from datetime import datetime
 from decimal import Decimal
 from itertools import islice
@@ -75,6 +76,32 @@ TASK_EVENT_TO_STATE = {
 }
 
 
+class CallableDefaultdict(defaultdict):
+    """:class:`~collections.defaultdict` with configurable __call__.
+
+    We use this for backwards compatibility in State.tasks_by_type
+    etc, which used to be a method but is now an index instead.
+
+    So you can do::
+
+        >>> add_tasks = state.tasks_by_type['proj.tasks.add']
+
+    while still supporting the method call::
+
+        >>> add_tasks = list(state.tasks_by_type(
+        ...     'proj.tasks.add', reverse=True))
+
+    """
+
+    def __init__(self, fun, *args, **kwargs):
+        self.fun = fun
+        super(CallableDefaultdict, self).__init__(*args, **kwargs)
+
+    def __call__(self, *args, **kwargs):
+        return self.fun(*args, **kwargs)
+Callable.register(CallableDefaultdict)
+
+
 @memoize(maxsize=1000, keyfun=lambda a, _: a[0])
 def _warn_drift(hostname, drift, local_received, timestamp):
     # we use memoize here so the warning is only logged once per hostname
@@ -367,6 +394,10 @@ class State(object):
         self._seen_types = set()
         self._tasks_to_resolve = {}
         self.rebuild_taskheap()
+        self.tasks_by_type = CallableDefaultdict(
+            self._tasks_by_type, WeakSet)
+        self.tasks_by_worker = CallableDefaultdict(
+            self._tasks_by_worker, WeakSet)
 
     @cached_property
     def _event(self):
@@ -463,6 +494,9 @@ class State(object):
         # avoid updating LRU entry at getitem
         get_worker, get_task = workers.data.__getitem__, tasks.data.__getitem__
 
+        get_task_by_type_set = self.tasks_by_type.__getitem__
+        get_task_by_worker_set = self.tasks_by_worker.__getitem__
+
         def _event(event,
                    timetuple=timetuple, KeyError=KeyError,
                    insort=bisect.insort, created=True):
@@ -504,14 +538,15 @@ class State(object):
                 # task-sent event is sent by client, not worker
                 is_client_event = subject == 'sent'
                 try:
-                    task, created = get_task(uuid), False
+                    task, task_created = get_task(uuid), False
                 except KeyError:
                     task = tasks[uuid] = Task(uuid, cluster_state=self)
+                    task_created = True
                 if is_client_event:
                     task.client = hostname
                 else:
                     try:
-                        worker, created = get_worker(hostname), False
+                        worker = get_worker(hostname)
                     except KeyError:
                         worker = workers[hostname] = Worker(hostname)
                     task.worker = worker
@@ -538,6 +573,9 @@ class State(object):
                 task_name = task.name
                 if task_name is not None:
                     add_type(task_name)
+                    if task_created:  # add to tasks_by_type index
+                        get_task_by_type_set(task_name).add(task)
+                        get_task_by_worker_set(hostname).add(task)
                 if task.parent_id:
                     try:
                         parent_task = self.tasks[task.parent_id]
@@ -552,7 +590,7 @@ class State(object):
                 else:
                     task.children.update(_children)
 
-                return (task, created), subject
+                return (task, task_created), subject
         return _event
 
     def _add_pending_task_child(self, task):
@@ -592,10 +630,13 @@ class State(object):
                     seen.add(uuid)
     tasks_by_timestamp = tasks_by_time
 
-    def tasks_by_type(self, name, limit=None, reverse=True):
+    def _tasks_by_type(self, name, limit=None, reverse=True):
         """Get all tasks by type.
 
-        Return a list of ``(uuid, Task)`` tuples.
+        This is slower than accessing :attr:`tasks_by_type`,
+        but will be ordered by time.
+
+        Return a generator giving ``(uuid, Task)`` tuples.
 
         """
         return islice(
@@ -604,9 +645,12 @@ class State(object):
             0, limit,
         )
 
-    def tasks_by_worker(self, hostname, limit=None, reverse=True):
+    def _tasks_by_worker(self, hostname, limit=None, reverse=True):
         """Get all tasks by worker.
 
+        This is slower than accessing :attr:`tasks_by_worker`,
+        but will be ordered by time.
+
         """
         return islice(
             ((uuid, task) for uuid, task in self.tasks_by_time(reverse=reverse)

+ 6 - 0
celery/tests/events/test_state.py

@@ -593,6 +593,9 @@ class test_State(AppCase):
         self.assertEqual(len(list(r.state.tasks_by_type('task1'))), 10)
         self.assertEqual(len(list(r.state.tasks_by_type('task2'))), 10)
 
+        self.assertEqual(len(r.state.tasks_by_type['task1']), 10)
+        self.assertEqual(len(r.state.tasks_by_type['task2']), 10)
+
     def test_alive_workers(self):
         r = ev_snapshot(State())
         r.play()
@@ -604,6 +607,9 @@ class test_State(AppCase):
         self.assertEqual(len(list(r.state.tasks_by_worker('utest1'))), 10)
         self.assertEqual(len(list(r.state.tasks_by_worker('utest2'))), 10)
 
+        self.assertEqual(len(r.state.tasks_by_worker['utest1']), 10)
+        self.assertEqual(len(r.state.tasks_by_worker['utest2']), 10)
+
     def test_survives_unknown_worker_event(self):
         s = State()
         s.event({