Browse Source

Use unittest2 features for better test diagnostics.

Ask Solem 16 năm trước cách đây
mục cha
commit
8b2643618f
37 tập tin đã thay đổi với 289 bổ sung305 xóa
  1. 11 13
      celery/tests/test_backends/__init__.py
  2. 2 2
      celery/tests/test_backends/test_amqp.py
  3. 10 10
      celery/tests/test_backends/test_base.py
  4. 6 7
      celery/tests/test_backends/test_cache.py
  5. 5 6
      celery/tests/test_backends/test_database.py
  6. 11 12
      celery/tests/test_backends/test_redis.py
  7. 7 8
      celery/tests/test_backends/test_tyrant.py
  8. 18 18
      celery/tests/test_beat.py
  9. 1 1
      celery/tests/test_bin_celeryd.py
  10. 8 8
      celery/tests/test_buckets.py
  11. 5 3
      celery/tests/test_conf.py
  12. 14 14
      celery/tests/test_datastructures.py
  13. 4 2
      celery/tests/test_discovery.py
  14. 1 1
      celery/tests/test_events.py
  15. 4 4
      celery/tests/test_loaders.py
  16. 12 17
      celery/tests/test_log.py
  17. 2 1
      celery/tests/test_messaging.py
  18. 17 17
      celery/tests/test_models.py
  19. 12 12
      celery/tests/test_pickle.py
  20. 20 31
      celery/tests/test_pool.py
  21. 16 16
      celery/tests/test_registry.py
  22. 6 6
      celery/tests/test_result.py
  23. 2 2
      celery/tests/test_serialization.py
  24. 19 19
      celery/tests/test_task.py
  25. 1 1
      celery/tests/test_task_builtins.py
  26. 6 6
      celery/tests/test_task_control.py
  27. 10 10
      celery/tests/test_task_http.py
  28. 6 6
      celery/tests/test_utils.py
  29. 1 1
      celery/tests/test_utils_info.py
  30. 17 17
      celery/tests/test_worker.py
  31. 3 3
      celery/tests/test_worker_control.py
  32. 2 2
      celery/tests/test_worker_controllers.py
  33. 6 6
      celery/tests/test_worker_heartbeat.py
  34. 17 17
      celery/tests/test_worker_job.py
  35. 3 3
      celery/tests/test_worker_revoke.py
  36. 3 3
      celery/tests/test_worker_scheduler.py
  37. 1 0
      contrib/requirements/test.txt

+ 11 - 13
celery/tests/test_backends/__init__.py

@@ -1,27 +1,25 @@
-import unittest
-
+import unittest2 as unittest
 
 
+from celery import backends
 from celery.backends.database import DatabaseBackend
 from celery.backends.database import DatabaseBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.pyredis import RedisBackend
 from celery.backends.pyredis import RedisBackend
-from celery import backends
 
 
 
 
 class TestBackends(unittest.TestCase):
 class TestBackends(unittest.TestCase):
 
 
     def test_get_backend_aliases(self):
     def test_get_backend_aliases(self):
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("amqp"), AMQPBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("database"), DatabaseBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("db"), DatabaseBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("redis"), RedisBackend))
+        expects = [("amqp", AMQPBackend),
+                   ("database", DatabaseBackend),
+                   ("db", DatabaseBackend),
+                   ("redis", RedisBackend)]
+        for expect_name, expect_cls in expects:
+            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
+                                  expect_cls)
 
 
     def test_get_backend_cahe(self):
     def test_get_backend_cahe(self):
         backends._backend_cache = {}
         backends._backend_cache = {}
         backends.get_backend_cls("amqp")
         backends.get_backend_cls("amqp")
-        self.assertTrue("amqp" in backends._backend_cache)
+        self.assertIn("amqp", backends._backend_cache)
         amqp_backend = backends.get_backend_cls("amqp")
         amqp_backend = backends.get_backend_cls("amqp")
-        self.assertTrue(amqp_backend is backends._backend_cache["amqp"])
+        self.assertIs(amqp_backend, backends._backend_cache["amqp"])

+ 2 - 2
celery/tests/test_backends/test_amqp.py

@@ -1,6 +1,6 @@
 import sys
 import sys
 import errno
 import errno
-import unittest
+import unittest2 as unittest
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
@@ -56,7 +56,7 @@ class TestRedisBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(tb.is_successful(tid3))
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
         self.assertEqual(tb.get_traceback(tid3), einfo.traceback)
         self.assertEqual(tb.get_traceback(tid3), einfo.traceback)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):

+ 10 - 10
celery/tests/test_backends/test_base.py

@@ -1,6 +1,6 @@
 import sys
 import sys
 import types
 import types
-import unittest
+import unittest2 as unittest
 
 
 from django.db.models.base import subclass_exception
 from django.db.models.base import subclass_exception
 from billiard.serialization import find_nearest_pickleable_exception as fnpe
 from billiard.serialization import find_nearest_pickleable_exception as fnpe
@@ -54,31 +54,31 @@ class TestBaseBackendInterface(unittest.TestCase):
 class TestPickleException(unittest.TestCase):
 class TestPickleException(unittest.TestCase):
 
 
     def test_oldstyle(self):
     def test_oldstyle(self):
-        self.assertTrue(fnpe(Oldstyle()) is None)
+        self.assertIsNone(fnpe(Oldstyle()))
 
 
     def test_BaseException(self):
     def test_BaseException(self):
-        self.assertTrue(fnpe(Exception()) is None)
+        self.assertIsNone(fnpe(Exception()))
 
 
     def test_get_pickleable_exception(self):
     def test_get_pickleable_exception(self):
         exc = Exception("foo")
         exc = Exception("foo")
         self.assertEqual(gpe(exc), exc)
         self.assertEqual(gpe(exc), exc)
 
 
     def test_unpickleable(self):
     def test_unpickleable(self):
-        self.assertTrue(isinstance(fnpe(Unpickleable()), KeyError))
-        self.assertEqual(fnpe(Impossible()), None)
+        self.assertIsInstance(fnpe(Unpickleable()), KeyError)
+        self.assertIsNone(fnpe(Impossible()))
 
 
 
 
 class TestPrepareException(unittest.TestCase):
 class TestPrepareException(unittest.TestCase):
 
 
     def test_unpickleable(self):
     def test_unpickleable(self):
         x = b.prepare_exception(Unpickleable(1, 2, "foo"))
         x = b.prepare_exception(Unpickleable(1, 2, "foo"))
-        self.assertTrue(isinstance(x, KeyError))
+        self.assertIsInstance(x, KeyError)
         y = b.exception_to_python(x)
         y = b.exception_to_python(x)
-        self.assertTrue(isinstance(y, KeyError))
+        self.assertIsInstance(y, KeyError)
 
 
     def test_impossible(self):
     def test_impossible(self):
         x = b.prepare_exception(Impossible())
         x = b.prepare_exception(Impossible())
-        self.assertTrue(isinstance(x, UnpickleableExceptionWrapper))
+        self.assertIsInstance(x, UnpickleableExceptionWrapper)
         y = b.exception_to_python(x)
         y = b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, "Impossible")
         self.assertEqual(y.__class__.__name__, "Impossible")
         if sys.version_info < (2, 5):
         if sys.version_info < (2, 5):
@@ -88,9 +88,9 @@ class TestPrepareException(unittest.TestCase):
 
 
     def test_regular(self):
     def test_regular(self):
         x = b.prepare_exception(KeyError("baz"))
         x = b.prepare_exception(KeyError("baz"))
-        self.assertTrue(isinstance(x, KeyError))
+        self.assertIsInstance(x, KeyError)
         y = b.exception_to_python(x)
         y = b.exception_to_python(x)
-        self.assertTrue(isinstance(y, KeyError))
+        self.assertIsInstance(y, KeyError)
 
 
 
 
 class TestKeyValueStoreBackendInterface(unittest.TestCase):
 class TestKeyValueStoreBackendInterface(unittest.TestCase):

+ 6 - 7
celery/tests/test_backends/test_cache.py

@@ -1,5 +1,5 @@
 import sys
 import sys
-import unittest
+import unittest2 as unittest
 
 
 from billiard.serialization import pickle
 from billiard.serialization import pickle
 from django.core.cache.backends.base import InvalidCacheBackendError
 from django.core.cache.backends.base import InvalidCacheBackendError
@@ -26,7 +26,7 @@ class TestCacheBackend(unittest.TestCase):
 
 
         self.assertFalse(cb.is_successful(tid))
         self.assertFalse(cb.is_successful(tid))
         self.assertEqual(cb.get_status(tid), states.PENDING)
         self.assertEqual(cb.get_status(tid), states.PENDING)
-        self.assertEqual(cb.get_result(tid), None)
+        self.assertIsNone(cb.get_result(tid))
 
 
         cb.mark_as_done(tid, 42)
         cb.mark_as_done(tid, 42)
         self.assertTrue(cb.is_successful(tid))
         self.assertTrue(cb.is_successful(tid))
@@ -42,7 +42,7 @@ class TestCacheBackend(unittest.TestCase):
         res = result.TaskSetResult(taskset_id, subtasks)
         res = result.TaskSetResult(taskset_id, subtasks)
         res.save(backend=backend)
         res.save(backend=backend)
         saved = result.TaskSetResult.restore(taskset_id, backend=backend)
         saved = result.TaskSetResult.restore(taskset_id, backend=backend)
-        self.assertEqual(saved.subtasks, subtasks)
+        self.assertListEqual(saved.subtasks, subtasks)
         self.assertEqual(saved.taskset_id, taskset_id)
         self.assertEqual(saved.taskset_id, taskset_id)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
@@ -69,12 +69,11 @@ class TestCacheBackend(unittest.TestCase):
         cb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         cb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(cb.is_successful(tid3))
         self.assertFalse(cb.is_successful(tid3))
         self.assertEqual(cb.get_status(tid3), states.FAILURE)
         self.assertEqual(cb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(cb.get_result(tid3), KeyError))
+        self.assertIsInstance(cb.get_result(tid3), KeyError)
         self.assertEqual(cb.get_traceback(tid3), einfo.traceback)
         self.assertEqual(cb.get_traceback(tid3), einfo.traceback)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         cb = CacheBackend()
         cb = CacheBackend()
-
         cb.process_cleanup()
         cb.process_cleanup()
 
 
 
 
@@ -91,7 +90,7 @@ class TestCustomCacheBackend(unittest.TestCase):
             from django.core.cache import cache as django_cache
             from django.core.cache import cache as django_cache
             self.assertEqual(cache.__class__.__module__,
             self.assertEqual(cache.__class__.__module__,
                               "django.core.cache.backends.dummy")
                               "django.core.cache.backends.dummy")
-            self.assertTrue(cache is not django_cache)
+            self.assertIsNot(cache, django_cache)
         finally:
         finally:
             conf.CELERY_CACHE_BACKEND = prev_backend
             conf.CELERY_CACHE_BACKEND = prev_backend
             sys.modules["celery.backends.cache"] = prev_module
             sys.modules["celery.backends.cache"] = prev_module
@@ -113,7 +112,7 @@ class TestMemcacheWrapper(unittest.TestCase):
         prev_backend_module = sys.modules.pop("celery.backends.cache")
         prev_backend_module = sys.modules.pop("celery.backends.cache")
         try:
         try:
             from celery.backends.cache import cache, DjangoMemcacheWrapper
             from celery.backends.cache import cache, DjangoMemcacheWrapper
-            self.assertTrue(isinstance(cache, DjangoMemcacheWrapper))
+            self.assertIsInstance(cache, DjangoMemcacheWrapper)
 
 
             key = "cu.test_memcache_wrapper"
             key = "cu.test_memcache_wrapper"
             val = "The quick brown fox."
             val = "The quick brown fox."

+ 5 - 6
celery/tests/test_backends/test_database.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from datetime import timedelta
 from datetime import timedelta
 
 
 from celery import states
 from celery import states
@@ -29,13 +29,12 @@ class TestDatabaseBackend(unittest.TestCase):
 
 
         self.assertFalse(b.is_successful(tid))
         self.assertFalse(b.is_successful(tid))
         self.assertEqual(b.get_status(tid), states.PENDING)
         self.assertEqual(b.get_status(tid), states.PENDING)
-        self.assertTrue(b.get_result(tid) is None)
+        self.assertIsNone(b.get_result(tid))
 
 
         b.mark_as_done(tid, 42)
         b.mark_as_done(tid, 42)
         self.assertTrue(b.is_successful(tid))
         self.assertTrue(b.is_successful(tid))
         self.assertEqual(b.get_status(tid), states.SUCCESS)
         self.assertEqual(b.get_status(tid), states.SUCCESS)
         self.assertEqual(b.get_result(tid), 42)
         self.assertEqual(b.get_result(tid), 42)
-        self.assertTrue(b.get_result(tid), 42)
 
 
         tid2 = gen_unique_id()
         tid2 = gen_unique_id()
         result = {"foo": "baz", "bar": SomeClass(12345)}
         result = {"foo": "baz", "bar": SomeClass(12345)}
@@ -53,17 +52,17 @@ class TestDatabaseBackend(unittest.TestCase):
         b.mark_as_failure(tid3, exception)
         b.mark_as_failure(tid3, exception)
         self.assertFalse(b.is_successful(tid3))
         self.assertFalse(b.is_successful(tid3))
         self.assertEqual(b.get_status(tid3), states.FAILURE)
         self.assertEqual(b.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(b.get_result(tid3), KeyError))
+        self.assertIsInstance(b.get_result(tid3), KeyError)
 
 
     def test_taskset_store(self):
     def test_taskset_store(self):
         b = DatabaseBackend()
         b = DatabaseBackend()
         tid = gen_unique_id()
         tid = gen_unique_id()
 
 
-        self.assertTrue(b.restore_taskset(tid) is None)
+        self.assertIsNone(b.restore_taskset(tid))
 
 
         result = {"foo": "baz", "bar": SomeClass(12345)}
         result = {"foo": "baz", "bar": SomeClass(12345)}
         b.save_taskset(tid, result)
         b.save_taskset(tid, result)
         rindb = b.restore_taskset(tid)
         rindb = b.restore_taskset(tid)
-        self.assertTrue(rindb is not None)
+        self.assertIsNotNone(rindb)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)
         self.assertEqual(rindb.get("bar").data, 12345)

+ 11 - 12
celery/tests/test_backends/test_redis.py

@@ -1,7 +1,7 @@
 import sys
 import sys
 import errno
 import errno
 import socket
 import socket
-import unittest
+import unittest2 as unittest
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
@@ -60,11 +60,11 @@ class TestRedisBackend(unittest.TestCase):
         if not tb:
         if not tb:
             return # Skip test
             return # Skip test
 
 
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.open()
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
         tb = get_redis_or_None()
         tb = get_redis_or_None()
@@ -75,13 +75,12 @@ class TestRedisBackend(unittest.TestCase):
 
 
         self.assertFalse(tb.is_successful(tid))
         self.assertFalse(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.PENDING)
         self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertEqual(tb.get_result(tid), None)
+        self.assertIsNone(tb.get_result(tid))
 
 
         tb.mark_as_done(tid, 42)
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
         self.assertTrue(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_result(tid), 42)
         self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb.get_result(tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
         tb = get_redis_or_None()
         tb = get_redis_or_None()
@@ -109,7 +108,7 @@ class TestRedisBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception)
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         tb = get_redis_or_None()
         tb = get_redis_or_None()
@@ -118,7 +117,7 @@ class TestRedisBackend(unittest.TestCase):
 
 
         tb.process_cleanup()
         tb.process_cleanup()
 
 
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
 
 
     def test_connection_close_if_connected(self):
     def test_connection_close_if_connected(self):
         tb = get_redis_or_None()
         tb = get_redis_or_None()
@@ -126,11 +125,11 @@ class TestRedisBackend(unittest.TestCase):
             return
             return
 
 
         tb.open()
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.close()
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
 
 
 
 
 class TestTyrantBackendNoTyrant(unittest.TestCase):
 class TestTyrantBackendNoTyrant(unittest.TestCase):
@@ -140,7 +139,7 @@ class TestTyrantBackendNoTyrant(unittest.TestCase):
         try:
         try:
             def with_redis_masked(_val):
             def with_redis_masked(_val):
                 from celery.backends.pyredis import redis
                 from celery.backends.pyredis import redis
-                self.assertTrue(redis is None)
+                self.assertIsNone(redis)
             context = mask_modules("redis")
             context = mask_modules("redis")
             execute_context(context, with_redis_masked)
             execute_context(context, with_redis_masked)
         finally:
         finally:

+ 7 - 8
celery/tests/test_backends/test_tyrant.py

@@ -1,7 +1,7 @@
 import sys
 import sys
 import errno
 import errno
 import socket
 import socket
-import unittest
+import unittest2 as unittest
 
 
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
@@ -50,11 +50,11 @@ class TestTyrantBackend(unittest.TestCase):
         if not tb:
         if not tb:
             return # Skip test
             return # Skip test
 
 
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.open()
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNone(tb._connection)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
         tb = get_tyrant_or_None()
         tb = get_tyrant_or_None()
@@ -65,13 +65,12 @@ class TestTyrantBackend(unittest.TestCase):
 
 
         self.assertFalse(tb.is_successful(tid))
         self.assertFalse(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.PENDING)
         self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertEqual(tb.get_result(tid), None)
+        self.assertIsNone(tb.get_result(tid), None)
 
 
         tb.mark_as_done(tid, 42)
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
         self.assertTrue(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_result(tid), 42)
         self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb.get_result(tid), 42)
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
         tb = get_tyrant_or_None()
         tb = get_tyrant_or_None()
@@ -99,7 +98,7 @@ class TestTyrantBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception)
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         tb = get_tyrant_or_None()
         tb = get_tyrant_or_None()
@@ -108,4 +107,4 @@ class TestTyrantBackend(unittest.TestCase):
 
 
         tb.process_cleanup()
         tb.process_cleanup()
 
 
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)

+ 18 - 18
celery/tests/test_beat.py

@@ -1,5 +1,5 @@
-import unittest
 import logging
 import logging
+import unittest2 as unittest
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
 from celery import log
 from celery import log
@@ -87,7 +87,7 @@ class TestScheduleEntry(unittest.TestCase):
     def test_constructor(self):
     def test_constructor(self):
         s = beat.ScheduleEntry(DuePeriodicTask.name)
         s = beat.ScheduleEntry(DuePeriodicTask.name)
         self.assertEqual(s.name, DuePeriodicTask.name)
         self.assertEqual(s.name, DuePeriodicTask.name)
-        self.assertTrue(isinstance(s.last_run_at, datetime))
+        self.assertIsInstance(s.last_run_at, datetime)
         self.assertEqual(s.total_run_count, 0)
         self.assertEqual(s.total_run_count, 0)
 
 
         now = datetime.now()
         now = datetime.now()
@@ -101,7 +101,7 @@ class TestScheduleEntry(unittest.TestCase):
         n = s.next()
         n = s.next()
         self.assertEqual(n.name, s.name)
         self.assertEqual(n.name, s.name)
         self.assertEqual(n.total_run_count, 301)
         self.assertEqual(n.total_run_count, 301)
-        self.assertTrue(n.last_run_at > s.last_run_at)
+        self.assertGreater(n.last_run_at, s.last_run_at)
 
 
     def test_is_due(self):
     def test_is_due(self):
         due = beat.ScheduleEntry(DuePeriodicTask.name)
         due = beat.ScheduleEntry(DuePeriodicTask.name)
@@ -123,20 +123,20 @@ class TestScheduler(unittest.TestCase):
 
 
     def test_constructor(self):
     def test_constructor(self):
         s = beat.Scheduler()
         s = beat.Scheduler()
-        self.assertTrue(isinstance(s.registry, TaskRegistry))
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.logger, logging.Logger))
+        self.assertIsInstance(s.registry, TaskRegistry)
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.logger, logging.Logger)
         self.assertEqual(s.max_interval, conf.CELERYBEAT_MAX_LOOP_INTERVAL)
         self.assertEqual(s.max_interval, conf.CELERYBEAT_MAX_LOOP_INTERVAL)
 
 
     def test_cleanup(self):
     def test_cleanup(self):
         self.scheduler.schedule["fbz"] = beat.ScheduleEntry("fbz")
         self.scheduler.schedule["fbz"] = beat.ScheduleEntry("fbz")
         self.scheduler.cleanup()
         self.scheduler.cleanup()
-        self.assertTrue("fbz" not in self.scheduler.schedule)
+        self.assertNotIn("fbz", self.scheduler.schedule)
 
 
     def test_schedule_registry(self):
     def test_schedule_registry(self):
         self.registry.register(AdditionalTask)
         self.registry.register(AdditionalTask)
         self.scheduler.schedule_registry()
         self.scheduler.schedule_registry()
-        self.assertTrue(AdditionalTask.name in self.scheduler.schedule)
+        self.assertIn(AdditionalTask.name, self.scheduler.schedule)
 
 
     def test_apply_async(self):
     def test_apply_async(self):
         due_task = self.registry[DuePeriodicTask.name]
         due_task = self.registry[DuePeriodicTask.name]
@@ -178,13 +178,13 @@ class TestClockService(unittest.TestCase):
         sh = MockShelve()
         sh = MockShelve()
         s.open_schedule = lambda *a, **kw: sh
         s.open_schedule = lambda *a, **kw: sh
 
 
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
-        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.scheduler, beat.Scheduler)
+        self.assertIsInstance(s.scheduler, beat.Scheduler)
 
 
-        self.assertTrue(s.schedule is sh)
-        self.assertTrue(s._schedule is sh)
+        self.assertIs(s.schedule, sh)
+        self.assertIs(s._schedule, sh)
 
 
         s._in_sync = False
         s._in_sync = False
         s.sync()
         s.sync()
@@ -204,8 +204,8 @@ class TestEmbeddedClockService(unittest.TestCase):
     def test_start_stop_process(self):
     def test_start_stop_process(self):
         s = beat.EmbeddedClockService()
         s = beat.EmbeddedClockService()
         from multiprocessing import Process
         from multiprocessing import Process
-        self.assertTrue(isinstance(s, Process))
-        self.assertTrue(isinstance(s.clockservice, beat.ClockService))
+        self.assertIsInstance(s, Process)
+        self.assertIsInstance(s.clockservice, beat.ClockService)
         s.clockservice = MockClockService()
         s.clockservice = MockClockService()
 
 
         class _Popen(object):
         class _Popen(object):
@@ -225,8 +225,8 @@ class TestEmbeddedClockService(unittest.TestCase):
     def test_start_stop_threaded(self):
     def test_start_stop_threaded(self):
         s = beat.EmbeddedClockService(thread=True)
         s = beat.EmbeddedClockService(thread=True)
         from threading import Thread
         from threading import Thread
-        self.assertTrue(isinstance(s, Thread))
-        self.assertTrue(isinstance(s.clockservice, beat.ClockService))
+        self.assertIsInstance(s, Thread)
+        self.assertIsInstance(s.clockservice, beat.ClockService)
         s.clockservice = MockClockService()
         s.clockservice = MockClockService()
 
 
         s.run()
         s.run()

+ 1 - 1
celery/tests/test_bin_celeryd.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery.bin import celeryd
 from celery.bin import celeryd
 
 

+ 8 - 8
celery/tests/test_buckets.py

@@ -3,7 +3,7 @@ import os
 import sys
 import sys
 sys.path.insert(0, os.getcwd())
 sys.path.insert(0, os.getcwd())
 import time
 import time
-import unittest
+import unittest2 as unittest
 from itertools import chain, izip
 from itertools import chain, izip
 
 
 from billiard.utils.functional import curry
 from billiard.utils.functional import curry
@@ -66,7 +66,7 @@ class TestTokenBucketQueue(unittest.TestCase):
         for i in xrange(20):
         for i in xrange(20):
             sys.stderr.write("x")
             sys.stderr.write("x")
             x.wait()
             x.wait()
-        self.assertTrue(time.time() - time_start > 1.5)
+        self.assertGreater(time.time() - time_start, 1.5)
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_can_consume(self):
     def test_can_consume(self):
@@ -90,7 +90,7 @@ class TestTokenBucketQueue(unittest.TestCase):
         x = buckets.TokenBucketQueue(fill_rate=1)
         x = buckets.TokenBucketQueue(fill_rate=1)
         x.put("The quick brown fox")
         x.put("The quick brown fox")
         self.assertEqual(x.qsize(), 1)
         self.assertEqual(x.qsize(), 1)
-        self.assertTrue(x.get_nowait(), "The quick brown fox")
+        self.assertEqual(x.get_nowait(), "The quick brown fox")
 
 
 
 
 class TestRateLimitString(unittest.TestCase):
 class TestRateLimitString(unittest.TestCase):
@@ -136,17 +136,17 @@ class TestTaskBuckets(unittest.TestCase):
     def test_auto_add_on_missing(self):
     def test_auto_add_on_missing(self):
         b = buckets.TaskBucket(task_registry=self.registry)
         b = buckets.TaskBucket(task_registry=self.registry)
         for task_cls in self.task_classes:
         for task_cls in self.task_classes:
-            self.assertTrue(task_cls.name in b.buckets.keys())
+            self.assertIn(task_cls.name, b.buckets.keys())
         self.registry.register(TaskD)
         self.registry.register(TaskD)
         self.assertTrue(b.get_bucket_for_type(TaskD.name))
         self.assertTrue(b.get_bucket_for_type(TaskD.name))
-        self.assertTrue(TaskD.name in b.buckets.keys())
+        self.assertIn(TaskD.name, b.buckets.keys())
         self.registry.unregister(TaskD)
         self.registry.unregister(TaskD)
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_has_rate_limits(self):
     def test_has_rate_limits(self):
         b = buckets.TaskBucket(task_registry=self.registry)
         b = buckets.TaskBucket(task_registry=self.registry)
         self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
         self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
-        self.assertTrue(isinstance(b.buckets[TaskB.name], buckets.Queue))
+        self.assertIsInstance(b.buckets[TaskB.name], buckets.Queue)
         self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
         self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
         self.registry.register(TaskD)
         self.registry.register(TaskD)
         b.init_with_registry()
         b.init_with_registry()
@@ -183,7 +183,7 @@ class TestTaskBuckets(unittest.TestCase):
         for i, job in enumerate(jobs):
         for i, job in enumerate(jobs):
             sys.stderr.write("i")
             sys.stderr.write("i")
             self.assertEqual(b.get(), job)
             self.assertEqual(b.get(), job)
-        self.assertTrue(time.time() - time_start > 1.5)
+        self.assertGreater(time.time() - time_start, 1.5)
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test__very_busy_queue_doesnt_block_others(self):
     def test__very_busy_queue_doesnt_block_others(self):
@@ -200,7 +200,7 @@ class TestTaskBuckets(unittest.TestCase):
             if job.task_name == TaskA.name:
             if job.task_name == TaskA.name:
                 got_ajobs += 1
                 got_ajobs += 1
 
 
-        self.assertTrue(got_ajobs > 2)
+        self.assertGreater(got_ajobs, 2)
 
 
     @skip_if_disabled
     @skip_if_disabled
     def test_thorough__multiple_types(self):
     def test_thorough__multiple_types(self):

+ 5 - 3
celery/tests/test_conf.py

@@ -1,7 +1,9 @@
-import unittest
-from celery import conf
+import unittest2 as unittest
+
 from django.conf import settings
 from django.conf import settings
 
 
+from celery import conf
+
 
 
 SETTING_VARS = (
 SETTING_VARS = (
     ("CELERY_DEFAULT_QUEUE", "DEFAULT_QUEUE"),
     ("CELERY_DEFAULT_QUEUE", "DEFAULT_QUEUE"),
@@ -31,4 +33,4 @@ class TestConf(unittest.TestCase):
     def test_configuration_cls(self):
     def test_configuration_cls(self):
         for setting_name, result_var in SETTING_VARS:
         for setting_name, result_var in SETTING_VARS:
             self.assertDefaultSetting(setting_name, result_var)
             self.assertDefaultSetting(setting_name, result_var)
-        self.assertTrue(isinstance(conf.CELERYD_LOG_LEVEL, int))
+        self.assertIsInstance(conf.CELERYD_LOG_LEVEL, int)

+ 14 - 14
celery/tests/test_datastructures.py

@@ -1,5 +1,5 @@
 import sys
 import sys
-import unittest
+import unittest2 as unittest
 from Queue import Queue
 from Queue import Queue
 
 
 from celery.datastructures import PositionQueue, ExceptionInfo, LocalCache
 from celery.datastructures import PositionQueue, ExceptionInfo, LocalCache
@@ -11,9 +11,9 @@ class TestPositionQueue(unittest.TestCase):
     def test_position_queue_unfilled(self):
     def test_position_queue_unfilled(self):
         q = PositionQueue(length=10)
         q = PositionQueue(length=10)
         for position in q.data:
         for position in q.data:
-            self.assertTrue(isinstance(position, q.UnfilledPosition))
+            self.assertIsInstance(position, q.UnfilledPosition)
 
 
-        self.assertEqual(q.filled, [])
+        self.assertListEqual(q.filled, [])
         self.assertEqual(len(q), 0)
         self.assertEqual(len(q), 0)
         self.assertFalse(q.full())
         self.assertFalse(q.full())
 
 
@@ -23,7 +23,7 @@ class TestPositionQueue(unittest.TestCase):
         q[6] = 6
         q[6] = 6
         q[9] = 9
         q[9] = 9
 
 
-        self.assertEqual(q.filled, [3, 6, 9])
+        self.assertListEqual(q.filled, [3, 6, 9])
         self.assertEqual(len(q), 3)
         self.assertEqual(len(q), 3)
         self.assertFalse(q.full())
         self.assertFalse(q.full())
 
 
@@ -31,7 +31,7 @@ class TestPositionQueue(unittest.TestCase):
         q = PositionQueue(length=10)
         q = PositionQueue(length=10)
         for i in xrange(10):
         for i in xrange(10):
             q[i] = i
             q[i] = i
-        self.assertEqual(q.filled, list(xrange(10)))
+        self.assertListEqual(q.filled, list(xrange(10)))
         self.assertEqual(len(q), 10)
         self.assertEqual(len(q), 10)
         self.assertTrue(q.full())
         self.assertTrue(q.full())
 
 
@@ -47,8 +47,8 @@ class TestExceptionInfo(unittest.TestCase):
 
 
         einfo = ExceptionInfo(exc_info)
         einfo = ExceptionInfo(exc_info)
         self.assertEqual(str(einfo), einfo.traceback)
         self.assertEqual(str(einfo), einfo.traceback)
-        self.assertTrue(isinstance(einfo.exception, LookupError))
-        self.assertEqual(einfo.exception.args,
+        self.assertIsInstance(einfo.exception, LookupError)
+        self.assertTupleEqual(einfo.exception.args,
                 ("The quick brown fox jumps...", ))
                 ("The quick brown fox jumps...", ))
         self.assertTrue(einfo.traceback)
         self.assertTrue(einfo.traceback)
 
 
@@ -98,7 +98,7 @@ class TestSharedCounter(unittest.TestCase):
         self.assertEqual(int(c), -10)
         self.assertEqual(int(c), -10)
 
 
     def test_repr(self):
     def test_repr(self):
-        self.assertTrue(repr(SharedCounter(10)).startswith("<SharedCounter:"))
+        self.assertIn("<SharedCounter:", repr(SharedCounter(10)))
 
 
 
 
 class TestLimitedSet(unittest.TestCase):
 class TestLimitedSet(unittest.TestCase):
@@ -108,11 +108,11 @@ class TestLimitedSet(unittest.TestCase):
         s.add("foo")
         s.add("foo")
         s.add("bar")
         s.add("bar")
         for n in "foo", "bar":
         for n in "foo", "bar":
-            self.assertTrue(n in s)
+            self.assertIn(n, s)
         s.add("baz")
         s.add("baz")
         for n in "bar", "baz":
         for n in "bar", "baz":
-            self.assertTrue(n in s)
-        self.assertTrue("foo" not in s)
+            self.assertIn(n, s)
+        self.assertNotIn("foo", s)
 
 
     def test_iter(self):
     def test_iter(self):
         s = LimitedSet(maxlen=2)
         s = LimitedSet(maxlen=2)
@@ -120,13 +120,13 @@ class TestLimitedSet(unittest.TestCase):
         map(s.add, items)
         map(s.add, items)
         l = list(iter(items))
         l = list(iter(items))
         for item in items:
         for item in items:
-            self.assertTrue(item in l)
+            self.assertIn(item, l)
 
 
     def test_repr(self):
     def test_repr(self):
         s = LimitedSet(maxlen=2)
         s = LimitedSet(maxlen=2)
         items = "foo", "bar"
         items = "foo", "bar"
         map(s.add, items)
         map(s.add, items)
-        self.assertTrue(repr(s).startswith("LimitedSet("))
+        self.assertIn("LimitedSet(", repr(s))
 
 
 
 
 class TestLocalCache(unittest.TestCase):
 class TestLocalCache(unittest.TestCase):
@@ -137,4 +137,4 @@ class TestLocalCache(unittest.TestCase):
         slots = list(range(limit * 2))
         slots = list(range(limit * 2))
         for i in slots:
         for i in slots:
             x[i] = i
             x[i] = i
-        self.assertEqual(x.keys(), slots[limit:])
+        self.assertListEqual(x.keys(), slots[limit:])

+ 4 - 2
celery/tests/test_discovery.py

@@ -1,5 +1,7 @@
-import unittest
+import unittest2 as unittest
+
 from django.conf import settings
 from django.conf import settings
+
 from celery.loaders.djangoapp import autodiscover
 from celery.loaders.djangoapp import autodiscover
 from celery.task import tasks
 from celery.task import tasks
 
 
@@ -9,7 +11,7 @@ class TestDiscovery(unittest.TestCase):
     def assertDiscovery(self):
     def assertDiscovery(self):
         apps = autodiscover()
         apps = autodiscover()
         self.assertTrue(apps)
         self.assertTrue(apps)
-        self.assertTrue("c.unittest.SomeAppTask" in tasks)
+        self.assertIn("c.unittest.SomeAppTask", tasks)
         self.assertEqual(tasks["c.unittest.SomeAppTask"].run(), 42)
         self.assertEqual(tasks["c.unittest.SomeAppTask"].run(), 42)
 
 
     def test_discovery(self):
     def test_discovery(self):

+ 1 - 1
celery/tests/test_events.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery import events
 from celery import events
 
 

+ 4 - 4
celery/tests/test_loaders.py

@@ -1,6 +1,6 @@
 import os
 import os
 import sys
 import sys
-import unittest
+import unittest2 as unittest
 
 
 from billiard.utils.functional import wraps
 from billiard.utils.functional import wraps
 
 
@@ -64,7 +64,7 @@ class TestLoaderBase(unittest.TestCase):
     def test_import_default_modules(self):
     def test_import_default_modules(self):
         import os
         import os
         import sys
         import sys
-        self.assertEqual(self.loader.import_default_modules(), [os, sys])
+        self.assertSameElements(self.loader.import_default_modules(), [os, sys])
 
 
 
 
 class TestDjangoLoader(unittest.TestCase):
 class TestDjangoLoader(unittest.TestCase):
@@ -117,11 +117,11 @@ class TestDefaultLoader(unittest.TestCase):
         try:
         try:
             l = default.Loader()
             l = default.Loader()
             settings = l.read_configuration()
             settings = l.read_configuration()
-            self.assertEqual(settings.CELERY_IMPORTS, ("os", "sys"))
+            self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             from django.conf import settings
             from django.conf import settings
             settings.configured = False
             settings.configured = False
             settings = l.read_configuration()
             settings = l.read_configuration()
-            self.assertEqual(settings.CELERY_IMPORTS, ("os", "sys"))
+            self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             self.assertTrue(settings.configured)
             self.assertTrue(settings.configured)
             l.on_worker_init()
             l.on_worker_init()
         finally:
         finally:

+ 12 - 17
celery/tests/test_log.py

@@ -3,7 +3,7 @@ from __future__ import generators
 import os
 import os
 import sys
 import sys
 import logging
 import logging
-import unittest
+import unittest2 as unittest
 from tempfile import mktemp
 from tempfile import mktemp
 from StringIO import StringIO
 from StringIO import StringIO
 
 
@@ -54,13 +54,8 @@ class TestLog(unittest.TestCase):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
         logger.handlers = [] # Reset previously set logger.
         logger.handlers = [] # Reset previously set logger.
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
-        self.assertTrue(logger.handlers[0].stream is sys.__stderr__,
+        self.assertIs(logger.handlers[0].stream, sys.__stderr__,
                 "setup_logger logs to stderr without logfile argument.")
                 "setup_logger logs to stderr without logfile argument.")
-        #self.assertTrue(logger._process_aware,
-        #        "setup_logger() returns process aware logger.")
-        #self.assertDidLogTrue(logger, "Logging something",
-        #        "Logger logs error when loglevel is ERROR",
-        #        loglevel=logging.ERROR)
         self.assertDidLogFalse(logger, "Logging something",
         self.assertDidLogFalse(logger, "Logging something",
                 "Logger doesn't info when loglevel is ERROR",
                 "Logger doesn't info when loglevel is ERROR",
                 loglevel=logging.INFO)
                 loglevel=logging.INFO)
@@ -80,7 +75,7 @@ class TestLog(unittest.TestCase):
             stdout, stderr = outs
             stdout, stderr = outs
             l = setup_logger(logfile=stderr, loglevel=logging.INFO)
             l = setup_logger(logfile=stderr, loglevel=logging.INFO)
             l.info("The quick brown fox...")
             l.info("The quick brown fox...")
-            self.assertTrue("The quick brown fox..." in stderr.getvalue())
+            self.assertIn("The quick brown fox...", stderr.getvalue())
 
 
         context = override_stdouts()
         context = override_stdouts()
         execute_context(context, with_override_stdouts)
         execute_context(context, with_override_stdouts)
@@ -91,14 +86,14 @@ class TestLog(unittest.TestCase):
         l.handlers = []
         l.handlers = []
         tempfile = mktemp(suffix="unittest", prefix="celery")
         tempfile = mktemp(suffix="unittest", prefix="celery")
         l = setup_logger(logfile=tempfile, loglevel=0)
         l = setup_logger(logfile=tempfile, loglevel=0)
-        self.assertTrue(isinstance(l.handlers[0], logging.FileHandler))
+        self.assertIsInstance(l.handlers[0], logging.FileHandler)
 
 
     def test_emergency_error_stderr(self):
     def test_emergency_error_stderr(self):
         def with_override_stdouts(outs):
         def with_override_stdouts(outs):
             stdout, stderr = outs
             stdout, stderr = outs
             emergency_error(None, "The lazy dog crawls under the fast fox")
             emergency_error(None, "The lazy dog crawls under the fast fox")
-            self.assertTrue("The lazy dog crawls under the fast fox" in
-                                stderr.getvalue())
+            self.assertIn("The lazy dog crawls under the fast fox",
+                          stderr.getvalue())
 
 
         context = override_stdouts()
         context = override_stdouts()
         execute_context(context, with_override_stdouts)
         execute_context(context, with_override_stdouts)
@@ -108,7 +103,7 @@ class TestLog(unittest.TestCase):
         emergency_error(tempfile, "Vandelay Industries")
         emergency_error(tempfile, "Vandelay Industries")
         tempfilefh = open(tempfile, "r")
         tempfilefh = open(tempfile, "r")
         try:
         try:
-            self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
+            self.assertIn("Vandelay Industries", "".join(tempfilefh))
         finally:
         finally:
             tempfilefh.close()
             tempfilefh.close()
             os.unlink(tempfile)
             os.unlink(tempfile)
@@ -119,7 +114,7 @@ class TestLog(unittest.TestCase):
             def with_wrap_logger(sio):
             def with_wrap_logger(sio):
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
                 logger.error("foo")
-                self.assertTrue("foo" in sio.getvalue())
+                self.assertIn("foo", sio.getvalue())
 
 
             context = wrap_logger(logger)
             context = wrap_logger(logger)
             execute_context(context, with_wrap_logger)
             execute_context(context, with_wrap_logger)
@@ -133,18 +128,18 @@ class TestLog(unittest.TestCase):
             p = LoggingProxy(logger)
             p = LoggingProxy(logger)
             p.close()
             p.close()
             p.write("foo")
             p.write("foo")
-            self.assertTrue("foo" not in sio.getvalue())
+            self.assertNotIn("foo", sio.getvalue())
             p.closed = False
             p.closed = False
             p.write("foo")
             p.write("foo")
-            self.assertTrue("foo" in sio.getvalue())
+            self.assertIn("foo", sio.getvalue())
             lines = ["baz", "xuzzy"]
             lines = ["baz", "xuzzy"]
             p.writelines(lines)
             p.writelines(lines)
             for line in lines:
             for line in lines:
-                self.assertTrue(line in sio.getvalue())
+                self.assertIn(line, sio.getvalue())
             p.flush()
             p.flush()
             p.close()
             p.close()
             self.assertFalse(p.isatty())
             self.assertFalse(p.isatty())
-            self.assertTrue(p.fileno() is None)
+            self.assertIsNone(p.fileno())
 
 
         context = wrap_logger(logger)
         context = wrap_logger(logger)
         execute_context(context, with_wrap_logger)
         execute_context(context, with_wrap_logger)

+ 2 - 1
celery/tests/test_messaging.py

@@ -1,4 +1,5 @@
-import unittest
+import unittest2 as unittest
+
 from celery.messaging import MSG_OPTIONS, extract_msg_options
 from celery.messaging import MSG_OPTIONS, extract_msg_options
 
 
 
 

+ 17 - 17
celery/tests/test_models.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
 from celery import states
 from celery import states
@@ -24,30 +24,30 @@ class TestModels(unittest.TestCase):
         m3 = self.createTaskMeta()
         m3 = self.createTaskMeta()
         self.assertTrue(unicode(m1).startswith("<Task:"))
         self.assertTrue(unicode(m1).startswith("<Task:"))
         self.assertTrue(m1.task_id)
         self.assertTrue(m1.task_id)
-        self.assertTrue(isinstance(m1.date_done, datetime))
+        self.assertIsInstance(m1.date_done, datetime)
 
 
         self.assertEqual(TaskMeta.objects.get_task(m1.task_id).task_id,
         self.assertEqual(TaskMeta.objects.get_task(m1.task_id).task_id,
                 m1.task_id)
                 m1.task_id)
-        self.assertFalse(
-                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
+        self.assertNotEqual(TaskMeta.objects.get_task(m1.task_id).status,
+                            states.SUCCESS)
         TaskMeta.objects.store_result(m1.task_id, True, status=states.SUCCESS)
         TaskMeta.objects.store_result(m1.task_id, True, status=states.SUCCESS)
         TaskMeta.objects.store_result(m2.task_id, True, status=states.SUCCESS)
         TaskMeta.objects.store_result(m2.task_id, True, status=states.SUCCESS)
-        self.assertTrue(
-                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
-        self.assertTrue(
-                TaskMeta.objects.get_task(m2.task_id).status == states.SUCCESS)
+        self.assertEqual(TaskMeta.objects.get_task(m1.task_id).status,
+                         states.SUCCESS)
+        self.assertEqual(TaskMeta.objects.get_task(m2.task_id).status,
+                         states.SUCCESS)
 
 
         # Have to avoid save() because it applies the auto_now=True.
         # Have to avoid save() because it applies the auto_now=True.
         TaskMeta.objects.filter(task_id=m1.task_id).update(
         TaskMeta.objects.filter(task_id=m1.task_id).update(
                 date_done=datetime.now() - timedelta(days=10))
                 date_done=datetime.now() - timedelta(days=10))
 
 
         expired = TaskMeta.objects.get_all_expired()
         expired = TaskMeta.objects.get_all_expired()
-        self.assertTrue(m1 in expired)
-        self.assertFalse(m2 in expired)
-        self.assertFalse(m3 in expired)
+        self.assertIn(m1, expired)
+        self.assertNotIn(m2, expired)
+        self.assertNotIn(m3, expired)
 
 
         TaskMeta.objects.delete_expired()
         TaskMeta.objects.delete_expired()
-        self.assertFalse(m1 in TaskMeta.objects.all())
+        self.assertNotIn(m1, TaskMeta.objects.all())
 
 
     def test_tasksetmeta(self):
     def test_tasksetmeta(self):
         m1 = self.createTaskSetMeta()
         m1 = self.createTaskSetMeta()
@@ -55,7 +55,7 @@ class TestModels(unittest.TestCase):
         m3 = self.createTaskSetMeta()
         m3 = self.createTaskSetMeta()
         self.assertTrue(unicode(m1).startswith("<TaskSet:"))
         self.assertTrue(unicode(m1).startswith("<TaskSet:"))
         self.assertTrue(m1.taskset_id)
         self.assertTrue(m1.taskset_id)
-        self.assertTrue(isinstance(m1.date_done, datetime))
+        self.assertIsInstance(m1.date_done, datetime)
 
 
         self.assertEqual(
         self.assertEqual(
                 TaskSetMeta.objects.restore_taskset(m1.taskset_id).taskset_id,
                 TaskSetMeta.objects.restore_taskset(m1.taskset_id).taskset_id,
@@ -66,9 +66,9 @@ class TestModels(unittest.TestCase):
                 date_done=datetime.now() - timedelta(days=10))
                 date_done=datetime.now() - timedelta(days=10))
 
 
         expired = TaskSetMeta.objects.get_all_expired()
         expired = TaskSetMeta.objects.get_all_expired()
-        self.assertTrue(m1 in expired)
-        self.assertFalse(m2 in expired)
-        self.assertFalse(m3 in expired)
+        self.assertIn(m1, expired)
+        self.assertNotIn(m2, expired)
+        self.assertNotIn(m3, expired)
 
 
         TaskSetMeta.objects.delete_expired()
         TaskSetMeta.objects.delete_expired()
-        self.assertFalse(m1 in TaskSetMeta.objects.all())
+        self.assertNotIn(m1, TaskSetMeta.objects.all())

+ 12 - 12
celery/tests/test_pickle.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from billiard.serialization import pickle
 from billiard.serialization import pickle
 
 
@@ -17,33 +17,33 @@ class ArgOverrideException(Exception):
 class TestPickle(unittest.TestCase):
 class TestPickle(unittest.TestCase):
 
 
     def test_pickle_regular_exception(self):
     def test_pickle_regular_exception(self):
-        e = None
+        exc = None
         try:
         try:
             raise RegularException("RegularException raised")
             raise RegularException("RegularException raised")
-        except RegularException, e:
+        except RegularException, exc:
             pass
             pass
 
 
-        pickled = pickle.dumps({"exception": e})
+        pickled = pickle.dumps({"exception": exc})
         unpickled = pickle.loads(pickled)
         unpickled = pickle.loads(pickled)
         exception = unpickled.get("exception")
         exception = unpickled.get("exception")
         self.assertTrue(exception)
         self.assertTrue(exception)
-        self.assertTrue(isinstance(exception, RegularException))
-        self.assertEqual(exception.args, ("RegularException raised", ))
+        self.assertIsInstance(exception, RegularException)
+        self.assertTupleEqual(exception.args, ("RegularException raised", ))
 
 
     def test_pickle_arg_override_exception(self):
     def test_pickle_arg_override_exception(self):
 
 
-        e = None
+        exc = None
         try:
         try:
             raise ArgOverrideException("ArgOverrideException raised",
             raise ArgOverrideException("ArgOverrideException raised",
                     status_code=100)
                     status_code=100)
-        except ArgOverrideException, e:
+        except ArgOverrideException, exc:
             pass
             pass
 
 
-        pickled = pickle.dumps({"exception": e})
+        pickled = pickle.dumps({"exception": exc})
         unpickled = pickle.loads(pickled)
         unpickled = pickle.loads(pickled)
         exception = unpickled.get("exception")
         exception = unpickled.get("exception")
         self.assertTrue(exception)
         self.assertTrue(exception)
-        self.assertTrue(isinstance(exception, ArgOverrideException))
-        self.assertEqual(exception.args, ("ArgOverrideException raised",
-                                          100))
+        self.assertIsInstance(exception, ArgOverrideException)
+        self.assertTupleEqual(exception.args, (
+                              "ArgOverrideException raised", 100))
         self.assertEqual(exception.status_code, 100)
         self.assertEqual(exception.status_code, 100)

+ 20 - 31
celery/tests/test_pool.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 import logging
 import logging
 import itertools
 import itertools
 import time
 import time
@@ -27,15 +27,15 @@ class TestTaskPool(unittest.TestCase):
     def test_attrs(self):
     def test_attrs(self):
         p = TaskPool(limit=2)
         p = TaskPool(limit=2)
         self.assertEqual(p.limit, 2)
         self.assertEqual(p.limit, 2)
-        self.assertTrue(isinstance(p.logger, logging.Logger))
-        self.assertTrue(p._pool is None)
+        self.assertIsInstance(p.logger, logging.Logger)
+        self.assertIsNone(p._pool)
 
 
-    def x_start_stop(self):
+    def test_start_stop(self):
         p = TaskPool(limit=2)
         p = TaskPool(limit=2)
         p.start()
         p.start()
-        self.assertTrue(p._pool)
+        self.assertIsNotNone(p._pool)
         p.stop()
         p.stop()
-        self.assertTrue(p._pool is None)
+        self.assertIsNone(p._pool)
 
 
     def x_apply(self):
     def x_apply(self):
         p = TaskPool(limit=2)
         p = TaskPool(limit=2)
@@ -43,50 +43,39 @@ class TestTaskPool(unittest.TestCase):
         scratchpad = {}
         scratchpad = {}
         proc_counter = itertools.count().next
         proc_counter = itertools.count().next
 
 
-        def mycallback(ret_value, meta):
+        def mycallback(ret_value):
             process = proc_counter()
             process = proc_counter()
             scratchpad[process] = {}
             scratchpad[process] = {}
             scratchpad[process]["ret_value"] = ret_value
             scratchpad[process]["ret_value"] = ret_value
-            scratchpad[process]["meta"] = meta
 
 
         myerrback = mycallback
         myerrback = mycallback
 
 
-        res = p.apply_async(do_something, args=[10], callbacks=[mycallback],
-                            meta={"foo": "bar"})
-        res2 = p.apply_async(raise_something, args=[10], errbacks=[myerrback],
-                            meta={"foo2": "bar2"})
-        res3 = p.apply_async(do_something, args=[20], callbacks=[mycallback],
-                            meta={"foo3": "bar3"})
+        res = p.apply_async(do_something, args=[10], callbacks=[mycallback])
+        res2 = p.apply_async(raise_something, args=[10], errbacks=[myerrback])
+        res3 = p.apply_async(do_something, args=[20], callbacks=[mycallback])
 
 
         self.assertEqual(res.get(), 100)
         self.assertEqual(res.get(), 100)
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(0))
-        self.assertEqual(scratchpad[0]["ret_value"], 100)
-        self.assertEqual(scratchpad[0]["meta"], {"foo": "bar"})
+        self.assertDictContainsSubset({"ret_value": 100},
+                                       scratchpad.get(0))
 
 
-        self.assertTrue(isinstance(res2.get(), ExceptionInfo))
+        self.assertIsInstance(res2.get(), ExceptionInfo)
         self.assertTrue(scratchpad.get(1))
         self.assertTrue(scratchpad.get(1))
         time.sleep(1)
         time.sleep(1)
-        #self.assertEqual(scratchpad[1]["ret_value"], "FOO")
-        self.assertTrue(isinstance(scratchpad[1]["ret_value"],
-                          ExceptionInfo))
+        self.assertIsInstance(scratchpad[1]["ret_value"],
+                              ExceptionInfo)
         self.assertEqual(scratchpad[1]["ret_value"].exception.args,
         self.assertEqual(scratchpad[1]["ret_value"].exception.args,
                           ("FOO EXCEPTION", ))
                           ("FOO EXCEPTION", ))
-        self.assertEqual(scratchpad[1]["meta"], {"foo2": "bar2"})
 
 
         self.assertEqual(res3.get(), 400)
         self.assertEqual(res3.get(), 400)
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(2))
-        self.assertEqual(scratchpad[2]["ret_value"], 400)
-        self.assertEqual(scratchpad[2]["meta"], {"foo3": "bar3"})
+        self.assertDictContainsSubset({"ret_value": 400},
+                                       scratchpad.get(2))
 
 
-        res3 = p.apply_async(do_something, args=[30], callbacks=[mycallback],
-                            meta={"foo4": "bar4"})
+        res3 = p.apply_async(do_something, args=[30], callbacks=[mycallback])
 
 
         self.assertEqual(res3.get(), 900)
         self.assertEqual(res3.get(), 900)
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(3))
-        self.assertEqual(scratchpad[3]["ret_value"], 900)
-        self.assertEqual(scratchpad[3]["meta"], {"foo4": "bar4"})
-
+        self.assertDictContainsSubset({"ret_value": 900},
+                                       scratchpad.get(3))
         p.stop()
         p.stop()

+ 16 - 16
celery/tests/test_registry.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery import registry
 from celery import registry
 from celery.task import Task, PeriodicTask
 from celery.task import Task, PeriodicTask
@@ -24,42 +24,42 @@ class TestTaskRegistry(unittest.TestCase):
     def assertRegisterUnregisterCls(self, r, task):
     def assertRegisterUnregisterCls(self, r, task):
         self.assertRaises(r.NotRegistered, r.unregister, task)
         self.assertRaises(r.NotRegistered, r.unregister, task)
         r.register(task)
         r.register(task)
-        self.assertTrue(task.name in r)
+        self.assertIn(task.name, r)
 
 
     def assertRegisterUnregisterFunc(self, r, task, task_name):
     def assertRegisterUnregisterFunc(self, r, task, task_name):
         self.assertRaises(r.NotRegistered, r.unregister, task_name)
         self.assertRaises(r.NotRegistered, r.unregister, task_name)
         r.register(task, task_name)
         r.register(task, task_name)
-        self.assertTrue(task_name in r)
+        self.assertIn(task_name, r)
 
 
     def test_task_registry(self):
     def test_task_registry(self):
         r = registry.TaskRegistry()
         r = registry.TaskRegistry()
-        self.assertTrue(isinstance(r.data, dict),
+        self.assertIsInstance(r.data, dict,
                 "TaskRegistry has composited dict")
                 "TaskRegistry has composited dict")
 
 
         self.assertRegisterUnregisterCls(r, TestTask)
         self.assertRegisterUnregisterCls(r, TestTask)
         self.assertRegisterUnregisterCls(r, TestPeriodicTask)
         self.assertRegisterUnregisterCls(r, TestPeriodicTask)
 
 
         tasks = dict(r)
         tasks = dict(r)
-        self.assertTrue(isinstance(tasks.get(TestTask.name), TestTask))
-        self.assertTrue(isinstance(tasks.get(TestPeriodicTask.name),
-                                   TestPeriodicTask))
+        self.assertIsInstance(tasks.get(TestTask.name), TestTask)
+        self.assertIsInstance(tasks.get(TestPeriodicTask.name),
+                                   TestPeriodicTask)
 
 
         regular = r.regular()
         regular = r.regular()
-        self.assertTrue(TestTask.name in regular)
-        self.assertFalse(TestPeriodicTask.name in regular)
+        self.assertIn(TestTask.name, regular)
+        self.assertNotIn(TestPeriodicTask.name, regular)
 
 
         periodic = r.periodic()
         periodic = r.periodic()
-        self.assertFalse(TestTask.name in periodic)
-        self.assertTrue(TestPeriodicTask.name in periodic)
+        self.assertNotIn(TestTask.name, periodic)
+        self.assertIn(TestPeriodicTask.name, periodic)
 
 
-        self.assertTrue(isinstance(r[TestTask.name], TestTask))
-        self.assertTrue(isinstance(r[TestPeriodicTask.name],
-                                   TestPeriodicTask))
+        self.assertIsInstance(r[TestTask.name], TestTask)
+        self.assertIsInstance(r[TestPeriodicTask.name],
+                                   TestPeriodicTask)
 
 
         r.unregister(TestTask)
         r.unregister(TestTask)
-        self.assertFalse(TestTask.name in r)
+        self.assertNotIn(TestTask.name, r)
         r.unregister(TestPeriodicTask)
         r.unregister(TestPeriodicTask)
-        self.assertFalse(TestPeriodicTask.name in r)
+        self.assertNotIn(TestPeriodicTask.name, r)
 
 
         self.assertTrue(TestTask().run())
         self.assertTrue(TestTask().run())
         self.assertTrue(TestPeriodicTask().run())
         self.assertTrue(TestPeriodicTask().run())

+ 6 - 6
celery/tests/test_result.py

@@ -1,6 +1,6 @@
 from __future__ import generators
 from __future__ import generators
 
 
-import unittest
+import unittest2 as unittest
 
 
 from celery import states
 from celery import states
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
@@ -90,7 +90,7 @@ class TestAsyncResult(unittest.TestCase):
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok2_res.get(), "quick")
         self.assertEqual(ok2_res.get(), "quick")
         self.assertRaises(KeyError, nok_res.get)
         self.assertRaises(KeyError, nok_res.get)
-        self.assertTrue(isinstance(nok2_res.result, KeyError))
+        self.assertIsInstance(nok2_res.result, KeyError)
 
 
     def test_get_timeout(self):
     def test_get_timeout(self):
         res = AsyncResult(self.task4["id"]) # has RETRY status
         res = AsyncResult(self.task4["id"]) # has RETRY status
@@ -105,7 +105,7 @@ class TestAsyncResult(unittest.TestCase):
         oks = (AsyncResult(self.task1["id"]),
         oks = (AsyncResult(self.task1["id"]),
                AsyncResult(self.task2["id"]),
                AsyncResult(self.task2["id"]),
                AsyncResult(self.task3["id"]))
                AsyncResult(self.task3["id"]))
-        [self.assertTrue(ok.ready()) for ok in oks]
+        self.assertTrue(all(result.ready() for result in oks))
         self.assertFalse(AsyncResult(self.task4["id"]).ready())
         self.assertFalse(AsyncResult(self.task4["id"]).ready())
 
 
 
 
@@ -173,11 +173,11 @@ class TestTaskSetResult(unittest.TestCase):
         it = iter(self.ts)
         it = iter(self.ts)
 
 
         results = sorted(list(it))
         results = sorted(list(it))
-        self.assertEqual(results, list(xrange(self.size)))
+        self.assertListEqual(results, list(xrange(self.size)))
 
 
     def test_join(self):
     def test_join(self):
         joined = self.ts.join()
         joined = self.ts.join()
-        self.assertEqual(joined, list(xrange(self.size)))
+        self.assertListEqual(joined, list(xrange(self.size)))
 
 
     def test_successful(self):
     def test_successful(self):
         self.assertTrue(self.ts.successful())
         self.assertTrue(self.ts.successful())
@@ -201,7 +201,7 @@ class TestPendingAsyncResult(unittest.TestCase):
         self.task = AsyncResult(gen_unique_id())
         self.task = AsyncResult(gen_unique_id())
 
 
     def test_result(self):
     def test_result(self):
-        self.assertTrue(self.task.result is None)
+        self.assertIsNone(self.task.result)
 
 
 
 
 class TestFailedTaskSetResult(TestTaskSetResult):
 class TestFailedTaskSetResult(TestTaskSetResult):

+ 2 - 2
celery/tests/test_serialization.py

@@ -1,5 +1,5 @@
 import sys
 import sys
-import unittest
+import unittest2 as unittest
 
 
 from celery.tests.utils import execute_context, mask_modules
 from celery.tests.utils import execute_context, mask_modules
 
 
@@ -12,7 +12,7 @@ class TestAAPickle(unittest.TestCase):
             def with_cPickle_masked(_val):
             def with_cPickle_masked(_val):
                 from billiard.serialization import pickle
                 from billiard.serialization import pickle
                 import pickle as orig_pickle
                 import pickle as orig_pickle
-                self.assertTrue(pickle.dumps is orig_pickle.dumps)
+                self.assertIs(pickle.dumps, orig_pickle.dumps)
 
 
             context = mask_modules("cPickle")
             context = mask_modules("cPickle")
             execute_context(context, with_cPickle_masked)
             execute_context(context, with_cPickle_masked)

+ 19 - 19
celery/tests/test_task.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from StringIO import StringIO
 from StringIO import StringIO
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
@@ -220,8 +220,8 @@ class TestCeleryTasks(unittest.TestCase):
         import operator
         import operator
         conf.ALWAYS_EAGER = True
         conf.ALWAYS_EAGER = True
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
-        self.assertTrue(res, sum([operator.add(x, x)
-                                    for x in xrange(10)]))
+        self.assertEqual(sum(res), sum(operator.add(x, x)
+                                        for x in xrange(10)))
         conf.ALWAYS_EAGER = False
         conf.ALWAYS_EAGER = False
 
 
     def test_dmap_async(self):
     def test_dmap_async(self):
@@ -229,8 +229,8 @@ class TestCeleryTasks(unittest.TestCase):
         import operator
         import operator
         conf.ALWAYS_EAGER = True
         conf.ALWAYS_EAGER = True
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
-        self.assertTrue(res.get(), sum([operator.add(x, x)
-                                            for x in xrange(10)]))
+        self.assertEqual(sum(res.get()), sum(operator.add(x, x)
+                                                for x in xrange(10)))
         conf.ALWAYS_EAGER = False
         conf.ALWAYS_EAGER = False
 
 
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
@@ -241,9 +241,9 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertEqual(task_data["task"], task_name)
         self.assertEqual(task_data["task"], task_name)
         task_kwargs = task_data.get("kwargs", {})
         task_kwargs = task_data.get("kwargs", {})
         if test_eta:
         if test_eta:
-            self.assertTrue(isinstance(task_data.get("eta"), basestring))
+            self.assertIsInstance(task_data.get("eta"), basestring)
             to_datetime = parse_iso8601(task_data.get("eta"))
             to_datetime = parse_iso8601(task_data.get("eta"))
-            self.assertTrue(isinstance(to_datetime, datetime))
+            self.assertIsInstance(to_datetime, datetime)
         for arg_name, arg_value in kwargs.items():
         for arg_name, arg_value in kwargs.items():
             self.assertEqual(task_kwargs.get(arg_name), arg_value)
             self.assertEqual(task_kwargs.get(arg_name), arg_value)
 
 
@@ -256,7 +256,7 @@ class TestCeleryTasks(unittest.TestCase):
 
 
     def test_regular_task(self):
     def test_regular_task(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
-        self.assertTrue(isinstance(T1(), T1))
+        self.assertIsInstance(T1(), T1)
         self.assertTrue(T1().run())
         self.assertTrue(T1().run())
         self.assertTrue(callable(T1()),
         self.assertTrue(callable(T1()),
                 "Task class is callable()")
                 "Task class is callable()")
@@ -271,7 +271,7 @@ class TestCeleryTasks(unittest.TestCase):
         consumer = t1.get_consumer()
         consumer = t1.get_consumer()
         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
         consumer.discard_all()
         consumer.discard_all()
-        self.assertTrue(consumer.fetch() is None)
+        self.assertIsNone(consumer.fetch())
 
 
         # Without arguments.
         # Without arguments.
         presult = t1.delay()
         presult = t1.delay()
@@ -303,14 +303,14 @@ class TestCeleryTasks(unittest.TestCase):
         consumer.discard_all()
         consumer.discard_all()
         task.apply_async(t1)
         task.apply_async(t1)
         self.assertEqual(consumer.discard_all(), 1)
         self.assertEqual(consumer.discard_all(), 1)
-        self.assertTrue(consumer.fetch() is None)
+        self.assertIsNone(consumer.fetch())
 
 
         self.assertFalse(presult.successful())
         self.assertFalse(presult.successful())
         default_backend.mark_as_done(presult.task_id, result=None)
         default_backend.mark_as_done(presult.task_id, result=None)
         self.assertTrue(presult.successful())
         self.assertTrue(presult.successful())
 
 
         publisher = t1.get_publisher()
         publisher = t1.get_publisher()
-        self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
+        self.assertIsInstance(publisher, messaging.TaskPublisher)
 
 
     def test_get_publisher(self):
     def test_get_publisher(self):
         from celery.task import base
         from celery.task import base
@@ -339,7 +339,7 @@ class TestTaskSet(unittest.TestCase):
         ts = task.TaskSet(return_True_task.name, [
         ts = task.TaskSet(return_True_task.name, [
             [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
             [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
         res = ts.apply_async()
         res = ts.apply_async()
-        self.assertEqual(res.join(), [True, True, True, True, True])
+        self.assertListEqual(res.join(), [True, True, True, True, True])
 
 
         conf.ALWAYS_EAGER = False
         conf.ALWAYS_EAGER = False
 
 
@@ -367,9 +367,9 @@ class TestTaskSet(unittest.TestCase):
         taskset_id = taskset_res.taskset_id
         taskset_id = taskset_res.taskset_id
         for subtask in subtasks:
         for subtask in subtasks:
             m = consumer.fetch().payload
             m = consumer.fetch().payload
-            self.assertEqual(m.get("taskset"), taskset_id)
-            self.assertEqual(m.get("task"), IncrementCounterTask.name)
-            self.assertEqual(m.get("id"), subtask.task_id)
+            self.assertDictContainsSubset({"taskset": taskset_id,
+                                           "task": IncrementCounterTask.name,
+                                           "id": subtask.task_id}, m)
             IncrementCounterTask().run(
             IncrementCounterTask().run(
                     increment_by=m.get("kwargs", {}).get("increment_by"))
                     increment_by=m.get("kwargs", {}).get("increment_by"))
         self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
         self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
@@ -381,7 +381,7 @@ class TestTaskApply(unittest.TestCase):
         IncrementCounterTask.count = 0
         IncrementCounterTask.count = 0
 
 
         e = IncrementCounterTask.apply()
         e = IncrementCounterTask.apply()
-        self.assertTrue(isinstance(e, EagerResult))
+        self.assertIsInstance(e, EagerResult)
         self.assertEqual(e.get(), 1)
         self.assertEqual(e.get(), 1)
 
 
         e = IncrementCounterTask.apply(args=[1])
         e = IncrementCounterTask.apply(args=[1])
@@ -412,9 +412,9 @@ class TestPeriodicTask(unittest.TestCase):
             (task.PeriodicTask, ), {"__module__": __name__})
             (task.PeriodicTask, ), {"__module__": __name__})
 
 
     def test_remaining_estimate(self):
     def test_remaining_estimate(self):
-        self.assertTrue(isinstance(
+        self.assertIsInstance(
             MyPeriodic().remaining_estimate(datetime.now()),
             MyPeriodic().remaining_estimate(datetime.now()),
-            timedelta))
+            timedelta)
 
 
     def test_timedelta_seconds_returns_0_on_negative_time(self):
     def test_timedelta_seconds_returns_0_on_negative_time(self):
         delta = timedelta(days=-2)
         delta = timedelta(days=-2)
@@ -432,7 +432,7 @@ class TestPeriodicTask(unittest.TestCase):
     def test_is_due_not_due(self):
     def test_is_due_not_due(self):
         due, remaining = MyPeriodic().is_due(datetime.now())
         due, remaining = MyPeriodic().is_due(datetime.now())
         self.assertFalse(due)
         self.assertFalse(due)
-        self.assertTrue(remaining > 60)
+        self.assertGreater(remaining, 60)
 
 
     def test_is_due(self):
     def test_is_due(self):
         p = MyPeriodic()
         p = MyPeriodic()

+ 1 - 1
celery/tests/test_task_builtins.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from billiard.serialization import pickle
 from billiard.serialization import pickle
 
 

+ 6 - 6
celery/tests/test_task_control.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery.task import control
 from celery.task import control
 from celery.task.builtins import PingTask
 from celery.task.builtins import PingTask
@@ -36,23 +36,23 @@ class TestBroadcast(unittest.TestCase):
     @with_mock_broadcast
     @with_mock_broadcast
     def test_broadcast(self):
     def test_broadcast(self):
         control.broadcast("foobarbaz", arguments=[])
         control.broadcast("foobarbaz", arguments=[])
-        self.assertTrue("foobarbaz" in MockBroadcastPublisher.sent)
+        self.assertIn("foobarbaz", MockBroadcastPublisher.sent)
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_rate_limit(self):
     def test_rate_limit(self):
         control.rate_limit(PingTask.name, "100/m")
         control.rate_limit(PingTask.name, "100/m")
-        self.assertTrue("rate_limit" in MockBroadcastPublisher.sent)
+        self.assertIn("rate_limit", MockBroadcastPublisher.sent)
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke(self):
     def test_revoke(self):
         control.revoke("foozbaaz")
         control.revoke("foozbaaz")
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke_from_result(self):
     def test_revoke_from_result(self):
         from celery.result import AsyncResult
         from celery.result import AsyncResult
         AsyncResult("foozbazzbar").revoke()
         AsyncResult("foozbazzbar").revoke()
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
     def test_revoke_from_resultset(self):
@@ -60,4 +60,4 @@ class TestBroadcast(unittest.TestCase):
         r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
         r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
                                                         for i in range(10)]))
                                                         for i in range(10)]))
         r.revoke()
         r.revoke()
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)

+ 10 - 10
celery/tests/test_task_http.py

@@ -3,7 +3,7 @@ from __future__ import generators
 
 
 import sys
 import sys
 import logging
 import logging
-import unittest
+import unittest2 as unittest
 from urllib import addinfourl
 from urllib import addinfourl
 try:
 try:
     from contextlib import contextmanager
     from contextlib import contextmanager
@@ -63,23 +63,23 @@ class TestEncodings(unittest.TestCase):
               "foobar".encode("utf-8"): "xuzzybaz".encode("utf-8")}
               "foobar".encode("utf-8"): "xuzzybaz".encode("utf-8")}
 
 
         for key, value in http.utf8dict(d.items()).items():
         for key, value in http.utf8dict(d.items()).items():
-            self.assertTrue(isinstance(key, str))
-            self.assertTrue(isinstance(value, str))
+            self.assertIsInstance(key, str)
+            self.assertIsInstance(value, str)
 
 
 
 
 class TestMutableURL(unittest.TestCase):
 class TestMutableURL(unittest.TestCase):
 
 
     def test_url_query(self):
     def test_url_query(self):
         url = http.MutableURL("http://example.com?x=10&y=20&z=Foo")
         url = http.MutableURL("http://example.com?x=10&y=20&z=Foo")
-        self.assertEqual(url.query.get("x"), "10")
-        self.assertEqual(url.query.get("y"), "20")
-        self.assertEqual(url.query.get("z"), "Foo")
+        self.assertDictContainsSubset({"x": "10",
+                                       "y": "20",
+                                       "z": "Foo"}, url.query)
         url.query["name"] = "George"
         url.query["name"] = "George"
         url = http.MutableURL(str(url))
         url = http.MutableURL(str(url))
-        self.assertEqual(url.query.get("x"), "10")
-        self.assertEqual(url.query.get("y"), "20")
-        self.assertEqual(url.query.get("z"), "Foo")
-        self.assertEqual(url.query.get("name"), "George")
+        self.assertDictContainsSubset({"x": "10",
+                                       "y": "20",
+                                       "z": "Foo",
+                                       "name": "George"}, url.query)
 
 
     def test_url_keeps_everything(self):
     def test_url_keeps_everything(self):
         url = "https://e.com:808/foo/bar#zeta?x=10&y=20"
         url = "https://e.com:808/foo/bar#zeta?x=10&y=20"

+ 6 - 6
celery/tests/test_utils.py

@@ -1,6 +1,6 @@
 import sys
 import sys
 import socket
 import socket
-import unittest
+import unittest2 as unittest
 
 
 from billiard.utils.functional import wraps
 from billiard.utils.functional import wraps
 
 
@@ -15,17 +15,17 @@ class TestChunks(unittest.TestCase):
 
 
         # n == 2
         # n == 2
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
 
 
         # n == 3
         # n == 3
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
             [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
 
 
         # n == 2 (exact)
         # n == 2 (exact)
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
 
 
 
 
@@ -37,10 +37,10 @@ class TestGenUniqueId(unittest.TestCase):
         def with_ctypes_masked(_val):
         def with_ctypes_masked(_val):
             from celery.utils import ctypes, gen_unique_id
             from celery.utils import ctypes, gen_unique_id
 
 
-            self.assertTrue(ctypes is None)
+            self.assertIsNone(ctypes)
             uuid = gen_unique_id()
             uuid = gen_unique_id()
             self.assertTrue(uuid)
             self.assertTrue(uuid)
-            self.assertTrue(isinstance(uuid, basestring))
+            self.assertIsInstance(uuid, basestring)
 
 
         try:
         try:
             context = mask_modules("ctypes")
             context = mask_modules("ctypes")

+ 1 - 1
celery/tests/test_utils_info.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery.utils import info
 from celery.utils import info
 
 

+ 17 - 17
celery/tests/test_worker.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from Queue import Queue, Empty
 from Queue import Queue, Empty
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from multiprocessing import get_logger
 from multiprocessing import get_logger
@@ -162,13 +162,13 @@ class TestCarrotListener(unittest.TestCase):
 
 
         records.clear()
         records.clear()
         self.assertEqual(l._detect_wait_method(), l._mainloop)
         self.assertEqual(l._detect_wait_method(), l._mainloop)
-        self.assertTrue(records.get("broadcast_callback"))
-        self.assertTrue(records.get("consume_broadcast"))
-        self.assertTrue(records.get("consume_tasks"))
+        for record in ("broadcast_callback", "consume_broadcast",
+                "consume_tasks"):
+            self.assertTrue(records.get(record))
 
 
         records.clear()
         records.clear()
         l.connection.connection = PlaceHolder()
         l.connection.connection = PlaceHolder()
-        self.assertTrue(l._detect_wait_method() is l.task_consumer.iterconsume)
+        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
         self.assertTrue(records.get("consumer_add"))
         self.assertTrue(records.get("consumer_add"))
 
 
     def test_connection(self):
     def test_connection(self):
@@ -176,19 +176,19 @@ class TestCarrotListener(unittest.TestCase):
                            send_events=False)
                            send_events=False)
 
 
         l.reset_connection()
         l.reset_connection()
-        self.assertTrue(isinstance(l.connection, BrokerConnection))
+        self.assertIsInstance(l.connection, BrokerConnection)
 
 
         l.stop_consumers()
         l.stop_consumers()
-        self.assertTrue(l.connection is None)
-        self.assertTrue(l.task_consumer is None)
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
 
 
         l.reset_connection()
         l.reset_connection()
-        self.assertTrue(isinstance(l.connection, BrokerConnection))
+        self.assertIsInstance(l.connection, BrokerConnection)
 
 
         l.stop()
         l.stop()
         l.close_connection()
         l.close_connection()
-        self.assertTrue(l.connection is None)
-        self.assertTrue(l.task_consumer is None)
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
 
 
     def test_receive_message_control_command(self):
     def test_receive_message_control_command(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
@@ -198,7 +198,7 @@ class TestCarrotListener(unittest.TestCase):
         l.event_dispatcher = MockEventDispatcher()
         l.event_dispatcher = MockEventDispatcher()
         l.control_dispatch = MockControlDispatch()
         l.control_dispatch = MockControlDispatch()
         l.receive_message(m.decode(), m)
         l.receive_message(m.decode(), m)
-        self.assertTrue("shutdown" in l.control_dispatch.commands)
+        self.assertIn("shutdown", l.control_dispatch.commands)
 
 
     def test_close_connection(self):
     def test_close_connection(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
@@ -226,7 +226,7 @@ class TestCarrotListener(unittest.TestCase):
         def with_catch_warnings(log):
         def with_catch_warnings(log):
             l.receive_message(m.decode(), m)
             l.receive_message(m.decode(), m)
             self.assertTrue(log)
             self.assertTrue(log)
-            self.assertTrue("unknown message" in log[0].message.args[0])
+            self.assertIn("unknown message", log[0].message.args[0])
 
 
         context = catch_warnings(record=True)
         context = catch_warnings(record=True)
         execute_context(context, with_catch_warnings)
         execute_context(context, with_catch_warnings)
@@ -242,7 +242,7 @@ class TestCarrotListener(unittest.TestCase):
         l.receive_message(m.decode(), m)
         l.receive_message(m.decode(), m)
 
 
         in_bucket = self.ready_queue.get_nowait()
         in_bucket = self.ready_queue.get_nowait()
-        self.assertTrue(isinstance(in_bucket, TaskWrapper))
+        self.assertIsInstance(in_bucket, TaskWrapper)
         self.assertEqual(in_bucket.task_name, foo_task.name)
         self.assertEqual(in_bucket.task_name, foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_schedule.empty())
         self.assertTrue(self.eta_schedule.empty())
@@ -278,7 +278,7 @@ class TestCarrotListener(unittest.TestCase):
         l.event_dispatcher = MockEventDispatcher()
         l.event_dispatcher = MockEventDispatcher()
         l.receive_message(c.decode(), c)
         l.receive_message(c.decode(), c)
         from celery.worker.revoke import revoked
         from celery.worker.revoke import revoked
-        self.assertTrue(id in revoked)
+        self.assertIn(id, revoked)
 
 
         l.receive_message(t.decode(), t)
         l.receive_message(t.decode(), t)
         self.assertTrue(ready_queue.empty())
         self.assertTrue(ready_queue.empty())
@@ -314,7 +314,7 @@ class TestCarrotListener(unittest.TestCase):
         in_hold = self.eta_schedule.queue[0]
         in_hold = self.eta_schedule.queue[0]
         self.assertEqual(len(in_hold), 4)
         self.assertEqual(len(in_hold), 4)
         eta, priority, task, on_accept = in_hold
         eta, priority, task, on_accept = in_hold
-        self.assertTrue(isinstance(task, TaskWrapper))
+        self.assertIsInstance(task, TaskWrapper)
         self.assertTrue(callable(on_accept))
         self.assertTrue(callable(on_accept))
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         self.assertEqual(task.execute(), 2 * 4 * 8)
@@ -329,7 +329,7 @@ class TestWorkController(unittest.TestCase):
 
 
     def test_attrs(self):
     def test_attrs(self):
         worker = self.worker
         worker = self.worker
-        self.assertTrue(isinstance(worker.eta_schedule, Scheduler))
+        self.assertIsInstance(worker.eta_schedule, Scheduler)
         self.assertTrue(worker.scheduler)
         self.assertTrue(worker.scheduler)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.listener)
         self.assertTrue(worker.listener)

+ 3 - 3
celery/tests/test_worker_control.py

@@ -1,5 +1,5 @@
 import socket
 import socket
-import unittest
+import unittest2 as unittest
 
 
 from celery.task.builtins import PingTask
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
@@ -49,10 +49,10 @@ class TestControlPanel(unittest.TestCase):
              "destination": hostname,
              "destination": hostname,
              "task_id": uuid}
              "task_id": uuid}
         self.panel.dispatch_from_message(m)
         self.panel.dispatch_from_message(m)
-        self.assertTrue(uuid in revoked)
+        self.assertIn(uuid, revoked)
 
 
         m = {"command": "revoke",
         m = {"command": "revoke",
              "destination": "does.not.exist",
              "destination": "does.not.exist",
              "task_id": uuid + "xxx"}
              "task_id": uuid + "xxx"}
         self.panel.dispatch_from_message(m)
         self.panel.dispatch_from_message(m)
-        self.assertTrue(uuid + "xxx" not in revoked)
+        self.assertNotIn(uuid + "xxx", revoked)

+ 2 - 2
celery/tests/test_worker_controllers.py

@@ -1,5 +1,5 @@
 import time
 import time
-import unittest
+import unittest2 as unittest
 from Queue import Queue
 from Queue import Queue
 
 
 from celery.utils import gen_unique_id
 from celery.utils import gen_unique_id
@@ -90,7 +90,7 @@ class TestMediator(unittest.TestCase):
 
 
         m.on_iteration()
         m.on_iteration()
 
 
-        self.assertTrue("value" not in got)
+        self.assertNotIn("value", got)
         self.assertTrue(t.acked)
         self.assertTrue(t.acked)
 
 
 
 

+ 6 - 6
celery/tests/test_worker_heartbeat.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery.worker.heartbeat import Heart
 from celery.worker.heartbeat import Heart
 
 
@@ -27,16 +27,16 @@ class TestHeart(unittest.TestCase):
         heart = Heart(eventer, interval=1)
         heart = Heart(eventer, interval=1)
         heart._shutdown.set()
         heart._shutdown.set()
         heart.run()
         heart.run()
-        self.assertTrue(heart._state == "RUN")
-        self.assertTrue("worker-online" in eventer.sent)
-        self.assertTrue("worker-heartbeat" in eventer.sent)
-        self.assertTrue("worker-offline" in eventer.sent)
+        self.assertEqual(heart._state, "RUN")
+        self.assertIn("worker-online", eventer.sent)
+        self.assertIn("worker-heartbeat", eventer.sent)
+        self.assertIn("worker-offline", eventer.sent)
 
 
         self.assertTrue(heart._stopped.isSet())
         self.assertTrue(heart._stopped.isSet())
 
 
         heart.stop()
         heart.stop()
         heart.stop()
         heart.stop()
-        self.assertTrue(heart._state == "CLOSE")
+        self.assertEqual(heart._state, "CLOSE")
 
 
         heart = Heart(eventer, interval=0.00001)
         heart = Heart(eventer, interval=0.00001)
         heart._shutdown.set()
         heart._shutdown.set()

+ 17 - 17
celery/tests/test_worker_job.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 import sys
 import sys
 import logging
 import logging
-import unittest
+import unittest2 as unittest
 import simplejson
 import simplejson
 from StringIO import StringIO
 from StringIO import StringIO
 
 
@@ -90,14 +90,14 @@ class TestJail(unittest.TestCase):
     def test_execute_jail_failure(self):
     def test_execute_jail_failure(self):
         ret = jail(gen_unique_id(), mytask_raising.name,
         ret = jail(gen_unique_id(), mytask_raising.name,
                    [4], {})
                    [4], {})
-        self.assertTrue(isinstance(ret, ExceptionInfo))
-        self.assertEqual(ret.exception.args, (4, ))
+        self.assertIsInstance(ret, ExceptionInfo)
+        self.assertTupleEqual(ret.exception.args, (4, ))
 
 
     def test_execute_ignore_result(self):
     def test_execute_ignore_result(self):
         task_id = gen_unique_id()
         task_id = gen_unique_id()
         ret = jail(id, MyTaskIgnoreResult.name,
         ret = jail(id, MyTaskIgnoreResult.name,
                    [4], {})
                    [4], {})
-        self.assertTrue(ret, 8)
+        self.assertEquals(ret, 256)
         self.assertFalse(AsyncResult(task_id).ready())
         self.assertFalse(AsyncResult(task_id).ready())
 
 
     def test_django_db_connection_is_closed(self):
     def test_django_db_connection_is_closed(self):
@@ -179,7 +179,7 @@ class TestTaskWrapper(unittest.TestCase):
         tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.eventer = MockEventDispatcher()
         tw.eventer = MockEventDispatcher()
         tw.send_event("task-frobulated")
         tw.send_event("task-frobulated")
-        self.assertTrue("task-frobulated" in tw.eventer.sent)
+        self.assertIn("task-frobulated", tw.eventer.sent)
 
 
     def test_send_email(self):
     def test_send_email(self):
         from celery import conf
         from celery import conf
@@ -229,10 +229,10 @@ class TestTaskWrapper(unittest.TestCase):
             def with_catch_warnings(log):
             def with_catch_warnings(log):
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                                         [4], {})
                                         [4], {})
-                self.assertTrue(isinstance(res, ExceptionInfo))
+                self.assertIsInstance(res, ExceptionInfo)
                 self.assertTrue(log)
                 self.assertTrue(log)
-                self.assertTrue("Exception outside" in log[0].message.args[0])
-                self.assertTrue("KeyError" in log[0].message.args[0])
+                self.assertIn("Exception outside", log[0].message.args[0])
+                self.assertIn("KeyError", log[0].message.args[0])
 
 
             context = catch_warnings(record=True)
             context = catch_warnings(record=True)
             execute_context(context, with_catch_warnings)
             execute_context(context, with_catch_warnings)
@@ -303,13 +303,13 @@ class TestTaskWrapper(unittest.TestCase):
                         content_type="application/json",
                         content_type="application/json",
                         content_encoding="utf-8")
                         content_encoding="utf-8")
         tw = TaskWrapper.from_message(m, m.decode())
         tw = TaskWrapper.from_message(m, m.decode())
-        self.assertTrue(isinstance(tw, TaskWrapper))
+        self.assertIsInstance(tw, TaskWrapper)
         self.assertEqual(tw.task_name, body["task"])
         self.assertEqual(tw.task_name, body["task"])
         self.assertEqual(tw.task_id, body["id"])
         self.assertEqual(tw.task_id, body["id"])
         self.assertEqual(tw.args, body["args"])
         self.assertEqual(tw.args, body["args"])
         self.assertEqual(tw.kwargs.keys()[0],
         self.assertEqual(tw.kwargs.keys()[0],
                           u"æØåveéðƒeæ".encode("utf-8"))
                           u"æØåveéðƒeæ".encode("utf-8"))
-        self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
+        self.assertNotIsInstance(tw.kwargs.keys()[0], unicode)
         self.assertTrue(tw.logger)
         self.assertTrue(tw.logger)
 
 
     def test_from_message_nonexistant_task(self):
     def test_from_message_nonexistant_task(self):
@@ -359,10 +359,10 @@ class TestTaskWrapper(unittest.TestCase):
     def test_execute_fail(self):
     def test_execute_fail(self):
         tid = gen_unique_id()
         tid = gen_unique_id()
         tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"})
         tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"})
-        self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
+        self.assertIsInstance(tw.execute(), ExceptionInfo)
         meta = TaskMeta.objects.get(task_id=tid)
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEqual(meta.status, states.FAILURE)
         self.assertEqual(meta.status, states.FAILURE)
-        self.assertTrue(isinstance(meta.result, KeyError))
+        self.assertIsInstance(meta.result, KeyError)
 
 
     def test_execute_using_pool(self):
     def test_execute_using_pool(self):
         tid = gen_unique_id()
         tid = gen_unique_id()
@@ -370,13 +370,13 @@ class TestTaskWrapper(unittest.TestCase):
         p = TaskPool(2)
         p = TaskPool(2)
         p.start()
         p.start()
         asyncres = tw.execute_using_pool(p)
         asyncres = tw.execute_using_pool(p)
-        self.assertTrue(asyncres.get(), 256)
+        self.assertEquals(asyncres.get(), 256)
         p.stop()
         p.stop()
 
 
     def test_default_kwargs(self):
     def test_default_kwargs(self):
         tid = gen_unique_id()
         tid = gen_unique_id()
         tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
         tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
-        self.assertEqual(tw.extend_with_default_kwargs(10, "some_logfile"), {
+        self.assertDictEqual(tw.extend_with_default_kwargs(10, "some_logfile"), {
             "f": "x",
             "f": "x",
             "logfile": "some_logfile",
             "logfile": "some_logfile",
             "loglevel": 10,
             "loglevel": 10,
@@ -403,8 +403,8 @@ class TestTaskWrapper(unittest.TestCase):
 
 
         tw.on_failure(exc_info)
         tw.on_failure(exc_info)
         logvalue = logfh.getvalue()
         logvalue = logfh.getvalue()
-        self.assertTrue(mytask.name in logvalue)
-        self.assertTrue(tid in logvalue)
-        self.assertTrue("ERROR" in logvalue)
+        self.assertIn(mytask.name, logvalue)
+        self.assertIn(tid, logvalue)
+        self.assertIn("ERROR", logvalue)
 
 
         conf.CELERY_SEND_TASK_ERROR_EMAILS = False
         conf.CELERY_SEND_TASK_ERROR_EMAILS = False

+ 3 - 3
celery/tests/test_worker_revoke.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 
 from celery.worker import revoke
 from celery.worker import revoke
 
 
@@ -7,6 +7,6 @@ class TestRevokeRegistry(unittest.TestCase):
 
 
     def test_is_working(self):
     def test_is_working(self):
         revoke.revoked.add("foo")
         revoke.revoked.add("foo")
-        self.assertTrue("foo" in revoke.revoked)
+        self.assertIn("foo", revoke.revoked)
         revoke.revoked.pop_value("foo")
         revoke.revoked.pop_value("foo")
-        self.assertTrue("foo" not in revoke.revoked)
+        self.assertNotIn("foo", revoke.revoked)

+ 3 - 3
celery/tests/test_worker_scheduler.py

@@ -1,5 +1,6 @@
 from __future__ import generators
 from __future__ import generators
-import unittest
+
+import unittest2 as unittest
 from Queue import Queue, Empty
 from Queue import Queue, Empty
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 
 
@@ -50,5 +51,4 @@ class TestScheduler(unittest.TestCase):
     def test_empty_queue_yields_None(self):
     def test_empty_queue_yields_None(self):
         ready_queue = Queue()
         ready_queue = Queue()
         sched = Scheduler(ready_queue)
         sched = Scheduler(ready_queue)
-
-        self.assertTrue(iter(sched).next() is None)
+        self.assertIsNone(iter(sched).next())

+ 1 - 0
contrib/requirements/test.txt

@@ -7,3 +7,4 @@ coverage>=3.0
 pytyrant
 pytyrant
 redis
 redis
 pymongo
 pymongo
+git+git://github.com/exogen/nose-achievements.git