|
@@ -4,6 +4,7 @@ from __future__ import with_statement
|
|
|
|
|
|
import anyjson
|
|
|
import os
|
|
|
+import signal
|
|
|
import sys
|
|
|
import time
|
|
|
|
|
@@ -29,6 +30,7 @@ from celery.task.trace import (
|
|
|
build_tracer,
|
|
|
)
|
|
|
from celery.result import AsyncResult
|
|
|
+from celery.signals import task_revoked
|
|
|
from celery.task import task as task_dec
|
|
|
from celery.task.base import Task
|
|
|
from celery.utils import uuid
|
|
@@ -36,7 +38,7 @@ from celery.worker import job as module
|
|
|
from celery.worker.job import Request, TaskRequest
|
|
|
from celery.worker.state import revoked
|
|
|
|
|
|
-from celery.tests.utils import Case
|
|
|
+from celery.tests.utils import Case, assert_signal_called
|
|
|
|
|
|
scratch = {'ACK': False}
|
|
|
some_kwargs_scratchpad = {}
|
|
@@ -288,28 +290,37 @@ class test_TaskRequest(Case):
|
|
|
|
|
|
def test_terminate__task_started(self):
|
|
|
pool = Mock()
|
|
|
+ signum = signal.SIGKILL
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
|
- tw.time_start = time.time()
|
|
|
- tw.worker_pid = 313
|
|
|
- tw.terminate(pool, signal='KILL')
|
|
|
- pool.terminate_job.assert_called_with(tw.worker_pid, 'KILL')
|
|
|
+ with assert_signal_called(task_revoked, sender=tw.task,
|
|
|
+ terminated=True,
|
|
|
+ expired=False,
|
|
|
+ signum=signum):
|
|
|
+ tw.time_start = time.time()
|
|
|
+ tw.worker_pid = 313
|
|
|
+ tw.terminate(pool, signal='KILL')
|
|
|
+ pool.terminate_job.assert_called_with(tw.worker_pid, signum)
|
|
|
|
|
|
def test_terminate__task_reserved(self):
|
|
|
pool = Mock()
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
|
tw.time_start = None
|
|
|
tw.terminate(pool, signal='KILL')
|
|
|
- self.assertFalse(pool.terminate_job.call_count)
|
|
|
+ self.assertFalse(pool.terminate_job.called)
|
|
|
self.assertTupleEqual(tw._terminate_on_ack, (pool, 'KILL'))
|
|
|
tw.terminate(pool, signal='KILL')
|
|
|
|
|
|
def test_revoked_expires_expired(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
|
|
|
expires=datetime.utcnow() - timedelta(days=1))
|
|
|
- tw.revoked()
|
|
|
- self.assertIn(tw.id, revoked)
|
|
|
- self.assertEqual(mytask.backend.get_status(tw.id),
|
|
|
- states.REVOKED)
|
|
|
+ with assert_signal_called(task_revoked, sender=tw.task,
|
|
|
+ terminated=False,
|
|
|
+ expired=True,
|
|
|
+ signum=None):
|
|
|
+ tw.revoked()
|
|
|
+ self.assertIn(tw.id, revoked)
|
|
|
+ self.assertEqual(mytask.backend.get_status(tw.id),
|
|
|
+ states.REVOKED)
|
|
|
|
|
|
def test_revoked_expires_not_expired(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'},
|
|
@@ -388,10 +399,14 @@ class test_TaskRequest(Case):
|
|
|
|
|
|
def test_revoked(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
|
- revoked.add(tw.id)
|
|
|
- self.assertTrue(tw.revoked())
|
|
|
- self.assertTrue(tw._already_revoked)
|
|
|
- self.assertTrue(tw.acknowledged)
|
|
|
+ with assert_signal_called(task_revoked, sender=tw.task,
|
|
|
+ terminated=False,
|
|
|
+ expired=False,
|
|
|
+ signum=None):
|
|
|
+ revoked.add(tw.id)
|
|
|
+ self.assertTrue(tw.revoked())
|
|
|
+ self.assertTrue(tw._already_revoked)
|
|
|
+ self.assertTrue(tw.acknowledged)
|
|
|
|
|
|
def test_execute_does_not_execute_revoked(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
@@ -434,12 +449,17 @@ class test_TaskRequest(Case):
|
|
|
mytask.acks_late = False
|
|
|
|
|
|
def test_on_accepted_terminates(self):
|
|
|
- tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
|
+ signum = signal.SIGKILL
|
|
|
pool = Mock()
|
|
|
- tw.terminate(pool, signal='KILL')
|
|
|
- self.assertFalse(pool.terminate_job.call_count)
|
|
|
- tw.on_accepted(pid=314, time_accepted=time.time())
|
|
|
- pool.terminate_job.assert_called_with(314, 'KILL')
|
|
|
+ tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|
|
|
+ with assert_signal_called(task_revoked, sender=tw.task,
|
|
|
+ terminated=True,
|
|
|
+ expired=False,
|
|
|
+ signum=signum):
|
|
|
+ tw.terminate(pool, signal='KILL')
|
|
|
+ self.assertFalse(pool.terminate_job.call_count)
|
|
|
+ tw.on_accepted(pid=314, time_accepted=time.time())
|
|
|
+ pool.terminate_job.assert_called_with(314, signum)
|
|
|
|
|
|
def test_on_success_acks_early(self):
|
|
|
tw = TaskRequest(mytask.name, uuid(), [1], {'f': 'x'})
|