Browse Source

Can now terminate worker processing task remotely.

A `terminate` argument have been added to the revoke remote control command.
Default signal is TERM, but can be changed using the `signal` argument.

Terminating a task also revokes it.

Example::

    >>> from celery.task import control

    >>> control.revoke(task_id, terminate=True)
    >>> control.revoke(task_id, terminate=True, signal="KILL")
    >>> control.revoke(task_id, terminate=True, signal="SIGKILL")

Signal can be the uppercase name of any signal defined in :mod:`signal`
(Python Standard Library).
Ask Solem 14 years ago
parent
commit
3b28b859db

+ 4 - 0
celery/concurrency/base.py

@@ -41,6 +41,10 @@ class BasePool(object):
     def on_apply(self, *args, **kwargs):
         pass
 
+    def kill_job(self, pid):
+        raise NotImplementedError(
+                "%s does not implement kill_job" % (self.__class__, ))
+
     def stop(self):
         self._state = self.CLOSE
         self.on_stop()

+ 6 - 0
celery/concurrency/processes/__init__.py

@@ -3,6 +3,9 @@
 Process Pools.
 
 """
+import os
+import signal as _signal
+
 from celery.concurrency.base import BasePool
 from celery.concurrency.processes.pool import Pool, RUN
 
@@ -47,6 +50,9 @@ class TaskPool(BasePool):
             self._pool.terminate()
             self._pool = None
 
+    def terminate_job(self, pid, signal=None):
+        os.kill(pid, signal or _signal.SIGTERM)
+
     def grow(self, n=1):
         return self._pool.grow(n)
 

+ 9 - 0
celery/platforms.py

@@ -274,6 +274,15 @@ def set_effective_user(uid=None, gid=None):
         gid and setegid(gid)
 
 
+def get_signal(signal_name):
+    """Get signal number from signal name."""
+    if not isinstance(signal_name, basestring) or not signal_name.isupper():
+        raise TypeError("signal name must be uppercase string.")
+    if not signal_name.startswith("SIG"):
+        signal_name = "SIG" + signal_name
+    return getattr(signal, signal_name)
+
+
 def reset_signal(signal_name):
     """Reset signal to the default signal handler.
 

+ 9 - 2
celery/task/control.py

@@ -105,13 +105,18 @@ class Control(object):
         return self.app.with_default_connection(_do_discard)(
                 connection=connection, connect_timeout=connect_timeout)
 
-    def revoke(self, task_id, destination=None, **kwargs):
+    def revoke(self, task_id, destination=None, terminate=False,
+            signal="SIGTERM", **kwargs):
         """Revoke a task by id.
 
         If a task is revoked, the workers will ignore the task and
         not execute it after all.
 
         :param task_id: Id of the task to revoke.
+        :keyword terminate: Also terminate the process currently working
+            on the task (if any).
+        :keyword signal: Name of signal to send to process if terminate.
+            Default is TERM.
         :keyword destination: If set, a list of the hosts to send the
             command to, when empty broadcast to all workers.
         :keyword connection: Custom broker connection to use, if not set,
@@ -124,7 +129,9 @@ class Control(object):
 
         """
         return self.broadcast("revoke", destination=destination,
-                              arguments={"task_id": task_id}, **kwargs)
+                              arguments={"task_id": task_id,
+                                         "terminate": terminate,
+                                         "signal": signal}, **kwargs)
 
     def ping(self, destination=None, timeout=1, **kwargs):
         """Ping workers.

+ 14 - 3
celery/worker/control/builtins.py

@@ -1,6 +1,8 @@
 import sys
+
 from datetime import datetime
 
+from celery.platforms import get_signal
 from celery.registry import tasks
 from celery.utils import timeutils
 from celery.worker import state
@@ -11,11 +13,20 @@ TASK_INFO_FIELDS = ("exchange", "routing_key", "rate_limit")
 
 
 @Panel.register
-def revoke(panel, task_id, **kwargs):
+def revoke(panel, task_id, terminate=False, signal=None, **kwargs):
     """Revoke task by task id."""
     revoked.add(task_id)
-    panel.logger.warn("Task %s revoked" % (task_id, ))
-    return {"ok": "task %s revoked" % (task_id, )}
+    action = "revoked"
+    if terminate:
+        signum = get_signal(signal)
+        for request in state.active_requests:
+            if request.task_id == task_id:
+                action = "terminated (%s)" % (signum, )
+                request.terminate(panel.consumer.pool, signal=signum)
+                break
+
+    panel.logger.warn("Task %s %s." % (task_id, action))
+    return {"ok": "task %s %s" % (task_id, action)}
 
 
 @Panel.register

+ 16 - 0
celery/worker/job.py

@@ -230,7 +230,11 @@ class TaskRequest(object):
     #: Timestamp set when the task is started.
     time_start = None
 
+    #: Process id of the worker processing this task (if any).
+    worker_pid = None
+
     _already_revoked = False
+    _terminate_on_ack = None
 
     def __init__(self, task_name, task_id, args, kwargs,
             on_ack=noop, retries=0, delivery_info=None, hostname=None,
@@ -379,6 +383,14 @@ class TaskRequest(object):
             if self._store_errors:
                 self.task.backend.mark_as_revoked(self.task_id)
 
+    def terminate(self, pool, signal=None):
+        if self._terminate_on_ack is not None:
+            return
+        elif self.time_start:
+            return pool.terminate_job(self.worker_pid, signal)
+        else:
+            self._terminate_on_ack = (True, pool, signal)
+
     def revoked(self):
         """If revoked, skip task and mark state."""
         if self._already_revoked:
@@ -400,6 +412,7 @@ class TaskRequest(object):
 
     def on_accepted(self, pid, time_accepted):
         """Handler called when task is accepted by worker pool."""
+        self.worker_pid = pid
         self.time_start = time_accepted
         state.task_accepted(self)
         if not self.task.acks_late:
@@ -407,6 +420,9 @@ class TaskRequest(object):
         self.send_event("task-started", uuid=self.task_id, pid=pid)
         self.logger.debug("Task accepted: %s[%s] pid:%r" % (
             self.task_name, self.task_id, pid))
+        if self._terminate_on_ack is not None:
+            _, pool, signal = self._terminate_on_ack
+            self.terminate(pool, signal)
 
     def on_timeout(self, soft):
         """Handler called if the task times out."""