Explorar o código

Windows: Process pool: terminate_process: Also terminate the tasks process children

This is not possible to do in the task process itself on Windows, because
os.kill calls TerminateProcess which can not be handled by any process.

Closes #384.
Miguel Hernandez Martos %!s(int64=14) %!d(string=hai) anos
pai
achega
a5cce0b2b8
Modificáronse 3 ficheiros con 118 adicións e 2 borrados
  1. 1 0
      AUTHORS
  2. 10 2
      celery/concurrency/processes/__init__.py
  3. 107 0
      celery/concurrency/processes/_win.py

+ 1 - 0
AUTHORS

@@ -69,4 +69,5 @@ Ordered by date of first contribution:
   Marcin Kuźmiński <marcin@python-works.com>
   Adriano Petrich <petrich@gmail.com>
   David Strauss <david@davidstrauss.net>
+  Miguel Hernandez Martos <enlavin@gmail.com>
   Jannis Leidel <jannis@leidel.info>

+ 10 - 2
celery/concurrency/processes/__init__.py

@@ -3,12 +3,20 @@
 Process Pools.
 
 """
-import os
+import platform
 import signal as _signal
 
+from os import kill as _kill
+
 from celery.concurrency.base import BasePool
 from celery.concurrency.processes.pool import Pool, RUN
 
+if platform.system() == "Windows":
+    # On Windows os.kill calls TerminateProcess which cannot be
+    # handled by # any process, so this is needed to terminate the task
+    # *and its children* (if any).
+    from celery.concurrency.processes._win import kill_processtree as _kill
+
 
 class TaskPool(BasePool):
     """Process Pool for processing tasks in parallel.
@@ -51,7 +59,7 @@ class TaskPool(BasePool):
             self._pool = None
 
     def terminate_job(self, pid, signal=None):
-        os.kill(pid, signal or _signal.SIGTERM)
+        kill(pid, signal or _signal.SIGTERM)
 
     def grow(self, n=1):
         return self._pool.grow(n)

+ 107 - 0
celery/concurrency/processes/_win.py

@@ -0,0 +1,107 @@
+import os
+
+__all__ = ["get_processtree_pids", "kill_processtree"]
+
+# psutil is painfully slow in win32. So to avoid adding big
+# dependencies like pywin32 a ctypes based solution is preferred
+
+# Code based on the winappdbg project http://winappdbg.sourceforge.net/
+# (BSD License)
+from ctypes import byref, sizeof, windll, Structure, WinError, POINTER
+from ctypes.wintypes import DWORD, c_size_t, LONG, c_char, c_void_p
+
+ERROR_NO_MORE_FILES = 18
+INVALID_HANDLE_VALUE = c_void_p(-1).value
+
+
+class PROCESSENTRY32(Structure):
+    _fields_ = [
+        ('dwSize',              DWORD),
+        ('cntUsage',            DWORD),
+        ('th32ProcessID',       DWORD),
+        ('th32DefaultHeapID',   c_size_t),
+        ('th32ModuleID',        DWORD),
+        ('cntThreads',          DWORD),
+        ('th32ParentProcessID', DWORD),
+        ('pcPriClassBase',      LONG),
+        ('dwFlags',             DWORD),
+        ('szExeFile',           c_char * 260),
+    ]
+LPPROCESSENTRY32 = POINTER(PROCESSENTRY32)
+
+
+def CreateToolhelp32Snapshot(dwFlags=2, th32ProcessID=0):
+    hSnapshot = windll.kernel32.CreateToolhelp32Snapshot(dwFlags,
+                                                         th32ProcessID)
+    if hSnapshot == INVALID_HANDLE_VALUE:
+        raise WinError()
+    return hSnapshot
+
+
+def Process32First(hSnapshot):
+    pe = PROCESSENTRY32()
+    pe.dwSize = sizeof(PROCESSENTRY32)
+    success = windll.kernel32.Process32First(hSnapshot, byref(pe))
+    if not success:
+        if windll.kernel32.GetLastError() == ERROR_NO_MORE_FILES:
+            return None
+        raise WinError()
+    return pe
+
+
+def Process32Next(hSnapshot, pe=None):
+    if pe is None:
+        pe = PROCESSENTRY32()
+    pe.dwSize = sizeof(PROCESSENTRY32)
+    success = windll.kernel32.Process32Next(hSnapshot, byref(pe))
+    if not success:
+        if windll.kernel32.GetLastError() == ERROR_NO_MORE_FILES:
+            return None
+        raise WinError()
+    return pe
+
+
+def get_all_processes_pids():
+    """Return a dictionary with all processes pids as keys and their
+       parents as value. Ignore processes with no parents.
+    """
+    h = CreateToolhelp32Snapshot()
+    parents = {}
+    pe = Process32First(h)
+    while pe:
+        if pe.th32ParentProcessID:
+            parents[pe.th32ProcessID] = pe.th32ParentProcessID
+        pe = Process32Next(h, pe)
+
+    return parents
+
+
+def get_processtree_pids(pid, include_parent=True):
+    """Return a list with all the pids of a process tree"""
+    parents = get_all_processes_pids()
+    all_pids = parents.keys()
+    pids = set([pid])
+    while True:
+        pids_new = pids.copy()
+
+        for _pid in all_pids:
+            if parents[_pid] in pids:
+                pids_new.add(_pid)
+
+        if pids_new == pids:
+            break
+
+        pids = pids_new.copy()
+
+    if not include_parent:
+        pids.remove(pid)
+
+    return list(pids)
+
+
+def kill_processtree(pid, signum):
+    """Kill a process and all its descendants"""
+    family_pids = get_processtree_pids(pid)
+
+    for _pid in family_pids:
+        os.kill(_pid, signum)