Browse Source

More test coverage for 2.3 release

Ask Solem 13 years ago
parent
commit
32204e65d1

+ 4 - 2
celery/tests/test_app/test_app.py

@@ -155,12 +155,14 @@ class test_App(unittest.TestCase):
                                        "userid": "guest",
                                        "password": "guest",
                                        "virtual_host": "/"},
-                                      self.app.broker_connection().info())
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
         self.app.conf.BROKER_PORT = 1978
         self.app.conf.BROKER_VHOST = "foo"
         self.assertDictContainsSubset({"port": 1978,
                                        "virtual_host": "foo"},
-                                      self.app.broker_connection().info())
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
         conn = self.app.broker_connection(virtual_host="/value")
         self.assertDictContainsSubset({"virtual_host": "/value"},
                                       conn.info())

+ 0 - 24
celery/tests/test_task/test_task.py

@@ -491,30 +491,6 @@ class TestPeriodicTask(unittest.TestCase):
             MyPeriodic().remaining_estimate(datetime.now()),
             timedelta)
 
-    def test_timedelta_seconds_returns_0_on_negative_time(self):
-        delta = timedelta(days=-2)
-        self.assertEqual(MyPeriodic().timedelta_seconds(delta), 0)
-
-    def test_timedelta_seconds(self):
-        deltamap = ((timedelta(seconds=1), 1),
-                    (timedelta(seconds=27), 27),
-                    (timedelta(minutes=3), 3 * 60),
-                    (timedelta(hours=4), 4 * 60 * 60),
-                    (timedelta(days=3), 3 * 86400))
-        for delta, seconds in deltamap:
-            self.assertEqual(MyPeriodic().timedelta_seconds(delta), seconds)
-
-    def test_delta_resolution(self):
-        D = timeutils.delta_resolution
-
-        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
-        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
-                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
-                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
-                    (timedelta(seconds=2), dt))
-        for delta, shoulda in deltamap:
-            self.assertEqual(D(dt, delta), shoulda)
-
     def test_is_due_not_due(self):
         due, remaining = MyPeriodic().is_due(datetime.now())
         self.assertFalse(due)

+ 0 - 6
celery/tests/test_utils/test_utils.py

@@ -31,12 +31,6 @@ class test_chunks(unittest.TestCase):
 
 class test_utils(unittest.TestCase):
 
-    def test_maybe_iso8601_datetime(self):
-        from celery.utils.timeutils import maybe_iso8601
-        from datetime import datetime
-        now = datetime.now()
-        self.assertIs(maybe_iso8601(now), now)
-
     def test_get_full_cls_name(self):
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")

+ 28 - 0
celery/tests/test_utils/test_utils_encoding.py

@@ -0,0 +1,28 @@
+from celery.utils import encoding
+
+from celery.tests.utils import unittest
+
+
+class test_encoding(unittest.TestCase):
+
+    def test_smart_str(self):
+        self.assertTrue(encoding.safe_str(object()))
+        self.assertTrue(encoding.safe_str("foo"))
+        self.assertTrue(encoding.safe_str(u"foo"))
+
+        class foo(unicode):
+
+            def encode(self, *args, **kwargs):
+                raise UnicodeDecodeError("foo")
+
+        self.assertIn("<Unrepresentable", encoding.safe_str(foo()))
+
+    def test_safe_repr(self):
+        self.assertTrue(encoding.safe_repr(object()))
+
+        class foo(object):
+            def __repr__(self):
+                raise ValueError("foo")
+
+        self.assertTrue(encoding.safe_repr(foo()))
+

+ 0 - 19
celery/tests/test_utils/test_utils_info.py

@@ -2,7 +2,6 @@ from celery.tests.utils import unittest
 
 from celery import Celery
 from celery.utils import textindent
-from celery.utils.timeutils import humanize_seconds
 
 RANDTEXT = """\
 The quick brown
@@ -34,24 +33,6 @@ QUEUE_FORMAT2 = """. queue2:      exchange:exchange2 (type2) binding:bind2"""
 
 class TestInfo(unittest.TestCase):
 
-    def test_humanize_seconds(self):
-        t = ((4 * 60 * 60 * 24, "4 days"),
-             (1 * 60 * 60 * 24, "1 day"),
-             (4 * 60 * 60, "4 hours"),
-             (1 * 60 * 60, "1 hour"),
-             (4 * 60, "4 minutes"),
-             (1 * 60, "1 minute"),
-             (4, "4.00 seconds"),
-             (1, "1.00 second"),
-             (4.3567631221, "4.36 seconds"),
-             (0, "now"))
-
-        for seconds, human in t:
-            self.assertEqual(humanize_seconds(seconds), human)
-
-        self.assertEqual(humanize_seconds(4, prefix="about "),
-                          "about 4.00 seconds")
-
     def test_textindent(self):
         self.assertEqual(textindent(RANDTEXT, 4), RANDTEXT_RES)
 

+ 62 - 0
celery/tests/test_utils/test_utils_timeutils.py

@@ -0,0 +1,62 @@
+from datetime import datetime, timedelta
+
+from celery.utils import timeutils
+
+from celery.tests.utils import unittest
+
+
+class test_timeutils(unittest.TestCase):
+
+    def test_delta_resolution(self):
+        D = timeutils.delta_resolution
+
+        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
+        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
+                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
+                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
+                    (timedelta(seconds=2), dt))
+        for delta, shoulda in deltamap:
+            self.assertEqual(D(dt, delta), shoulda)
+
+    def test_timedelta_seconds(self):
+        deltamap = ((timedelta(seconds=1), 1),
+                    (timedelta(seconds=27), 27),
+                    (timedelta(minutes=3), 3 * 60),
+                    (timedelta(hours=4), 4 * 60 * 60),
+                    (timedelta(days=3), 3 * 86400))
+        for delta, seconds in deltamap:
+            self.assertEqual(timeutils.timedelta_seconds(delta), seconds)
+
+    def test_timedelta_seconds_returns_0_on_negative_time(self):
+        delta = timedelta(days=-2)
+        self.assertEqual(timeutils.timedelta_seconds(delta), 0)
+
+    def test_humanize_seconds(self):
+        t = ((4 * 60 * 60 * 24, "4 days"),
+             (1 * 60 * 60 * 24, "1 day"),
+             (4 * 60 * 60, "4 hours"),
+             (1 * 60 * 60, "1 hour"),
+             (4 * 60, "4 minutes"),
+             (1 * 60, "1 minute"),
+             (4, "4.00 seconds"),
+             (1, "1.00 second"),
+             (4.3567631221, "4.36 seconds"),
+             (0, "now"))
+
+        for seconds, human in t:
+            self.assertEqual(timeutils.humanize_seconds(seconds), human)
+
+        self.assertEqual(timeutils.humanize_seconds(4, prefix="about "),
+                          "about 4.00 seconds")
+
+    def test_maybe_iso8601_datetime(self):
+        now = datetime.now()
+        self.assertIs(timeutils.maybe_iso8601(now), now)
+
+    def test_maybe_timdelta(self):
+        D = timeutils.maybe_timedelta
+
+        for i in (30, 30.6):
+            self.assertEquals(D(i), timedelta(seconds=i))
+
+        self.assertEqual(D(timedelta(days=2)), timedelta(days=2))

+ 156 - 1
celery/tests/test_worker/test_worker.py

@@ -21,7 +21,7 @@ from celery.worker import WorkController
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
@@ -181,6 +181,21 @@ class test_QoS(unittest.TestCase):
         qos.increment()
         self.assertEqual(qos.value, 0)
 
+    def test_consumer_decrement_eventually(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 9)
+        qos.value = 0
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 0)
+
+    def test_set(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.set(12)
+        self.assertEqual(qos.prev, 12)
+        qos.set(qos.prev)
 
 class test_Consumer(unittest.TestCase):
 
@@ -205,6 +220,12 @@ class test_Consumer(unittest.TestCase):
         info = l.info
         self.assertTrue(info["broker"])
 
+    def test_start_when_closed(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = CLOSE
+        l.start()
+
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
@@ -338,6 +359,46 @@ class test_Consumer(unittest.TestCase):
         l.heart.stop()
         l.priority_timer.stop()
 
+    def test_consume_messages_ignores_socket_timeout(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.timeout(10)
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l.connection = Connection()
+        l.task_consumer = Mock()
+        l.connection.obj = l
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        l.consume_messages()
+
+    def test_consume_messages_when_socket_error(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.error("foo")
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = RUN
+        c = l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        with self.assertRaises(socket.error):
+            l.consume_messages()
+
+        l._state = CLOSE
+        l.connection = c
+        l.consume_messages()
+
     def test_consume_messages(self):
 
         class Connection(current_app.broker_connection().__class__):
@@ -395,6 +456,7 @@ class test_Consumer(unittest.TestCase):
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.event_dispatcher = Mock()
+        l.enabled = False
         l.receive_message(m.decode(), m)
         l.eta_schedule.stop()
 
@@ -407,6 +469,26 @@ class test_Consumer(unittest.TestCase):
         self.assertTrue(l.task_consumer.qos.call_count)
         l.eta_schedule.stop()
 
+    def test_on_control(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.pidbox_node = Mock()
+        l.reset_pidbox_node = Mock()
+
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = KeyError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = ValueError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+        l.reset_pidbox_node.assert_called_with()
+
     def test_revoke(self):
         ready_queue = FastQueue()
         l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
@@ -432,6 +514,26 @@ class test_Consumer(unittest.TestCase):
         self.assertRaises(Empty, self.ready_queue.get_nowait)
         self.assertTrue(self.eta_schedule.empty())
 
+    def test_receieve_message_ack_raises(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        l.connection_errors = (socket.error, )
+        l.logger = Mock()
+        m.ack = Mock()
+        m.ack.side_effect = socket.error("foo")
+        with catch_warnings(record=True) as log:
+            self.assertFalse(l.receive_message(m.decode(), m))
+            self.assertTrue(log)
+            self.assertIn("unknown message", log[0].message.args[0])
+        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        self.assertTrue(self.eta_schedule.empty())
+        m.ack.assert_called_with()
+        self.assertTrue(l.logger.critical.call_count)
+
     def test_receieve_message_eta(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
@@ -463,6 +565,59 @@ class test_Consumer(unittest.TestCase):
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.ready_queue.get_nowait)
 
+    def test_reset_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+        chan = l.pidbox_node.channel = Mock()
+        l.connection = Mock()
+        chan.close.side_effect = socket.error("foo")
+        l.connection_errors = (socket.error, )
+        l.reset_pidbox_node()
+        chan.close.assert_called_with()
+
+    def test_reset_pidbox_node_green(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pool = Mock()
+        l.pool.is_green = True
+        l.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+
+    def test__green_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+
+        connections = []
+
+        class Connection(object):
+
+            def __init__(self, obj):
+                connections.append(self)
+                self.obj = obj
+                self.closed = False
+
+            def channel(self):
+                return Mock()
+
+            def drain_events(self):
+                self.obj.connection = None
+
+            def close(self):
+                self.closed = True
+
+        l.connection = Mock()
+        l._open_connection = lambda: Connection(obj=l)
+        l._green_pidbox_node()
+
+        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
+        self.assertTrue(l.broadcast_consumer)
+        l.broadcast_consumer.consume.assert_called_with()
+
+        self.assertIsNone(l.connection)
+        self.assertTrue(connections[0].closed)
+
     def test_start__consume_messages(self):
 
         class _QoS(object):

+ 33 - 0
celery/tests/test_worker/test_worker_autoscale.py

@@ -1,7 +1,10 @@
 import logging
+import os
 
 from time import time
 
+from mock import Mock, patch
+
 from celery.concurrency.base import BasePool
 from celery.worker import state
 from celery.worker import autoscale
@@ -17,6 +20,7 @@ class Object(object):
 
 class MockPool(BasePool):
     shrink_raises_exception = False
+    shrink_raises_ValueError = False
 
     def __init__(self, *args, **kwargs):
         super(MockPool, self).__init__(*args, **kwargs)
@@ -29,6 +33,8 @@ class MockPool(BasePool):
     def shrink(self, n=1):
         if self.shrink_raises_exception:
             raise KeyError("foo")
+        if self.shrink_raises_ValueError:
+            raise ValueError("foo")
         self._pool._processes -= n
 
     @property
@@ -100,3 +106,30 @@ class test_Autoscaler(unittest.TestCase):
         x._last_action = time() - 10000
         x.pool.shrink_raises_exception = True
         x.scale_down(1)
+
+    def test_shrink_raises_ValueError(self):
+        x = autoscale.Autoscaler(self.pool, 10, 3, logger=logger)
+        x.logger = Mock()
+        x.scale_up(3)
+        x._last_action = time() - 10000
+        x.pool.shrink_raises_ValueError = True
+        x.scale_down(1)
+        self.assertTrue(x.logger.debug.call_count)
+
+    @patch("os._exit")
+    def test_thread_crash(self, _exit):
+
+        class _Autoscaler(autoscale.Autoscaler):
+
+            def scale(self):
+                self._shutdown.set()
+                raise OSError("foo")
+
+        x = _Autoscaler(self.pool, 10, 3, logger=logger)
+        x.logger = Mock()
+        x.run()
+        _exit.assert_called_with(1)
+        self.assertTrue(x.logger.error.call_count)
+
+
+

+ 1 - 0
celery/tests/test_worker/test_worker_heartbeat.py

@@ -51,6 +51,7 @@ class TestHeart(unittest.TestCase):
         self.assertTrue(h.tref)
         h.stop()
         self.assertIsNone(h.tref)
+        h.stop()
 
     @sleepdeprived
     def test_run_manages_cycle(self):

+ 47 - 1
celery/tests/test_worker/test_worker_job.py

@@ -23,7 +23,8 @@ from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.utils import gen_unique_id
 from celery.worker.job import (WorkerTaskTrace, TaskRequest,
-                               InvalidTaskError, execute_and_trace)
+                               InvalidTaskError, execute_and_trace,
+                               default_encode)
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings
@@ -71,6 +72,26 @@ def mytask_raising(i, **kwargs):
     raise KeyError(i)
 
 
+class test_default_encode(unittest.TestCase):
+
+    def test_jython(self):
+        prev, sys.platform = sys.platform, "java 1.6.1"
+        try:
+            self.assertEqual(default_encode("foo"), "foo")
+        finally:
+            sys.platform = prev
+
+    def test_cython(self):
+        prev, sys.platform = sys.platform, "darwin"
+        gfe, sys.getfilesystemencoding = sys.getfilesystemencoding, \
+                                         lambda: "utf-8"
+        try:
+            self.assertEqual(default_encode("foo"), "foo")
+        finally:
+            sys.platform = prev
+            sys.getfilesystemencoding = gfe
+
+
 class test_RetryTaskError(unittest.TestCase):
 
     def test_retry_task_error(self):
@@ -197,6 +218,23 @@ class test_TaskRequest(unittest.TestCase):
         tw.on_failure(einfo)
         self.assertIn("task-retried", tw.eventer.sent)
 
+    def test_terminate__task_started(self):
+        pool = Mock()
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = time.time()
+        tw.worker_pid = 313
+        tw.terminate(pool, signal="KILL")
+        pool.terminate_job.assert_called_with(tw.worker_pid, "KILL")
+
+    def test_terminate__task_reserved(self):
+        pool = Mock()
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = None
+        tw.terminate(pool, signal="KILL")
+        self.assertFalse(pool.terminate_job.call_count)
+        self.assertTupleEqual(tw._terminate_on_ack, (True, pool, "KILL"))
+        tw.terminate(pool, signal="KILL")
+
     def test_revoked_expires_expired(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.expires = datetime.now() - timedelta(days=1)
@@ -314,6 +352,14 @@ class test_TaskRequest(unittest.TestCase):
         finally:
             mytask.acks_late = False
 
+    def test_on_accepted_terminates(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        pool = Mock()
+        tw.terminate(pool, signal="KILL")
+        self.assertFalse(pool.terminate_job.call_count)
+        tw.on_accepted(pid=314, time_accepted=time.time())
+        pool.terminate_job.assert_called_with(314, "KILL")
+
     def test_on_success_acks_early(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.time_start = 1

+ 20 - 1
celery/tests/test_worker/test_worker_mediator.py

@@ -2,7 +2,7 @@ from celery.tests.utils import unittest
 
 from Queue import Queue
 
-from mock import Mock
+from mock import Mock, patch
 
 from celery.utils import gen_unique_id
 from celery.worker.mediator import Mediator
@@ -53,6 +53,25 @@ class test_Mediator(unittest.TestCase):
 
         self.assertEqual(got["value"], "George Costanza")
 
+    @patch("os._exit")
+    def test_mediator_crash(self, _exit):
+        ms = [None]
+
+        class _Mediator(Mediator):
+
+            def move(self):
+                try:
+                    raise KeyError("foo")
+                finally:
+                    ms[0]._shutdown.set()
+
+        ready_queue = Queue()
+        ms[0] = m = _Mediator(ready_queue, None)
+        ready_queue.put(MockTask("George Constanza"))
+        m.run()
+
+        self.assertTrue(_exit.call_count)
+
     def test_mediator_move_exception(self):
         ready_queue = Queue()
 

+ 5 - 1
celery/utils/encoding.py

@@ -11,6 +11,10 @@ def default_encoding():
 def safe_str(s, errors="replace"):
     if not isinstance(s, basestring):
         return safe_repr(s, errors)
+    return _safe_str(s, errors)
+
+
+def _safe_str(s, errors="replace"):
     encoding = default_encoding()
     try:
         if isinstance(s, unicode):
@@ -24,4 +28,4 @@ def safe_repr(o, errors="replace"):
     try:
         return repr(o)
     except Exception:
-        return safe_str(o, errors)
+        return _safe_str(o, errors)

+ 4 - 4
celery/worker/consumer.py

@@ -135,24 +135,24 @@ class QoS(object):
         self.value = initial_value
 
     def increment(self, n=1):
-        """Increment the current prefetch count value by one."""
+        """Increment the current prefetch count value by n."""
         with self._mutex:
             if self.value:
                 new_value = self.value + max(n, 0)
                 self.value = self.set(new_value)
-            return self.value
+        return self.value
 
     def _sub(self, n=1):
         assert self.value - n > 1
         self.value -= n
 
     def decrement(self, n=1):
-        """Decrement the current prefetch count value by one."""
+        """Decrement the current prefetch count value by n."""
         with self._mutex:
             if self.value:
                 self._sub(n)
                 self.set(self.value)
-            return self.value
+        return self.value
 
     def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.

+ 3 - 9
celery/worker/job.py

@@ -17,7 +17,7 @@ from celery.datastructures import ExceptionInfo
 from celery.execute.trace import TaskTrace
 from celery.utils import (noop, kwdict, fun_takes_kwargs,
                           get_symbol_by_name, truncate_text)
-from celery.utils.encoding import safe_repr, safe_str
+from celery.utils.encoding import safe_repr, safe_str, default_encoding
 from celery.utils.timeutils import maybe_iso8601
 from celery.worker import state
 
@@ -52,11 +52,7 @@ class InvalidTaskError(Exception):
 
 
 def default_encode(obj):
-    if sys.platform.startswith("java"):
-        coding = "utf-8"
-    else:
-        coding = sys.getfilesystemencoding()
-    return unicode(obj, coding)
+    return unicode(obj, default_encoding())
 
 
 class WorkerTaskTrace(TaskTrace):
@@ -401,9 +397,7 @@ class TaskRequest(object):
                 self.task.backend.mark_as_revoked(self.task_id)
 
     def terminate(self, pool, signal=None):
-        if self._terminate_on_ack is not None:
-            return
-        elif self.time_start:
+        if self.time_start:
             return pool.terminate_job(self.worker_pid, signal)
         else:
             self._terminate_on_ack = (True, pool, signal)

+ 1 - 1
celery/worker/state.py

@@ -51,7 +51,7 @@ def task_ready(request):
     reserved_requests.discard(request)
 
 
-if os.environ.get("CELERY_BENCH"):
+if os.environ.get("CELERY_BENCH"):  # pragma: no cover
     from time import time
 
     all_count = 0