Browse Source

Worker: inspect active and friends must copy active_requests when using threads. Closes #2567

Ask Solem 9 years ago
parent
commit
2d6d660ee6
3 changed files with 21 additions and 8 deletions
  1. 5 2
      celery/tests/worker/test_control.py
  2. 10 5
      celery/worker/control.py
  3. 6 1
      celery/worker/pidbox.py

+ 5 - 2
celery/tests/worker/test_control.py

@@ -126,6 +126,7 @@ class test_ControlPanel(AppCase):
     def create_state(self, **kwargs):
         kwargs.setdefault('app', self.app)
         kwargs.setdefault('hostname', hostname)
+        kwargs.setdefault('tset', set)
         return AttributeDict(kwargs)
 
     def create_panel(self, **kwargs):
@@ -481,14 +482,16 @@ class test_ControlPanel(AppCase):
     def test_revoke_terminate(self):
         request = Mock()
         request.id = tid = uuid()
+        state = self.create_state()
+        state.consumer = Mock()
         worker_state.reserved_requests.add(request)
         try:
-            r = control.revoke(Mock(), tid, terminate=True)
+            r = control.revoke(state, tid, terminate=True)
             self.assertIn(tid, revoked)
             self.assertTrue(request.terminate.call_count)
             self.assertIn('terminate:', r['ok'])
             # unknown task id only revokes
-            r = control.revoke(Mock(), uuid(), terminate=True)
+            r = control.revoke(state, uuid(), terminate=True)
             self.assertIn('tasks unknown', r['ok'])
         finally:
             worker_state.reserved_requests.discard(request)

+ 10 - 5
celery/worker/control.py

@@ -54,10 +54,12 @@ def query_task(state, ids, **kwargs):
     ids = maybe_list(ids)
     return dict({
         req.id: ('reserved', req.info())
-        for req in _find_requests_by_id(ids, worker_state.reserved_requests)
+        for req in _find_requests_by_id(
+            ids, state.tset(worker_state.reserved_requests))
     }, **{
         req.id: ('active', req.info())
-        for req in _find_requests_by_id(ids, worker_state.active_requests)
+        for req in _find_requests_by_id(
+            ids, state.tset(worker_state.active_requests))
     })
 
 
@@ -76,7 +78,7 @@ def revoke(state, task_id, terminate=False, signal=None, **kwargs):
         # so need to consume the items first, then terminate after.
         requests = set(_find_requests_by_id(
             task_ids,
-            worker_state.reserved_requests,
+            state.tset(worker_state.reserved_requests),
         ))
         for request in requests:
             if request.id not in terminated:
@@ -197,7 +199,10 @@ def dump_schedule(state, safe=False, **kwargs):
 
 @Panel.register
 def dump_reserved(state, safe=False, **kwargs):
-    reserved = worker_state.reserved_requests - worker_state.active_requests
+    reserved = (
+        state.tset(worker_state.reserved_requests) -
+        state.tset(worker_state.active_requests)
+    )
     if not reserved:
         return []
     return [request.info(safe=safe) for request in reserved]
@@ -206,7 +211,7 @@ def dump_reserved(state, safe=False, **kwargs):
 @Panel.register
 def dump_active(state, safe=False, **kwargs):
     return [request.info(safe=safe)
-            for request in worker_state.active_requests]
+            for request in state.tset(worker_state.active_requests)]
 
 
 @Panel.register

+ 6 - 1
celery/worker/pidbox.py

@@ -7,6 +7,7 @@ from kombu.common import ignore_errors
 from kombu.utils.encoding import safe_str
 
 from celery.datastructures import AttributeDict
+from celery.utils.functional import pass1
 from celery.utils.log import get_logger
 
 from . import control
@@ -26,7 +27,11 @@ class Pidbox(object):
         self.node = c.app.control.mailbox.Node(
             safe_str(c.hostname),
             handlers=control.Panel.data,
-            state=AttributeDict(app=c.app, hostname=c.hostname, consumer=c),
+            state=AttributeDict(
+                app=c.app,
+                hostname=c.hostname,
+                consumer=c,
+                tset=pass1 if c.controller.use_eventloop else set),
         )
         self._forward_clock = self.c.app.clock.forward