Browse Source

More test coverage for 2.3 release

Ask Solem 14 years ago
parent
commit
32204e65d1

+ 4 - 2
celery/tests/test_app/test_app.py

@@ -155,12 +155,14 @@ class test_App(unittest.TestCase):
                                        "userid": "guest",
                                        "userid": "guest",
                                        "password": "guest",
                                        "password": "guest",
                                        "virtual_host": "/"},
                                        "virtual_host": "/"},
-                                      self.app.broker_connection().info())
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
         self.app.conf.BROKER_PORT = 1978
         self.app.conf.BROKER_PORT = 1978
         self.app.conf.BROKER_VHOST = "foo"
         self.app.conf.BROKER_VHOST = "foo"
         self.assertDictContainsSubset({"port": 1978,
         self.assertDictContainsSubset({"port": 1978,
                                        "virtual_host": "foo"},
                                        "virtual_host": "foo"},
-                                      self.app.broker_connection().info())
+                                      self.app.broker_connection(
+                                          transport="amqplib").info())
         conn = self.app.broker_connection(virtual_host="/value")
         conn = self.app.broker_connection(virtual_host="/value")
         self.assertDictContainsSubset({"virtual_host": "/value"},
         self.assertDictContainsSubset({"virtual_host": "/value"},
                                       conn.info())
                                       conn.info())

+ 0 - 24
celery/tests/test_task/test_task.py

@@ -491,30 +491,6 @@ class TestPeriodicTask(unittest.TestCase):
             MyPeriodic().remaining_estimate(datetime.now()),
             MyPeriodic().remaining_estimate(datetime.now()),
             timedelta)
             timedelta)
 
 
-    def test_timedelta_seconds_returns_0_on_negative_time(self):
-        delta = timedelta(days=-2)
-        self.assertEqual(MyPeriodic().timedelta_seconds(delta), 0)
-
-    def test_timedelta_seconds(self):
-        deltamap = ((timedelta(seconds=1), 1),
-                    (timedelta(seconds=27), 27),
-                    (timedelta(minutes=3), 3 * 60),
-                    (timedelta(hours=4), 4 * 60 * 60),
-                    (timedelta(days=3), 3 * 86400))
-        for delta, seconds in deltamap:
-            self.assertEqual(MyPeriodic().timedelta_seconds(delta), seconds)
-
-    def test_delta_resolution(self):
-        D = timeutils.delta_resolution
-
-        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
-        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
-                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
-                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
-                    (timedelta(seconds=2), dt))
-        for delta, shoulda in deltamap:
-            self.assertEqual(D(dt, delta), shoulda)
-
     def test_is_due_not_due(self):
     def test_is_due_not_due(self):
         due, remaining = MyPeriodic().is_due(datetime.now())
         due, remaining = MyPeriodic().is_due(datetime.now())
         self.assertFalse(due)
         self.assertFalse(due)

+ 0 - 6
celery/tests/test_utils/test_utils.py

@@ -31,12 +31,6 @@ class test_chunks(unittest.TestCase):
 
 
 class test_utils(unittest.TestCase):
 class test_utils(unittest.TestCase):
 
 
-    def test_maybe_iso8601_datetime(self):
-        from celery.utils.timeutils import maybe_iso8601
-        from datetime import datetime
-        now = datetime.now()
-        self.assertIs(maybe_iso8601(now), now)
-
     def test_get_full_cls_name(self):
     def test_get_full_cls_name(self):
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")
         self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")

+ 28 - 0
celery/tests/test_utils/test_utils_encoding.py

@@ -0,0 +1,28 @@
+from celery.utils import encoding
+
+from celery.tests.utils import unittest
+
+
+class test_encoding(unittest.TestCase):
+
+    def test_smart_str(self):
+        self.assertTrue(encoding.safe_str(object()))
+        self.assertTrue(encoding.safe_str("foo"))
+        self.assertTrue(encoding.safe_str(u"foo"))
+
+        class foo(unicode):
+
+            def encode(self, *args, **kwargs):
+                raise UnicodeDecodeError("foo")
+
+        self.assertIn("<Unrepresentable", encoding.safe_str(foo()))
+
+    def test_safe_repr(self):
+        self.assertTrue(encoding.safe_repr(object()))
+
+        class foo(object):
+            def __repr__(self):
+                raise ValueError("foo")
+
+        self.assertTrue(encoding.safe_repr(foo()))
+

+ 0 - 19
celery/tests/test_utils/test_utils_info.py

@@ -2,7 +2,6 @@ from celery.tests.utils import unittest
 
 
 from celery import Celery
 from celery import Celery
 from celery.utils import textindent
 from celery.utils import textindent
-from celery.utils.timeutils import humanize_seconds
 
 
 RANDTEXT = """\
 RANDTEXT = """\
 The quick brown
 The quick brown
@@ -34,24 +33,6 @@ QUEUE_FORMAT2 = """. queue2:      exchange:exchange2 (type2) binding:bind2"""
 
 
 class TestInfo(unittest.TestCase):
 class TestInfo(unittest.TestCase):
 
 
-    def test_humanize_seconds(self):
-        t = ((4 * 60 * 60 * 24, "4 days"),
-             (1 * 60 * 60 * 24, "1 day"),
-             (4 * 60 * 60, "4 hours"),
-             (1 * 60 * 60, "1 hour"),
-             (4 * 60, "4 minutes"),
-             (1 * 60, "1 minute"),
-             (4, "4.00 seconds"),
-             (1, "1.00 second"),
-             (4.3567631221, "4.36 seconds"),
-             (0, "now"))
-
-        for seconds, human in t:
-            self.assertEqual(humanize_seconds(seconds), human)
-
-        self.assertEqual(humanize_seconds(4, prefix="about "),
-                          "about 4.00 seconds")
-
     def test_textindent(self):
     def test_textindent(self):
         self.assertEqual(textindent(RANDTEXT, 4), RANDTEXT_RES)
         self.assertEqual(textindent(RANDTEXT, 4), RANDTEXT_RES)
 
 

+ 62 - 0
celery/tests/test_utils/test_utils_timeutils.py

@@ -0,0 +1,62 @@
+from datetime import datetime, timedelta
+
+from celery.utils import timeutils
+
+from celery.tests.utils import unittest
+
+
+class test_timeutils(unittest.TestCase):
+
+    def test_delta_resolution(self):
+        D = timeutils.delta_resolution
+
+        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
+        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
+                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
+                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
+                    (timedelta(seconds=2), dt))
+        for delta, shoulda in deltamap:
+            self.assertEqual(D(dt, delta), shoulda)
+
+    def test_timedelta_seconds(self):
+        deltamap = ((timedelta(seconds=1), 1),
+                    (timedelta(seconds=27), 27),
+                    (timedelta(minutes=3), 3 * 60),
+                    (timedelta(hours=4), 4 * 60 * 60),
+                    (timedelta(days=3), 3 * 86400))
+        for delta, seconds in deltamap:
+            self.assertEqual(timeutils.timedelta_seconds(delta), seconds)
+
+    def test_timedelta_seconds_returns_0_on_negative_time(self):
+        delta = timedelta(days=-2)
+        self.assertEqual(timeutils.timedelta_seconds(delta), 0)
+
+    def test_humanize_seconds(self):
+        t = ((4 * 60 * 60 * 24, "4 days"),
+             (1 * 60 * 60 * 24, "1 day"),
+             (4 * 60 * 60, "4 hours"),
+             (1 * 60 * 60, "1 hour"),
+             (4 * 60, "4 minutes"),
+             (1 * 60, "1 minute"),
+             (4, "4.00 seconds"),
+             (1, "1.00 second"),
+             (4.3567631221, "4.36 seconds"),
+             (0, "now"))
+
+        for seconds, human in t:
+            self.assertEqual(timeutils.humanize_seconds(seconds), human)
+
+        self.assertEqual(timeutils.humanize_seconds(4, prefix="about "),
+                          "about 4.00 seconds")
+
+    def test_maybe_iso8601_datetime(self):
+        now = datetime.now()
+        self.assertIs(timeutils.maybe_iso8601(now), now)
+
+    def test_maybe_timdelta(self):
+        D = timeutils.maybe_timedelta
+
+        for i in (30, 30.6):
+            self.assertEquals(D(i), timedelta(seconds=i))
+
+        self.assertEqual(D(timedelta(days=2)), timedelta(days=2))

+ 156 - 1
celery/tests/test_worker/test_worker.py

@@ -21,7 +21,7 @@ from celery.worker import WorkController
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.job import TaskRequest
 from celery.worker.consumer import Consumer as MainConsumer
 from celery.worker.consumer import Consumer as MainConsumer
-from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX
+from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 from celery.utils.timer2 import Timer
 
 
@@ -181,6 +181,21 @@ class test_QoS(unittest.TestCase):
         qos.increment()
         qos.increment()
         self.assertEqual(qos.value, 0)
         self.assertEqual(qos.value, 0)
 
 
+    def test_consumer_decrement_eventually(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 9)
+        qos.value = 0
+        qos.decrement_eventually()
+        self.assertEqual(qos.value, 0)
+
+    def test_set(self):
+        consumer = Mock()
+        qos = QoS(consumer, 10, current_app.log.get_default_logger())
+        qos.set(12)
+        self.assertEqual(qos.prev, 12)
+        qos.set(qos.prev)
 
 
 class test_Consumer(unittest.TestCase):
 class test_Consumer(unittest.TestCase):
 
 
@@ -205,6 +220,12 @@ class test_Consumer(unittest.TestCase):
         info = l.info
         info = l.info
         self.assertTrue(info["broker"])
         self.assertTrue(info["broker"])
 
 
+    def test_start_when_closed(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = CLOSE
+        l.start()
+
     def test_connection(self):
     def test_connection(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
                            send_events=False)
@@ -338,6 +359,46 @@ class test_Consumer(unittest.TestCase):
         l.heart.stop()
         l.heart.stop()
         l.priority_timer.stop()
         l.priority_timer.stop()
 
 
+    def test_consume_messages_ignores_socket_timeout(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.timeout(10)
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l.connection = Connection()
+        l.task_consumer = Mock()
+        l.connection.obj = l
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        l.consume_messages()
+
+    def test_consume_messages_when_socket_error(self):
+
+        class Connection(current_app.broker_connection().__class__):
+            obj = None
+
+            def drain_events(self, **kwargs):
+                self.obj.connection = None
+                raise socket.error("foo")
+
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                            send_events=False)
+        l._state = RUN
+        c = l.connection = Connection()
+        l.connection.obj = l
+        l.task_consumer = Mock()
+        l.qos = QoS(l.task_consumer, 10, l.logger)
+        with self.assertRaises(socket.error):
+            l.consume_messages()
+
+        l._state = CLOSE
+        l.connection = c
+        l.consume_messages()
+
     def test_consume_messages(self):
     def test_consume_messages(self):
 
 
         class Connection(current_app.broker_connection().__class__):
         class Connection(current_app.broker_connection().__class__):
@@ -395,6 +456,7 @@ class test_Consumer(unittest.TestCase):
         l.task_consumer = Mock()
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.qos = QoS(l.task_consumer, l.initial_prefetch_count, l.logger)
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
+        l.enabled = False
         l.receive_message(m.decode(), m)
         l.receive_message(m.decode(), m)
         l.eta_schedule.stop()
         l.eta_schedule.stop()
 
 
@@ -407,6 +469,26 @@ class test_Consumer(unittest.TestCase):
         self.assertTrue(l.task_consumer.qos.call_count)
         self.assertTrue(l.task_consumer.qos.call_count)
         l.eta_schedule.stop()
         l.eta_schedule.stop()
 
 
+    def test_on_control(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                             send_events=False)
+        l.pidbox_node = Mock()
+        l.reset_pidbox_node = Mock()
+
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = KeyError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+
+        l.pidbox_node = Mock()
+        l.pidbox_node.handle_message.side_effect = ValueError("foo")
+        l.on_control("foo", "bar")
+        l.pidbox_node.handle_message.assert_called_with("foo", "bar")
+        l.reset_pidbox_node.assert_called_with()
+
     def test_revoke(self):
     def test_revoke(self):
         ready_queue = FastQueue()
         ready_queue = FastQueue()
         l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
         l = MyKombuConsumer(ready_queue, self.eta_schedule, self.logger,
@@ -432,6 +514,26 @@ class test_Consumer(unittest.TestCase):
         self.assertRaises(Empty, self.ready_queue.get_nowait)
         self.assertRaises(Empty, self.ready_queue.get_nowait)
         self.assertTrue(self.eta_schedule.empty())
         self.assertTrue(self.eta_schedule.empty())
 
 
+    def test_receieve_message_ack_raises(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        backend = Mock()
+        m = create_message(backend, args=[2, 4, 8], kwargs={})
+
+        l.event_dispatcher = Mock()
+        l.connection_errors = (socket.error, )
+        l.logger = Mock()
+        m.ack = Mock()
+        m.ack.side_effect = socket.error("foo")
+        with catch_warnings(record=True) as log:
+            self.assertFalse(l.receive_message(m.decode(), m))
+            self.assertTrue(log)
+            self.assertIn("unknown message", log[0].message.args[0])
+        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        self.assertTrue(self.eta_schedule.empty())
+        m.ack.assert_called_with()
+        self.assertTrue(l.logger.critical.call_count)
+
     def test_receieve_message_eta(self):
     def test_receieve_message_eta(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
                           send_events=False)
                           send_events=False)
@@ -463,6 +565,59 @@ class test_Consumer(unittest.TestCase):
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.ready_queue.get_nowait)
         self.assertRaises(Empty, self.ready_queue.get_nowait)
 
 
+    def test_reset_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+        chan = l.pidbox_node.channel = Mock()
+        l.connection = Mock()
+        chan.close.side_effect = socket.error("foo")
+        l.connection_errors = (socket.error, )
+        l.reset_pidbox_node()
+        chan.close.assert_called_with()
+
+    def test_reset_pidbox_node_green(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pool = Mock()
+        l.pool.is_green = True
+        l.reset_pidbox_node()
+        l.pool.spawn_n.assert_called_with(l._green_pidbox_node)
+
+    def test__green_pidbox_node(self):
+        l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
+                          send_events=False)
+        l.pidbox_node = Mock()
+
+        connections = []
+
+        class Connection(object):
+
+            def __init__(self, obj):
+                connections.append(self)
+                self.obj = obj
+                self.closed = False
+
+            def channel(self):
+                return Mock()
+
+            def drain_events(self):
+                self.obj.connection = None
+
+            def close(self):
+                self.closed = True
+
+        l.connection = Mock()
+        l._open_connection = lambda: Connection(obj=l)
+        l._green_pidbox_node()
+
+        l.pidbox_node.listen.assert_called_with(callback=l.on_control)
+        self.assertTrue(l.broadcast_consumer)
+        l.broadcast_consumer.consume.assert_called_with()
+
+        self.assertIsNone(l.connection)
+        self.assertTrue(connections[0].closed)
+
     def test_start__consume_messages(self):
     def test_start__consume_messages(self):
 
 
         class _QoS(object):
         class _QoS(object):

+ 33 - 0
celery/tests/test_worker/test_worker_autoscale.py

@@ -1,7 +1,10 @@
 import logging
 import logging
+import os
 
 
 from time import time
 from time import time
 
 
+from mock import Mock, patch
+
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.worker import state
 from celery.worker import state
 from celery.worker import autoscale
 from celery.worker import autoscale
@@ -17,6 +20,7 @@ class Object(object):
 
 
 class MockPool(BasePool):
 class MockPool(BasePool):
     shrink_raises_exception = False
     shrink_raises_exception = False
+    shrink_raises_ValueError = False
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super(MockPool, self).__init__(*args, **kwargs)
         super(MockPool, self).__init__(*args, **kwargs)
@@ -29,6 +33,8 @@ class MockPool(BasePool):
     def shrink(self, n=1):
     def shrink(self, n=1):
         if self.shrink_raises_exception:
         if self.shrink_raises_exception:
             raise KeyError("foo")
             raise KeyError("foo")
+        if self.shrink_raises_ValueError:
+            raise ValueError("foo")
         self._pool._processes -= n
         self._pool._processes -= n
 
 
     @property
     @property
@@ -100,3 +106,30 @@ class test_Autoscaler(unittest.TestCase):
         x._last_action = time() - 10000
         x._last_action = time() - 10000
         x.pool.shrink_raises_exception = True
         x.pool.shrink_raises_exception = True
         x.scale_down(1)
         x.scale_down(1)
+
+    def test_shrink_raises_ValueError(self):
+        x = autoscale.Autoscaler(self.pool, 10, 3, logger=logger)
+        x.logger = Mock()
+        x.scale_up(3)
+        x._last_action = time() - 10000
+        x.pool.shrink_raises_ValueError = True
+        x.scale_down(1)
+        self.assertTrue(x.logger.debug.call_count)
+
+    @patch("os._exit")
+    def test_thread_crash(self, _exit):
+
+        class _Autoscaler(autoscale.Autoscaler):
+
+            def scale(self):
+                self._shutdown.set()
+                raise OSError("foo")
+
+        x = _Autoscaler(self.pool, 10, 3, logger=logger)
+        x.logger = Mock()
+        x.run()
+        _exit.assert_called_with(1)
+        self.assertTrue(x.logger.error.call_count)
+
+
+

+ 1 - 0
celery/tests/test_worker/test_worker_heartbeat.py

@@ -51,6 +51,7 @@ class TestHeart(unittest.TestCase):
         self.assertTrue(h.tref)
         self.assertTrue(h.tref)
         h.stop()
         h.stop()
         self.assertIsNone(h.tref)
         self.assertIsNone(h.tref)
+        h.stop()
 
 
     @sleepdeprived
     @sleepdeprived
     def test_run_manages_cycle(self):
     def test_run_manages_cycle(self):

+ 47 - 1
celery/tests/test_worker/test_worker_job.py

@@ -23,7 +23,8 @@ from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 from celery.worker.job import (WorkerTaskTrace, TaskRequest,
 from celery.worker.job import (WorkerTaskTrace, TaskRequest,
-                               InvalidTaskError, execute_and_trace)
+                               InvalidTaskError, execute_and_trace,
+                               default_encode)
 from celery.worker.state import revoked
 from celery.worker.state import revoked
 
 
 from celery.tests.compat import catch_warnings
 from celery.tests.compat import catch_warnings
@@ -71,6 +72,26 @@ def mytask_raising(i, **kwargs):
     raise KeyError(i)
     raise KeyError(i)
 
 
 
 
+class test_default_encode(unittest.TestCase):
+
+    def test_jython(self):
+        prev, sys.platform = sys.platform, "java 1.6.1"
+        try:
+            self.assertEqual(default_encode("foo"), "foo")
+        finally:
+            sys.platform = prev
+
+    def test_cython(self):
+        prev, sys.platform = sys.platform, "darwin"
+        gfe, sys.getfilesystemencoding = sys.getfilesystemencoding, \
+                                         lambda: "utf-8"
+        try:
+            self.assertEqual(default_encode("foo"), "foo")
+        finally:
+            sys.platform = prev
+            sys.getfilesystemencoding = gfe
+
+
 class test_RetryTaskError(unittest.TestCase):
 class test_RetryTaskError(unittest.TestCase):
 
 
     def test_retry_task_error(self):
     def test_retry_task_error(self):
@@ -197,6 +218,23 @@ class test_TaskRequest(unittest.TestCase):
         tw.on_failure(einfo)
         tw.on_failure(einfo)
         self.assertIn("task-retried", tw.eventer.sent)
         self.assertIn("task-retried", tw.eventer.sent)
 
 
+    def test_terminate__task_started(self):
+        pool = Mock()
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = time.time()
+        tw.worker_pid = 313
+        tw.terminate(pool, signal="KILL")
+        pool.terminate_job.assert_called_with(tw.worker_pid, "KILL")
+
+    def test_terminate__task_reserved(self):
+        pool = Mock()
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = None
+        tw.terminate(pool, signal="KILL")
+        self.assertFalse(pool.terminate_job.call_count)
+        self.assertTupleEqual(tw._terminate_on_ack, (True, pool, "KILL"))
+        tw.terminate(pool, signal="KILL")
+
     def test_revoked_expires_expired(self):
     def test_revoked_expires_expired(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.expires = datetime.now() - timedelta(days=1)
         tw.expires = datetime.now() - timedelta(days=1)
@@ -314,6 +352,14 @@ class test_TaskRequest(unittest.TestCase):
         finally:
         finally:
             mytask.acks_late = False
             mytask.acks_late = False
 
 
+    def test_on_accepted_terminates(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        pool = Mock()
+        tw.terminate(pool, signal="KILL")
+        self.assertFalse(pool.terminate_job.call_count)
+        tw.on_accepted(pid=314, time_accepted=time.time())
+        pool.terminate_job.assert_called_with(314, "KILL")
+
     def test_on_success_acks_early(self):
     def test_on_success_acks_early(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.time_start = 1
         tw.time_start = 1

+ 20 - 1
celery/tests/test_worker/test_worker_mediator.py

@@ -2,7 +2,7 @@ from celery.tests.utils import unittest
 
 
 from Queue import Queue
 from Queue import Queue
 
 
-from mock import Mock
+from mock import Mock, patch
 
 
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
 from celery.worker.mediator import Mediator
 from celery.worker.mediator import Mediator
@@ -53,6 +53,25 @@ class test_Mediator(unittest.TestCase):
 
 
         self.assertEqual(got["value"], "George Costanza")
         self.assertEqual(got["value"], "George Costanza")
 
 
+    @patch("os._exit")
+    def test_mediator_crash(self, _exit):
+        ms = [None]
+
+        class _Mediator(Mediator):
+
+            def move(self):
+                try:
+                    raise KeyError("foo")
+                finally:
+                    ms[0]._shutdown.set()
+
+        ready_queue = Queue()
+        ms[0] = m = _Mediator(ready_queue, None)
+        ready_queue.put(MockTask("George Constanza"))
+        m.run()
+
+        self.assertTrue(_exit.call_count)
+
     def test_mediator_move_exception(self):
     def test_mediator_move_exception(self):
         ready_queue = Queue()
         ready_queue = Queue()
 
 

+ 5 - 1
celery/utils/encoding.py

@@ -11,6 +11,10 @@ def default_encoding():
 def safe_str(s, errors="replace"):
 def safe_str(s, errors="replace"):
     if not isinstance(s, basestring):
     if not isinstance(s, basestring):
         return safe_repr(s, errors)
         return safe_repr(s, errors)
+    return _safe_str(s, errors)
+
+
+def _safe_str(s, errors="replace"):
     encoding = default_encoding()
     encoding = default_encoding()
     try:
     try:
         if isinstance(s, unicode):
         if isinstance(s, unicode):
@@ -24,4 +28,4 @@ def safe_repr(o, errors="replace"):
     try:
     try:
         return repr(o)
         return repr(o)
     except Exception:
     except Exception:
-        return safe_str(o, errors)
+        return _safe_str(o, errors)

+ 4 - 4
celery/worker/consumer.py

@@ -135,24 +135,24 @@ class QoS(object):
         self.value = initial_value
         self.value = initial_value
 
 
     def increment(self, n=1):
     def increment(self, n=1):
-        """Increment the current prefetch count value by one."""
+        """Increment the current prefetch count value by n."""
         with self._mutex:
         with self._mutex:
             if self.value:
             if self.value:
                 new_value = self.value + max(n, 0)
                 new_value = self.value + max(n, 0)
                 self.value = self.set(new_value)
                 self.value = self.set(new_value)
-            return self.value
+        return self.value
 
 
     def _sub(self, n=1):
     def _sub(self, n=1):
         assert self.value - n > 1
         assert self.value - n > 1
         self.value -= n
         self.value -= n
 
 
     def decrement(self, n=1):
     def decrement(self, n=1):
-        """Decrement the current prefetch count value by one."""
+        """Decrement the current prefetch count value by n."""
         with self._mutex:
         with self._mutex:
             if self.value:
             if self.value:
                 self._sub(n)
                 self._sub(n)
                 self.set(self.value)
                 self.set(self.value)
-            return self.value
+        return self.value
 
 
     def decrement_eventually(self, n=1):
     def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.
         """Decrement the value, but do not update the qos.

+ 3 - 9
celery/worker/job.py

@@ -17,7 +17,7 @@ from celery.datastructures import ExceptionInfo
 from celery.execute.trace import TaskTrace
 from celery.execute.trace import TaskTrace
 from celery.utils import (noop, kwdict, fun_takes_kwargs,
 from celery.utils import (noop, kwdict, fun_takes_kwargs,
                           get_symbol_by_name, truncate_text)
                           get_symbol_by_name, truncate_text)
-from celery.utils.encoding import safe_repr, safe_str
+from celery.utils.encoding import safe_repr, safe_str, default_encoding
 from celery.utils.timeutils import maybe_iso8601
 from celery.utils.timeutils import maybe_iso8601
 from celery.worker import state
 from celery.worker import state
 
 
@@ -52,11 +52,7 @@ class InvalidTaskError(Exception):
 
 
 
 
 def default_encode(obj):
 def default_encode(obj):
-    if sys.platform.startswith("java"):
-        coding = "utf-8"
-    else:
-        coding = sys.getfilesystemencoding()
-    return unicode(obj, coding)
+    return unicode(obj, default_encoding())
 
 
 
 
 class WorkerTaskTrace(TaskTrace):
 class WorkerTaskTrace(TaskTrace):
@@ -401,9 +397,7 @@ class TaskRequest(object):
                 self.task.backend.mark_as_revoked(self.task_id)
                 self.task.backend.mark_as_revoked(self.task_id)
 
 
     def terminate(self, pool, signal=None):
     def terminate(self, pool, signal=None):
-        if self._terminate_on_ack is not None:
-            return
-        elif self.time_start:
+        if self.time_start:
             return pool.terminate_job(self.worker_pid, signal)
             return pool.terminate_job(self.worker_pid, signal)
         else:
         else:
             self._terminate_on_ack = (True, pool, signal)
             self._terminate_on_ack = (True, pool, signal)

+ 1 - 1
celery/worker/state.py

@@ -51,7 +51,7 @@ def task_ready(request):
     reserved_requests.discard(request)
     reserved_requests.discard(request)
 
 
 
 
-if os.environ.get("CELERY_BENCH"):
+if os.environ.get("CELERY_BENCH"):  # pragma: no cover
     from time import time
     from time import time
 
 
     all_count = 0
     all_count = 0