Browse Source

Task revoked with tests

Ask Solem 12 years ago
parent
commit
8451c989d8
5 changed files with 66 additions and 25 deletions
  1. 1 1
      celery/signals.py
  2. 12 0
      celery/tests/utils.py
  3. 39 19
      celery/tests/worker/test_request.py
  4. 9 3
      celery/worker/job.py
  5. 5 2
      docs/userguide/signals.rst

+ 1 - 1
celery/signals.py

@@ -23,7 +23,7 @@ task_postrun = Signal(providing_args=[
 task_success = Signal(providing_args=['result'])
 task_failure = Signal(providing_args=[
     'task_id', 'exception', 'args', 'kwargs', 'traceback', 'einfo'])
-task_revoked = Signal(providing_args=['terminated', 'signal'])
+task_revoked = Signal(providing_args=['terminated', 'signum', 'expired'])
 celeryd_init = Signal(providing_args=['instance'])
 worker_init = Signal(providing_args=[])
 worker_process_init = Signal(providing_args=[])

+ 12 - 0
celery/tests/utils.py

@@ -531,3 +531,15 @@ def patch_settings(app=None, **config):
 
     for key, value in prev.iteritems():
         setattr(app.conf, key, value)
+
+
+@contextmanager
+def assert_signal_called(signal, **expected):
+    handler = Mock()
+    call_handler = partial(handler)
+    signal.connect(call_handler)
+    try:
+        yield handler
+    finally:
+        signal.disconnect(call_handler)
+    handler.assert_called_with(signal=signal, **expected)

+ 39 - 19
celery/tests/worker/test_request.py

@@ -4,6 +4,7 @@ from __future__ import with_statement
 
 import anyjson
 import os
+import signal
 import sys
 import time
 
@@ -29,6 +30,7 @@ from celery.task.trace import (
     build_tracer,
 )
 from celery.result import AsyncResult
+from celery.signals import task_revoked
 from celery.task import task as task_dec
 from celery.task.base import Task
 from celery.utils import uuid
@@ -36,7 +38,7 @@ from celery.worker import job as module
 from celery.worker.job import Request, TaskRequest
 from celery.worker.state import revoked
 
-from celery.tests.utils import Case
+from celery.tests.utils import Case, assert_signal_called
 
 scratch = {'ACK': False}
 some_kwargs_scratchpad = {}
@@ -288,28 +290,37 @@ class test_TaskRequest(Case):
 
     def test_terminate__task_started(self):
         pool = Mock()
+        signum = signal.SIGKILL
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
-        tw.time_start = time.time()
-        tw.worker_pid = 313
-        tw.terminate(pool, signal='KILL')
-        pool.terminate_job.assert_called_with(tw.worker_pid, 'KILL')
+        with assert_signal_called(task_revoked, sender=tw.task,
+                                  terminated=True,
+                                  expired=False,
+                                  signum=signum):
+            tw.time_start = time.time()
+            tw.worker_pid = 313
+            tw.terminate(pool, signal='KILL')
+            pool.terminate_job.assert_called_with(tw.worker_pid, signum)
 
     def test_terminate__task_reserved(self):
         pool = Mock()
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
         tw.time_start = None
         tw.terminate(pool, signal='KILL')
-        self.assertFalse(pool.terminate_job.call_count)
+        self.assertFalse(pool.terminate_job.called)
         self.assertTupleEqual(tw._terminate_on_ack, (pool, 'KILL'))
         tw.terminate(pool, signal='KILL')
 
     def test_revoked_expires_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
                          expires=datetime.utcnow() - timedelta(days=1))
-        tw.revoked()
-        self.assertIn(tw.id, revoked)
-        self.assertEqual(mytask.backend.get_status(tw.id),
-                         states.REVOKED)
+        with assert_signal_called(task_revoked, sender=tw.task,
+                                  terminated=False,
+                                  expired=True,
+                                  signum=None):
+            tw.revoked()
+            self.assertIn(tw.id, revoked)
+            self.assertEqual(mytask.backend.get_status(tw.id),
+                             states.REVOKED)
 
     def test_revoked_expires_not_expired(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
@@ -388,10 +399,14 @@ class test_TaskRequest(Case):
 
     def test_revoked(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
-        revoked.add(tw.id)
-        self.assertTrue(tw.revoked())
-        self.assertTrue(tw._already_revoked)
-        self.assertTrue(tw.acknowledged)
+        with assert_signal_called(task_revoked, sender=tw.task,
+                                  terminated=False,
+                                  expired=False,
+                                  signum=None):
+            revoked.add(tw.id)
+            self.assertTrue(tw.revoked())
+            self.assertTrue(tw._already_revoked)
+            self.assertTrue(tw.acknowledged)
 
     def test_execute_does_not_execute_revoked(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
@@ -434,12 +449,17 @@ class test_TaskRequest(Case):
             mytask.acks_late = False
 
     def test_on_accepted_terminates(self):
-        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        signum = signal.SIGKILL
         pool = Mock()
-        tw.terminate(pool, signal='KILL')
-        self.assertFalse(pool.terminate_job.call_count)
-        tw.on_accepted(pid=314, time_accepted=time.time())
-        pool.terminate_job.assert_called_with(314, 'KILL')
+        tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
+        with assert_signal_called(task_revoked, sender=tw.task,
+                                  terminated=True,
+                                  expired=False,
+                                  signum=signum):
+            tw.terminate(pool, signal='KILL')
+            self.assertFalse(pool.terminate_job.call_count)
+            tw.on_accepted(pid=314, time_accepted=time.time())
+            pool.terminate_job.assert_called_with(314, signum)
 
     def test_on_success_acks_early(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})

+ 9 - 3
celery/worker/job.py

@@ -23,6 +23,7 @@ from celery import exceptions
 from celery import signals
 from celery.app import app_or_default
 from celery.datastructures import ExceptionInfo
+from celery.platforms import signals as _signals
 from celery.task.trace import (
     trace_task,
     trace_task_ret,
@@ -228,26 +229,31 @@ class Request(object):
             revoked_tasks.add(self.id)
             if self.store_errors:
                 self.task.backend.mark_as_revoked(self.id)
+            return True
 
     def terminate(self, pool, signal=None):
         if self.time_start:
+            signal = _signals.signum(signal or 'TERM')
             pool.terminate_job(self.worker_pid, signal)
-            send_revoked(self.task, terminated=True, signal=signal)
+            send_revoked(self.task, signum=signal,
+                         terminated=True, expired=False)
         else:
             self._terminate_on_ack = pool, signal
 
     def revoked(self):
         """If revoked, skip task and mark state."""
+        expired = False
         if self._already_revoked:
             return True
         if self.expires:
-            self.maybe_expire()
+            expired = self.maybe_expire()
         if self.id in revoked_tasks:
             warn('Skipping revoked task: %s[%s]', self.name, self.id)
             self.send_event('task-revoked', uuid=self.id)
             self.acknowledge()
             self._already_revoked = True
-            send_revoked(self.task, terminated=False)
+            send_revoked(self.task, terminated=False,
+                         signum=None, expired=expired)
             return True
         return False
 

+ 5 - 2
docs/userguide/signals.rst

@@ -189,12 +189,15 @@ Sender is the task class revoked/terminated.
 Provides arguments:
 
 * terminated
-    :const:`True` if the task was terminated.
+    Set to :const:`True` if the task was terminated.
 
-* signal
+* signum
     Signal number used to terminate the task. If this is :const:`None` and
     terminated is :const:`True` then :sig:`TERM` should be assumed.
 
+* expired
+  Set to :const:`True` if the task expired.
+
 Worker Signals
 --------------