Browse Source

Use with self.assertRaises

Ask Solem 13 years ago
parent
commit
faffab30d3
61 changed files with 386 additions and 221 deletions
  1. 2 0
      celery/tests/compat.py
  2. 3 1
      celery/tests/test_app/__init__.py
  3. 2 2
      celery/tests/test_app/test_app_amqp.py
  4. 3 2
      celery/tests/test_app/test_beat.py
  5. 1 0
      celery/tests/test_app/test_celery.py
  6. 9 9
      celery/tests/test_app/test_loaders.py
  7. 2 1
      celery/tests/test_app/test_log.py
  8. 5 1
      celery/tests/test_app/test_routes.py
  9. 2 2
      celery/tests/test_backends/__init__.py
  10. 33 26
      celery/tests/test_backends/test_amqp.py
  11. 35 34
      celery/tests/test_backends/test_base.py
  12. 3 2
      celery/tests/test_backends/test_cache.py
  13. 5 2
      celery/tests/test_backends/test_database.py
  14. 2 0
      celery/tests/test_backends/test_pyredis_compat.py
  15. 3 3
      celery/tests/test_backends/test_redis.py
  16. 2 0
      celery/tests/test_backends/test_redis_unit.py
  17. 3 1
      celery/tests/test_backends/test_tyrant.py
  18. 4 2
      celery/tests/test_bin/__init__.py
  19. 6 3
      celery/tests/test_bin/test_celerybeat.py
  20. 21 17
      celery/tests/test_bin/test_celeryd.py
  21. 2 0
      celery/tests/test_bin/test_celeryev.py
  22. 1 0
      celery/tests/test_compat/test_decorators.py
  23. 2 0
      celery/tests/test_compat/test_messaging.py
  24. 2 0
      celery/tests/test_concurrency/__init__.py
  25. 2 0
      celery/tests/test_concurrency/test_concurrency_eventlet.py
  26. 4 1
      celery/tests/test_concurrency/test_concurrency_processes.py
  27. 2 0
      celery/tests/test_concurrency/test_concurrency_solo.py
  28. 2 0
      celery/tests/test_concurrency/test_pool.py
  29. 8 4
      celery/tests/test_events/__init__.py
  30. 2 0
      celery/tests/test_events/test_events_cursesmon.py
  31. 4 1
      celery/tests/test_events/test_events_snapshot.py
  32. 3 1
      celery/tests/test_events/test_events_state.py
  33. 11 4
      celery/tests/test_slow/test_buckets.py
  34. 54 33
      celery/tests/test_task/__init__.py
  35. 2 0
      celery/tests/test_task/test_chord.py
  36. 3 2
      celery/tests/test_task/test_context.py
  37. 6 3
      celery/tests/test_task/test_execute_trace.py
  38. 7 3
      celery/tests/test_task/test_registry.py
  39. 26 15
      celery/tests/test_task/test_result.py
  40. 2 1
      celery/tests/test_task/test_states.py
  41. 2 1
      celery/tests/test_task/test_task_abortable.py
  42. 1 0
      celery/tests/test_task/test_task_builtins.py
  43. 6 2
      celery/tests/test_task/test_task_control.py
  44. 9 4
      celery/tests/test_task/test_task_http.py
  45. 1 0
      celery/tests/test_task/test_task_sets.py
  46. 3 1
      celery/tests/test_utils/__init__.py
  47. 12 5
      celery/tests/test_utils/test_datastructures.py
  48. 2 1
      celery/tests/test_utils/test_pickle.py
  49. 1 0
      celery/tests/test_utils/test_serialization.py
  50. 1 0
      celery/tests/test_utils/test_timer2.py
  51. 2 0
      celery/tests/test_utils/test_utils_encoding.py
  52. 2 1
      celery/tests/test_utils/test_utils_info.py
  53. 2 1
      celery/tests/test_utils/test_utils_timeutils.py
  54. 25 13
      celery/tests/test_worker/__init__.py
  55. 2 1
      celery/tests/test_worker/test_worker_autoscale.py
  56. 9 5
      celery/tests/test_worker/test_worker_control.py
  57. 2 1
      celery/tests/test_worker/test_worker_heartbeat.py
  58. 7 6
      celery/tests/test_worker/test_worker_job.py
  59. 2 1
      celery/tests/test_worker/test_worker_mediator.py
  60. 2 1
      celery/tests/test_worker/test_worker_revoke.py
  61. 2 1
      celery/tests/test_worker/test_worker_state.py

+ 2 - 0
celery/tests/compat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import sys
 
 
 
 

+ 3 - 1
celery/tests/test_app/__init__.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import os
 import os
@@ -264,7 +265,8 @@ class test_defaults(unittest.TestCase):
             self.assertFalse(defaults.str_to_bool(s))
             self.assertFalse(defaults.str_to_bool(s))
         for s in ("true", "yes", "1"):
         for s in ("true", "yes", "1"):
             self.assertTrue(defaults.str_to_bool(s))
             self.assertTrue(defaults.str_to_bool(s))
-        self.assertRaises(TypeError, defaults.str_to_bool, "unsure")
+        with self.assertRaises(TypeError):
+            defaults.str_to_bool("unsure")
 
 
 
 
 class test_debugging_utils(unittest.TestCase):
 class test_debugging_utils(unittest.TestCase):

+ 2 - 2
celery/tests/test_app/test_app_amqp.py

@@ -1,10 +1,10 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 from mock import Mock
 from mock import Mock
 
 
-from celery.tests.utils import AppCase
-
 from celery.app.amqp import MSG_OPTIONS, extract_msg_options
 from celery.app.amqp import MSG_OPTIONS, extract_msg_options
+from celery.tests.utils import AppCase
 
 
 
 
 class TestMsgOptions(AppCase):
 class TestMsgOptions(AppCase):

+ 3 - 2
celery/tests/test_app/test_beat.py

@@ -1,8 +1,8 @@
+from __future__ import absolute_import
+
 import logging
 import logging
-from celery.tests.utils import unittest
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery import beat
 from celery import beat
@@ -11,6 +11,7 @@ from celery.result import AsyncResult
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.task.base import Task
 from celery.task.base import Task
 from celery.utils import uuid
 from celery.utils import uuid
+from celery.tests.utils import unittest
 
 
 
 
 class Object(object):
 class Object(object):

+ 1 - 0
celery/tests/test_app/test_celery.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
 
 
 import celery
 import celery

+ 9 - 9
celery/tests/test_app/test_loaders.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import os
 import os
@@ -137,8 +138,8 @@ class TestLoaderBase(unittest.TestCase):
             self.assertIsInstance(warning, MockMail.SendmailWarning)
             self.assertIsInstance(warning, MockMail.SendmailWarning)
             self.assertIn("KeyError", warning.args[0])
             self.assertIn("KeyError", warning.args[0])
 
 
-            self.assertRaises(KeyError, self.loader.mail_admins,
-                              fail_silently=False, **opts)
+            with self.assertRaises(KeyError):
+                self.loader.mail_admins(fail_silently=False, **opts)
 
 
     def test_mail_admins(self):
     def test_mail_admins(self):
         MockMail.Mailer.raise_on_send = False
         MockMail.Mailer.raise_on_send = False
@@ -154,8 +155,8 @@ class TestLoaderBase(unittest.TestCase):
         self.assertIs(loader.mail, mail)
         self.assertIs(loader.mail, mail)
 
 
     def test_cmdline_config_ValueError(self):
     def test_cmdline_config_ValueError(self):
-        self.assertRaises(ValueError, self.loader.cmdline_config_parser,
-                         ["broker.port=foobar"])
+        with self.assertRaises(ValueError):
+            self.loader.cmdline_config_parser(["broker.port=foobar"])
 
 
 
 
 class TestDefaultLoader(unittest.TestCase):
 class TestDefaultLoader(unittest.TestCase):
@@ -231,17 +232,16 @@ class test_AppLoader(unittest.TestCase):
     def test_config_from_envvar(self, key="CELERY_HARNESS_CFG1"):
     def test_config_from_envvar(self, key="CELERY_HARNESS_CFG1"):
         self.assertFalse(self.loader.config_from_envvar("HDSAJIHWIQHEWQU",
         self.assertFalse(self.loader.config_from_envvar("HDSAJIHWIQHEWQU",
                                                         silent=True))
                                                         silent=True))
-        self.assertRaises(ImproperlyConfigured,
-                          self.loader.config_from_envvar, "HDSAJIHWIQHEWQU",
-                          silent=False)
+        with self.assertRaises(ImproperlyConfigured):
+            self.loader.config_from_envvar("HDSAJIHWIQHEWQU", silent=False)
         os.environ[key] = __name__ + ".object_config"
         os.environ[key] = __name__ + ".object_config"
         self.assertTrue(self.loader.config_from_envvar(key))
         self.assertTrue(self.loader.config_from_envvar(key))
         self.assertEqual(self.loader.conf["FOO"], 1)
         self.assertEqual(self.loader.conf["FOO"], 1)
         self.assertEqual(self.loader.conf["BAR"], 2)
         self.assertEqual(self.loader.conf["BAR"], 2)
 
 
         os.environ[key] = "unknown_asdwqe.asdwqewqe"
         os.environ[key] = "unknown_asdwqe.asdwqewqe"
-        self.assertRaises(ImportError,
-                          self.loader.config_from_envvar, key, silent=False)
+        with self.assertRaises(ImportError):
+            self.loader.config_from_envvar(key, silent=False)
         self.assertFalse(self.loader.config_from_envvar(key, silent=True))
         self.assertFalse(self.loader.config_from_envvar(key, silent=True))
 
 
         os.environ[key] = __name__ + ".dict_config"
         os.environ[key] = __name__ + ".dict_config"

+ 2 - 1
celery/tests/test_app/test_log.py

@@ -1,8 +1,8 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys
 import logging
 import logging
-from celery.tests.utils import unittest
 from tempfile import mktemp
 from tempfile import mktemp
 
 
 from celery import log
 from celery import log
@@ -12,6 +12,7 @@ from celery.log import (setup_logger, setup_task_logger,
                         setup_logging_subsystem)
                         setup_logging_subsystem)
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.utils.compat import _CompatLoggerAdapter
 from celery.utils.compat import _CompatLoggerAdapter
+from celery.tests.utils import unittest
 from celery.tests.utils import (override_stdouts, wrap_logger,
 from celery.tests.utils import (override_stdouts, wrap_logger,
                                 get_handlers, set_handlers)
                                 get_handlers, set_handlers)
 
 

+ 5 - 1
celery/tests/test_app/test_routes.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 from functools import wraps
 from functools import wraps
 
 
 from celery import routes
 from celery import routes
@@ -65,7 +68,8 @@ class test_MapRoute(unittest.TestCase):
     def test_expand_route_not_found(self):
     def test_expand_route_not_found(self):
         expand = E(current_app.conf.CELERY_QUEUES)
         expand = E(current_app.conf.CELERY_QUEUES)
         route = routes.MapRoute({"a": {"queue": "x"}})
         route = routes.MapRoute({"a": {"queue": "x"}})
-        self.assertRaises(QueueNotFound, expand, route.route_for_task("a"))
+        with self.assertRaises(QueueNotFound):
+            expand(route.route_for_task("a"))
 
 
 
 
 class test_lookup_route(unittest.TestCase):
 class test_lookup_route(unittest.TestCase):

+ 2 - 2
celery/tests/test_backends/__init__.py

@@ -1,10 +1,10 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
-from celery.tests.utils import unittest
-
 from celery import backends
 from celery import backends
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.cache import CacheBackend
 from celery.backends.cache import CacheBackend
+from celery.tests.utils import unittest
 
 
 
 
 class TestBackends(unittest.TestCase):
 class TestBackends(unittest.TestCase):

+ 33 - 26
celery/tests/test_backends/test_amqp.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import socket
 import socket
@@ -111,11 +112,11 @@ class test_AMQPBackend(unittest.TestCase):
             Producer = _Producer
             Producer = _Producer
 
 
         backend = Backend()
         backend = Backend()
-        self.assertRaises(KeyError, backend.store_result,
-                          "foo", "bar", "STARTED", max_retries=None)
+        with self.assertRaises(KeyError):
+            backend.store_result("foo", "bar", "STARTED", max_retries=None)
 
 
-        self.assertRaises(KeyError, backend.store_result,
-                          "foo", "bar", "STARTED", max_retries=10)
+        with self.assertRaises(KeyError):
+            backend.store_result("foo", "bar", "STARTED", max_retries=10)
 
 
     def assertState(self, retval, state):
     def assertState(self, retval, state):
         self.assertEqual(retval["status"], state)
         self.assertEqual(retval["status"], state)
@@ -182,11 +183,14 @@ class test_AMQPBackend(unittest.TestCase):
         b = self.create_backend()
         b = self.create_backend()
 
 
         tid = uuid()
         tid = uuid()
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, None, states.STARTED)
         b.store_result(tid, None, states.STARTED)
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, None, states.RETRY)
         b.store_result(tid, None, states.RETRY)
-        self.assertRaises(TimeoutError, b.wait_for, tid, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            b.wait_for(tid, timeout=0.1)
         b.store_result(tid, 42, states.SUCCESS)
         b.store_result(tid, 42, states.SUCCESS)
         self.assertEqual(b.wait_for(tid, timeout=1), 42)
         self.assertEqual(b.wait_for(tid, timeout=1), 42)
         b.store_result(tid, 56, states.SUCCESS)
         b.store_result(tid, 56, states.SUCCESS)
@@ -194,7 +198,8 @@ class test_AMQPBackend(unittest.TestCase):
                          "result is cached")
                          "result is cached")
         self.assertEqual(b.wait_for(tid, timeout=1, cache=False), 56)
         self.assertEqual(b.wait_for(tid, timeout=1, cache=False), 56)
         b.store_result(tid, KeyError("foo"), states.FAILURE)
         b.store_result(tid, KeyError("foo"), states.FAILURE)
-        self.assertRaises(KeyError, b.wait_for, tid, timeout=1, cache=False)
+        with self.assertRaises(KeyError):
+            b.wait_for(tid, timeout=1, cache=False)
 
 
     def test_drain_events_remaining_timeouts(self):
     def test_drain_events_remaining_timeouts(self):
 
 
@@ -207,8 +212,8 @@ class test_AMQPBackend(unittest.TestCase):
         with current_app.pool.acquire_channel(block=False) as (_, channel):
         with current_app.pool.acquire_channel(block=False) as (_, channel):
             binding = b._create_binding(uuid())
             binding = b._create_binding(uuid())
             consumer = b._create_consumer(binding, channel)
             consumer = b._create_consumer(binding, channel)
-            self.assertRaises(socket.timeout, b.drain_events,
-                              Connection(), consumer, timeout=0.1)
+            with self.assertRaises(socket.timeout):
+                b.drain_events(Connection(), consumer, timeout=0.1)
 
 
     def test_get_many(self):
     def test_get_many(self):
         b = self.create_backend()
         b = self.create_backend()
@@ -230,8 +235,8 @@ class test_AMQPBackend(unittest.TestCase):
         cached_res = list(b.get_many(tids, timeout=1))
         cached_res = list(b.get_many(tids, timeout=1))
         self.assertEqual(sorted(cached_res), sorted(expected_results))
         self.assertEqual(sorted(cached_res), sorted(expected_results))
         b._cache[res[0][0]]["status"] = states.RETRY
         b._cache[res[0][0]]["status"] = states.RETRY
-        self.assertRaises(socket.timeout, list,
-                          b.get_many(tids, timeout=0.01))
+        with self.assertRaises(socket.timeout):
+            list(b.get_many(tids, timeout=0.01))
 
 
     def test_test_get_many_raises_outer_block(self):
     def test_test_get_many_raises_outer_block(self):
 
 
@@ -241,7 +246,8 @@ class test_AMQPBackend(unittest.TestCase):
                 raise KeyError("foo")
                 raise KeyError("foo")
 
 
         b = Backend()
         b = Backend()
-        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+        with self.assertRaises(KeyError):
+            b.get_many(["id1"]).next()
 
 
     def test_test_get_many_raises_inner_block(self):
     def test_test_get_many_raises_inner_block(self):
 
 
@@ -251,7 +257,8 @@ class test_AMQPBackend(unittest.TestCase):
                 raise KeyError("foo")
                 raise KeyError("foo")
 
 
         b = Backend()
         b = Backend()
-        self.assertRaises(KeyError, b.get_many(["id1"]).next)
+        with self.assertRaises(KeyError):
+            b.get_many(["id1"]).next()
 
 
     def test_no_expires(self):
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         b = self.create_backend(expires=None)
@@ -260,8 +267,8 @@ class test_AMQPBackend(unittest.TestCase):
         app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
         app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = None
         try:
         try:
             b = self.create_backend(expires=None)
             b = self.create_backend(expires=None)
-            self.assertRaises(KeyError, b.queue_arguments.__getitem__,
-                              "x-expires")
+            with self.assertRaises(KeyError):
+                b.queue_arguments["x-expires"]
         finally:
         finally:
             app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
             app.conf.CELERY_AMQP_TASK_RESULT_EXPIRES = prev
 
 
@@ -269,21 +276,21 @@ class test_AMQPBackend(unittest.TestCase):
         self.create_backend().process_cleanup()
         self.create_backend().process_cleanup()
 
 
     def test_reload_task_result(self):
     def test_reload_task_result(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().reload_task_result, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().reload_task_result("x")
 
 
     def test_reload_taskset_result(self):
     def test_reload_taskset_result(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().reload_taskset_result, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().reload_taskset_result("x")
 
 
     def test_save_taskset(self):
     def test_save_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().save_taskset, "x", "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().save_taskset("x", "x")
 
 
     def test_restore_taskset(self):
     def test_restore_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().restore_taskset, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().restore_taskset("x")
 
 
     def test_delete_taskset(self):
     def test_delete_taskset(self):
-        self.assertRaises(NotImplementedError,
-                          self.create_backend().delete_taskset, "x")
+        with self.assertRaises(NotImplementedError):
+            self.create_backend().delete_taskset("x")

+ 35 - 34
celery/tests/test_backends/test_base.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys
@@ -48,52 +49,52 @@ class test_serialization(unittest.TestCase):
 class test_BaseBackend_interface(unittest.TestCase):
 class test_BaseBackend_interface(unittest.TestCase):
 
 
     def test_get_status(self):
     def test_get_status(self):
-        self.assertRaises(NotImplementedError,
-                b.get_status, "SOMExx-N0Nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_status("SOMExx-N0Nex1stant-IDxx-")
 
 
     def test__forget(self):
     def test__forget(self):
-        self.assertRaises(NotImplementedError,
-                b.forget, "SOMExx-N0Nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.forget("SOMExx-N0Nex1stant-IDxx-")
 
 
     def test_store_result(self):
     def test_store_result(self):
-        self.assertRaises(NotImplementedError,
-                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
+        with self.assertRaises(NotImplementedError):
+            b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
 
 
     def test_mark_as_started(self):
     def test_mark_as_started(self):
-        self.assertRaises(NotImplementedError,
-                b.mark_as_started, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.mark_as_started("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_reload_task_result(self):
     def test_reload_task_result(self):
-        self.assertRaises(NotImplementedError,
-                b.reload_task_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.reload_task_result("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_reload_taskset_result(self):
     def test_reload_taskset_result(self):
-        self.assertRaises(NotImplementedError,
-                b.reload_taskset_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.reload_taskset_result("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_get_result(self):
     def test_get_result(self):
-        self.assertRaises(NotImplementedError,
-                b.get_result, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_result("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_restore_taskset(self):
     def test_restore_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.restore_taskset, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.restore_taskset("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_delete_taskset(self):
     def test_delete_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.delete_taskset, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.delete_taskset("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_save_taskset(self):
     def test_save_taskset(self):
-        self.assertRaises(NotImplementedError,
-                b.save_taskset, "SOMExx-N0nex1stant-IDxx-", "blergh")
+        with self.assertRaises(NotImplementedError):
+            b.save_taskset("SOMExx-N0nex1stant-IDxx-", "blergh")
 
 
     def test_get_traceback(self):
     def test_get_traceback(self):
-        self.assertRaises(NotImplementedError,
-                b.get_traceback, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.get_traceback("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_forget(self):
     def test_forget(self):
-        self.assertRaises(NotImplementedError,
-                b.forget, "SOMExx-N0nex1stant-IDxx-")
+        with self.assertRaises(NotImplementedError):
+            b.forget("SOMExx-N0nex1stant-IDxx-")
 
 
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
     def test_on_chord_apply(self, unlock="celery.chord_unlock"):
         from celery.registry import tasks
         from celery.registry import tasks
@@ -277,27 +278,27 @@ class test_KeyValueStoreBackend(unittest.TestCase):
 class test_KeyValueStoreBackend_interface(unittest.TestCase):
 class test_KeyValueStoreBackend_interface(unittest.TestCase):
 
 
     def test_get(self):
     def test_get(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().get,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().get("a")
 
 
     def test_set(self):
     def test_set(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().set,
-                "a", 1)
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().set("a", 1)
 
 
     def test_cleanup(self):
     def test_cleanup(self):
         self.assertFalse(KeyValueStoreBackend().cleanup())
         self.assertFalse(KeyValueStoreBackend().cleanup())
 
 
     def test_delete(self):
     def test_delete(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().delete,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().delete("a")
 
 
     def test_mget(self):
     def test_mget(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().mget,
-                ["a"])
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().mget(["a"])
 
 
     def test_forget(self):
     def test_forget(self):
-        self.assertRaises(NotImplementedError, KeyValueStoreBackend().forget,
-                "a")
+        with self.assertRaises(NotImplementedError):
+            KeyValueStoreBackend().forget("a")
 
 
 
 
 class test_DisabledBackend(unittest.TestCase):
 class test_DisabledBackend(unittest.TestCase):

+ 3 - 2
celery/tests/test_backends/test_cache.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys
@@ -82,8 +83,8 @@ class test_CacheBackend(unittest.TestCase):
         self.assertEqual(tb.expires, 10)
         self.assertEqual(tb.expires, 10)
 
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
     def test_unknown_backend_raises_ImproperlyConfigured(self):
-        self.assertRaises(ImproperlyConfigured,
-                          CacheBackend, backend="unknown://")
+        with self.assertRaises(ImproperlyConfigured):
+            CacheBackend(backend="unknown://")
 
 
 
 
 class MyClient(DummyClient):
 class MyClient(DummyClient):

+ 5 - 2
celery/tests/test_backends/test_database.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys
@@ -43,7 +44,8 @@ class test_DatabaseBackend(unittest.TestCase):
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
     def test_missing_SQLAlchemy_raises_ImproperlyConfigured(self):
         with mask_modules("sqlalchemy"):
         with mask_modules("sqlalchemy"):
             from celery.backends.database import _sqlalchemy_installed
             from celery.backends.database import _sqlalchemy_installed
-            self.assertRaises(ImproperlyConfigured, _sqlalchemy_installed)
+            with self.assertRaises(ImproperlyConfigured):
+                _sqlalchemy_installed()
 
 
     def test_pickle_hack_for_sqla_05(self):
     def test_pickle_hack_for_sqla_05(self):
         import sqlalchemy as sa
         import sqlalchemy as sa
@@ -66,7 +68,8 @@ class test_DatabaseBackend(unittest.TestCase):
         conf = app_or_default().conf
         conf = app_or_default().conf
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
         try:
         try:
-            self.assertRaises(ImproperlyConfigured, DatabaseBackend)
+            with self.assertRaises(ImproperlyConfigured):
+                DatabaseBackend()
         finally:
         finally:
             conf.CELERY_RESULT_DBURI = prev
             conf.CELERY_RESULT_DBURI = prev
 
 

+ 2 - 0
celery/tests/test_backends/test_pyredis_compat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured

+ 3 - 3
celery/tests/test_backends/test_redis.py

@@ -3,7 +3,6 @@ from __future__ import with_statement
 
 
 import sys
 import sys
 import socket
 import socket
-from celery.tests.utils import unittest
 
 
 from nose import SkipTest
 from nose import SkipTest
 
 
@@ -13,8 +12,8 @@ from celery import states
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.backends import redis
 from celery.backends import redis
 from celery.backends.redis import RedisBackend
 from celery.backends.redis import RedisBackend
-
 from celery.tests.utils import mask_modules
 from celery.tests.utils import mask_modules
+from celery.tests.utils import unittest
 
 
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg_emitted = False
 _no_redis_msg_emitted = False
@@ -112,6 +111,7 @@ class TestRedisBackendNoRedis(unittest.TestCase):
         prev = redis.RedisBackend.redis
         prev = redis.RedisBackend.redis
         redis.RedisBackend.redis = None
         redis.RedisBackend.redis = None
         try:
         try:
-            self.assertRaises(ImproperlyConfigured, redis.RedisBackend)
+            with self.assertRaises(ImproperlyConfigured):
+                redis.RedisBackend()
         finally:
         finally:
             redis.RedisBackend.redis = prev
             redis.RedisBackend.redis = prev

+ 2 - 0
celery/tests/test_backends/test_redis_unit.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import timedelta
 from datetime import timedelta
 
 
 from mock import Mock, patch
 from mock import Mock, patch

+ 3 - 1
celery/tests/test_backends/test_tyrant.py

@@ -1,6 +1,7 @@
+from __future__ import absolute_import
+
 import sys
 import sys
 import socket
 import socket
-from celery.tests.utils import unittest
 
 
 from nose import SkipTest
 from nose import SkipTest
 
 
@@ -10,6 +11,7 @@ from celery import states
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.backends import tyrant
 from celery.backends import tyrant
 from celery.backends.tyrant import TyrantBackend
 from celery.backends.tyrant import TyrantBackend
+from celery.tests.utils import unittest
 
 
 _no_tyrant_msg = "* Tokyo Tyrant %s. Will not execute related tests."
 _no_tyrant_msg = "* Tokyo Tyrant %s. Will not execute related tests."
 _no_tyrant_msg_emitted = False
 _no_tyrant_msg_emitted = False

+ 4 - 2
celery/tests/test_bin/__init__.py

@@ -1,7 +1,8 @@
+from __future__ import absolute_import
+
 import os
 import os
 
 
 from celery.bin.base import Command
 from celery.bin.base import Command
-
 from celery.tests.utils import AppCase
 from celery.tests.utils import AppCase
 
 
 
 
@@ -35,7 +36,8 @@ class test_Command(AppCase):
         self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
         self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
 
 
     def test_run_interface(self):
     def test_run_interface(self):
-        self.assertRaises(NotImplementedError, Command().run)
+        with self.assertRaises(NotImplementedError):
+            Command().run()
 
 
     def test_execute_from_commandline(self):
     def test_execute_from_commandline(self):
         cmd = MockCommand()
         cmd = MockCommand()

+ 6 - 3
celery/tests/test_bin/test_celerybeat.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import logging
 import logging
 import sys
 import sys
 
 
@@ -98,8 +100,8 @@ class test_Beat(AppCase):
         clock = MockService()
         clock = MockService()
         MockService.in_sync = False
         MockService.in_sync = False
         handlers = self.psig(b.install_sync_handler, clock)
         handlers = self.psig(b.install_sync_handler, clock)
-        self.assertRaises(SystemExit, handlers["SIGINT"],
-                          "SIGINT", object())
+        with self.assertRaises(SystemExit):
+            handlers["SIGINT"]("SIGINT", object())
         self.assertTrue(MockService.in_sync)
         self.assertTrue(MockService.in_sync)
         MockService.in_sync = False
         MockService.in_sync = False
 
 
@@ -112,7 +114,8 @@ class test_Beat(AppCase):
         b = beatapp.Beat()
         b = beatapp.Beat()
         b.redirect_stdouts = False
         b.redirect_stdouts = False
         b.setup_logging()
         b.setup_logging()
-        self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
 
 
     @redirect_stdouts
     @redirect_stdouts
     def test_logs_errors(self, stdout, stderr):
     def test_logs_errors(self, stdout, stderr):

+ 21 - 17
celery/tests/test_bin/test_celeryd.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import logging
 import logging
@@ -209,8 +210,8 @@ class test_Worker(AppCase):
             self.assertNotIn("celery", app.amqp.queues.consume_from)
             self.assertNotIn("celery", app.amqp.queues.consume_from)
 
 
             c.CELERY_CREATE_MISSING_QUEUES = False
             c.CELERY_CREATE_MISSING_QUEUES = False
-            self.assertRaises(ImproperlyConfigured,
-                    self.Worker(queues=["image"]).init_queues)
+            with self.assertRaises(ImproperlyConfigured):
+                self.Worker(queues=["image"]).init_queues()
             c.CELERY_CREATE_MISSING_QUEUES = True
             c.CELERY_CREATE_MISSING_QUEUES = True
             worker = self.Worker(queues=["image"])
             worker = self.Worker(queues=["image"])
             worker.init_queues()
             worker.init_queues()
@@ -241,7 +242,8 @@ class test_Worker(AppCase):
 
 
     @disable_stdouts
     @disable_stdouts
     def test_unknown_loglevel(self):
     def test_unknown_loglevel(self):
-        self.assertRaises(SystemExit, self.Worker, loglevel="ALIEN")
+        with self.assertRaises(SystemExit):
+            self.Worker(loglevel="ALIEN")
         worker1 = self.Worker(loglevel=0xFFFF)
         worker1 = self.Worker(loglevel=0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
         self.assertEqual(worker1.loglevel, 0xFFFF)
 
 
@@ -299,7 +301,8 @@ class test_Worker(AppCase):
         worker = self.Worker()
         worker = self.Worker()
         worker.redirect_stdouts = False
         worker.redirect_stdouts = False
         worker.redirect_stdouts_to_logger()
         worker.redirect_stdouts_to_logger()
-        self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+        with self.assertRaises(AttributeError):
+            sys.stdout.logger
 
 
     def test_redirect_stdouts_already_handled(self):
     def test_redirect_stdouts_already_handled(self):
         logging_setup = [False]
         logging_setup = [False]
@@ -313,7 +316,8 @@ class test_Worker(AppCase):
             worker.app.log.__class__._setup = False
             worker.app.log.__class__._setup = False
             worker.redirect_stdouts_to_logger()
             worker.redirect_stdouts_to_logger()
             self.assertTrue(logging_setup[0])
             self.assertTrue(logging_setup[0])
-            self.assertRaises(AttributeError, getattr, sys.stdout, "logger")
+            with self.assertRaises(AttributeError):
+                sys.stdout.logger
         finally:
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
             signals.setup_logging.disconnect(on_logging_setup)
 
 
@@ -470,14 +474,14 @@ class test_signal_handlers(AppCase):
 
 
         p, platforms.signals = platforms.signals, Signals()
         p, platforms.signals = platforms.signals, Signals()
         try:
         try:
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                              "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertTrue(worker.stopped)
             self.assertTrue(worker.stopped)
         finally:
         finally:
             platforms.signals = p
             platforms.signals = p
 
 
-        self.assertRaises(SystemExit, next_handlers["SIGINT"],
-                          "SIGINT", object())
+        with self.assertRaises(SystemExit):
+            next_handlers["SIGINT"]("SIGINT", object())
         self.assertTrue(worker.terminated)
         self.assertTrue(worker.terminated)
 
 
     @disable_stdouts
     @disable_stdouts
@@ -489,8 +493,8 @@ class test_signal_handlers(AppCase):
         try:
         try:
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_int_handler, worker)
             handlers = self.psig(cd.install_worker_int_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                            "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertFalse(worker.stopped)
             self.assertFalse(worker.stopped)
         finally:
         finally:
             process.name = name
             process.name = name
@@ -510,8 +514,8 @@ class test_signal_handlers(AppCase):
         try:
         try:
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_int_again_handler, worker)
             handlers = self.psig(cd.install_worker_int_again_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGINT"],
-                            "SIGINT", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGINT"]("SIGINT", object())
             self.assertFalse(worker.terminated)
             self.assertFalse(worker.terminated)
         finally:
         finally:
             process.name = name
             process.name = name
@@ -520,8 +524,8 @@ class test_signal_handlers(AppCase):
     def test_worker_term_handler(self):
     def test_worker_term_handler(self):
         worker = self._Worker()
         worker = self._Worker()
         handlers = self.psig(cd.install_worker_term_handler, worker)
         handlers = self.psig(cd.install_worker_term_handler, worker)
-        self.assertRaises(SystemExit, handlers["SIGTERM"],
-                          "SIGTERM", object())
+        with self.assertRaises(SystemExit):
+            handlers["SIGTERM"]("SIGTERM", object())
         self.assertTrue(worker.stopped)
         self.assertTrue(worker.stopped)
 
 
     def test_worker_cry_handler(self):
     def test_worker_cry_handler(self):
@@ -552,8 +556,8 @@ class test_signal_handlers(AppCase):
         try:
         try:
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_handler, worker)
             handlers = self.psig(cd.install_worker_term_handler, worker)
-            self.assertRaises(SystemExit, handlers["SIGTERM"],
-                          "SIGTERM", object())
+            with self.assertRaises(SystemExit):
+                handlers["SIGTERM"]("SIGTERM", object())
             self.assertFalse(worker.stopped)
             self.assertFalse(worker.stopped)
         finally:
         finally:
             process.name = name
             process.name = name

+ 2 - 0
celery/tests/test_bin/test_celeryev.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery.app import app_or_default
 from celery.app import app_or_default

+ 1 - 0
celery/tests/test_compat/test_decorators.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import warnings
 import warnings

+ 2 - 0
celery/tests/test_compat/test_messaging.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery import messaging
 from celery import messaging
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
 
 

+ 2 - 0
celery/tests/test_concurrency/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 import os
 
 
 from itertools import count
 from itertools import count

+ 2 - 0
celery/tests/test_concurrency/test_concurrency_eventlet.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import os
 import os
 import sys
 import sys
 
 

+ 4 - 1
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import signal
 import signal
 import sys
 import sys
 
 
@@ -185,7 +187,8 @@ class test_TaskPool(unittest.TestCase):
     def test_on_ready_exit_exception(self):
     def test_on_ready_exit_exception(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         exc = to_excinfo(SystemExit("foo"))
         exc = to_excinfo(SystemExit("foo"))
-        self.assertRaises(SystemExit, pool.on_ready, [], [], exc)
+        with self.assertRaises(SystemExit):
+            pool.on_ready([], [], exc)
 
 
     def test_apply_async(self):
     def test_apply_async(self):
         pool = TaskPool(10)
         pool = TaskPool(10)

+ 2 - 0
celery/tests/test_concurrency/test_concurrency_solo.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import operator
 import operator
 
 
 from celery.concurrency import solo
 from celery.concurrency import solo

+ 2 - 0
celery/tests/test_concurrency/test_pool.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import sys
 import time
 import time
 import logging
 import logging

+ 8 - 4
celery/tests/test_events/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import socket
 import socket
 
 
 from celery import events
 from celery import events
@@ -54,7 +56,8 @@ class TestEventDispatcher(unittest.TestCase):
         eventer.enabled = True
         eventer.enabled = True
         eventer.publisher.raise_on_publish = True
         eventer.publisher.raise_on_publish = True
         eventer.buffer_while_offline = False
         eventer.buffer_while_offline = False
-        self.assertRaises(KeyError, eventer.send, "Event X")
+        with self.assertRaises(KeyError):
+            eventer.send("Event X")
         eventer.buffer_while_offline = True
         eventer.buffer_while_offline = True
         for ev in evs:
         for ev in evs:
             eventer.send(ev)
             eventer.send(ev)
@@ -142,10 +145,11 @@ class TestEventReceiver(unittest.TestCase):
             self.assertTrue(consumer.queues)
             self.assertTrue(consumer.queues)
             self.assertEqual(consumer.callbacks[0], r._receive)
             self.assertEqual(consumer.callbacks[0], r._receive)
 
 
-            self.assertRaises(socket.timeout, it.next)
+            with self.assertRaises(socket.timeout):
+                it.next()
 
 
-            self.assertRaises(socket.timeout,
-                              r.capture, timeout=0.00001)
+            with self.assertRaises(socket.timeout):
+                r.capture(timeout=0.00001)
         finally:
         finally:
             connection.close()
             connection.close()
 
 

+ 2 - 0
celery/tests/test_events/test_events_cursesmon.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest

+ 4 - 1
celery/tests/test_events/test_events_snapshot.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from celery.app import app_or_default
 from celery.app import app_or_default
 from celery.events import Events
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
 from celery.events.snapshot import Polaroid, evcam
@@ -117,4 +119,5 @@ class test_evcam(unittest.TestCase):
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer)
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         evcam(Polaroid, timer=timer, loglevel="CRITICAL")
         self.MockReceiver.raise_keyboard_interrupt = True
         self.MockReceiver.raise_keyboard_interrupt = True
-        self.assertRaises(SystemExit, evcam, Polaroid, timer=timer)
+        with self.assertRaises(SystemExit):
+            evcam(Polaroid, timer=timer)

+ 3 - 1
celery/tests/test_events/test_events_state.py

@@ -1,5 +1,6 @@
+from __future__ import absolute_import
+
 from time import time
 from time import time
-from celery.tests.utils import unittest
 
 
 from itertools import count
 from itertools import count
 
 
@@ -7,6 +8,7 @@ from celery import states
 from celery.events import Event
 from celery.events import Event
 from celery.events.state import State, Worker, Task, HEARTBEAT_EXPIRE
 from celery.events.state import State, Worker, Task, HEARTBEAT_EXPIRE
 from celery.utils import uuid
 from celery.utils import uuid
+from celery.tests.utils import unittest
 
 
 
 
 class replay(object):
 class replay(object):

+ 11 - 4
celery/tests/test_slow/test_buckets.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import sys
 import sys
 import time
 import time
 
 
@@ -42,7 +45,8 @@ class test_TokenBucketQueue(unittest.TestCase):
     @skip_if_disabled
     @skip_if_disabled
     def empty_queue_yields_QueueEmpty(self):
     def empty_queue_yields_QueueEmpty(self):
         x = buckets.TokenBucketQueue(fill_rate=10)
         x = buckets.TokenBucketQueue(fill_rate=10)
-        self.assertRaises(buckets.Empty, x.get)
+        with self.assertRaises(buckets.Empty):
+            x.get()
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_bucket__put_get(self):
     def test_bucket__put_get(self):
@@ -73,7 +77,8 @@ class test_TokenBucketQueue(unittest.TestCase):
         time.sleep(0.1)
         time.sleep(0.1)
         # Not yet ready for another token
         # Not yet ready for another token
         x.put("The lazy dog")
         x.put("The lazy dog")
-        self.assertRaises(x.RateLimitExceeded, x.get)
+        with self.assertRaises(x.RateLimitExceeded):
+            x.get()
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_expected_time(self):
     def test_expected_time(self):
@@ -132,7 +137,8 @@ class test_TaskBucket(unittest.TestCase):
     @skip_if_disabled
     @skip_if_disabled
     def test_get_nowait(self):
     def test_get_nowait(self):
         x = buckets.TaskBucket(task_registry=self.registry)
         x = buckets.TaskBucket(task_registry=self.registry)
-        self.assertRaises(buckets.Empty, x.get_nowait)
+        with self.assertRaises(buckets.Empty):
+            x.get_nowait()
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_refresh(self):
     def test_refresh(self):
@@ -194,7 +200,8 @@ class test_TaskBucket(unittest.TestCase):
     @skip_if_disabled
     @skip_if_disabled
     def test_on_empty_buckets__get_raises_empty(self):
     def test_on_empty_buckets__get_raises_empty(self):
         b = buckets.TaskBucket(task_registry=self.registry)
         b = buckets.TaskBucket(task_registry=self.registry)
-        self.assertRaises(buckets.Empty, b.get, block=False)
+        with self.assertRaises(buckets.Empty):
+            b.get(block=False)
         self.assertEqual(b.qsize(), 0)
         self.assertEqual(b.qsize(), 0)
 
 
     @skip_if_disabled
     @skip_if_disabled

+ 54 - 33
celery/tests/test_task/__init__.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from functools import wraps
 from functools import wraps
 
 
@@ -150,8 +152,8 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(RetryTaskNoArgs.iterations, 4)
         self.assertEqual(RetryTaskNoArgs.iterations, 4)
 
 
     def test_retry_kwargs_can_be_empty(self):
     def test_retry_kwargs_can_be_empty(self):
-        self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
-                            args=[4, 4], kwargs=None)
+        with self.assertRaises(RetryTaskError):
+            RetryTaskMockApply.retry(args=[4, 4], kwargs=None)
 
 
     def test_retry_not_eager(self):
     def test_retry_not_eager(self):
         RetryTaskMockApply.request.called_directly = False
         RetryTaskMockApply.request.called_directly = False
@@ -164,7 +166,8 @@ class TestTaskRetries(unittest.TestCase):
             RetryTaskMockApply.applied = 0
             RetryTaskMockApply.applied = 0
 
 
         try:
         try:
-            self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
+            with self.assertRaises(RetryTaskError):
+                RetryTaskMockApply.retry(
                     args=[4, 4], kwargs={"task_retries": 0},
                     args=[4, 4], kwargs={"task_retries": 0},
                     exc=exc, throw=True)
                     exc=exc, throw=True)
             self.assertTrue(RetryTaskMockApply.applied)
             self.assertTrue(RetryTaskMockApply.applied)
@@ -182,23 +185,23 @@ class TestTaskRetries(unittest.TestCase):
         RetryTaskCustomExc.max_retries = 2
         RetryTaskCustomExc.max_retries = 2
         RetryTaskCustomExc.iterations = 0
         RetryTaskCustomExc.iterations = 0
         result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
         result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
-        self.assertRaises(MyCustomException,
-                          result.get)
+        with self.assertRaises(MyCustomException):
+            result.get()
         self.assertEqual(RetryTaskCustomExc.iterations, 3)
         self.assertEqual(RetryTaskCustomExc.iterations, 3)
 
 
     def test_max_retries_exceeded(self):
     def test_max_retries_exceeded(self):
         RetryTask.max_retries = 2
         RetryTask.max_retries = 2
         RetryTask.iterations = 0
         RetryTask.iterations = 0
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
+        with self.assertRaises(RetryTask.MaxRetriesExceededError):
+            result.get()
         self.assertEqual(RetryTask.iterations, 3)
         self.assertEqual(RetryTask.iterations, 3)
 
 
         RetryTask.max_retries = 1
         RetryTask.max_retries = 1
         RetryTask.iterations = 0
         RetryTask.iterations = 0
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
         result = RetryTask.apply([0xFF, 0xFFFF], {"care": False})
-        self.assertRaises(RetryTask.MaxRetriesExceededError,
-                          result.get)
+        with self.assertRaises(RetryTask.MaxRetriesExceededError):
+            result.get()
         self.assertEqual(RetryTask.iterations, 2)
         self.assertEqual(RetryTask.iterations, 2)
 
 
 
 
@@ -255,15 +258,16 @@ class TestCeleryTasks(unittest.TestCase):
         class IncompleteTask(task.Task):
         class IncompleteTask(task.Task):
             name = "c.unittest.t.itask"
             name = "c.unittest.t.itask"
 
 
-        self.assertRaises(NotImplementedError, IncompleteTask().run)
+        with self.assertRaises(NotImplementedError):
+            IncompleteTask().run()
 
 
     def test_task_kwargs_must_be_dictionary(self):
     def test_task_kwargs_must_be_dictionary(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          [], "str")
+        with self.assertRaises(ValueError):
+            IncrementCounterTask.apply_async([], "str")
 
 
     def test_task_args_must_be_list(self):
     def test_task_args_must_be_list(self):
-        self.assertRaises(ValueError, IncrementCounterTask.apply_async,
-                          "str", {})
+        with self.assertRaises(ValueError):
+            IncrementCounterTask.apply_async("str", {})
 
 
     def test_regular_task(self):
     def test_regular_task(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
@@ -280,7 +284,8 @@ class TestCeleryTasks(unittest.TestCase):
 
 
         t1 = T1()
         t1 = T1()
         consumer = t1.get_consumer()
         consumer = t1.get_consumer()
-        self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
+        with self.assertRaises(NotImplementedError):
+            consumer.receive("foo", "foo")
         consumer.discard_all()
         consumer.discard_all()
         self.assertIsNone(consumer.fetch())
         self.assertIsNone(consumer.fetch())
 
 
@@ -466,12 +471,14 @@ class TestTaskSet(unittest.TestCase):
 class TestTaskApply(unittest.TestCase):
 class TestTaskApply(unittest.TestCase):
 
 
     def test_apply_throw(self):
     def test_apply_throw(self):
-        self.assertRaises(KeyError, RaisingTask.apply, throw=True)
+        with self.assertRaises(KeyError):
+            RaisingTask.apply(throw=True)
 
 
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
     def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
         RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
         try:
         try:
-            self.assertRaises(KeyError, RaisingTask.apply)
+            with self.assertRaises(KeyError):
+                RaisingTask.apply()
         finally:
         finally:
             RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
             RaisingTask.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = False
 
 
@@ -496,7 +503,8 @@ class TestTaskApply(unittest.TestCase):
         self.assertTrue(f.ready())
         self.assertTrue(f.ready())
         self.assertFalse(f.successful())
         self.assertFalse(f.successful())
         self.assertTrue(f.traceback)
         self.assertTrue(f.traceback)
-        self.assertRaises(KeyError, f.get)
+        with self.assertRaises(KeyError):
+            f.get()
 
 
 
 
 class MyPeriodic(task.PeriodicTask):
 class MyPeriodic(task.PeriodicTask):
@@ -506,8 +514,8 @@ class MyPeriodic(task.PeriodicTask):
 class TestPeriodicTask(unittest.TestCase):
 class TestPeriodicTask(unittest.TestCase):
 
 
     def test_must_have_run_every(self):
     def test_must_have_run_every(self):
-        self.assertRaises(NotImplementedError, type, "Foo",
-            (task.PeriodicTask, ), {"__module__": __name__})
+        with self.assertRaises(NotImplementedError):
+            type("Foo", (task.PeriodicTask, ), {"__module__": __name__})
 
 
     def test_remaining_estimate(self):
     def test_remaining_estimate(self):
         self.assertIsInstance(
         self.assertIsInstance(
@@ -610,23 +618,28 @@ class test_crontab_parser(unittest.TestCase):
                     20, 25, 30, 35, 40, 45, 50, 55]))
                     20, 25, 30, 35, 40, 45, 50, 55]))
 
 
     def test_parse_errors_on_empty_string(self):
     def test_parse_errors_on_empty_string(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('')
 
 
     def test_parse_errors_on_empty_group(self):
     def test_parse_errors_on_empty_group(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '1,,2')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('1,,2')
 
 
     def test_parse_errors_on_empty_steps(self):
     def test_parse_errors_on_empty_steps(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '*/')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('*/')
 
 
     def test_parse_errors_on_negative_number(self):
     def test_parse_errors_on_negative_number(self):
-        self.assertRaises(ParseException, crontab_parser(60).parse, '-20')
+        with self.assertRaises(ParseException):
+            crontab_parser(60).parse('-20')
 
 
     def test_expand_cronspec_eats_iterables(self):
     def test_expand_cronspec_eats_iterables(self):
         self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
         self.assertEqual(crontab._expand_cronspec(iter([1, 2, 3]), 100),
                          set([1, 2, 3]))
                          set([1, 2, 3]))
 
 
     def test_expand_cronspec_invalid_type(self):
     def test_expand_cronspec_invalid_type(self):
-        self.assertRaises(TypeError, crontab._expand_cronspec, object(), 100)
+        with self.assertRaises(TypeError):
+            crontab._expand_cronspec(object(), 100)
 
 
     def test_repr(self):
     def test_repr(self):
         self.assertIn("*", repr(crontab("*")))
         self.assertIn("*", repr(crontab("*")))
@@ -720,8 +733,10 @@ class test_crontab_is_due(unittest.TestCase):
         self.assertEqual(c.minute, set([30, 40, 50]))
         self.assertEqual(c.minute, set([30, 40, 50]))
 
 
     def test_crontab_spec_invalid_minute(self):
     def test_crontab_spec_invalid_minute(self):
-        self.assertRaises(ValueError, crontab, minute=60)
-        self.assertRaises(ValueError, crontab, minute='0-100')
+        with self.assertRaises(ValueError):
+            crontab(minute=60)
+        with self.assertRaises(ValueError):
+            crontab(minute='0-100')
 
 
     def test_crontab_spec_hour_formats(self):
     def test_crontab_spec_hour_formats(self):
         c = crontab(hour=6)
         c = crontab(hour=6)
@@ -732,8 +747,10 @@ class test_crontab_is_due(unittest.TestCase):
         self.assertEqual(c.hour, set([4, 8, 12]))
         self.assertEqual(c.hour, set([4, 8, 12]))
 
 
     def test_crontab_spec_invalid_hour(self):
     def test_crontab_spec_invalid_hour(self):
-        self.assertRaises(ValueError, crontab, hour=24)
-        self.assertRaises(ValueError, crontab, hour='0-30')
+        with self.assertRaises(ValueError):
+            crontab(hour=24)
+        with self.assertRaises(ValueError):
+            crontab(hour='0-30')
 
 
     def test_crontab_spec_dow_formats(self):
     def test_crontab_spec_dow_formats(self):
         c = crontab(day_of_week=5)
         c = crontab(day_of_week=5)
@@ -760,10 +777,14 @@ class test_crontab_is_due(unittest.TestCase):
                 break
                 break
 
 
     def test_crontab_spec_invalid_dow(self):
     def test_crontab_spec_invalid_dow(self):
-        self.assertRaises(ValueError, crontab, day_of_week='fooday-barday')
-        self.assertRaises(ValueError, crontab, day_of_week='1,4,foo')
-        self.assertRaises(ValueError, crontab, day_of_week='7')
-        self.assertRaises(ValueError, crontab, day_of_week='12')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='fooday-barday')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='1,4,foo')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='7')
+        with self.assertRaises(ValueError):
+            crontab(day_of_week='12')
 
 
     def test_every_minute_execution_is_due(self):
     def test_every_minute_execution_is_due(self):
         last_ran = self.now - timedelta(seconds=61)
         last_ran = self.now - timedelta(seconds=61)

+ 2 - 0
celery/tests/test_task/test_chord.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 from mock import patch
 from mock import patch
 
 
 from celery import current_app
 from celery import current_app

+ 3 - 2
celery/tests/test_task/test_context.py

@@ -1,9 +1,10 @@
 # -*- coding: utf-8 -*-"
 # -*- coding: utf-8 -*-"
-import threading
+from __future__ import absolute_import
 
 
-from celery.tests.utils import unittest
+import threading
 
 
 from celery.task.base import Context
 from celery.task.base import Context
+from celery.tests.utils import unittest
 
 
 
 
 # Retreive the values of all context attributes as a
 # Retreive the values of all context attributes as a

+ 6 - 3
celery/tests/test_task/test_execute_trace.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import operator
 import operator
 
 
 from celery import states
 from celery import states
@@ -20,7 +22,8 @@ class test_TraceInfo(unittest.TestCase):
         self.assertEqual(info.retval, 4)
         self.assertEqual(info.retval, 4)
 
 
     def test_trace_SystemExit(self):
     def test_trace_SystemExit(self):
-        self.assertRaises(SystemExit, trace, raises, (SystemExit(), ), {})
+        with self.assertRaises(SystemExit):
+            trace(raises, (SystemExit(), ), {})
 
 
     def test_trace_RetryTaskError(self):
     def test_trace_RetryTaskError(self):
         exc = RetryTaskError("foo", "bar")
         exc = RetryTaskError("foo", "bar")
@@ -35,5 +38,5 @@ class test_TraceInfo(unittest.TestCase):
         self.assertIs(info.retval, exc)
         self.assertIs(info.retval, exc)
 
 
     def test_trace_exception_propagate(self):
     def test_trace_exception_propagate(self):
-        self.assertRaises(KeyError, trace, raises, (KeyError("foo"), ), {},
-                          propagate=True)
+        with self.assertRaises(KeyError):
+            trace(raises, (KeyError("foo"), ), {}, propagate=True)

+ 7 - 3
celery/tests/test_task/test_registry.py

@@ -1,7 +1,9 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
+from __future__ import with_statement
 
 
 from celery import registry
 from celery import registry
 from celery.task import Task, PeriodicTask
 from celery.task import Task, PeriodicTask
+from celery.tests.utils import unittest
 
 
 
 
 class TestTask(Task):
 class TestTask(Task):
@@ -22,12 +24,14 @@ class TestPeriodicTask(PeriodicTask):
 class TestTaskRegistry(unittest.TestCase):
 class TestTaskRegistry(unittest.TestCase):
 
 
     def assertRegisterUnregisterCls(self, r, task):
     def assertRegisterUnregisterCls(self, r, task):
-        self.assertRaises(r.NotRegistered, r.unregister, task)
+        with self.assertRaises(r.NotRegistered):
+            r.unregister(task)
         r.register(task)
         r.register(task)
         self.assertIn(task.name, r)
         self.assertIn(task.name, r)
 
 
     def assertRegisterUnregisterFunc(self, r, task, task_name):
     def assertRegisterUnregisterFunc(self, r, task, task_name):
-        self.assertRaises(r.NotRegistered, r.unregister, task_name)
+        with self.assertRaises(r.NotRegistered):
+            r.unregister(task_name)
         r.register(task, task_name)
         r.register(task, task_name)
         self.assertIn(task_name, r)
         self.assertIn(task_name, r)
 
 

+ 26 - 15
celery/tests/test_task/test_result.py

@@ -121,21 +121,25 @@ class TestAsyncResult(AppCase):
 
 
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok2_res.get(), "quick")
         self.assertEqual(ok2_res.get(), "quick")
-        self.assertRaises(KeyError, nok_res.get)
+        with self.assertRaises(KeyError):
+            nok_res.get()
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertIsInstance(nok2_res.result, KeyError)
         self.assertEqual(ok_res.info, "the")
         self.assertEqual(ok_res.info, "the")
 
 
     def test_get_timeout(self):
     def test_get_timeout(self):
         res = AsyncResult(self.task4["id"])             # has RETRY status
         res = AsyncResult(self.task4["id"])             # has RETRY status
-        self.assertRaises(TimeoutError, res.get, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            res.get(timeout=0.1)
 
 
         pending_res = AsyncResult(uuid())
         pending_res = AsyncResult(uuid())
-        self.assertRaises(TimeoutError, pending_res.get, timeout=0.1)
+        with self.assertRaises(TimeoutError):
+            pending_res.get(timeout=0.1)
 
 
     @skip_if_quick
     @skip_if_quick
     def test_get_timeout_longer(self):
     def test_get_timeout_longer(self):
         res = AsyncResult(self.task4["id"])             # has RETRY status
         res = AsyncResult(self.task4["id"])             # has RETRY status
-        self.assertRaises(TimeoutError, res.get, timeout=1)
+        with self.assertRaises(TimeoutError):
+            res.get(timeout=1)
 
 
     def test_ready(self):
     def test_ready(self):
         oks = (AsyncResult(self.task1["id"]),
         oks = (AsyncResult(self.task1["id"]),
@@ -224,7 +228,8 @@ class TestTaskSetResult(AppCase):
         ar = MockAsyncResultFailure(uuid())
         ar = MockAsyncResultFailure(uuid())
         ts = TaskSetResult(uuid(), [ar])
         ts = TaskSetResult(uuid(), [ar])
         it = iter(ts)
         it = iter(ts)
-        self.assertRaises(KeyError, it.next)
+        with self.assertRaises(KeyError):
+            it.next()
 
 
     def test_forget(self):
     def test_forget(self):
         subs = [MockAsyncResultSuccess(uuid()),
         subs = [MockAsyncResultSuccess(uuid()),
@@ -245,14 +250,14 @@ class TestTaskSetResult(AppCase):
                 MockAsyncResultSuccess(uuid())]
                 MockAsyncResultSuccess(uuid())]
         ts = TaskSetResult(uuid(), subs)
         ts = TaskSetResult(uuid(), subs)
         ts.save()
         ts.save()
-        self.assertRaises(AttributeError, ts.save, backend=object())
+        with self.assertRaises(AttributeError):
+            ts.save(backend=object())
         self.assertEqual(TaskSetResult.restore(ts.taskset_id).subtasks,
         self.assertEqual(TaskSetResult.restore(ts.taskset_id).subtasks,
                          ts.subtasks)
                          ts.subtasks)
         ts.delete()
         ts.delete()
         self.assertIsNone(TaskSetResult.restore(ts.taskset_id))
         self.assertIsNone(TaskSetResult.restore(ts.taskset_id))
-        self.assertRaises(AttributeError,
-                          TaskSetResult.restore, ts.taskset_id,
-                          backend=object())
+        with self.assertRaises(AttributeError):
+            TaskSetResult.restore(ts.taskset_id, backend=object())
 
 
     def test_join_native(self):
     def test_join_native(self):
         backend = SimpleBackend()
         backend = SimpleBackend()
@@ -292,7 +297,8 @@ class TestTaskSetResult(AppCase):
         ar2 = MockAsyncResultSuccess(uuid())
         ar2 = MockAsyncResultSuccess(uuid())
         ar3 = AsyncResult(uuid())
         ar3 = AsyncResult(uuid())
         ts = TaskSetResult(uuid(), [ar, ar2, ar3])
         ts = TaskSetResult(uuid(), [ar, ar2, ar3])
-        self.assertRaises(TimeoutError, ts.join, timeout=0.0000001)
+        with self.assertRaises(TimeoutError):
+            ts.join(timeout=0.0000001)
 
 
     def test_itersubtasks(self):
     def test_itersubtasks(self):
 
 
@@ -367,10 +373,12 @@ class TestFailedTaskSetResult(TestTaskSetResult):
         def consume():
         def consume():
             return list(it)
             return list(it)
 
 
-        self.assertRaises(KeyError, consume)
+        with self.assertRaises(KeyError):
+            consume()
 
 
     def test_join(self):
     def test_join(self):
-        self.assertRaises(KeyError, self.ts.join)
+        with self.assertRaises(KeyError):
+            self.ts.join()
 
 
     def test_successful(self):
     def test_successful(self):
         self.assertFalse(self.ts.successful())
         self.assertFalse(self.ts.successful())
@@ -396,11 +404,13 @@ class TestTaskSetPending(AppCase):
         self.assertTrue(self.ts.waiting())
         self.assertTrue(self.ts.waiting())
 
 
     def x_join(self):
     def x_join(self):
-        self.assertRaises(TimeoutError, self.ts.join, timeout=0.001)
+        with self.assertRaises(TimeoutError):
+            self.ts.join(timeout=0.001)
 
 
     @skip_if_quick
     @skip_if_quick
     def x_join_longer(self):
     def x_join_longer(self):
-        self.assertRaises(TimeoutError, self.ts.join, timeout=1)
+        with self.assertRaises(TimeoutError):
+            self.ts.join(timeout=1)
 
 
 
 
 class RaisingTask(Task):
 class RaisingTask(Task):
@@ -413,7 +423,8 @@ class TestEagerResult(AppCase):
 
 
     def test_wait_raises(self):
     def test_wait_raises(self):
         res = RaisingTask.apply(args=[3, 3])
         res = RaisingTask.apply(args=[3, 3])
-        self.assertRaises(KeyError, res.wait)
+        with self.assertRaises(KeyError):
+            res.wait()
 
 
     def test_wait(self):
     def test_wait(self):
         res = EagerResult("x", "x", states.RETRY)
         res = EagerResult("x", "x", states.RETRY)

+ 2 - 1
celery/tests/test_task/test_states.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery.states import state
 from celery.states import state
 from celery import states
 from celery import states
+from celery.tests.utils import unittest
 
 
 
 
 class test_state_precedence(unittest.TestCase):
 class test_state_precedence(unittest.TestCase):

+ 2 - 1
celery/tests/test_task/test_task_abortable.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
+from celery.tests.utils import unittest
 
 
 
 
 class MyAbortableTask(AbortableTask):
 class MyAbortableTask(AbortableTask):

+ 1 - 0
celery/tests/test_task/test_task_builtins.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import warnings
 import warnings

+ 6 - 2
celery/tests/test_task/test_task_control.py

@@ -1,3 +1,6 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 from functools import wraps
 from functools import wraps
 
 
 from kombu.pidbox import Mailbox
 from kombu.pidbox import Mailbox
@@ -135,8 +138,9 @@ class test_Broadcast(unittest.TestCase):
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_broadcast_validate(self):
     def test_broadcast_validate(self):
-        self.assertRaises(ValueError, self.control.broadcast, "foobarbaz2",
-                          destination="foo")
+        with self.assertRaises(ValueError):
+            self.control.broadcast("foobarbaz2",
+                                   destination="foo")
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_rate_limit(self):
     def test_rate_limit(self):

+ 9 - 4
celery/tests/test_task/test_task_http.py

@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import logging
 import logging
@@ -112,7 +113,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(fail_response("Invalid moon alignment")):
         with mock_urlopen(fail_response("Invalid moon alignment")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.RemoteExecuteError, d.dispatch)
+            with self.assertRaises(http.RemoteExecuteError):
+                d.dispatch()
 
 
     def test_dispatch_empty_response(self):
     def test_dispatch_empty_response(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
@@ -120,7 +122,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(_response("")):
         with mock_urlopen(_response("")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.InvalidResponseError, d.dispatch)
+            with self.assertRaises(http.InvalidResponseError):
+                d.dispatch()
 
 
     def test_dispatch_non_json(self):
     def test_dispatch_non_json(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
@@ -128,7 +131,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(_response("{'#{:'''")):
         with mock_urlopen(_response("{'#{:'''")):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.InvalidResponseError, d.dispatch)
+            with self.assertRaises(http.InvalidResponseError):
+                d.dispatch()
 
 
     def test_dispatch_unknown_status(self):
     def test_dispatch_unknown_status(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")
@@ -136,7 +140,8 @@ class TestHttpDispatch(unittest.TestCase):
         with mock_urlopen(unknown_response()):
         with mock_urlopen(unknown_response()):
             d = http.HttpDispatch("http://example.com/mul", "GET", {
             d = http.HttpDispatch("http://example.com/mul", "GET", {
                                     "x": 10, "y": 10}, logger)
                                     "x": 10, "y": 10}, logger)
-            self.assertRaises(http.UnknownStatusError, d.dispatch)
+            with self.assertRaises(http.UnknownStatusError):
+                d.dispatch()
 
 
     def test_dispatch_POST(self):
     def test_dispatch_POST(self):
         logger = logging.getLogger("celery.unittest")
         logger = logging.getLogger("celery.unittest")

+ 1 - 0
celery/tests/test_task/test_task_sets.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import anyjson
 import anyjson

+ 3 - 1
celery/tests/test_utils/__init__.py

@@ -1,8 +1,10 @@
+from __future__ import absolute_import
+
 import pickle
 import pickle
-from celery.tests.utils import unittest
 
 
 from celery import utils
 from celery import utils
 from celery.utils import promise, mpromise, maybe_promise
 from celery.utils import promise, mpromise, maybe_promise
+from celery.tests.utils import unittest
 
 
 
 
 def double(x):
 def double(x):

+ 12 - 5
celery/tests/test_utils/test_datastructures.py

@@ -1,11 +1,14 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import sys
 import sys
-from celery.tests.utils import unittest
 from Queue import Queue
 from Queue import Queue
 
 
 from celery.datastructures import ExceptionInfo, LRUCache
 from celery.datastructures import ExceptionInfo, LRUCache
 from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
 from celery.datastructures import ConfigurationView
+from celery.tests.utils import unittest
 
 
 
 
 class Object(object):
 class Object(object):
@@ -21,7 +24,8 @@ class test_DictAttribute(unittest.TestCase):
         self.assertEqual(x["foo"], x.obj.foo)
         self.assertEqual(x["foo"], x.obj.foo)
         self.assertEqual(x.get("foo"), "The quick brown fox")
         self.assertEqual(x.get("foo"), "The quick brown fox")
         self.assertIsNone(x.get("bar"))
         self.assertIsNone(x.get("bar"))
-        self.assertRaises(KeyError, x.__getitem__, "bar")
+        with self.assertRaises(KeyError):
+            x["bar"]
 
 
     def test_setdefault(self):
     def test_setdefault(self):
         x = DictAttribute(Object())
         x = DictAttribute(Object())
@@ -96,11 +100,13 @@ class test_utilities(unittest.TestCase):
     def test_consume_queue(self):
     def test_consume_queue(self):
         x = Queue()
         x = Queue()
         it = consume_queue(x)
         it = consume_queue(x)
-        self.assertRaises(StopIteration, it.next)
+        with self.assertRaises(StopIteration):
+            it.next()
         x.put("foo")
         x.put("foo")
         it = consume_queue(x)
         it = consume_queue(x)
         self.assertEqual(it.next(), "foo")
         self.assertEqual(it.next(), "foo")
-        self.assertRaises(StopIteration, it.next)
+        with self.assertRaises(StopIteration):
+            it.next()
 
 
 
 
 class test_LimitedSet(unittest.TestCase):
 class test_LimitedSet(unittest.TestCase):
@@ -166,6 +172,7 @@ class test_AttributeDict(unittest.TestCase):
     def test_getattr__setattr(self):
     def test_getattr__setattr(self):
         x = AttributeDict({"foo": "bar"})
         x = AttributeDict({"foo": "bar"})
         self.assertEqual(x["foo"], "bar")
         self.assertEqual(x["foo"], "bar")
-        self.assertRaises(AttributeError, getattr, x, "bar")
+        with self.assertRaises(AttributeError):
+            x.bar
         x.bar = "foo"
         x.bar = "foo"
         self.assertEqual(x["bar"], "foo")
         self.assertEqual(x["bar"], "foo")

+ 2 - 1
celery/tests/test_utils/test_pickle.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery.utils.serialization import pickle
 from celery.utils.serialization import pickle
+from celery.tests.utils import unittest
 
 
 
 
 class RegularException(Exception):
 class RegularException(Exception):

+ 1 - 0
celery/tests/test_utils/test_serialization.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys

+ 1 - 0
celery/tests/test_utils/test_timer2.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import sys
 import sys

+ 2 - 0
celery/tests/test_utils/test_utils_encoding.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import sys
 import sys
 
 
 from nose import SkipTest
 from nose import SkipTest

+ 2 - 1
celery/tests/test_utils/test_utils_info.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery import Celery
 from celery import Celery
 from celery.utils import textindent
 from celery.utils import textindent
+from celery.tests.utils import unittest
 
 
 RANDTEXT = """\
 RANDTEXT = """\
 The quick brown
 The quick brown

+ 2 - 1
celery/tests/test_utils/test_utils_timeutils.py

@@ -1,7 +1,8 @@
+from __future__ import absolute_import
+
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
 from celery.utils import timeutils
 from celery.utils import timeutils
-
 from celery.tests.utils import unittest
 from celery.tests.utils import unittest
 
 
 
 

+ 25 - 13
celery/tests/test_worker/__init__.py

@@ -1,3 +1,4 @@
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import socket
 import socket
@@ -112,7 +113,8 @@ class test_QoS(unittest.TestCase):
         self.assertEqual(qos.increment(-30), 14)
         self.assertEqual(qos.increment(-30), 14)
         self.assertEqual(qos.decrement(7), 7)
         self.assertEqual(qos.decrement(7), 7)
         self.assertEqual(qos.decrement(), 6)
         self.assertEqual(qos.decrement(), 6)
-        self.assertRaises(AssertionError, qos.decrement, 10)
+        with self.assertRaises(AssertionError):
+            qos.decrement(10)
 
 
     def test_qos_disabled_increment_decrement(self):
     def test_qos_disabled_increment_decrement(self):
         qos = self._QoS(0)
         qos = self._QoS(0)
@@ -358,7 +360,8 @@ class test_Consumer(unittest.TestCase):
         l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
         l = MockConsumer(self.ready_queue, self.eta_schedule, self.logger,
                              send_events=False, pool=BasePool())
                              send_events=False, pool=BasePool())
         l.connection_errors = (KeyError, )
         l.connection_errors = (KeyError, )
-        self.assertRaises(SyntaxError, l.start)
+        with self.assertRaises(SyntaxError):
+            l.start()
         l.heart.stop()
         l.heart.stop()
         l.priority_timer.stop()
         l.priority_timer.stop()
 
 
@@ -433,8 +436,8 @@ class test_Consumer(unittest.TestCase):
         l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
         l.maybe_conn_error(Mock(side_effect=AttributeError("foo")))
         l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
         l.maybe_conn_error(Mock(side_effect=KeyError("foo")))
         l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
         l.maybe_conn_error(Mock(side_effect=SyntaxError("foo")))
-        self.assertRaises(IndexError, l.maybe_conn_error,
-                Mock(side_effect=IndexError("foo")))
+        with self.assertRaises(IndexError):
+            l.maybe_conn_error(Mock(side_effect=IndexError("foo")))
 
 
     def test_apply_eta_task(self):
     def test_apply_eta_task(self):
         from celery.worker import state
         from celery.worker import state
@@ -514,7 +517,8 @@ class test_Consumer(unittest.TestCase):
 
 
         l.event_dispatcher = Mock()
         l.event_dispatcher = Mock()
         self.assertFalse(l.receive_message(m.decode(), m))
         self.assertFalse(l.receive_message(m.decode(), m))
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
         self.assertTrue(self.eta_schedule.empty())
         self.assertTrue(self.eta_schedule.empty())
 
 
     def test_receieve_message_ack_raises(self):
     def test_receieve_message_ack_raises(self):
@@ -532,7 +536,8 @@ class test_Consumer(unittest.TestCase):
             self.assertFalse(l.receive_message(m.decode(), m))
             self.assertFalse(l.receive_message(m.decode(), m))
             self.assertTrue(log)
             self.assertTrue(log)
             self.assertIn("unknown message", log[0].message.args[0])
             self.assertIn("unknown message", log[0].message.args[0])
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
         self.assertTrue(self.eta_schedule.empty())
         self.assertTrue(self.eta_schedule.empty())
         m.ack.assert_called_with()
         m.ack.assert_called_with()
         self.assertTrue(l.logger.critical.call_count)
         self.assertTrue(l.logger.critical.call_count)
@@ -566,7 +571,8 @@ class test_Consumer(unittest.TestCase):
         self.assertIsInstance(task, TaskRequest)
         self.assertIsInstance(task, TaskRequest)
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertEqual(task.execute(), 2 * 4 * 8)
-        self.assertRaises(Empty, self.ready_queue.get_nowait)
+        with self.assertRaises(Empty):
+            self.ready_queue.get_nowait()
 
 
     def test_reset_pidbox_node(self):
     def test_reset_pidbox_node(self):
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
         l = MyKombuConsumer(self.ready_queue, self.eta_schedule, self.logger,
@@ -654,7 +660,8 @@ class test_Consumer(unittest.TestCase):
                 raise KeyError("foo")
                 raise KeyError("foo")
 
 
         l.consume_messages = raises_KeyError
         l.consume_messages = raises_KeyError
-        self.assertRaises(KeyError, l.start)
+        with self.assertRaises(KeyError):
+            l.start()
         self.assertTrue(init_callback.call_count)
         self.assertTrue(init_callback.call_count)
         self.assertEqual(l.iterations, 1)
         self.assertEqual(l.iterations, 1)
         self.assertEqual(l.qos.prev, l.qos.value)
         self.assertEqual(l.qos.prev, l.qos.value)
@@ -667,7 +674,8 @@ class test_Consumer(unittest.TestCase):
         l.broadcast_consumer = Mock()
         l.broadcast_consumer = Mock()
         l.connection = BrokerConnection()
         l.connection = BrokerConnection()
         l.consume_messages = Mock(side_effect=socket.error("foo"))
         l.consume_messages = Mock(side_effect=socket.error("foo"))
-        self.assertRaises(socket.error, l.start)
+        with self.assertRaises(socket.error):
+            l.start()
         self.assertTrue(init_callback.call_count)
         self.assertTrue(init_callback.call_count)
         self.assertTrue(l.consume_messages.call_count)
         self.assertTrue(l.consume_messages.call_count)
 
 
@@ -799,7 +807,8 @@ class test_WorkController(AppCase):
         task = TaskRequest.from_message(m, m.decode())
         task = TaskRequest.from_message(m, m.decode())
         worker.components = []
         worker.components = []
         worker._state = worker.RUN
         worker._state = worker.RUN
-        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
+        with self.assertRaises(KeyboardInterrupt):
+            worker.process_task(task)
         self.assertEqual(worker._state, worker.TERMINATE)
         self.assertEqual(worker._state, worker.TERMINATE)
 
 
     def test_process_task_raise_SystemTerminate(self):
     def test_process_task_raise_SystemTerminate(self):
@@ -812,7 +821,8 @@ class test_WorkController(AppCase):
         task = TaskRequest.from_message(m, m.decode())
         task = TaskRequest.from_message(m, m.decode())
         worker.components = []
         worker.components = []
         worker._state = worker.RUN
         worker._state = worker.RUN
-        self.assertRaises(SystemExit, worker.process_task, task)
+        with self.assertRaises(SystemExit):
+            worker.process_task(task)
         self.assertEqual(worker._state, worker.TERMINATE)
         self.assertEqual(worker._state, worker.TERMINATE)
 
 
     def test_process_task_raise_regular(self):
     def test_process_task_raise_regular(self):
@@ -831,7 +841,8 @@ class test_WorkController(AppCase):
         stc = Mock()
         stc = Mock()
         stc.start.side_effect = SystemTerminate()
         stc.start.side_effect = SystemTerminate()
         worker1.components = [stc]
         worker1.components = [stc]
-        self.assertRaises(SystemExit, worker1.start)
+        with self.assertRaises(SystemExit):
+            worker1.start()
         self.assertTrue(stc.terminate.call_count)
         self.assertTrue(stc.terminate.call_count)
 
 
         worker2 = self.create_worker()
         worker2 = self.create_worker()
@@ -839,7 +850,8 @@ class test_WorkController(AppCase):
         sec.start.side_effect = SystemExit()
         sec.start.side_effect = SystemExit()
         sec.terminate = None
         sec.terminate = None
         worker2.components = [sec]
         worker2.components = [sec]
-        self.assertRaises(SystemExit, worker2.start)
+        with self.assertRaises(SystemExit):
+            worker2.start()
         self.assertTrue(sec.stop.call_count)
         self.assertTrue(sec.stop.call_count)
 
 
     def test_state_db(self):
     def test_state_db(self):

+ 2 - 1
celery/tests/test_worker/test_worker_autoscale.py

@@ -1,3 +1,5 @@
+from __future__ import absolute_import
+
 import logging
 import logging
 
 
 from time import time
 from time import time
@@ -7,7 +9,6 @@ from mock import Mock, patch
 from celery.concurrency.base import BasePool
 from celery.concurrency.base import BasePool
 from celery.worker import state
 from celery.worker import state
 from celery.worker import autoscale
 from celery.worker import autoscale
-
 from celery.tests.utils import unittest, sleepdeprived
 from celery.tests.utils import unittest, sleepdeprived
 
 
 logger = logging.getLogger("celery.tests.autoscale")
 logger = logging.getLogger("celery.tests.autoscale")

+ 9 - 5
celery/tests/test_worker/test_worker_control.py

@@ -1,25 +1,27 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
 import socket
 import socket
-from celery.tests.utils import unittest
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
 from kombu import pidbox
 from kombu import pidbox
 from mock import Mock
 from mock import Mock
 
 
-from celery.utils.timer2 import Timer
-
 from celery import current_app
 from celery import current_app
 from celery.datastructures import AttributeDict
 from celery.datastructures import AttributeDict
 from celery.task import task
 from celery.task import task
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.task import PingTask
 from celery.task import PingTask
 from celery.utils import uuid
 from celery.utils import uuid
+from celery.utils.timer2 import Timer
 from celery.worker.buckets import FastQueue
 from celery.worker.buckets import FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.job import TaskRequest
 from celery.worker import state
 from celery.worker import state
 from celery.worker.state import revoked
 from celery.worker.state import revoked
 from celery.worker.control import builtins
 from celery.worker.control import builtins
 from celery.worker.control.registry import Panel
 from celery.worker.control.registry import Panel
+from celery.tests.utils import unittest
 
 
 hostname = socket.gethostname()
 hostname = socket.gethostname()
 
 
@@ -295,7 +297,8 @@ class test_ControlPanel(unittest.TestCase):
                                 "rate_limit": "1000/s"})
                                 "rate_limit": "1000/s"})
 
 
     def test_unexposed_command(self):
     def test_unexposed_command(self):
-        self.assertRaises(KeyError, self.panel.handle, "foo", arguments={})
+        with self.assertRaises(KeyError):
+            self.panel.handle("foo", arguments={})
 
 
     def test_revoke_with_name(self):
     def test_revoke_with_name(self):
         tid = uuid()
         tid = uuid()
@@ -353,7 +356,8 @@ class test_ControlPanel(unittest.TestCase):
     def test_shutdown(self):
     def test_shutdown(self):
         m = {"method": "shutdown",
         m = {"method": "shutdown",
              "destination": hostname}
              "destination": hostname}
-        self.assertRaises(SystemExit, self.panel.dispatch_from_message, m)
+        with self.assertRaises(SystemExit):
+            self.panel.dispatch_from_message(m)
 
 
     def test_panel_reply(self):
     def test_panel_reply(self):
 
 

+ 2 - 1
celery/tests/test_worker/test_worker_heartbeat.py

@@ -1,5 +1,6 @@
-from celery.worker.heartbeat import Heart
+from __future__ import absolute_import
 
 
+from celery.worker.heartbeat import Heart
 from celery.tests.utils import unittest, sleepdeprived
 from celery.tests.utils import unittest, sleepdeprived
 
 
 
 

+ 7 - 6
celery/tests/test_worker/test_worker_job.py

@@ -1,4 +1,5 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
+from __future__ import absolute_import
 from __future__ import with_statement
 from __future__ import with_statement
 
 
 import anyjson
 import anyjson
@@ -132,8 +133,8 @@ class test_WorkerTaskTrace(unittest.TestCase):
         mytask.backend = Mock()
         mytask.backend = Mock()
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         mytask.backend.process_cleanup = Mock(side_effect=SystemExit())
         try:
         try:
-            self.assertRaises(SystemExit,
-                    jail, uuid(), mytask.name, [2], {})
+            with self.assertRaises(SystemExit):
+                jail(uuid(), mytask.name, [2], {})
         finally:
         finally:
             mytask.backend = backend
             mytask.backend = backend
 
 
@@ -416,8 +417,8 @@ class test_TaskRequest(unittest.TestCase):
 
 
     def test_from_message_invalid_kwargs(self):
     def test_from_message_invalid_kwargs(self):
         body = dict(task="foo", id=1, args=(), kwargs="foo")
         body = dict(task="foo", id=1, args=(), kwargs="foo")
-        self.assertRaises(InvalidTaskError,
-                          TaskRequest.from_message, None, body)
+        with self.assertRaises(InvalidTaskError):
+            TaskRequest.from_message(None, body)
 
 
     def test_on_timeout(self):
     def test_on_timeout(self):
 
 
@@ -542,8 +543,8 @@ class test_TaskRequest(unittest.TestCase):
         m = Message(None, body=anyjson.serialize(body), backend="foo",
         m = Message(None, body=anyjson.serialize(body), backend="foo",
                           content_type="application/json",
                           content_type="application/json",
                           content_encoding="utf-8")
                           content_encoding="utf-8")
-        self.assertRaises(NotRegistered, TaskRequest.from_message,
-                          m, m.decode())
+        with self.assertRaises(NotRegistered):
+            TaskRequest.from_message(m, m.decode())
 
 
     def test_execute(self):
     def test_execute(self):
         tid = uuid()
         tid = uuid()

+ 2 - 1
celery/tests/test_worker/test_worker_mediator.py

@@ -1,4 +1,4 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from Queue import Queue
 from Queue import Queue
 
 
@@ -7,6 +7,7 @@ from mock import Mock, patch
 from celery.utils import uuid
 from celery.utils import uuid
 from celery.worker.mediator import Mediator
 from celery.worker.mediator import Mediator
 from celery.worker.state import revoked as revoked_tasks
 from celery.worker.state import revoked as revoked_tasks
+from celery.tests.utils import unittest
 
 
 
 
 class MockTask(object):
 class MockTask(object):

+ 2 - 1
celery/tests/test_worker/test_worker_revoke.py

@@ -1,6 +1,7 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery.worker import state
 from celery.worker import state
+from celery.tests.utils import unittest
 
 
 
 
 class TestRevokeRegistry(unittest.TestCase):
 class TestRevokeRegistry(unittest.TestCase):

+ 2 - 1
celery/tests/test_worker/test_worker_state.py

@@ -1,7 +1,8 @@
-from celery.tests.utils import unittest
+from __future__ import absolute_import
 
 
 from celery.datastructures import LimitedSet
 from celery.datastructures import LimitedSet
 from celery.worker import state
 from celery.worker import state
+from celery.tests.utils import unittest
 
 
 
 
 class StateResetCase(unittest.TestCase):
 class StateResetCase(unittest.TestCase):