Kaynağa Gözat

Use with self.assertRaises

Ask Solem 13 yıl önce
ebeveyn
işleme
faffab30d3
61 değiştirilmiş dosya ile 386 ekleme ve 221 silme
  1. 2 0
      celery/tests/compat.py
  2. 3 1
      celery/tests/test_app/__init__.py
  3. 2 2
      celery/tests/test_app/test_app_amqp.py
  4. 3 2
      celery/tests/test_app/test_beat.py
  5. 1 0
      celery/tests/test_app/test_celery.py
  6. 9 9
      celery/tests/test_app/test_loaders.py
  7. 2 1
      celery/tests/test_app/test_log.py
  8. 5 1
      celery/tests/test_app/test_routes.py
  9. 2 2
      celery/tests/test_backends/__init__.py
  10. 33 26
      celery/tests/test_backends/test_amqp.py
  11. 35 34
      celery/tests/test_backends/test_base.py
  12. 3 2
      celery/tests/test_backends/test_cache.py
  13. 5 2
      celery/tests/test_backends/test_database.py
  14. 2 0
      celery/tests/test_backends/test_pyredis_compat.py
  15. 3 3
      celery/tests/test_backends/test_redis.py
  16. 2 0
      celery/tests/test_backends/test_redis_unit.py
  17. 3 1
      celery/tests/test_backends/test_tyrant.py
  18. 4 2
      celery/tests/test_bin/__init__.py
  19. 6 3
      celery/tests/test_bin/test_celerybeat.py
  20. 21 17
      celery/tests/test_bin/test_celeryd.py
  21. 2 0
      celery/tests/test_bin/test_celeryev.py
  22. 1 0
      celery/tests/test_compat/test_decorators.py
  23. 2 0
      celery/tests/test_compat/test_messaging.py
  24. 2 0
      celery/tests/test_concurrency/__init__.py
  25. 2 0
      celery/tests/test_concurrency/test_concurrency_eventlet.py
  26. 4 1
      celery/tests/test_concurrency/test_concurrency_processes.py
  27. 2 0
      celery/tests/test_concurrency/test_concurrency_solo.py
  28. 2 0
      celery/tests/test_concurrency/test_pool.py
  29. 8 4
      celery/tests/test_events/__init__.py
  30. 2 0
      celery/tests/test_events/test_events_cursesmon.py
  31. 4 1
      celery/tests/test_events/test_events_snapshot.py
  32. 3 1
      celery/tests/test_events/test_events_state.py
  33. 11 4
      celery/tests/test_slow/test_buckets.py
  34. 54 33
      celery/tests/test_task/__init__.py
  35. 2 0
      celery/tests/test_task/test_chord.py
  36. 3 2
      celery/tests/test_task/test_context.py
  37. 6 3
      celery/tests/test_task/test_execute_trace.py
  38. 7 3
      celery/tests/test_task/test_registry.py
  39. 26 15
      celery/tests/test_task/test_result.py
  40. 2 1
      celery/tests/test_task/test_states.py
  41. 2 1
      celery/tests/test_task/test_task_abortable.py
  42. 1 0
      celery/tests/test_task/test_task_builtins.py
  43. 6 2
      celery/tests/test_task/test_task_control.py
  44. 9 4
      celery/tests/test_task/test_task_http.py
  45. 1 0
      celery/tests/test_task/test_task_sets.py
  46. 3 1
      celery/tests/test_utils/__init__.py
  47. 12 5
      celery/tests/test_utils/test_datastructures.py
  48. 2 1
      celery/tests/test_utils/test_pickle.py
  49. 1 0
      celery/tests/test_utils/test_serialization.py
  50. 1 0
      celery/tests/test_utils/test_timer2.py
  51. 2 0
      celery/tests/test_utils/test_utils_encoding.py
  52. 2 1
      celery/tests/test_utils/test_utils_info.py
  53. 2 1
      celery/tests/test_utils/test_utils_timeutils.py
  54. 25 13
      celery/tests/test_worker/__init__.py
  55. 2 1
      celery/tests/test_worker/test_worker_autoscale.py
  56. 9 5
      celery/tests/test_worker/test_worker_control.py
  57. 2 1
      celery/tests/test_worker/test_worker_heartbeat.py
  58. 7 6
      celery/tests/test_worker/test_worker_job.py
  59. 2 1
      celery/tests/test_worker/test_worker_mediator.py
  60. 2 1
      celery/tests/test_worker/test_worker_revoke.py
  61. 2 1
      celery/tests/test_worker/test_worker_state.py

+ 2 - 0
celery/tests/compat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 
 

+ 3 - 1
celery/tests/test_app/__init__.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import os
@@ -264,7 +265,8 @@ class test_defaults(unittest.TestCase):
             self.assertFalse(defaults.str_to_bool(s))
         for s in ("true", "yes", "1"):
             self.assertTrue(defaults.str_to_bool(s))
-        self.assertRaises(TypeError, defaults.str_to_bool, "unsure")
+        with self.assertRaises(TypeError):
+            defaults.str_to_bool("unsure")
 
 
 class test_debugging_utils(unittest.TestCase):

+ 2 - 2
celery/tests/test_app/test_app_amqp.py

@@ -1,10 +1,10 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 from mock import Mock
 
-from celery.tests.utils import AppCase
-
 from celery.app.amqp import MSG_OPTIONS, extract_msg_options
+from celery.tests.utils import AppCase
 
 
 class TestMsgOptions(AppCase):

+ 3 - 2
celery/tests/test_app/test_beat.py

@@ -1,8 +1,8 @@
+from __future__ import absolute_import
+
 import logging
-from celery.tests.utils import unittest
 
 from datetime import datetime, timedelta
-
 from nose import SkipTest
 
 from celery import beat
@@ -11,6 +11,7 @@ from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.task.base import Task
 from celery.utils import uuid
+from celery.tests.utils import unittest
 
 
 class Object(object):

+ 1 - 0
celery/tests/test_app/test_celery.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from celery.tests.utils import unittest
 
 import celery

+ 9 - 9
celery/tests/test_app/test_loaders.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import os
@@ -137,8 +138,8 @@ class TestLoaderBase(unittest.TestCase):
             self.assertIsInstance(warning, MockMail.SendmailWarning)
             self.assertIn("KeyError", warning.args[0])
 
-            self.assertRaises(KeyError, self.loader.mail_admins,
-                              fail_silently=False, **opts)
+            with self.assertRaises(KeyError):
+                self.loader.mail_admins(fail_silently=False, **opts)
 
     def test_mail_admins(self):
         MockMail.Mailer.raise_on_send = False
@@ -154,8 +155,8 @@ class TestLoaderBase(unittest.TestCase):
         self.assertIs(loader.mail, mail)
 
     def test_cmdline_config_ValueError(self):
-        self.assertRaises(ValueError, self.loader.cmdline_config_parser,
-                         ["broker.port=foobar"])
+        with self.assertRaises(ValueError):
+            self.loader.cmdline_config_parser(["broker.port=foobar"])
 
 
 class TestDefaultLoader(unittest.TestCase):
@@ -231,17 +232,16 @@ class test_AppLoader(unittest.TestCase):
     def test_config_from_envvar(self, key="CELERY_HARNESS_CFG1"):
         self.assertFalse(self.loader.config_from_envvar("HDSAJIHWIQHEWQU",
                                                         silent=True))
-        self.assertRaises(ImproperlyConfigured,
-                          self.loader.config_from_envvar, "HDSAJIHWIQHEWQU",
-                          silent=False)
+        with self.assertRaises(ImproperlyConfigured):
+            self.loader.config_from_envvar("HDSAJIHWIQHEWQU", silent=False)
         os.environ[key] = __name__ + ".object_config"
         self.assertTrue(self.loader.config_from_envvar(key))
         self.assertEqual(self.loader.conf["FOO"], 1)
         self.assertEqual(self.loader.conf["BAR"], 2)
 
         os.environ[key] = "unknown_asdwqe.asdwqewqe"
-        self.assertRaises(ImportError,
-                          self.loader.config_from_envvar, key, silent=False)
+        with self.assertRaises(ImportError):
+            self.loader.config_from_envvar(key, silent=False)
         self.assertFalse(self.loader.config_from_envvar(key, silent=True))
 
         os.environ[key] = __name__ + ".dict_config"

+ 2 - 1
celery/tests/test_app/test_log.py

@@ -1,8 +1,8 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys
 import logging
-from celery.tests.utils import unittest
 from tempfile import mktemp
 
 from celery import log
@@ -12,6 +12,7 @@ from celery.log import (setup_logger, setup_task_logger,
                         setup_logging_subsystem)
 from celery.utils import uuid
 from celery.utils.compat import _CompatLoggerAdapter
+from celery.tests.utils import unittest
 from celery.tests.utils import (override_stdouts, wrap_logger,
                                 get_handlers, set_handlers)
 

+ 5 - 1
celery/tests/test_app/test_routes.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 from functools import wraps
 
 from celery import routes
@@ -65,7 +68,8 @@ class test_MapRoute(unittest.TestCase):
     def test_expand_route_not_found(self):
         expand = E(current_app.conf.CELERY_QUEUES)
         route = routes.MapRoute({"a": {"queue": "x"}})
-        self.assertRaises(QueueNotFound, expand, route.route_for_task("a"))
+        with self.assertRaises(QueueNotFound):
+            expand(route.route_for_task("a"))
 
 
 class test_lookup_route(unittest.TestCase):

+ 2 - 2
celery/tests/test_backends/__init__.py

@@ -1,10 +1,10 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
-from celery.tests.utils import unittest
-
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
+from celery.tests.utils import unittest
 
 
 class TestBackends(unittest.TestCase):

+ 33 - 26
celery/tests/test_backends/test_amqp.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import socket
@@ -111,11 +112,11 @@ class test_AMQPBackend(unittest.TestCase):
             Producer = _Producer
 
         backend = Backend()
-        self.assertRaises(KeyError, backend.store_result,
-                          "foo", "bar", "STARTED", max_retries=None)
+        with self.assertRaises(KeyError):
+            backend.store_result("foo", "bar", "STARTED", max_retries=None)
 
-        self.assertRaises(KeyError, backend.store_result,
-                          "foo", "bar", "STARTED", max_retries=10)
+        with self.assertRaises(KeyError):
+            backend.store_result("foo", "bar", "STARTED", max_retries=10)
 
     def assertState(self, retval, state):
         self.assertEqual(retval["status"], state)
@@ -182,11 +183,14 @@ class test_AMQPBackend(unittest.TestCase):
         b = self.create_backend()
 
         tid = uuid()
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, None, states.STARTED)
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, None, states.RETRY)
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, 42, states.SUCCESS)
         self.assertEqual(b.wait_for(tid, timeout=1), 42)
         b.store_result(tid, 56, states.SUCCESS)
@@ -194,7 +198,8 @@ class test_AMQPBackend(unittest.TestCase):
                          "result is cached")
         self.assertEqual(b.wait_for(tid, timeout=1, cache=False), 56)
         b.store_result(tid, KeyError("foo"), states.FAILURE)
-        self.assertRaises(KeyError, b.wait_for, tid, timeout=1, cache=False)
+        with self.assertRaises(KeyError):
+            b.wait_for(tid, timeout=1, cache=False)
 
     def test_drain_events_remaining_timeouts(self):
 
@@ -207,8 +212,8 @@ class test_AMQPBackend(unittest.TestCase):
         with current_app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(uuid())
             consumer = b._create_consumer(binding, channel)
-            self.assertRaises(socket.timeout, b.drain_events,
-                              Connection(), consumer, timeout=0.1)
+            with self.assertRaises(socket.timeout):
+                b.drain_events(Connection(), consumer, timeout=0.1)
 
     def test_get_many(self):
         b = self.create_backend()
@@ -230,8 +235,8 @@ class test_AMQPBackend(unittest.TestCase):
         cached_res = list(b.get_many(tids, timeout=1))
         self.assertEqual(sorted(cached_res), sorted(expected_results))
         b._cache[res[0][0]]["status"] = states.RETRY
-        self.assertRaises(socket.timeout, list,
-                          b.get_many(tids, timeout=0.01))
+        with self.assertRaises(socket.timeout):
+            list(b.get_many(tids, timeout=0.01))
 
     def test_test_get_many_raises_outer_block(self):
 
@@ -241,7 +246,8 @@ class test_AMQPBackend(unittest.TestCase):
                 raise KeyError("foo")
 
         b = Backend()
-        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+        with self.assertRaises(KeyError):
+            b.get_many(["id1"]).next()
 
     def test_test_get_many_raises_inner_block(self):
 
@@ -251,7 +257,8 @@ class test_AMQPBackend(unittest.TestCase):
                 raise KeyError("foo")
 
         b = Backend()
-        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+        with self.assertRaises(KeyError):
+            b.get_many(["id1"]).next()
 
     def test_no_expires(self):
         b = self.create_backend(expires=None)
@@ -260,8 +267,8 @@ class test_AMQPBackend(unittest.TestCase):
         app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
         try:
             b = self.create_backend(expires=None)
-            self.assertRaises(KeyError, b.queue_arguments.__getitem__,
-                              "x-expires")
+            with self.assertRaises(KeyError):
+                b.queue_arguments["x-expires"]
         finally:
             app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
 
@@ -269,21 +276,21 @@ class test_AMQPBackend(unittest.TestCase):
         self.create_backend().process_cleanup()
 
     def test_reload_task_result(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().reload_task_result, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().reload_task_result("x")
 
     def test_reload_taskset_result(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().reload_taskset_result, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().reload_taskset_result("x")
 
     def test_save_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().save_taskset, "x", "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().save_taskset("x", "x")
 
     def test_restore_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().restore_taskset, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().restore_taskset("x")
 
     def test_delete_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().delete_taskset, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().delete_taskset("x")

+ 35 - 34
celery/tests/test_backends/test_base.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys
@@ -48,52 +49,52 @@ class test_serialization(unittest.TestCase):
 class test_BaseBackend_interface(unittest.TestCase):
 
     def test_get_status(self):
-        self.assertRaises(NotImplementedError,
-                b.get_status, "SOMExx-N0Nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_status("SOMExx-N0Nex1stant-IDxx-")
 
     def test__forget(self):
-        self.assertRaises(NotImplementedError,
-                b.forget, "SOMExx-N0Nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.forget("SOMExx-N0Nex1stant-IDxx-")
 
     def test_store_result(self):
-        self.assertRaises(NotImplementedError,
-                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
+        with self.assertRaises(NotImplementedError):
+            b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
 
     def test_mark_as_started(self):
-        self.assertRaises(NotImplementedError,
-                b.mark_as_started, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.mark_as_started("SOMExx-N0nex1stant-IDxx-")
 
     def test_reload_task_result(self):
-        self.assertRaises(NotImplementedError,
-                b.reload_task_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.reload_task_result("SOMExx-N0nex1stant-IDxx-")
 
     def test_reload_taskset_result(self):
-        self.assertRaises(NotImplementedError,
-                b.reload_taskset_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.reload_taskset_result("SOMExx-N0nex1stant-IDxx-")
 
     def test_get_result(self):
-        self.assertRaises(NotImplementedError,
-                b.get_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_result("SOMExx-N0nex1stant-IDxx-")
 
     def test_restore_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.restore_taskset, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.restore_taskset("SOMExx-N0nex1stant-IDxx-")
 
     def test_delete_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.delete_taskset, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.delete_taskset("SOMExx-N0nex1stant-IDxx-")
 
     def test_save_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.save_taskset, "SOMExx-N0nex1stant-IDxx-", "blergh")
+        with self.assertRaises(NotImplementedError):
+            b.save_taskset("SOMExx-N0nex1stant-IDxx-", "blergh")
 
     def test_get_traceback(self):
-        self.assertRaises(NotImplementedError,
-                b.get_traceback, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_traceback("SOMExx-N0nex1stant-IDxx-")
 
     def test_forget(self):
-        self.assertRaises(NotImplementedError,
-                b.forget, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.forget("SOMExx-N0nex1stant-IDxx-")
 
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
         from celery.registry import tasks
@@ -277,27 +278,27 @@ class test_KeyValueStoreBackend(unittest.TestCase):
 class test_KeyValueStoreBackend_interface(unittest.TestCase):
 
     def test_get(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().get,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().get("a")
 
     def test_set(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().set,
-                "a", 1)
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().set("a", 1)
 
     def test_cleanup(self):
         self.assertFalse(KeyValueStoreBackend().cleanup())
 
     def test_delete(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().delete,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().delete("a")
 
     def test_mget(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().mget,
-                ["a"])
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().mget(["a"])
 
     def test_forget(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().forget,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().forget("a")
 
 
 class test_DisabledBackend(unittest.TestCase):

+ 3 - 2
celery/tests/test_backends/test_cache.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys
@@ -82,8 +83,8 @@ class test_CacheBackend(unittest.TestCase):
         self.assertEqual(tb.expires, 10)
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
-        self.assertRaises(ImproperlyConfigured,
-                          CacheBackend, backend="unknown://")
+        with self.assertRaises(ImproperlyConfigured):
+            CacheBackend(backend="unknown://")
 
 
 class MyClient(DummyClient):

+ 5 - 2
celery/tests/test_backends/test_database.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys
@@ -43,7 +44,8 @@ class test_DatabaseBackend(unittest.TestCase):
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
         with mask_modules("sqlalchemy"):
             from celery.backends.database import _sqlalchemy_installed
-            self.assertRaises(ImproperlyConfigured, _sqlalchemy_installed)
+            with self.assertRaises(ImproperlyConfigured):
+                _sqlalchemy_installed()
 
     def test_pickle_hack_for_sqla_05(self):
         import sqlalchemy as sa
@@ -66,7 +68,8 @@ class test_DatabaseBackend(unittest.TestCase):
         conf = app_or_default().conf
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         try:
-            self.assertRaises(ImproperlyConfigured, DatabaseBackend)
+            with self.assertRaises(ImproperlyConfigured):
+                DatabaseBackend()
         finally:
             conf.CELERY_RESULT_DBURI = prev
 

+ 2 - 0
celery/tests/test_backends/test_pyredis_compat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 
 from celery.exceptions import ImproperlyConfigured

+ 3 - 3
celery/tests/test_backends/test_redis.py

@@ -3,7 +3,6 @@ from __future__ import with_statement
 
 import sys
 import socket
-from celery.tests.utils import unittest
 
 from nose import SkipTest
 
@@ -13,8 +12,8 @@ from celery import states
 from celery.utils import uuid
 from celery.backends import redis
 from celery.backends.redis import RedisBackend
-
 from celery.tests.utils import mask_modules
+from celery.tests.utils import unittest
 
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg_emitted = False
@@ -112,6 +111,7 @@ class TestRedisBackendNoRedis(unittest.TestCase):
         prev = redis.RedisBackend.redis
         redis.RedisBackend.redis = None
         try:
-            self.assertRaises(ImproperlyConfigured, redis.RedisBackend)
+            with self.assertRaises(ImproperlyConfigured):
+                redis.RedisBackend()
         finally:
             redis.RedisBackend.redis = prev

+ 2 - 0
celery/tests/test_backends/test_redis_unit.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import timedelta
 
 from mock import Mock, patch

+ 3 - 1
celery/tests/test_backends/test_tyrant.py

@@ -1,6 +1,7 @@
+from __future__ import absolute_import
+
 import sys
 import socket
-from celery.tests.utils import unittest
 
 from nose import SkipTest
 
@@ -10,6 +11,7 @@ from celery import states
 from celery.utils import uuid
 from celery.backends import tyrant
 from celery.backends.tyrant import TyrantBackend
+from celery.tests.utils import unittest
 
 _no_tyrant_msg = "* Tokyo Tyrant %s. Will not execute related tests."
 _no_tyrant_msg_emitted = False

+ 4 - 2
celery/tests/test_bin/__init__.py

@@ -1,7 +1,8 @@
+from __future__ import absolute_import
+
 import os
 
 from celery.bin.base import Command
-
 from celery.tests.utils import AppCase
 
 
@@ -35,7 +36,8 @@ class test_Command(AppCase):
         self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
 
     def test_run_interface(self):
-        self.assertRaises(NotImplementedError, Command().run)
+        with self.assertRaises(NotImplementedError):
+            Command().run()
 
     def test_execute_from_commandline(self):
         cmd = MockCommand()

+ 6 - 3
celery/tests/test_bin/test_celerybeat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import logging
 import sys
 
@@ -98,8 +100,8 @@ class test_Beat(AppCase):
         clock = MockService()
         MockService.in_sync = False
         handlers = self.psig(b.install_sync_handler, clock)
-        self.assertRaises(SystemExit, handlers["SIGINT"],
-                          "SIGINT", object())
+        with self.assertRaises(SystemExit):
+            handlers["SIGINT"]("SIGINT", object())
         self.assertTrue(MockService.in_sync)
         MockService.in_sync = False
 
@@ -112,7 +114,8 @@ class test_Beat(AppCase):
         b = beatapp.Beat()
         b.redirect_stdouts = False
         b.setup_logging()
-        self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
 
     @redirect_stdouts
     def test_logs_errors(self, stdout, stderr):

+ 21 - 17
celery/tests/test_bin/test_celeryd.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import logging
@@ -209,8 +210,8 @@ class test_Worker(AppCase):
             self.assertNotIn("celery", app.amqp.queues.consume_from)
 
             c.CELERY_CREATE_MISSING_QUEUES = False
-            self.assertRaises(ImproperlyConfigured,
-                    self.Worker(queues=["image"]).init_queues)
+            with self.assertRaises(ImproperlyConfigured):
+                self.Worker(queues=["image"]).init_queues()
             c.CELERY_CREATE_MISSING_QUEUES = True
             worker = self.Worker(queues=["image"])
             worker.init_queues()
@@ -241,7 +242,8 @@ class test_Worker(AppCase):
 
     @disable_stdouts
     def test_unknown_loglevel(self):
-        self.assertRaises(SystemExit, self.Worker, loglevel="ALIEN")
+        with self.assertRaises(SystemExit):
+            self.Worker(loglevel="ALIEN")
         worker1 = self.Worker(loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
@@ -299,7 +301,8 @@ class test_Worker(AppCase):
         worker = self.Worker()
         worker.redirect_stdouts = False
         worker.redirect_stdouts_to_logger()
-        self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
 
     def test_redirect_stdouts_already_handled(self):
         logging_setup = [False]
@@ -313,7 +316,8 @@ class test_Worker(AppCase):
             worker.app.log.__class__._setup = False
             worker.redirect_stdouts_to_logger()
             self.assertTrue(logging_setup[0])
-            self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+            with self.assertRaises(AttributeError):
+                sys.stdout.logger
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
 
@@ -470,14 +474,14 @@ class test_signal_handlers(AppCase):
 
         p, platforms.signals = platforms.signals, Signals()
         try:
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                              "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertTrue(worker.stopped)
         finally:
             platforms.signals = p
 
-        self.assertRaises(SystemExit, next_handlers["SIGINT"],
-                          "SIGINT", object())
+        with self.assertRaises(SystemExit):
+            next_handlers["SIGINT"]("SIGINT", object())
         self.assertTrue(worker.terminated)
 
     @disable_stdouts
@@ -489,8 +493,8 @@ class test_signal_handlers(AppCase):
         try:
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_int_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                            "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertFalse(worker.stopped)
         finally:
             process.name = name
@@ -510,8 +514,8 @@ class test_signal_handlers(AppCase):
         try:
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_int_again_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                            "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertFalse(worker.terminated)
         finally:
             process.name = name
@@ -520,8 +524,8 @@ class test_signal_handlers(AppCase):
     def test_worker_term_handler(self):
         worker = self._Worker()
         handlers = self.psig(cd.install_worker_term_handler, worker)
-        self.assertRaises(SystemExit, handlers["SIGTERM"],
-                          "SIGTERM", object())
+        with self.assertRaises(SystemExit):
+            handlers["SIGTERM"]("SIGTERM", object())
         self.assertTrue(worker.stopped)
 
     def test_worker_cry_handler(self):
@@ -552,8 +556,8 @@ class test_signal_handlers(AppCase):
         try:
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGTERM"],
-                          "SIGTERM", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGTERM"]("SIGTERM", object())
             self.assertFalse(worker.stopped)
         finally:
             process.name = name

+ 2 - 0
celery/tests/test_bin/test_celeryev.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 
 from celery.app import app_or_default

+ 1 - 0
celery/tests/test_compat/test_decorators.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import warnings

+ 2 - 0
celery/tests/test_compat/test_messaging.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery import messaging
 from celery.tests.utils import unittest
 

+ 2 - 0
celery/tests/test_concurrency/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 
 from itertools import count

+ 2 - 0
celery/tests/test_concurrency/test_concurrency_eventlet.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 import sys
 

+ 4 - 1
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import signal
 import sys
 
@@ -185,7 +187,8 @@ class test_TaskPool(unittest.TestCase):
     def test_on_ready_exit_exception(self):
         pool = TaskPool(10)
         exc = to_excinfo(SystemExit("foo"))
-        self.assertRaises(SystemExit, pool.on_ready, [], [], exc)
+        with self.assertRaises(SystemExit):
+            pool.on_ready([], [], exc)
 
     def test_apply_async(self):
         pool = TaskPool(10)

+ 2 - 0
celery/tests/test_concurrency/test_concurrency_solo.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import operator
 
 from celery.concurrency import solo

+ 2 - 0
celery/tests/test_concurrency/test_pool.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import time
 import logging

+ 8 - 4
celery/tests/test_events/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import socket
 
 from celery import events
@@ -54,7 +56,8 @@ class TestEventDispatcher(unittest.TestCase):
         eventer.enabled = True
         eventer.publisher.raise_on_publish = True
         eventer.buffer_while_offline = False
-        self.assertRaises(KeyError, eventer.send, "Event X")
+        with self.assertRaises(KeyError):
+            eventer.send("Event X")
         eventer.buffer_while_offline = True
         for ev in evs:
             eventer.send(ev)
@@ -142,10 +145,11 @@ class TestEventReceiver(unittest.TestCase):
             self.assertTrue(consumer.queues)
             self.assertEqual(consumer.callbacks[0], r._receive)
 
-            self.assertRaises(socket.timeout, it.next)
+            with self.assertRaises(socket.timeout):
+                it.next()
 
-            self.assertRaises(socket.timeout,
-                              r.capture, timeout=0.00001)
+            with self.assertRaises(socket.timeout):
+                r.capture(timeout=0.00001)
         finally:
             connection.close()
 

+ 2 - 0
celery/tests/test_events/test_events_cursesmon.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 
 from celery.tests.utils import unittest

+ 4 - 1
celery/tests/test_events/test_events_snapshot.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery.app import app_or_default
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
@@ -117,4 +119,5 @@ class test_evcam(unittest.TestCase):
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         self.MockReceiver.raise_keyboard_interrupt = True
-        self.assertRaises(SystemExit, evcam, Polaroid, timer=timer)
+        with self.assertRaises(SystemExit):
+            evcam(Polaroid, timer=timer)

+ 3 - 1
celery/tests/test_events/test_events_state.py

@@ -1,5 +1,6 @@
+from __future__ import absolute_import
+
 from time import time
-from celery.tests.utils import unittest
 
 from itertools import count
 
@@ -7,6 +8,7 @@ from celery import states
 from celery.events import Event
 from celery.events.state import State, Worker, Task, HEARTBEAT_EXPIRE
 from celery.utils import uuid
+from celery.tests.utils import unittest
 
 
 class replay(object):

+ 11 - 4
celery/tests/test_slow/test_buckets.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import sys
 import time
 
@@ -42,7 +45,8 @@ class test_TokenBucketQueue(unittest.TestCase):
     @skip_if_disabled
     def empty_queue_yields_QueueEmpty(self):
         x = buckets.TokenBucketQueue(fill_rate=10)
-        self.assertRaises(buckets.Empty, x.get)
+        with self.assertRaises(buckets.Empty):
+            x.get()
 
     @skip_if_disabled
     def test_bucket__put_get(self):
@@ -73,7 +77,8 @@ class test_TokenBucketQueue(unittest.TestCase):
         time.sleep(0.1)
         # Not yet ready for another token
         x.put("The lazy dog")
-        self.assertRaises(x.RateLimitExceeded, x.get)
+        with self.assertRaises(x.RateLimitExceeded):
+            x.get()
 
     @skip_if_disabled
     def test_expected_time(self):
@@ -132,7 +137,8 @@ class test_TaskBucket(unittest.TestCase):
     @skip_if_disabled
     def test_get_nowait(self):
         x = buckets.TaskBucket(task_registry=self.registry)
-        self.assertRaises(buckets.Empty, x.get_nowait)
+        with self.assertRaises(buckets.Empty):
+            x.get_nowait()
 
     @skip_if_disabled
     def test_refresh(self):
@@ -194,7 +200,8 @@ class test_TaskBucket(unittest.TestCase):
     @skip_if_disabled
     def test_on_empty_buckets__get_raises_empty(self):
         b = buckets.TaskBucket(task_registry=self.registry)
-        self.assertRaises(buckets.Empty, b.get, block=False)
+        with self.assertRaises(buckets.Empty):
+            b.get(block=False)
         self.assertEqual(b.qsize(), 0)
 
     @skip_if_disabled

+ 54 - 33
celery/tests/test_task/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 from functools import wraps
 
@@ -150,8 +152,8 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(RetryTaskNoArgs.iterations, 4)
 
     def test_retry_kwargs_can_be_empty(self):
-        self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
-                            args=[4, 4], kwargs=None)
+        with self.assertRaises(RetryTaskError):
+            RetryTaskMockApply.retry(args=[4, 4], kwargs=None)
 
     def test_retry_not_eager(self):
         RetryTaskMockApply.request.called_directly = False
@@ -164,7 +166,8 @@ class TestTaskRetries(unittest.TestCase):
             RetryTaskMockApply.applied = 0
 
         try:
-            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
+            with self.assertRaises(RetryTaskError):
+                RetryTaskMockApply.retry(
                     args=[4, 4], kwargs={"task_retries": 0},
                     exc=exc, throw=True)
             self.assertTrue(RetryTaskMockApply.applied)
@@ -182,23 +185,23 @@ class TestTaskRetries(unittest.TestCase):
         RetryTaskCustomExc.max_retries = 2
         RetryTaskCustomExc.iterations = 0
         result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        self.assertRaises(MyCustomException,
-                          result.get)
+        with self.assertRaises(MyCustomException):
+            result.get()
         self.assertEqual(RetryTaskCustomExc.iterations, 3)
 
     def test_max_retries_exceeded(self):
         RetryTask.max_retries = 2
         RetryTask.iterations = 0
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
+        with self.assertRaises(RetryTask.MaxRetriesExceededError):
+            result.get()
         self.assertEqual(RetryTask.iterations, 3)
 
         RetryTask.max_retries = 1
         RetryTask.iterations = 0
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
+        with self.assertRaises(RetryTask.MaxRetriesExceededError):
+            result.get()
         self.assertEqual(RetryTask.iterations, 2)
 
 
@@ -255,15 +258,16 @@ class TestCeleryTasks(unittest.TestCase):
         class IncompleteTask(task.Task):
             name = "c.unittest.t.itask"
 
-        self.assertRaises(NotImplementedError, IncompleteTask().run)
+        with self.assertRaises(NotImplementedError):
+            IncompleteTask().run()
 
     def test_task_kwargs_must_be_dictionary(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          [], "str")
+        with self.assertRaises(ValueError):
+            IncrementCounterTask.apply_async([], "str")
 
     def test_task_args_must_be_list(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          "str", {})
+        with self.assertRaises(ValueError):
+            IncrementCounterTask.apply_async("str", {})
 
     def test_regular_task(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
@@ -280,7 +284,8 @@ class TestCeleryTasks(unittest.TestCase):
 
         t1 = T1()
         consumer = t1.get_consumer()
-        self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
+        with self.assertRaises(NotImplementedError):
+            consumer.receive("foo", "foo")
         consumer.discard_all()
         self.assertIsNone(consumer.fetch())
 
@@ -466,12 +471,14 @@ class TestTaskSet(unittest.TestCase):
 class TestTaskApply(unittest.TestCase):
 
     def test_apply_throw(self):
-        self.assertRaises(KeyError, RaisingTask.apply, throw=True)
+        with self.assertRaises(KeyError):
+            RaisingTask.apply(throw=True)
 
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
         RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         try:
-            self.assertRaises(KeyError, RaisingTask.apply)
+            with self.assertRaises(KeyError):
+                RaisingTask.apply()
         finally:
             RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
 
@@ -496,7 +503,8 @@ class TestTaskApply(unittest.TestCase):
         self.assertTrue(f.ready())
         self.assertFalse(f.successful())
         self.assertTrue(f.traceback)
-        self.assertRaises(KeyError, f.get)
+        with self.assertRaises(KeyError):
+            f.get()
 
 
 class MyPeriodic(task.PeriodicTask):
@@ -506,8 +514,8 @@ class MyPeriodic(task.PeriodicTask):
 class TestPeriodicTask(unittest.TestCase):
 
     def test_must_have_run_every(self):
-        self.assertRaises(NotImplementedError, type, "Foo",
-            (task.PeriodicTask, ), {"__module__": __name__})
+        with self.assertRaises(NotImplementedError):
+            type("Foo", (task.PeriodicTask, ), {"__module__": __name__})
 
     def test_remaining_estimate(self):
         self.assertIsInstance(
@@ -610,23 +618,28 @@ class test_crontab_parser(unittest.TestCase):
                     20, 25, 30, 35, 40, 45, 50, 55]))
 
     def test_parse_errors_on_empty_string(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('')
 
     def test_parse_errors_on_empty_group(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '1,,2')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('1,,2')
 
     def test_parse_errors_on_empty_steps(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '*/')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('*/')
 
     def test_parse_errors_on_negative_number(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '-20')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('-20')
 
     def test_expand_cronspec_eats_iterables(self):
         self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
                          set([1, 2, 3]))
 
     def test_expand_cronspec_invalid_type(self):
-        self.assertRaises(TypeError, crontab._expand_cronspec, object(), 100)
+        with self.assertRaises(TypeError):
+            crontab._expand_cronspec(object(), 100)
 
     def test_repr(self):
         self.assertIn("*", repr(crontab("*")))
@@ -720,8 +733,10 @@ class test_crontab_is_due(unittest.TestCase):
         self.assertEqual(c.minute, set([30, 40, 50]))
 
     def test_crontab_spec_invalid_minute(self):
-        self.assertRaises(ValueError, crontab, minute=60)
-        self.assertRaises(ValueError, crontab, minute='0-100')
+        with self.assertRaises(ValueError):
+            crontab(minute=60)
+        with self.assertRaises(ValueError):
+            crontab(minute='0-100')
 
     def test_crontab_spec_hour_formats(self):
         c = crontab(hour=6)
@@ -732,8 +747,10 @@ class test_crontab_is_due(unittest.TestCase):
         self.assertEqual(c.hour, set([4, 8, 12]))
 
     def test_crontab_spec_invalid_hour(self):
-        self.assertRaises(ValueError, crontab, hour=24)
-        self.assertRaises(ValueError, crontab, hour='0-30')
+        with self.assertRaises(ValueError):
+            crontab(hour=24)
+        with self.assertRaises(ValueError):
+            crontab(hour='0-30')
 
     def test_crontab_spec_dow_formats(self):
         c = crontab(day_of_week=5)
@@ -760,10 +777,14 @@ class test_crontab_is_due(unittest.TestCase):
                 break
 
     def test_crontab_spec_invalid_dow(self):
-        self.assertRaises(ValueError, crontab, day_of_week='fooday-barday')
-        self.assertRaises(ValueError, crontab, day_of_week='1,4,foo')
-        self.assertRaises(ValueError, crontab, day_of_week='7')
-        self.assertRaises(ValueError, crontab, day_of_week='12')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='fooday-barday')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='1,4,foo')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='7')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='12')
 
     def test_every_minute_execution_is_due(self):
         last_ran = self.now - timedelta(seconds=61)

+ 2 - 0
celery/tests/test_task/test_chord.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from mock import patch
 
 from celery import current_app

+ 3 - 2
celery/tests/test_task/test_context.py

@@ -1,9 +1,10 @@
 # -*- coding: utf-8 -*-"
-import threading
+from __future__ import absolute_import
 
-from celery.tests.utils import unittest
+import threading
 
 from celery.task.base import Context
+from celery.tests.utils import unittest
 
 
 # Retreive the values of all context attributes as a

+ 6 - 3
celery/tests/test_task/test_execute_trace.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import operator
 
 from celery import states
@@ -20,7 +22,8 @@ class test_TraceInfo(unittest.TestCase):
         self.assertEqual(info.retval, 4)
 
     def test_trace_SystemExit(self):
-        self.assertRaises(SystemExit, trace, raises, (SystemExit(), ), {})
+        with self.assertRaises(SystemExit):
+            trace(raises, (SystemExit(), ), {})
 
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError("foo", "bar")
@@ -35,5 +38,5 @@ class test_TraceInfo(unittest.TestCase):
         self.assertIs(info.retval, exc)
 
     def test_trace_exception_propagate(self):
-        self.assertRaises(KeyError, trace, raises, (KeyError("foo"), ), {},
-                          propagate=True)
+        with self.assertRaises(KeyError):
+            trace(raises, (KeyError("foo"), ), {}, propagate=True)

+ 7 - 3
celery/tests/test_task/test_registry.py

@@ -1,7 +1,9 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
+from __future__ import with_statement
 
 from celery import registry
 from celery.task import Task, PeriodicTask
+from celery.tests.utils import unittest
 
 
 class TestTask(Task):
@@ -22,12 +24,14 @@ class TestPeriodicTask(PeriodicTask):
 class TestTaskRegistry(unittest.TestCase):
 
     def assertRegisterUnregisterCls(self, r, task):
-        self.assertRaises(r.NotRegistered, r.unregister, task)
+        with self.assertRaises(r.NotRegistered):
+            r.unregister(task)
         r.register(task)
         self.assertIn(task.name, r)
 
     def assertRegisterUnregisterFunc(self, r, task, task_name):
-        self.assertRaises(r.NotRegistered, r.unregister, task_name)
+        with self.assertRaises(r.NotRegistered):
+            r.unregister(task_name)
         r.register(task, task_name)
         self.assertIn(task_name, r)
 

+ 26 - 15
celery/tests/test_task/test_result.py

@@ -121,21 +121,25 @@ class TestAsyncResult(AppCase):
 
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok2_res.get(), "quick")
-        self.assertRaises(KeyError, nok_res.get)
+        with self.assertRaises(KeyError):
+            nok_res.get()
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertEqual(ok_res.info, "the")
 
     def test_get_timeout(self):
         res = AsyncResult(self.task4["id"])             # has RETRY status
-        self.assertRaises(TimeoutError, res.get, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            res.get(timeout=0.1)
 
         pending_res = AsyncResult(uuid())
-        self.assertRaises(TimeoutError, pending_res.get, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            pending_res.get(timeout=0.1)
 
     @skip_if_quick
     def test_get_timeout_longer(self):
         res = AsyncResult(self.task4["id"])             # has RETRY status
-        self.assertRaises(TimeoutError, res.get, timeout=1)
+        with self.assertRaises(TimeoutError):
+            res.get(timeout=1)
 
     def test_ready(self):
         oks = (AsyncResult(self.task1["id"]),
@@ -224,7 +228,8 @@ class TestTaskSetResult(AppCase):
         ar = MockAsyncResultFailure(uuid())
         ts = TaskSetResult(uuid(), [ar])
         it = iter(ts)
-        self.assertRaises(KeyError, it.next)
+        with self.assertRaises(KeyError):
+            it.next()
 
     def test_forget(self):
         subs = [MockAsyncResultSuccess(uuid()),
@@ -245,14 +250,14 @@ class TestTaskSetResult(AppCase):
                 MockAsyncResultSuccess(uuid())]
         ts = TaskSetResult(uuid(), subs)
         ts.save()
-        self.assertRaises(AttributeError, ts.save, backend=object())
+        with self.assertRaises(AttributeError):
+            ts.save(backend=object())
         self.assertEqual(TaskSetResult.restore(ts.taskset_id).subtasks,
                          ts.subtasks)
         ts.delete()
         self.assertIsNone(TaskSetResult.restore(ts.taskset_id))
-        self.assertRaises(AttributeError,
-                          TaskSetResult.restore, ts.taskset_id,
-                          backend=object())
+        with self.assertRaises(AttributeError):
+            TaskSetResult.restore(ts.taskset_id, backend=object())
 
     def test_join_native(self):
         backend = SimpleBackend()
@@ -292,7 +297,8 @@ class TestTaskSetResult(AppCase):
         ar2 = MockAsyncResultSuccess(uuid())
         ar3 = AsyncResult(uuid())
         ts = TaskSetResult(uuid(), [ar, ar2, ar3])
-        self.assertRaises(TimeoutError, ts.join, timeout=0.0000001)
+        with self.assertRaises(TimeoutError):
+            ts.join(timeout=0.0000001)
 
     def test_itersubtasks(self):
 
@@ -367,10 +373,12 @@ class TestFailedTaskSetResult(TestTaskSetResult):
         def consume():
             return list(it)
 
-        self.assertRaises(KeyError, consume)
+        with self.assertRaises(KeyError):
+            consume()
 
     def test_join(self):
-        self.assertRaises(KeyError, self.ts.join)
+        with self.assertRaises(KeyError):
+            self.ts.join()
 
     def test_successful(self):
         self.assertFalse(self.ts.successful())
@@ -396,11 +404,13 @@ class TestTaskSetPending(AppCase):
         self.assertTrue(self.ts.waiting())
 
     def x_join(self):
-        self.assertRaises(TimeoutError, self.ts.join, timeout=0.001)
+        with self.assertRaises(TimeoutError):
+            self.ts.join(timeout=0.001)
 
     @skip_if_quick
     def x_join_longer(self):
-        self.assertRaises(TimeoutError, self.ts.join, timeout=1)
+        with self.assertRaises(TimeoutError):
+            self.ts.join(timeout=1)
 
 
 class RaisingTask(Task):
@@ -413,7 +423,8 @@ class TestEagerResult(AppCase):
 
     def test_wait_raises(self):
         res = RaisingTask.apply(args=[3, 3])
-        self.assertRaises(KeyError, res.wait)
+        with self.assertRaises(KeyError):
+            res.wait()
 
     def test_wait(self):
         res = EagerResult("x", "x", states.RETRY)

+ 2 - 1
celery/tests/test_task/test_states.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery.states import state
 from celery import states
+from celery.tests.utils import unittest
 
 
 class test_state_precedence(unittest.TestCase):

+ 2 - 1
celery/tests/test_task/test_task_abortable.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
+from celery.tests.utils import unittest
 
 
 class MyAbortableTask(AbortableTask):

+ 1 - 0
celery/tests/test_task/test_task_builtins.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import warnings

+ 6 - 2
celery/tests/test_task/test_task_control.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 from functools import wraps
 
 from kombu.pidbox import Mailbox
@@ -135,8 +138,9 @@ class test_Broadcast(unittest.TestCase):
 
     @with_mock_broadcast
     def test_broadcast_validate(self):
-        self.assertRaises(ValueError, self.control.broadcast, "foobarbaz2",
-                          destination="foo")
+        with self.assertRaises(ValueError):
+            self.control.broadcast("foobarbaz2",
+                                   destination="foo")
 
     @with_mock_broadcast
     def test_rate_limit(self):

+ 9 - 4
celery/tests/test_task/test_task_http.py

@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import logging
@@ -112,7 +113,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(fail_response("Invalid moon alignment")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.RemoteExecuteError, d.dispatch)
+            with self.assertRaises(http.RemoteExecuteError):
+                d.dispatch()
 
     def test_dispatch_empty_response(self):
         logger = logging.getLogger("celery.unittest")
@@ -120,7 +122,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(_response("")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.InvalidResponseError, d.dispatch)
+            with self.assertRaises(http.InvalidResponseError):
+                d.dispatch()
 
     def test_dispatch_non_json(self):
         logger = logging.getLogger("celery.unittest")
@@ -128,7 +131,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(_response("{'#{:'''")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.InvalidResponseError, d.dispatch)
+            with self.assertRaises(http.InvalidResponseError):
+                d.dispatch()
 
     def test_dispatch_unknown_status(self):
         logger = logging.getLogger("celery.unittest")
@@ -136,7 +140,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(unknown_response()):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.UnknownStatusError, d.dispatch)
+            with self.assertRaises(http.UnknownStatusError):
+                d.dispatch()
 
     def test_dispatch_POST(self):
         logger = logging.getLogger("celery.unittest")

+ 1 - 0
celery/tests/test_task/test_task_sets.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import anyjson

+ 3 - 1
celery/tests/test_utils/__init__.py

@@ -1,8 +1,10 @@
+from __future__ import absolute_import
+
 import pickle
-from celery.tests.utils import unittest
 
 from celery import utils
 from celery.utils import promise, mpromise, maybe_promise
+from celery.tests.utils import unittest
 
 
 def double(x):

+ 12 - 5
celery/tests/test_utils/test_datastructures.py

@@ -1,11 +1,14 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import sys
-from celery.tests.utils import unittest
 from Queue import Queue
 
 from celery.datastructures import ExceptionInfo, LRUCache
 from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
+from celery.tests.utils import unittest
 
 
 class Object(object):
@@ -21,7 +24,8 @@ class test_DictAttribute(unittest.TestCase):
         self.assertEqual(x["foo"], x.obj.foo)
         self.assertEqual(x.get("foo"), "The quick brown fox")
         self.assertIsNone(x.get("bar"))
-        self.assertRaises(KeyError, x.__getitem__, "bar")
+        with self.assertRaises(KeyError):
+            x["bar"]
 
     def test_setdefault(self):
         x = DictAttribute(Object())
@@ -96,11 +100,13 @@ class test_utilities(unittest.TestCase):
     def test_consume_queue(self):
         x = Queue()
         it = consume_queue(x)
-        self.assertRaises(StopIteration, it.next)
+        with self.assertRaises(StopIteration):
+            it.next()
         x.put("foo")
         it = consume_queue(x)
         self.assertEqual(it.next(), "foo")
-        self.assertRaises(StopIteration, it.next)
+        with self.assertRaises(StopIteration):
+            it.next()
 
 
 class test_LimitedSet(unittest.TestCase):
@@ -166,6 +172,7 @@ class test_AttributeDict(unittest.TestCase):
     def test_getattr__setattr(self):
         x = AttributeDict({"foo": "bar"})
         self.assertEqual(x["foo"], "bar")
-        self.assertRaises(AttributeError, getattr, x, "bar")
+        with self.assertRaises(AttributeError):
+            x.bar
         x.bar = "foo"
         self.assertEqual(x["bar"], "foo")

+ 2 - 1
celery/tests/test_utils/test_pickle.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery.utils.serialization import pickle
+from celery.tests.utils import unittest
 
 
 class RegularException(Exception):

+ 1 - 0
celery/tests/test_utils/test_serialization.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys

+ 1 - 0
celery/tests/test_utils/test_timer2.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import sys

+ 2 - 0
celery/tests/test_utils/test_utils_encoding.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 
 from nose import SkipTest

+ 2 - 1
celery/tests/test_utils/test_utils_info.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery import Celery
 from celery.utils import textindent
+from celery.tests.utils import unittest
 
 RANDTEXT = """\
 The quick brown

+ 2 - 1
celery/tests/test_utils/test_utils_timeutils.py

@@ -1,7 +1,8 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 
 from celery.utils import timeutils
-
 from celery.tests.utils import unittest
 
 

+ 25 - 13
celery/tests/test_worker/__init__.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import socket
@@ -112,7 +113,8 @@ class test_QoS(unittest.TestCase):
         self.assertEqual(qos.increment(-30), 14)
         self.assertEqual(qos.decrement(7), 7)
         self.assertEqual(qos.decrement(), 6)
-        self.assertRaises(AssertionError, qos.decrement, 10)
+        with self.assertRaises(AssertionError):
+            qos.decrement(10)
 
     def test_qos_disabled_increment_decrement(self):
         qos = self._QoS(0)
@@ -358,7 +360,8 @@ class test_Consumer(unittest.TestCase):
         l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False, pool=BasePool())
         l.connection_errors = (KeyError, )
-        self.assertRaises(SyntaxError, l.start)
+        with self.assertRaises(SyntaxError):
+            l.start()
         l.heart.stop()
         l.priority_timer.stop()
 
@@ -433,8 +436,8 @@ class test_Consumer(unittest.TestCase):
         l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
         l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
         l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
-        self.assertRaises(IndexError, l.maybe_conn_error,
-                Mock(side_effect=IndexError("foo")))
+        with self.assertRaises(IndexError):
+            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))
 
     def test_apply_eta_task(self):
         from celery.worker import state
@@ -514,7 +517,8 @@ class test_Consumer(unittest.TestCase):
 
         l.event_dispatcher = Mock()
         self.assertFalse(l.receive_message(m.decode(), m))
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
         self.assertTrue(self.eta_schedule.empty())
 
     def test_receieve_message_ack_raises(self):
@@ -532,7 +536,8 @@ class test_Consumer(unittest.TestCase):
             self.assertFalse(l.receive_message(m.decode(), m))
             self.assertTrue(log)
             self.assertIn("unknown message", log[0].message.args[0])
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
         self.assertTrue(self.eta_schedule.empty())
         m.ack.assert_called_with()
         self.assertTrue(l.logger.critical.call_count)
@@ -566,7 +571,8 @@ class test_Consumer(unittest.TestCase):
         self.assertIsInstance(task, TaskRequest)
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
 
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
@@ -654,7 +660,8 @@ class test_Consumer(unittest.TestCase):
                 raise KeyError("foo")
 
         l.consume_messages = raises_KeyError
-        self.assertRaises(KeyError, l.start)
+        with self.assertRaises(KeyError):
+            l.start()
         self.assertTrue(init_callback.call_count)
         self.assertEqual(l.iterations, 1)
         self.assertEqual(l.qos.prev, l.qos.value)
@@ -667,7 +674,8 @@ class test_Consumer(unittest.TestCase):
         l.broadcast_consumer = Mock()
         l.connection = BrokerConnection()
         l.consume_messages = Mock(side_effect=socket.error("foo"))
-        self.assertRaises(socket.error, l.start)
+        with self.assertRaises(socket.error):
+            l.start()
         self.assertTrue(init_callback.call_count)
         self.assertTrue(l.consume_messages.call_count)
 
@@ -799,7 +807,8 @@ class test_WorkController(AppCase):
         task = TaskRequest.from_message(m, m.decode())
         worker.components = []
         worker._state = worker.RUN
-        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
+        with self.assertRaises(KeyboardInterrupt):
+            worker.process_task(task)
         self.assertEqual(worker._state, worker.TERMINATE)
 
     def test_process_task_raise_SystemTerminate(self):
@@ -812,7 +821,8 @@ class test_WorkController(AppCase):
         task = TaskRequest.from_message(m, m.decode())
         worker.components = []
         worker._state = worker.RUN
-        self.assertRaises(SystemExit, worker.process_task, task)
+        with self.assertRaises(SystemExit):
+            worker.process_task(task)
         self.assertEqual(worker._state, worker.TERMINATE)
 
     def test_process_task_raise_regular(self):
@@ -831,7 +841,8 @@ class test_WorkController(AppCase):
         stc = Mock()
         stc.start.side_effect = SystemTerminate()
         worker1.components = [stc]
-        self.assertRaises(SystemExit, worker1.start)
+        with self.assertRaises(SystemExit):
+            worker1.start()
         self.assertTrue(stc.terminate.call_count)
 
         worker2 = self.create_worker()
@@ -839,7 +850,8 @@ class test_WorkController(AppCase):
         sec.start.side_effect = SystemExit()
         sec.terminate = None
         worker2.components = [sec]
-        self.assertRaises(SystemExit, worker2.start)
+        with self.assertRaises(SystemExit):
+            worker2.start()
         self.assertTrue(sec.stop.call_count)
 
     def test_state_db(self):

+ 2 - 1
celery/tests/test_worker/test_worker_autoscale.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import logging
 
 from time import time
@@ -7,7 +9,6 @@ from mock import Mock, patch
 from celery.concurrency.base import BasePool
 from celery.worker import state
 from celery.worker import autoscale
-
 from celery.tests.utils import unittest, sleepdeprived
 
 logger = logging.getLogger("celery.tests.autoscale")

+ 9 - 5
celery/tests/test_worker/test_worker_control.py

@@ -1,25 +1,27 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import socket
-from celery.tests.utils import unittest
 
 from datetime import datetime, timedelta
 
 from kombu import pidbox
 from mock import Mock
 
-from celery.utils.timer2 import Timer
-
 from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.registry import tasks
 from celery.task import PingTask
 from celery.utils import uuid
+from celery.utils.timer2 import Timer
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker import state
 from celery.worker.state import revoked
 from celery.worker.control import builtins
 from celery.worker.control.registry import Panel
+from celery.tests.utils import unittest
 
 hostname = socket.gethostname()
 
@@ -295,7 +297,8 @@ class test_ControlPanel(unittest.TestCase):
                                 "rate_limit": "1000/s"})
 
     def test_unexposed_command(self):
-        self.assertRaises(KeyError, self.panel.handle, "foo", arguments={})
+        with self.assertRaises(KeyError):
+            self.panel.handle("foo", arguments={})
 
     def test_revoke_with_name(self):
         tid = uuid()
@@ -353,7 +356,8 @@ class test_ControlPanel(unittest.TestCase):
     def test_shutdown(self):
         m = {"method": "shutdown",
              "destination": hostname}
-        self.assertRaises(SystemExit, self.panel.dispatch_from_message, m)
+        with self.assertRaises(SystemExit):
+            self.panel.dispatch_from_message(m)
 
     def test_panel_reply(self):
 

+ 2 - 1
celery/tests/test_worker/test_worker_heartbeat.py

@@ -1,5 +1,6 @@
-from celery.worker.heartbeat import Heart
+from __future__ import absolute_import
 
+from celery.worker.heartbeat import Heart
 from celery.tests.utils import unittest, sleepdeprived
 
 

+ 7 - 6
celery/tests/test_worker/test_worker_job.py

@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import
 from __future__ import with_statement
 
 import anyjson
@@ -132,8 +133,8 @@ class test_WorkerTaskTrace(unittest.TestCase):
         mytask.backend = Mock()
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         try:
-            self.assertRaises(SystemExit,
-                    jail, uuid(), mytask.name, [2], {})
+            with self.assertRaises(SystemExit):
+                jail(uuid(), mytask.name, [2], {})
         finally:
             mytask.backend = backend
 
@@ -416,8 +417,8 @@ class test_TaskRequest(unittest.TestCase):
 
     def test_from_message_invalid_kwargs(self):
         body = dict(task="foo", id=1, args=(), kwargs="foo")
-        self.assertRaises(InvalidTaskError,
-                          TaskRequest.from_message, None, body)
+        with self.assertRaises(InvalidTaskError):
+            TaskRequest.from_message(None, body)
 
     def test_on_timeout(self):
 
@@ -542,8 +543,8 @@ class test_TaskRequest(unittest.TestCase):
         m = Message(None, body=anyjson.serialize(body), backend="foo",
                           content_type="application/json",
                           content_encoding="utf-8")
-        self.assertRaises(NotRegistered, TaskRequest.from_message,
-                          m, m.decode())
+        with self.assertRaises(NotRegistered):
+            TaskRequest.from_message(m, m.decode())
 
     def test_execute(self):
         tid = uuid()

+ 2 - 1
celery/tests/test_worker/test_worker_mediator.py

@@ -1,4 +1,4 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from Queue import Queue
 
@@ -7,6 +7,7 @@ from mock import Mock, patch
 from celery.utils import uuid
 from celery.worker.mediator import Mediator
 from celery.worker.state import revoked as revoked_tasks
+from celery.tests.utils import unittest
 
 
 class MockTask(object):

+ 2 - 1
celery/tests/test_worker/test_worker_revoke.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery.worker import state
+from celery.tests.utils import unittest
 
 
 class TestRevokeRegistry(unittest.TestCase):

+ 2 - 1
celery/tests/test_worker/test_worker_state.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 from celery.datastructures import LimitedSet
 from celery.worker import state
+from celery.tests.utils import unittest
 
 
 class StateResetCase(unittest.TestCase):