Selaa lähdekoodia

Worker: Control: O(1) lookup of tasks by id for query_tasks, etc.

Ask Solem 8 vuotta sitten
vanhempi
commit
90c6a684ad
2 muutettua tiedostoa jossa 25 lisäystä ja 28 poistoa
  1. 3 3
      celery/tests/worker/test_control.py
  2. 22 25
      celery/worker/control.py

+ 3 - 3
celery/tests/worker/test_control.py

@@ -489,7 +489,7 @@ class test_ControlPanel(AppCase):
         request.id = tid = uuid()
         state = self.create_state()
         state.consumer = Mock()
-        worker_state.reserved_requests.add(request)
+        worker_state.task_reserved(request)
         try:
             r = control.revoke(state, tid, terminate=True)
             self.assertIn(tid, revoked)
@@ -499,7 +499,7 @@ class test_ControlPanel(AppCase):
             r = control.revoke(state, uuid(), terminate=True)
             self.assertIn('tasks unknown', r['ok'])
         finally:
-            worker_state.reserved_requests.discard(request)
+            worker_state.task_ready(request)
 
     def test_autoscale(self):
         self.panel.state.consumer = Mock()
@@ -645,7 +645,7 @@ class test_ControlPanel(AppCase):
             TaskMessage(self.mytask.name, args=(2, 2)),
             app=self.app,
         )
-        worker_state.reserved_requests.add(req1)
+        worker_state.task_reserved(req1)
         try:
             self.assertFalse(panel.handle('query_task', {'ids': {'1daa'}}))
             ret = panel.handle('query_task', {'ids': {req1.id}})

+ 22 - 25
celery/worker/control.py

@@ -48,28 +48,31 @@ class Panel(UserDict):
         return method
 
 
-def _find_requests_by_id(ids, requests):
-    found, total = 0, len(ids)
-    for request in requests:
-        if request.id in ids:
-            yield request
-            found += 1
-            if found >= total:
-                break
+def _find_requests_by_id(ids,
+                         get_request=worker_state.requests.__getitem__):
+    for task_id in ids:
+        try:
+            yield get_request(task_id)
+        except KeyError:
+            pass
+
+
+def _state_of_task(request,
+                   is_active=worker_state.active_requests.__contains__,
+                   is_reserved=worker_state.reserved_requests.__contains__):
+    if is_active(request):
+        return 'active'
+    elif is_reserved(request):
+        return 'reserved'
+    return 'ready'
 
 
 @Panel.register
 def query_task(state, ids, **kwargs):
-    ids = maybe_list(ids)
-    return dict({
-        req.id: ('reserved', req.info())
-        for req in _find_requests_by_id(
-            ids, state.tset(worker_state.reserved_requests))
-    }, **{
-        req.id: ('active', req.info())
-        for req in _find_requests_by_id(
-            ids, state.tset(worker_state.active_requests))
-    })
+    return {
+        req.id: (_state_of_task(req), req.info())
+        for req in _find_requests_by_id(maybe_list(ids))
+    }
 
 
 @Panel.register
@@ -83,13 +86,7 @@ def revoke(state, task_id, terminate=False, signal=None, **kwargs):
     revoked.update(task_ids)
     if terminate:
         signum = _signals.signum(signal or TERM_SIGNAME)
-        # reserved_requests changes size during iteration
-        # so need to consume the items first, then terminate after.
-        requests = set(_find_requests_by_id(
-            task_ids,
-            state.tset(worker_state.reserved_requests),
-        ))
-        for request in requests:
+        for request in _find_requests_by_id(task_ids):
             if request.id not in terminated:
                 terminated.add(request.id)
                 logger.info('Terminating %s (%s)', request.id, signum)