Pārlūkot izejas kodu

Use unittest2 features for better test diagnostics.

Ask Solem 15 gadi atpakaļ
vecāks
revīzija
8b2643618f
37 mainītis faili ar 289 papildinājumiem un 305 dzēšanām
  1. 11 13
      celery/tests/test_backends/__init__.py
  2. 2 2
      celery/tests/test_backends/test_amqp.py
  3. 10 10
      celery/tests/test_backends/test_base.py
  4. 6 7
      celery/tests/test_backends/test_cache.py
  5. 5 6
      celery/tests/test_backends/test_database.py
  6. 11 12
      celery/tests/test_backends/test_redis.py
  7. 7 8
      celery/tests/test_backends/test_tyrant.py
  8. 18 18
      celery/tests/test_beat.py
  9. 1 1
      celery/tests/test_bin_celeryd.py
  10. 8 8
      celery/tests/test_buckets.py
  11. 5 3
      celery/tests/test_conf.py
  12. 14 14
      celery/tests/test_datastructures.py
  13. 4 2
      celery/tests/test_discovery.py
  14. 1 1
      celery/tests/test_events.py
  15. 4 4
      celery/tests/test_loaders.py
  16. 12 17
      celery/tests/test_log.py
  17. 2 1
      celery/tests/test_messaging.py
  18. 17 17
      celery/tests/test_models.py
  19. 12 12
      celery/tests/test_pickle.py
  20. 20 31
      celery/tests/test_pool.py
  21. 16 16
      celery/tests/test_registry.py
  22. 6 6
      celery/tests/test_result.py
  23. 2 2
      celery/tests/test_serialization.py
  24. 19 19
      celery/tests/test_task.py
  25. 1 1
      celery/tests/test_task_builtins.py
  26. 6 6
      celery/tests/test_task_control.py
  27. 10 10
      celery/tests/test_task_http.py
  28. 6 6
      celery/tests/test_utils.py
  29. 1 1
      celery/tests/test_utils_info.py
  30. 17 17
      celery/tests/test_worker.py
  31. 3 3
      celery/tests/test_worker_control.py
  32. 2 2
      celery/tests/test_worker_controllers.py
  33. 6 6
      celery/tests/test_worker_heartbeat.py
  34. 17 17
      celery/tests/test_worker_job.py
  35. 3 3
      celery/tests/test_worker_revoke.py
  36. 3 3
      celery/tests/test_worker_scheduler.py
  37. 1 0
      contrib/requirements/test.txt

+ 11 - 13
celery/tests/test_backends/__init__.py

@@ -1,27 +1,25 @@
-import unittest
-
+import unittest2 as unittest
 
+from celery import backends
 from celery.backends.database import DatabaseBackend
 from celery.backends.amqp import AMQPBackend
 from celery.backends.pyredis import RedisBackend
-from celery import backends
 
 
 class TestBackends(unittest.TestCase):
 
     def test_get_backend_aliases(self):
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("amqp"), AMQPBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("database"), DatabaseBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("db"), DatabaseBackend))
-        self.assertTrue(issubclass(
-            backends.get_backend_cls("redis"), RedisBackend))
+        expects = [("amqp", AMQPBackend),
+                   ("database", DatabaseBackend),
+                   ("db", DatabaseBackend),
+                   ("redis", RedisBackend)]
+        for expect_name, expect_cls in expects:
+            self.assertIsInstance(backends.get_backend_cls(expect_name)(),
+                                  expect_cls)
 
     def test_get_backend_cahe(self):
         backends._backend_cache = {}
         backends.get_backend_cls("amqp")
-        self.assertTrue("amqp" in backends._backend_cache)
+        self.assertIn("amqp", backends._backend_cache)
         amqp_backend = backends.get_backend_cls("amqp")
-        self.assertTrue(amqp_backend is backends._backend_cache["amqp"])
+        self.assertIs(amqp_backend, backends._backend_cache["amqp"])

+ 2 - 2
celery/tests/test_backends/test_amqp.py

@@ -1,6 +1,6 @@
 import sys
 import errno
-import unittest
+import unittest2 as unittest
 
 from celery.exceptions import ImproperlyConfigured
 
@@ -56,7 +56,7 @@ class TestRedisBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
         self.assertEqual(tb.get_traceback(tid3), einfo.traceback)
 
     def test_process_cleanup(self):

+ 10 - 10
celery/tests/test_backends/test_base.py

@@ -1,6 +1,6 @@
 import sys
 import types
-import unittest
+import unittest2 as unittest
 
 from django.db.models.base import subclass_exception
 from billiard.serialization import find_nearest_pickleable_exception as fnpe
@@ -54,31 +54,31 @@ class TestBaseBackendInterface(unittest.TestCase):
 class TestPickleException(unittest.TestCase):
 
     def test_oldstyle(self):
-        self.assertTrue(fnpe(Oldstyle()) is None)
+        self.assertIsNone(fnpe(Oldstyle()))
 
     def test_BaseException(self):
-        self.assertTrue(fnpe(Exception()) is None)
+        self.assertIsNone(fnpe(Exception()))
 
     def test_get_pickleable_exception(self):
         exc = Exception("foo")
         self.assertEqual(gpe(exc), exc)
 
     def test_unpickleable(self):
-        self.assertTrue(isinstance(fnpe(Unpickleable()), KeyError))
-        self.assertEqual(fnpe(Impossible()), None)
+        self.assertIsInstance(fnpe(Unpickleable()), KeyError)
+        self.assertIsNone(fnpe(Impossible()))
 
 
 class TestPrepareException(unittest.TestCase):
 
     def test_unpickleable(self):
         x = b.prepare_exception(Unpickleable(1, 2, "foo"))
-        self.assertTrue(isinstance(x, KeyError))
+        self.assertIsInstance(x, KeyError)
         y = b.exception_to_python(x)
-        self.assertTrue(isinstance(y, KeyError))
+        self.assertIsInstance(y, KeyError)
 
     def test_impossible(self):
         x = b.prepare_exception(Impossible())
-        self.assertTrue(isinstance(x, UnpickleableExceptionWrapper))
+        self.assertIsInstance(x, UnpickleableExceptionWrapper)
         y = b.exception_to_python(x)
         self.assertEqual(y.__class__.__name__, "Impossible")
         if sys.version_info < (2, 5):
@@ -88,9 +88,9 @@ class TestPrepareException(unittest.TestCase):
 
     def test_regular(self):
         x = b.prepare_exception(KeyError("baz"))
-        self.assertTrue(isinstance(x, KeyError))
+        self.assertIsInstance(x, KeyError)
         y = b.exception_to_python(x)
-        self.assertTrue(isinstance(y, KeyError))
+        self.assertIsInstance(y, KeyError)
 
 
 class TestKeyValueStoreBackendInterface(unittest.TestCase):

+ 6 - 7
celery/tests/test_backends/test_cache.py

@@ -1,5 +1,5 @@
 import sys
-import unittest
+import unittest2 as unittest
 
 from billiard.serialization import pickle
 from django.core.cache.backends.base import InvalidCacheBackendError
@@ -26,7 +26,7 @@ class TestCacheBackend(unittest.TestCase):
 
         self.assertFalse(cb.is_successful(tid))
         self.assertEqual(cb.get_status(tid), states.PENDING)
-        self.assertEqual(cb.get_result(tid), None)
+        self.assertIsNone(cb.get_result(tid))
 
         cb.mark_as_done(tid, 42)
         self.assertTrue(cb.is_successful(tid))
@@ -42,7 +42,7 @@ class TestCacheBackend(unittest.TestCase):
         res = result.TaskSetResult(taskset_id, subtasks)
         res.save(backend=backend)
         saved = result.TaskSetResult.restore(taskset_id, backend=backend)
-        self.assertEqual(saved.subtasks, subtasks)
+        self.assertListEqual(saved.subtasks, subtasks)
         self.assertEqual(saved.taskset_id, taskset_id)
 
     def test_is_pickled(self):
@@ -69,12 +69,11 @@ class TestCacheBackend(unittest.TestCase):
         cb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(cb.is_successful(tid3))
         self.assertEqual(cb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(cb.get_result(tid3), KeyError))
+        self.assertIsInstance(cb.get_result(tid3), KeyError)
         self.assertEqual(cb.get_traceback(tid3), einfo.traceback)
 
     def test_process_cleanup(self):
         cb = CacheBackend()
-
         cb.process_cleanup()
 
 
@@ -91,7 +90,7 @@ class TestCustomCacheBackend(unittest.TestCase):
             from django.core.cache import cache as django_cache
             self.assertEqual(cache.__class__.__module__,
                               "django.core.cache.backends.dummy")
-            self.assertTrue(cache is not django_cache)
+            self.assertIsNot(cache, django_cache)
         finally:
             conf.CELERY_CACHE_BACKEND = prev_backend
             sys.modules["celery.backends.cache"] = prev_module
@@ -113,7 +112,7 @@ class TestMemcacheWrapper(unittest.TestCase):
         prev_backend_module = sys.modules.pop("celery.backends.cache")
         try:
             from celery.backends.cache import cache, DjangoMemcacheWrapper
-            self.assertTrue(isinstance(cache, DjangoMemcacheWrapper))
+            self.assertIsInstance(cache, DjangoMemcacheWrapper)
 
             key = "cu.test_memcache_wrapper"
             val = "The quick brown fox."

+ 5 - 6
celery/tests/test_backends/test_database.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from datetime import timedelta
 
 from celery import states
@@ -29,13 +29,12 @@ class TestDatabaseBackend(unittest.TestCase):
 
         self.assertFalse(b.is_successful(tid))
         self.assertEqual(b.get_status(tid), states.PENDING)
-        self.assertTrue(b.get_result(tid) is None)
+        self.assertIsNone(b.get_result(tid))
 
         b.mark_as_done(tid, 42)
         self.assertTrue(b.is_successful(tid))
         self.assertEqual(b.get_status(tid), states.SUCCESS)
         self.assertEqual(b.get_result(tid), 42)
-        self.assertTrue(b.get_result(tid), 42)
 
         tid2 = gen_unique_id()
         result = {"foo": "baz", "bar": SomeClass(12345)}
@@ -53,17 +52,17 @@ class TestDatabaseBackend(unittest.TestCase):
         b.mark_as_failure(tid3, exception)
         self.assertFalse(b.is_successful(tid3))
         self.assertEqual(b.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(b.get_result(tid3), KeyError))
+        self.assertIsInstance(b.get_result(tid3), KeyError)
 
     def test_taskset_store(self):
         b = DatabaseBackend()
         tid = gen_unique_id()
 
-        self.assertTrue(b.restore_taskset(tid) is None)
+        self.assertIsNone(b.restore_taskset(tid))
 
         result = {"foo": "baz", "bar": SomeClass(12345)}
         b.save_taskset(tid, result)
         rindb = b.restore_taskset(tid)
-        self.assertTrue(rindb is not None)
+        self.assertIsNotNone(rindb)
         self.assertEqual(rindb.get("foo"), "baz")
         self.assertEqual(rindb.get("bar").data, 12345)

+ 11 - 12
celery/tests/test_backends/test_redis.py

@@ -1,7 +1,7 @@
 import sys
 import errno
 import socket
-import unittest
+import unittest2 as unittest
 
 from celery.exceptions import ImproperlyConfigured
 
@@ -60,11 +60,11 @@ class TestRedisBackend(unittest.TestCase):
         if not tb:
             return # Skip test
 
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
 
     def test_mark_as_done(self):
         tb = get_redis_or_None()
@@ -75,13 +75,12 @@ class TestRedisBackend(unittest.TestCase):
 
         self.assertFalse(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertEqual(tb.get_result(tid), None)
+        self.assertIsNone(tb.get_result(tid))
 
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb.get_result(tid), 42)
 
     def test_is_pickled(self):
         tb = get_redis_or_None()
@@ -109,7 +108,7 @@ class TestRedisBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
 
     def test_process_cleanup(self):
         tb = get_redis_or_None()
@@ -118,7 +117,7 @@ class TestRedisBackend(unittest.TestCase):
 
         tb.process_cleanup()
 
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
 
     def test_connection_close_if_connected(self):
         tb = get_redis_or_None()
@@ -126,11 +125,11 @@ class TestRedisBackend(unittest.TestCase):
             return
 
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
 
 
 class TestTyrantBackendNoTyrant(unittest.TestCase):
@@ -140,7 +139,7 @@ class TestTyrantBackendNoTyrant(unittest.TestCase):
         try:
             def with_redis_masked(_val):
                 from celery.backends.pyredis import redis
-                self.assertTrue(redis is None)
+                self.assertIsNone(redis)
             context = mask_modules("redis")
             execute_context(context, with_redis_masked)
         finally:

+ 7 - 8
celery/tests/test_backends/test_tyrant.py

@@ -1,7 +1,7 @@
 import sys
 import errno
 import socket
-import unittest
+import unittest2 as unittest
 
 from celery.exceptions import ImproperlyConfigured
 
@@ -50,11 +50,11 @@ class TestTyrantBackend(unittest.TestCase):
         if not tb:
             return # Skip test
 
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNotNone(tb._connection)
         tb.close()
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)
         tb.open()
-        self.assertTrue(tb._connection is not None)
+        self.assertIsNone(tb._connection)
 
     def test_mark_as_done(self):
         tb = get_tyrant_or_None()
@@ -65,13 +65,12 @@ class TestTyrantBackend(unittest.TestCase):
 
         self.assertFalse(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.PENDING)
-        self.assertEqual(tb.get_result(tid), None)
+        self.assertIsNone(tb.get_result(tid), None)
 
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
         self.assertEqual(tb.get_status(tid), states.SUCCESS)
         self.assertEqual(tb.get_result(tid), 42)
-        self.assertTrue(tb.get_result(tid), 42)
 
     def test_is_pickled(self):
         tb = get_tyrant_or_None()
@@ -99,7 +98,7 @@ class TestTyrantBackend(unittest.TestCase):
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
-        self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
+        self.assertIsInstance(tb.get_result(tid3), KeyError)
 
     def test_process_cleanup(self):
         tb = get_tyrant_or_None()
@@ -108,4 +107,4 @@ class TestTyrantBackend(unittest.TestCase):
 
         tb.process_cleanup()
 
-        self.assertTrue(tb._connection is None)
+        self.assertIsNone(tb._connection)

+ 18 - 18
celery/tests/test_beat.py

@@ -1,5 +1,5 @@
-import unittest
 import logging
+import unittest2 as unittest
 from datetime import datetime, timedelta
 
 from celery import log
@@ -87,7 +87,7 @@ class TestScheduleEntry(unittest.TestCase):
     def test_constructor(self):
         s = beat.ScheduleEntry(DuePeriodicTask.name)
         self.assertEqual(s.name, DuePeriodicTask.name)
-        self.assertTrue(isinstance(s.last_run_at, datetime))
+        self.assertIsInstance(s.last_run_at, datetime)
         self.assertEqual(s.total_run_count, 0)
 
         now = datetime.now()
@@ -101,7 +101,7 @@ class TestScheduleEntry(unittest.TestCase):
         n = s.next()
         self.assertEqual(n.name, s.name)
         self.assertEqual(n.total_run_count, 301)
-        self.assertTrue(n.last_run_at > s.last_run_at)
+        self.assertGreater(n.last_run_at, s.last_run_at)
 
     def test_is_due(self):
         due = beat.ScheduleEntry(DuePeriodicTask.name)
@@ -123,20 +123,20 @@ class TestScheduler(unittest.TestCase):
 
     def test_constructor(self):
         s = beat.Scheduler()
-        self.assertTrue(isinstance(s.registry, TaskRegistry))
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.logger, logging.Logger))
+        self.assertIsInstance(s.registry, TaskRegistry)
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.logger, logging.Logger)
         self.assertEqual(s.max_interval, conf.CELERYBEAT_MAX_LOOP_INTERVAL)
 
     def test_cleanup(self):
         self.scheduler.schedule["fbz"] = beat.ScheduleEntry("fbz")
         self.scheduler.cleanup()
-        self.assertTrue("fbz" not in self.scheduler.schedule)
+        self.assertNotIn("fbz", self.scheduler.schedule)
 
     def test_schedule_registry(self):
         self.registry.register(AdditionalTask)
         self.scheduler.schedule_registry()
-        self.assertTrue(AdditionalTask.name in self.scheduler.schedule)
+        self.assertIn(AdditionalTask.name, self.scheduler.schedule)
 
     def test_apply_async(self):
         due_task = self.registry[DuePeriodicTask.name]
@@ -178,13 +178,13 @@ class TestClockService(unittest.TestCase):
         sh = MockShelve()
         s.open_schedule = lambda *a, **kw: sh
 
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.schedule, dict))
-        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
-        self.assertTrue(isinstance(s.scheduler, beat.Scheduler))
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.schedule, dict)
+        self.assertIsInstance(s.scheduler, beat.Scheduler)
+        self.assertIsInstance(s.scheduler, beat.Scheduler)
 
-        self.assertTrue(s.schedule is sh)
-        self.assertTrue(s._schedule is sh)
+        self.assertIs(s.schedule, sh)
+        self.assertIs(s._schedule, sh)
 
         s._in_sync = False
         s.sync()
@@ -204,8 +204,8 @@ class TestEmbeddedClockService(unittest.TestCase):
     def test_start_stop_process(self):
         s = beat.EmbeddedClockService()
         from multiprocessing import Process
-        self.assertTrue(isinstance(s, Process))
-        self.assertTrue(isinstance(s.clockservice, beat.ClockService))
+        self.assertIsInstance(s, Process)
+        self.assertIsInstance(s.clockservice, beat.ClockService)
         s.clockservice = MockClockService()
 
         class _Popen(object):
@@ -225,8 +225,8 @@ class TestEmbeddedClockService(unittest.TestCase):
     def test_start_stop_threaded(self):
         s = beat.EmbeddedClockService(thread=True)
         from threading import Thread
-        self.assertTrue(isinstance(s, Thread))
-        self.assertTrue(isinstance(s.clockservice, beat.ClockService))
+        self.assertIsInstance(s, Thread)
+        self.assertIsInstance(s.clockservice, beat.ClockService)
         s.clockservice = MockClockService()
 
         s.run()

+ 1 - 1
celery/tests/test_bin_celeryd.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery.bin import celeryd
 

+ 8 - 8
celery/tests/test_buckets.py

@@ -3,7 +3,7 @@ import os
 import sys
 sys.path.insert(0, os.getcwd())
 import time
-import unittest
+import unittest2 as unittest
 from itertools import chain, izip
 
 from billiard.utils.functional import curry
@@ -66,7 +66,7 @@ class TestTokenBucketQueue(unittest.TestCase):
         for i in xrange(20):
             sys.stderr.write("x")
             x.wait()
-        self.assertTrue(time.time() - time_start > 1.5)
+        self.assertGreater(time.time() - time_start, 1.5)
 
     @skip_if_disabled
     def test_can_consume(self):
@@ -90,7 +90,7 @@ class TestTokenBucketQueue(unittest.TestCase):
         x = buckets.TokenBucketQueue(fill_rate=1)
         x.put("The quick brown fox")
         self.assertEqual(x.qsize(), 1)
-        self.assertTrue(x.get_nowait(), "The quick brown fox")
+        self.assertEqual(x.get_nowait(), "The quick brown fox")
 
 
 class TestRateLimitString(unittest.TestCase):
@@ -136,17 +136,17 @@ class TestTaskBuckets(unittest.TestCase):
     def test_auto_add_on_missing(self):
         b = buckets.TaskBucket(task_registry=self.registry)
         for task_cls in self.task_classes:
-            self.assertTrue(task_cls.name in b.buckets.keys())
+            self.assertIn(task_cls.name, b.buckets.keys())
         self.registry.register(TaskD)
         self.assertTrue(b.get_bucket_for_type(TaskD.name))
-        self.assertTrue(TaskD.name in b.buckets.keys())
+        self.assertIn(TaskD.name, b.buckets.keys())
         self.registry.unregister(TaskD)
 
     @skip_if_disabled
     def test_has_rate_limits(self):
         b = buckets.TaskBucket(task_registry=self.registry)
         self.assertEqual(b.buckets[TaskA.name].fill_rate, 10)
-        self.assertTrue(isinstance(b.buckets[TaskB.name], buckets.Queue))
+        self.assertIsInstance(b.buckets[TaskB.name], buckets.Queue)
         self.assertEqual(b.buckets[TaskC.name].fill_rate, 1)
         self.registry.register(TaskD)
         b.init_with_registry()
@@ -183,7 +183,7 @@ class TestTaskBuckets(unittest.TestCase):
         for i, job in enumerate(jobs):
             sys.stderr.write("i")
             self.assertEqual(b.get(), job)
-        self.assertTrue(time.time() - time_start > 1.5)
+        self.assertGreater(time.time() - time_start, 1.5)
 
     @skip_if_disabled
     def test__very_busy_queue_doesnt_block_others(self):
@@ -200,7 +200,7 @@ class TestTaskBuckets(unittest.TestCase):
             if job.task_name == TaskA.name:
                 got_ajobs += 1
 
-        self.assertTrue(got_ajobs > 2)
+        self.assertGreater(got_ajobs, 2)
 
     @skip_if_disabled
     def test_thorough__multiple_types(self):

+ 5 - 3
celery/tests/test_conf.py

@@ -1,7 +1,9 @@
-import unittest
-from celery import conf
+import unittest2 as unittest
+
 from django.conf import settings
 
+from celery import conf
+
 
 SETTING_VARS = (
     ("CELERY_DEFAULT_QUEUE", "DEFAULT_QUEUE"),
@@ -31,4 +33,4 @@ class TestConf(unittest.TestCase):
     def test_configuration_cls(self):
         for setting_name, result_var in SETTING_VARS:
             self.assertDefaultSetting(setting_name, result_var)
-        self.assertTrue(isinstance(conf.CELERYD_LOG_LEVEL, int))
+        self.assertIsInstance(conf.CELERYD_LOG_LEVEL, int)

+ 14 - 14
celery/tests/test_datastructures.py

@@ -1,5 +1,5 @@
 import sys
-import unittest
+import unittest2 as unittest
 from Queue import Queue
 
 from celery.datastructures import PositionQueue, ExceptionInfo, LocalCache
@@ -11,9 +11,9 @@ class TestPositionQueue(unittest.TestCase):
     def test_position_queue_unfilled(self):
         q = PositionQueue(length=10)
         for position in q.data:
-            self.assertTrue(isinstance(position, q.UnfilledPosition))
+            self.assertIsInstance(position, q.UnfilledPosition)
 
-        self.assertEqual(q.filled, [])
+        self.assertListEqual(q.filled, [])
         self.assertEqual(len(q), 0)
         self.assertFalse(q.full())
 
@@ -23,7 +23,7 @@ class TestPositionQueue(unittest.TestCase):
         q[6] = 6
         q[9] = 9
 
-        self.assertEqual(q.filled, [3, 6, 9])
+        self.assertListEqual(q.filled, [3, 6, 9])
         self.assertEqual(len(q), 3)
         self.assertFalse(q.full())
 
@@ -31,7 +31,7 @@ class TestPositionQueue(unittest.TestCase):
         q = PositionQueue(length=10)
         for i in xrange(10):
             q[i] = i
-        self.assertEqual(q.filled, list(xrange(10)))
+        self.assertListEqual(q.filled, list(xrange(10)))
         self.assertEqual(len(q), 10)
         self.assertTrue(q.full())
 
@@ -47,8 +47,8 @@ class TestExceptionInfo(unittest.TestCase):
 
         einfo = ExceptionInfo(exc_info)
         self.assertEqual(str(einfo), einfo.traceback)
-        self.assertTrue(isinstance(einfo.exception, LookupError))
-        self.assertEqual(einfo.exception.args,
+        self.assertIsInstance(einfo.exception, LookupError)
+        self.assertTupleEqual(einfo.exception.args,
                 ("The quick brown fox jumps...", ))
         self.assertTrue(einfo.traceback)
 
@@ -98,7 +98,7 @@ class TestSharedCounter(unittest.TestCase):
         self.assertEqual(int(c), -10)
 
     def test_repr(self):
-        self.assertTrue(repr(SharedCounter(10)).startswith("<SharedCounter:"))
+        self.assertIn("<SharedCounter:", repr(SharedCounter(10)))
 
 
 class TestLimitedSet(unittest.TestCase):
@@ -108,11 +108,11 @@ class TestLimitedSet(unittest.TestCase):
         s.add("foo")
         s.add("bar")
         for n in "foo", "bar":
-            self.assertTrue(n in s)
+            self.assertIn(n, s)
         s.add("baz")
         for n in "bar", "baz":
-            self.assertTrue(n in s)
-        self.assertTrue("foo" not in s)
+            self.assertIn(n, s)
+        self.assertNotIn("foo", s)
 
     def test_iter(self):
         s = LimitedSet(maxlen=2)
@@ -120,13 +120,13 @@ class TestLimitedSet(unittest.TestCase):
         map(s.add, items)
         l = list(iter(items))
         for item in items:
-            self.assertTrue(item in l)
+            self.assertIn(item, l)
 
     def test_repr(self):
         s = LimitedSet(maxlen=2)
         items = "foo", "bar"
         map(s.add, items)
-        self.assertTrue(repr(s).startswith("LimitedSet("))
+        self.assertIn("LimitedSet(", repr(s))
 
 
 class TestLocalCache(unittest.TestCase):
@@ -137,4 +137,4 @@ class TestLocalCache(unittest.TestCase):
         slots = list(range(limit * 2))
         for i in slots:
             x[i] = i
-        self.assertEqual(x.keys(), slots[limit:])
+        self.assertListEqual(x.keys(), slots[limit:])

+ 4 - 2
celery/tests/test_discovery.py

@@ -1,5 +1,7 @@
-import unittest
+import unittest2 as unittest
+
 from django.conf import settings
+
 from celery.loaders.djangoapp import autodiscover
 from celery.task import tasks
 
@@ -9,7 +11,7 @@ class TestDiscovery(unittest.TestCase):
     def assertDiscovery(self):
         apps = autodiscover()
         self.assertTrue(apps)
-        self.assertTrue("c.unittest.SomeAppTask" in tasks)
+        self.assertIn("c.unittest.SomeAppTask", tasks)
         self.assertEqual(tasks["c.unittest.SomeAppTask"].run(), 42)
 
     def test_discovery(self):

+ 1 - 1
celery/tests/test_events.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery import events
 

+ 4 - 4
celery/tests/test_loaders.py

@@ -1,6 +1,6 @@
 import os
 import sys
-import unittest
+import unittest2 as unittest
 
 from billiard.utils.functional import wraps
 
@@ -64,7 +64,7 @@ class TestLoaderBase(unittest.TestCase):
     def test_import_default_modules(self):
         import os
         import sys
-        self.assertEqual(self.loader.import_default_modules(), [os, sys])
+        self.assertSameElements(self.loader.import_default_modules(), [os, sys])
 
 
 class TestDjangoLoader(unittest.TestCase):
@@ -117,11 +117,11 @@ class TestDefaultLoader(unittest.TestCase):
         try:
             l = default.Loader()
             settings = l.read_configuration()
-            self.assertEqual(settings.CELERY_IMPORTS, ("os", "sys"))
+            self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             from django.conf import settings
             settings.configured = False
             settings = l.read_configuration()
-            self.assertEqual(settings.CELERY_IMPORTS, ("os", "sys"))
+            self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
             self.assertTrue(settings.configured)
             l.on_worker_init()
         finally:

+ 12 - 17
celery/tests/test_log.py

@@ -3,7 +3,7 @@ from __future__ import generators
 import os
 import sys
 import logging
-import unittest
+import unittest2 as unittest
 from tempfile import mktemp
 from StringIO import StringIO
 
@@ -54,13 +54,8 @@ class TestLog(unittest.TestCase):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
         logger.handlers = [] # Reset previously set logger.
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
-        self.assertTrue(logger.handlers[0].stream is sys.__stderr__,
+        self.assertIs(logger.handlers[0].stream, sys.__stderr__,
                 "setup_logger logs to stderr without logfile argument.")
-        #self.assertTrue(logger._process_aware,
-        #        "setup_logger() returns process aware logger.")
-        #self.assertDidLogTrue(logger, "Logging something",
-        #        "Logger logs error when loglevel is ERROR",
-        #        loglevel=logging.ERROR)
         self.assertDidLogFalse(logger, "Logging something",
                 "Logger doesn't info when loglevel is ERROR",
                 loglevel=logging.INFO)
@@ -80,7 +75,7 @@ class TestLog(unittest.TestCase):
             stdout, stderr = outs
             l = setup_logger(logfile=stderr, loglevel=logging.INFO)
             l.info("The quick brown fox...")
-            self.assertTrue("The quick brown fox..." in stderr.getvalue())
+            self.assertIn("The quick brown fox...", stderr.getvalue())
 
         context = override_stdouts()
         execute_context(context, with_override_stdouts)
@@ -91,14 +86,14 @@ class TestLog(unittest.TestCase):
         l.handlers = []
         tempfile = mktemp(suffix="unittest", prefix="celery")
         l = setup_logger(logfile=tempfile, loglevel=0)
-        self.assertTrue(isinstance(l.handlers[0], logging.FileHandler))
+        self.assertIsInstance(l.handlers[0], logging.FileHandler)
 
     def test_emergency_error_stderr(self):
         def with_override_stdouts(outs):
             stdout, stderr = outs
             emergency_error(None, "The lazy dog crawls under the fast fox")
-            self.assertTrue("The lazy dog crawls under the fast fox" in
-                                stderr.getvalue())
+            self.assertIn("The lazy dog crawls under the fast fox",
+                          stderr.getvalue())
 
         context = override_stdouts()
         execute_context(context, with_override_stdouts)
@@ -108,7 +103,7 @@ class TestLog(unittest.TestCase):
         emergency_error(tempfile, "Vandelay Industries")
         tempfilefh = open(tempfile, "r")
         try:
-            self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
+            self.assertIn("Vandelay Industries", "".join(tempfilefh))
         finally:
             tempfilefh.close()
             os.unlink(tempfile)
@@ -119,7 +114,7 @@ class TestLog(unittest.TestCase):
             def with_wrap_logger(sio):
                 redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
                 logger.error("foo")
-                self.assertTrue("foo" in sio.getvalue())
+                self.assertIn("foo", sio.getvalue())
 
             context = wrap_logger(logger)
             execute_context(context, with_wrap_logger)
@@ -133,18 +128,18 @@ class TestLog(unittest.TestCase):
             p = LoggingProxy(logger)
             p.close()
             p.write("foo")
-            self.assertTrue("foo" not in sio.getvalue())
+            self.assertNotIn("foo", sio.getvalue())
             p.closed = False
             p.write("foo")
-            self.assertTrue("foo" in sio.getvalue())
+            self.assertIn("foo", sio.getvalue())
             lines = ["baz", "xuzzy"]
             p.writelines(lines)
             for line in lines:
-                self.assertTrue(line in sio.getvalue())
+                self.assertIn(line, sio.getvalue())
             p.flush()
             p.close()
             self.assertFalse(p.isatty())
-            self.assertTrue(p.fileno() is None)
+            self.assertIsNone(p.fileno())
 
         context = wrap_logger(logger)
         execute_context(context, with_wrap_logger)

+ 2 - 1
celery/tests/test_messaging.py

@@ -1,4 +1,5 @@
-import unittest
+import unittest2 as unittest
+
 from celery.messaging import MSG_OPTIONS, extract_msg_options
 
 

+ 17 - 17
celery/tests/test_models.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from datetime import datetime, timedelta
 
 from celery import states
@@ -24,30 +24,30 @@ class TestModels(unittest.TestCase):
         m3 = self.createTaskMeta()
         self.assertTrue(unicode(m1).startswith("<Task:"))
         self.assertTrue(m1.task_id)
-        self.assertTrue(isinstance(m1.date_done, datetime))
+        self.assertIsInstance(m1.date_done, datetime)
 
         self.assertEqual(TaskMeta.objects.get_task(m1.task_id).task_id,
                 m1.task_id)
-        self.assertFalse(
-                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
+        self.assertNotEqual(TaskMeta.objects.get_task(m1.task_id).status,
+                            states.SUCCESS)
         TaskMeta.objects.store_result(m1.task_id, True, status=states.SUCCESS)
         TaskMeta.objects.store_result(m2.task_id, True, status=states.SUCCESS)
-        self.assertTrue(
-                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
-        self.assertTrue(
-                TaskMeta.objects.get_task(m2.task_id).status == states.SUCCESS)
+        self.assertEqual(TaskMeta.objects.get_task(m1.task_id).status,
+                         states.SUCCESS)
+        self.assertEqual(TaskMeta.objects.get_task(m2.task_id).status,
+                         states.SUCCESS)
 
         # Have to avoid save() because it applies the auto_now=True.
         TaskMeta.objects.filter(task_id=m1.task_id).update(
                 date_done=datetime.now() - timedelta(days=10))
 
         expired = TaskMeta.objects.get_all_expired()
-        self.assertTrue(m1 in expired)
-        self.assertFalse(m2 in expired)
-        self.assertFalse(m3 in expired)
+        self.assertIn(m1, expired)
+        self.assertNotIn(m2, expired)
+        self.assertNotIn(m3, expired)
 
         TaskMeta.objects.delete_expired()
-        self.assertFalse(m1 in TaskMeta.objects.all())
+        self.assertNotIn(m1, TaskMeta.objects.all())
 
     def test_tasksetmeta(self):
         m1 = self.createTaskSetMeta()
@@ -55,7 +55,7 @@ class TestModels(unittest.TestCase):
         m3 = self.createTaskSetMeta()
         self.assertTrue(unicode(m1).startswith("<TaskSet:"))
         self.assertTrue(m1.taskset_id)
-        self.assertTrue(isinstance(m1.date_done, datetime))
+        self.assertIsInstance(m1.date_done, datetime)
 
         self.assertEqual(
                 TaskSetMeta.objects.restore_taskset(m1.taskset_id).taskset_id,
@@ -66,9 +66,9 @@ class TestModels(unittest.TestCase):
                 date_done=datetime.now() - timedelta(days=10))
 
         expired = TaskSetMeta.objects.get_all_expired()
-        self.assertTrue(m1 in expired)
-        self.assertFalse(m2 in expired)
-        self.assertFalse(m3 in expired)
+        self.assertIn(m1, expired)
+        self.assertNotIn(m2, expired)
+        self.assertNotIn(m3, expired)
 
         TaskSetMeta.objects.delete_expired()
-        self.assertFalse(m1 in TaskSetMeta.objects.all())
+        self.assertNotIn(m1, TaskSetMeta.objects.all())

+ 12 - 12
celery/tests/test_pickle.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from billiard.serialization import pickle
 
@@ -17,33 +17,33 @@ class ArgOverrideException(Exception):
 class TestPickle(unittest.TestCase):
 
     def test_pickle_regular_exception(self):
-        e = None
+        exc = None
         try:
             raise RegularException("RegularException raised")
-        except RegularException, e:
+        except RegularException, exc:
             pass
 
-        pickled = pickle.dumps({"exception": e})
+        pickled = pickle.dumps({"exception": exc})
         unpickled = pickle.loads(pickled)
         exception = unpickled.get("exception")
         self.assertTrue(exception)
-        self.assertTrue(isinstance(exception, RegularException))
-        self.assertEqual(exception.args, ("RegularException raised", ))
+        self.assertIsInstance(exception, RegularException)
+        self.assertTupleEqual(exception.args, ("RegularException raised", ))
 
     def test_pickle_arg_override_exception(self):
 
-        e = None
+        exc = None
         try:
             raise ArgOverrideException("ArgOverrideException raised",
                     status_code=100)
-        except ArgOverrideException, e:
+        except ArgOverrideException, exc:
             pass
 
-        pickled = pickle.dumps({"exception": e})
+        pickled = pickle.dumps({"exception": exc})
         unpickled = pickle.loads(pickled)
         exception = unpickled.get("exception")
         self.assertTrue(exception)
-        self.assertTrue(isinstance(exception, ArgOverrideException))
-        self.assertEqual(exception.args, ("ArgOverrideException raised",
-                                          100))
+        self.assertIsInstance(exception, ArgOverrideException)
+        self.assertTupleEqual(exception.args, (
+                              "ArgOverrideException raised", 100))
         self.assertEqual(exception.status_code, 100)

+ 20 - 31
celery/tests/test_pool.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 import logging
 import itertools
 import time
@@ -27,15 +27,15 @@ class TestTaskPool(unittest.TestCase):
     def test_attrs(self):
         p = TaskPool(limit=2)
         self.assertEqual(p.limit, 2)
-        self.assertTrue(isinstance(p.logger, logging.Logger))
-        self.assertTrue(p._pool is None)
+        self.assertIsInstance(p.logger, logging.Logger)
+        self.assertIsNone(p._pool)
 
-    def x_start_stop(self):
+    def test_start_stop(self):
         p = TaskPool(limit=2)
         p.start()
-        self.assertTrue(p._pool)
+        self.assertIsNotNone(p._pool)
         p.stop()
-        self.assertTrue(p._pool is None)
+        self.assertIsNone(p._pool)
 
     def x_apply(self):
         p = TaskPool(limit=2)
@@ -43,50 +43,39 @@ class TestTaskPool(unittest.TestCase):
         scratchpad = {}
         proc_counter = itertools.count().next
 
-        def mycallback(ret_value, meta):
+        def mycallback(ret_value):
             process = proc_counter()
             scratchpad[process] = {}
             scratchpad[process]["ret_value"] = ret_value
-            scratchpad[process]["meta"] = meta
 
         myerrback = mycallback
 
-        res = p.apply_async(do_something, args=[10], callbacks=[mycallback],
-                            meta={"foo": "bar"})
-        res2 = p.apply_async(raise_something, args=[10], errbacks=[myerrback],
-                            meta={"foo2": "bar2"})
-        res3 = p.apply_async(do_something, args=[20], callbacks=[mycallback],
-                            meta={"foo3": "bar3"})
+        res = p.apply_async(do_something, args=[10], callbacks=[mycallback])
+        res2 = p.apply_async(raise_something, args=[10], errbacks=[myerrback])
+        res3 = p.apply_async(do_something, args=[20], callbacks=[mycallback])
 
         self.assertEqual(res.get(), 100)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(0))
-        self.assertEqual(scratchpad[0]["ret_value"], 100)
-        self.assertEqual(scratchpad[0]["meta"], {"foo": "bar"})
+        self.assertDictContainsSubset({"ret_value": 100},
+                                       scratchpad.get(0))
 
-        self.assertTrue(isinstance(res2.get(), ExceptionInfo))
+        self.assertIsInstance(res2.get(), ExceptionInfo)
         self.assertTrue(scratchpad.get(1))
         time.sleep(1)
-        #self.assertEqual(scratchpad[1]["ret_value"], "FOO")
-        self.assertTrue(isinstance(scratchpad[1]["ret_value"],
-                          ExceptionInfo))
+        self.assertIsInstance(scratchpad[1]["ret_value"],
+                              ExceptionInfo)
         self.assertEqual(scratchpad[1]["ret_value"].exception.args,
                           ("FOO EXCEPTION", ))
-        self.assertEqual(scratchpad[1]["meta"], {"foo2": "bar2"})
 
         self.assertEqual(res3.get(), 400)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(2))
-        self.assertEqual(scratchpad[2]["ret_value"], 400)
-        self.assertEqual(scratchpad[2]["meta"], {"foo3": "bar3"})
+        self.assertDictContainsSubset({"ret_value": 400},
+                                       scratchpad.get(2))
 
-        res3 = p.apply_async(do_something, args=[30], callbacks=[mycallback],
-                            meta={"foo4": "bar4"})
+        res3 = p.apply_async(do_something, args=[30], callbacks=[mycallback])
 
         self.assertEqual(res3.get(), 900)
         time.sleep(0.5)
-        self.assertTrue(scratchpad.get(3))
-        self.assertEqual(scratchpad[3]["ret_value"], 900)
-        self.assertEqual(scratchpad[3]["meta"], {"foo4": "bar4"})
-
+        self.assertDictContainsSubset({"ret_value": 900},
+                                       scratchpad.get(3))
         p.stop()

+ 16 - 16
celery/tests/test_registry.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery import registry
 from celery.task import Task, PeriodicTask
@@ -24,42 +24,42 @@ class TestTaskRegistry(unittest.TestCase):
     def assertRegisterUnregisterCls(self, r, task):
         self.assertRaises(r.NotRegistered, r.unregister, task)
         r.register(task)
-        self.assertTrue(task.name in r)
+        self.assertIn(task.name, r)
 
     def assertRegisterUnregisterFunc(self, r, task, task_name):
         self.assertRaises(r.NotRegistered, r.unregister, task_name)
         r.register(task, task_name)
-        self.assertTrue(task_name in r)
+        self.assertIn(task_name, r)
 
     def test_task_registry(self):
         r = registry.TaskRegistry()
-        self.assertTrue(isinstance(r.data, dict),
+        self.assertIsInstance(r.data, dict,
                 "TaskRegistry has composited dict")
 
         self.assertRegisterUnregisterCls(r, TestTask)
         self.assertRegisterUnregisterCls(r, TestPeriodicTask)
 
         tasks = dict(r)
-        self.assertTrue(isinstance(tasks.get(TestTask.name), TestTask))
-        self.assertTrue(isinstance(tasks.get(TestPeriodicTask.name),
-                                   TestPeriodicTask))
+        self.assertIsInstance(tasks.get(TestTask.name), TestTask)
+        self.assertIsInstance(tasks.get(TestPeriodicTask.name),
+                                   TestPeriodicTask)
 
         regular = r.regular()
-        self.assertTrue(TestTask.name in regular)
-        self.assertFalse(TestPeriodicTask.name in regular)
+        self.assertIn(TestTask.name, regular)
+        self.assertNotIn(TestPeriodicTask.name, regular)
 
         periodic = r.periodic()
-        self.assertFalse(TestTask.name in periodic)
-        self.assertTrue(TestPeriodicTask.name in periodic)
+        self.assertNotIn(TestTask.name, periodic)
+        self.assertIn(TestPeriodicTask.name, periodic)
 
-        self.assertTrue(isinstance(r[TestTask.name], TestTask))
-        self.assertTrue(isinstance(r[TestPeriodicTask.name],
-                                   TestPeriodicTask))
+        self.assertIsInstance(r[TestTask.name], TestTask)
+        self.assertIsInstance(r[TestPeriodicTask.name],
+                                   TestPeriodicTask)
 
         r.unregister(TestTask)
-        self.assertFalse(TestTask.name in r)
+        self.assertNotIn(TestTask.name, r)
         r.unregister(TestPeriodicTask)
-        self.assertFalse(TestPeriodicTask.name in r)
+        self.assertNotIn(TestPeriodicTask.name, r)
 
         self.assertTrue(TestTask().run())
         self.assertTrue(TestPeriodicTask().run())

+ 6 - 6
celery/tests/test_result.py

@@ -1,6 +1,6 @@
 from __future__ import generators
 
-import unittest
+import unittest2 as unittest
 
 from celery import states
 from celery.utils import gen_unique_id
@@ -90,7 +90,7 @@ class TestAsyncResult(unittest.TestCase):
         self.assertEqual(ok_res.get(), "the")
         self.assertEqual(ok2_res.get(), "quick")
         self.assertRaises(KeyError, nok_res.get)
-        self.assertTrue(isinstance(nok2_res.result, KeyError))
+        self.assertIsInstance(nok2_res.result, KeyError)
 
     def test_get_timeout(self):
         res = AsyncResult(self.task4["id"]) # has RETRY status
@@ -105,7 +105,7 @@ class TestAsyncResult(unittest.TestCase):
         oks = (AsyncResult(self.task1["id"]),
                AsyncResult(self.task2["id"]),
                AsyncResult(self.task3["id"]))
-        [self.assertTrue(ok.ready()) for ok in oks]
+        self.assertTrue(all(result.ready() for result in oks))
         self.assertFalse(AsyncResult(self.task4["id"]).ready())
 
 
@@ -173,11 +173,11 @@ class TestTaskSetResult(unittest.TestCase):
         it = iter(self.ts)
 
         results = sorted(list(it))
-        self.assertEqual(results, list(xrange(self.size)))
+        self.assertListEqual(results, list(xrange(self.size)))
 
     def test_join(self):
         joined = self.ts.join()
-        self.assertEqual(joined, list(xrange(self.size)))
+        self.assertListEqual(joined, list(xrange(self.size)))
 
     def test_successful(self):
         self.assertTrue(self.ts.successful())
@@ -201,7 +201,7 @@ class TestPendingAsyncResult(unittest.TestCase):
         self.task = AsyncResult(gen_unique_id())
 
     def test_result(self):
-        self.assertTrue(self.task.result is None)
+        self.assertIsNone(self.task.result)
 
 
 class TestFailedTaskSetResult(TestTaskSetResult):

+ 2 - 2
celery/tests/test_serialization.py

@@ -1,5 +1,5 @@
 import sys
-import unittest
+import unittest2 as unittest
 
 from celery.tests.utils import execute_context, mask_modules
 
@@ -12,7 +12,7 @@ class TestAAPickle(unittest.TestCase):
             def with_cPickle_masked(_val):
                 from billiard.serialization import pickle
                 import pickle as orig_pickle
-                self.assertTrue(pickle.dumps is orig_pickle.dumps)
+                self.assertIs(pickle.dumps, orig_pickle.dumps)
 
             context = mask_modules("cPickle")
             execute_context(context, with_cPickle_masked)

+ 19 - 19
celery/tests/test_task.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from StringIO import StringIO
 from datetime import datetime, timedelta
 
@@ -220,8 +220,8 @@ class TestCeleryTasks(unittest.TestCase):
         import operator
         conf.ALWAYS_EAGER = True
         res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
-        self.assertTrue(res, sum([operator.add(x, x)
-                                    for x in xrange(10)]))
+        self.assertEqual(sum(res), sum(operator.add(x, x)
+                                        for x in xrange(10)))
         conf.ALWAYS_EAGER = False
 
     def test_dmap_async(self):
@@ -229,8 +229,8 @@ class TestCeleryTasks(unittest.TestCase):
         import operator
         conf.ALWAYS_EAGER = True
         res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
-        self.assertTrue(res.get(), sum([operator.add(x, x)
-                                            for x in xrange(10)]))
+        self.assertEqual(sum(res.get()), sum(operator.add(x, x)
+                                                for x in xrange(10)))
         conf.ALWAYS_EAGER = False
 
     def assertNextTaskDataEquals(self, consumer, presult, task_name,
@@ -241,9 +241,9 @@ class TestCeleryTasks(unittest.TestCase):
         self.assertEqual(task_data["task"], task_name)
         task_kwargs = task_data.get("kwargs", {})
         if test_eta:
-            self.assertTrue(isinstance(task_data.get("eta"), basestring))
+            self.assertIsInstance(task_data.get("eta"), basestring)
             to_datetime = parse_iso8601(task_data.get("eta"))
-            self.assertTrue(isinstance(to_datetime, datetime))
+            self.assertIsInstance(to_datetime, datetime)
         for arg_name, arg_value in kwargs.items():
             self.assertEqual(task_kwargs.get(arg_name), arg_value)
 
@@ -256,7 +256,7 @@ class TestCeleryTasks(unittest.TestCase):
 
     def test_regular_task(self):
         T1 = self.createTaskCls("T1", "c.unittest.t.t1")
-        self.assertTrue(isinstance(T1(), T1))
+        self.assertIsInstance(T1(), T1)
         self.assertTrue(T1().run())
         self.assertTrue(callable(T1()),
                 "Task class is callable()")
@@ -271,7 +271,7 @@ class TestCeleryTasks(unittest.TestCase):
         consumer = t1.get_consumer()
         self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
         consumer.discard_all()
-        self.assertTrue(consumer.fetch() is None)
+        self.assertIsNone(consumer.fetch())
 
         # Without arguments.
         presult = t1.delay()
@@ -303,14 +303,14 @@ class TestCeleryTasks(unittest.TestCase):
         consumer.discard_all()
         task.apply_async(t1)
         self.assertEqual(consumer.discard_all(), 1)
-        self.assertTrue(consumer.fetch() is None)
+        self.assertIsNone(consumer.fetch())
 
         self.assertFalse(presult.successful())
         default_backend.mark_as_done(presult.task_id, result=None)
         self.assertTrue(presult.successful())
 
         publisher = t1.get_publisher()
-        self.assertTrue(isinstance(publisher, messaging.TaskPublisher))
+        self.assertIsInstance(publisher, messaging.TaskPublisher)
 
     def test_get_publisher(self):
         from celery.task import base
@@ -339,7 +339,7 @@ class TestTaskSet(unittest.TestCase):
         ts = task.TaskSet(return_True_task.name, [
             [[1], {}], [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
         res = ts.apply_async()
-        self.assertEqual(res.join(), [True, True, True, True, True])
+        self.assertListEqual(res.join(), [True, True, True, True, True])
 
         conf.ALWAYS_EAGER = False
 
@@ -367,9 +367,9 @@ class TestTaskSet(unittest.TestCase):
         taskset_id = taskset_res.taskset_id
         for subtask in subtasks:
             m = consumer.fetch().payload
-            self.assertEqual(m.get("taskset"), taskset_id)
-            self.assertEqual(m.get("task"), IncrementCounterTask.name)
-            self.assertEqual(m.get("id"), subtask.task_id)
+            self.assertDictContainsSubset({"taskset": taskset_id,
+                                           "task": IncrementCounterTask.name,
+                                           "id": subtask.task_id}, m)
             IncrementCounterTask().run(
                     increment_by=m.get("kwargs", {}).get("increment_by"))
         self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
@@ -381,7 +381,7 @@ class TestTaskApply(unittest.TestCase):
         IncrementCounterTask.count = 0
 
         e = IncrementCounterTask.apply()
-        self.assertTrue(isinstance(e, EagerResult))
+        self.assertIsInstance(e, EagerResult)
         self.assertEqual(e.get(), 1)
 
         e = IncrementCounterTask.apply(args=[1])
@@ -412,9 +412,9 @@ class TestPeriodicTask(unittest.TestCase):
             (task.PeriodicTask, ), {"__module__": __name__})
 
     def test_remaining_estimate(self):
-        self.assertTrue(isinstance(
+        self.assertIsInstance(
             MyPeriodic().remaining_estimate(datetime.now()),
-            timedelta))
+            timedelta)
 
     def test_timedelta_seconds_returns_0_on_negative_time(self):
         delta = timedelta(days=-2)
@@ -432,7 +432,7 @@ class TestPeriodicTask(unittest.TestCase):
     def test_is_due_not_due(self):
         due, remaining = MyPeriodic().is_due(datetime.now())
         self.assertFalse(due)
-        self.assertTrue(remaining > 60)
+        self.assertGreater(remaining, 60)
 
     def test_is_due(self):
         p = MyPeriodic()

+ 1 - 1
celery/tests/test_task_builtins.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from billiard.serialization import pickle
 

+ 6 - 6
celery/tests/test_task_control.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery.task import control
 from celery.task.builtins import PingTask
@@ -36,23 +36,23 @@ class TestBroadcast(unittest.TestCase):
     @with_mock_broadcast
     def test_broadcast(self):
         control.broadcast("foobarbaz", arguments=[])
-        self.assertTrue("foobarbaz" in MockBroadcastPublisher.sent)
+        self.assertIn("foobarbaz", MockBroadcastPublisher.sent)
 
     @with_mock_broadcast
     def test_rate_limit(self):
         control.rate_limit(PingTask.name, "100/m")
-        self.assertTrue("rate_limit" in MockBroadcastPublisher.sent)
+        self.assertIn("rate_limit", MockBroadcastPublisher.sent)
 
     @with_mock_broadcast
     def test_revoke(self):
         control.revoke("foozbaaz")
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)
 
     @with_mock_broadcast
     def test_revoke_from_result(self):
         from celery.result import AsyncResult
         AsyncResult("foozbazzbar").revoke()
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)
 
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
@@ -60,4 +60,4 @@ class TestBroadcast(unittest.TestCase):
         r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
                                                         for i in range(10)]))
         r.revoke()
-        self.assertTrue("revoke" in MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockBroadcastPublisher.sent)

+ 10 - 10
celery/tests/test_task_http.py

@@ -3,7 +3,7 @@ from __future__ import generators
 
 import sys
 import logging
-import unittest
+import unittest2 as unittest
 from urllib import addinfourl
 try:
     from contextlib import contextmanager
@@ -63,23 +63,23 @@ class TestEncodings(unittest.TestCase):
               "foobar".encode("utf-8"): "xuzzybaz".encode("utf-8")}
 
         for key, value in http.utf8dict(d.items()).items():
-            self.assertTrue(isinstance(key, str))
-            self.assertTrue(isinstance(value, str))
+            self.assertIsInstance(key, str)
+            self.assertIsInstance(value, str)
 
 
 class TestMutableURL(unittest.TestCase):
 
     def test_url_query(self):
         url = http.MutableURL("http://example.com?x=10&y=20&z=Foo")
-        self.assertEqual(url.query.get("x"), "10")
-        self.assertEqual(url.query.get("y"), "20")
-        self.assertEqual(url.query.get("z"), "Foo")
+        self.assertDictContainsSubset({"x": "10",
+                                       "y": "20",
+                                       "z": "Foo"}, url.query)
         url.query["name"] = "George"
         url = http.MutableURL(str(url))
-        self.assertEqual(url.query.get("x"), "10")
-        self.assertEqual(url.query.get("y"), "20")
-        self.assertEqual(url.query.get("z"), "Foo")
-        self.assertEqual(url.query.get("name"), "George")
+        self.assertDictContainsSubset({"x": "10",
+                                       "y": "20",
+                                       "z": "Foo",
+                                       "name": "George"}, url.query)
 
     def test_url_keeps_everything(self):
         url = "https://e.com:808/foo/bar#zeta?x=10&y=20"

+ 6 - 6
celery/tests/test_utils.py

@@ -1,6 +1,6 @@
 import sys
 import socket
-import unittest
+import unittest2 as unittest
 
 from billiard.utils.functional import wraps
 
@@ -15,17 +15,17 @@ class TestChunks(unittest.TestCase):
 
         # n == 2
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]])
 
         # n == 3
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]])
 
         # n == 2 (exact)
         x = utils.chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertEqual(list(x),
+        self.assertListEqual(list(x),
             [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
 
 
@@ -37,10 +37,10 @@ class TestGenUniqueId(unittest.TestCase):
         def with_ctypes_masked(_val):
             from celery.utils import ctypes, gen_unique_id
 
-            self.assertTrue(ctypes is None)
+            self.assertIsNone(ctypes)
             uuid = gen_unique_id()
             self.assertTrue(uuid)
-            self.assertTrue(isinstance(uuid, basestring))
+            self.assertIsInstance(uuid, basestring)
 
         try:
             context = mask_modules("ctypes")

+ 1 - 1
celery/tests/test_utils_info.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery.utils import info
 

+ 17 - 17
celery/tests/test_worker.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 from Queue import Queue, Empty
 from datetime import datetime, timedelta
 from multiprocessing import get_logger
@@ -162,13 +162,13 @@ class TestCarrotListener(unittest.TestCase):
 
         records.clear()
         self.assertEqual(l._detect_wait_method(), l._mainloop)
-        self.assertTrue(records.get("broadcast_callback"))
-        self.assertTrue(records.get("consume_broadcast"))
-        self.assertTrue(records.get("consume_tasks"))
+        for record in ("broadcast_callback", "consume_broadcast",
+                "consume_tasks"):
+            self.assertTrue(records.get(record))
 
         records.clear()
         l.connection.connection = PlaceHolder()
-        self.assertTrue(l._detect_wait_method() is l.task_consumer.iterconsume)
+        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
         self.assertTrue(records.get("consumer_add"))
 
     def test_connection(self):
@@ -176,19 +176,19 @@ class TestCarrotListener(unittest.TestCase):
                            send_events=False)
 
         l.reset_connection()
-        self.assertTrue(isinstance(l.connection, BrokerConnection))
+        self.assertIsInstance(l.connection, BrokerConnection)
 
         l.stop_consumers()
-        self.assertTrue(l.connection is None)
-        self.assertTrue(l.task_consumer is None)
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
 
         l.reset_connection()
-        self.assertTrue(isinstance(l.connection, BrokerConnection))
+        self.assertIsInstance(l.connection, BrokerConnection)
 
         l.stop()
         l.close_connection()
-        self.assertTrue(l.connection is None)
-        self.assertTrue(l.task_consumer is None)
+        self.assertIsNone(l.connection)
+        self.assertIsNone(l.task_consumer)
 
     def test_receive_message_control_command(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
@@ -198,7 +198,7 @@ class TestCarrotListener(unittest.TestCase):
         l.event_dispatcher = MockEventDispatcher()
         l.control_dispatch = MockControlDispatch()
         l.receive_message(m.decode(), m)
-        self.assertTrue("shutdown" in l.control_dispatch.commands)
+        self.assertIn("shutdown", l.control_dispatch.commands)
 
     def test_close_connection(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,
@@ -226,7 +226,7 @@ class TestCarrotListener(unittest.TestCase):
         def with_catch_warnings(log):
             l.receive_message(m.decode(), m)
             self.assertTrue(log)
-            self.assertTrue("unknown message" in log[0].message.args[0])
+            self.assertIn("unknown message", log[0].message.args[0])
 
         context = catch_warnings(record=True)
         execute_context(context, with_catch_warnings)
@@ -242,7 +242,7 @@ class TestCarrotListener(unittest.TestCase):
         l.receive_message(m.decode(), m)
 
         in_bucket = self.ready_queue.get_nowait()
-        self.assertTrue(isinstance(in_bucket, TaskWrapper))
+        self.assertIsInstance(in_bucket, TaskWrapper)
         self.assertEqual(in_bucket.task_name, foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
         self.assertTrue(self.eta_schedule.empty())
@@ -278,7 +278,7 @@ class TestCarrotListener(unittest.TestCase):
         l.event_dispatcher = MockEventDispatcher()
         l.receive_message(c.decode(), c)
         from celery.worker.revoke import revoked
-        self.assertTrue(id in revoked)
+        self.assertIn(id, revoked)
 
         l.receive_message(t.decode(), t)
         self.assertTrue(ready_queue.empty())
@@ -314,7 +314,7 @@ class TestCarrotListener(unittest.TestCase):
         in_hold = self.eta_schedule.queue[0]
         self.assertEqual(len(in_hold), 4)
         eta, priority, task, on_accept = in_hold
-        self.assertTrue(isinstance(task, TaskWrapper))
+        self.assertIsInstance(task, TaskWrapper)
         self.assertTrue(callable(on_accept))
         self.assertEqual(task.task_name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
@@ -329,7 +329,7 @@ class TestWorkController(unittest.TestCase):
 
     def test_attrs(self):
         worker = self.worker
-        self.assertTrue(isinstance(worker.eta_schedule, Scheduler))
+        self.assertIsInstance(worker.eta_schedule, Scheduler)
         self.assertTrue(worker.scheduler)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.listener)

+ 3 - 3
celery/tests/test_worker_control.py

@@ -1,5 +1,5 @@
 import socket
-import unittest
+import unittest2 as unittest
 
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
@@ -49,10 +49,10 @@ class TestControlPanel(unittest.TestCase):
              "destination": hostname,
              "task_id": uuid}
         self.panel.dispatch_from_message(m)
-        self.assertTrue(uuid in revoked)
+        self.assertIn(uuid, revoked)
 
         m = {"command": "revoke",
              "destination": "does.not.exist",
              "task_id": uuid + "xxx"}
         self.panel.dispatch_from_message(m)
-        self.assertTrue(uuid + "xxx" not in revoked)
+        self.assertNotIn(uuid + "xxx", revoked)

+ 2 - 2
celery/tests/test_worker_controllers.py

@@ -1,5 +1,5 @@
 import time
-import unittest
+import unittest2 as unittest
 from Queue import Queue
 
 from celery.utils import gen_unique_id
@@ -90,7 +90,7 @@ class TestMediator(unittest.TestCase):
 
         m.on_iteration()
 
-        self.assertTrue("value" not in got)
+        self.assertNotIn("value", got)
         self.assertTrue(t.acked)
 
 

+ 6 - 6
celery/tests/test_worker_heartbeat.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery.worker.heartbeat import Heart
 
@@ -27,16 +27,16 @@ class TestHeart(unittest.TestCase):
         heart = Heart(eventer, interval=1)
         heart._shutdown.set()
         heart.run()
-        self.assertTrue(heart._state == "RUN")
-        self.assertTrue("worker-online" in eventer.sent)
-        self.assertTrue("worker-heartbeat" in eventer.sent)
-        self.assertTrue("worker-offline" in eventer.sent)
+        self.assertEqual(heart._state, "RUN")
+        self.assertIn("worker-online", eventer.sent)
+        self.assertIn("worker-heartbeat", eventer.sent)
+        self.assertIn("worker-offline", eventer.sent)
 
         self.assertTrue(heart._stopped.isSet())
 
         heart.stop()
         heart.stop()
-        self.assertTrue(heart._state == "CLOSE")
+        self.assertEqual(heart._state, "CLOSE")
 
         heart = Heart(eventer, interval=0.00001)
         heart._shutdown.set()

+ 17 - 17
celery/tests/test_worker_job.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 import sys
 import logging
-import unittest
+import unittest2 as unittest
 import simplejson
 from StringIO import StringIO
 
@@ -90,14 +90,14 @@ class TestJail(unittest.TestCase):
     def test_execute_jail_failure(self):
         ret = jail(gen_unique_id(), mytask_raising.name,
                    [4], {})
-        self.assertTrue(isinstance(ret, ExceptionInfo))
-        self.assertEqual(ret.exception.args, (4, ))
+        self.assertIsInstance(ret, ExceptionInfo)
+        self.assertTupleEqual(ret.exception.args, (4, ))
 
     def test_execute_ignore_result(self):
         task_id = gen_unique_id()
         ret = jail(id, MyTaskIgnoreResult.name,
                    [4], {})
-        self.assertTrue(ret, 8)
+        self.assertEquals(ret, 256)
         self.assertFalse(AsyncResult(task_id).ready())
 
     def test_django_db_connection_is_closed(self):
@@ -179,7 +179,7 @@ class TestTaskWrapper(unittest.TestCase):
         tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.eventer = MockEventDispatcher()
         tw.send_event("task-frobulated")
-        self.assertTrue("task-frobulated" in tw.eventer.sent)
+        self.assertIn("task-frobulated", tw.eventer.sent)
 
     def test_send_email(self):
         from celery import conf
@@ -229,10 +229,10 @@ class TestTaskWrapper(unittest.TestCase):
             def with_catch_warnings(log):
                 res = execute_and_trace(mytask.name, gen_unique_id(),
                                         [4], {})
-                self.assertTrue(isinstance(res, ExceptionInfo))
+                self.assertIsInstance(res, ExceptionInfo)
                 self.assertTrue(log)
-                self.assertTrue("Exception outside" in log[0].message.args[0])
-                self.assertTrue("KeyError" in log[0].message.args[0])
+                self.assertIn("Exception outside", log[0].message.args[0])
+                self.assertIn("KeyError", log[0].message.args[0])
 
             context = catch_warnings(record=True)
             execute_context(context, with_catch_warnings)
@@ -303,13 +303,13 @@ class TestTaskWrapper(unittest.TestCase):
                         content_type="application/json",
                         content_encoding="utf-8")
         tw = TaskWrapper.from_message(m, m.decode())
-        self.assertTrue(isinstance(tw, TaskWrapper))
+        self.assertIsInstance(tw, TaskWrapper)
         self.assertEqual(tw.task_name, body["task"])
         self.assertEqual(tw.task_id, body["id"])
         self.assertEqual(tw.args, body["args"])
         self.assertEqual(tw.kwargs.keys()[0],
                           u"æØåveéðƒeæ".encode("utf-8"))
-        self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
+        self.assertNotIsInstance(tw.kwargs.keys()[0], unicode)
         self.assertTrue(tw.logger)
 
     def test_from_message_nonexistant_task(self):
@@ -359,10 +359,10 @@ class TestTaskWrapper(unittest.TestCase):
     def test_execute_fail(self):
         tid = gen_unique_id()
         tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"})
-        self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
+        self.assertIsInstance(tw.execute(), ExceptionInfo)
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEqual(meta.status, states.FAILURE)
-        self.assertTrue(isinstance(meta.result, KeyError))
+        self.assertIsInstance(meta.result, KeyError)
 
     def test_execute_using_pool(self):
         tid = gen_unique_id()
@@ -370,13 +370,13 @@ class TestTaskWrapper(unittest.TestCase):
         p = TaskPool(2)
         p.start()
         asyncres = tw.execute_using_pool(p)
-        self.assertTrue(asyncres.get(), 256)
+        self.assertEquals(asyncres.get(), 256)
         p.stop()
 
     def test_default_kwargs(self):
         tid = gen_unique_id()
         tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"})
-        self.assertEqual(tw.extend_with_default_kwargs(10, "some_logfile"), {
+        self.assertDictEqual(tw.extend_with_default_kwargs(10, "some_logfile"), {
             "f": "x",
             "logfile": "some_logfile",
             "loglevel": 10,
@@ -403,8 +403,8 @@ class TestTaskWrapper(unittest.TestCase):
 
         tw.on_failure(exc_info)
         logvalue = logfh.getvalue()
-        self.assertTrue(mytask.name in logvalue)
-        self.assertTrue(tid in logvalue)
-        self.assertTrue("ERROR" in logvalue)
+        self.assertIn(mytask.name, logvalue)
+        self.assertIn(tid, logvalue)
+        self.assertIn("ERROR", logvalue)
 
         conf.CELERY_SEND_TASK_ERROR_EMAILS = False

+ 3 - 3
celery/tests/test_worker_revoke.py

@@ -1,4 +1,4 @@
-import unittest
+import unittest2 as unittest
 
 from celery.worker import revoke
 
@@ -7,6 +7,6 @@ class TestRevokeRegistry(unittest.TestCase):
 
     def test_is_working(self):
         revoke.revoked.add("foo")
-        self.assertTrue("foo" in revoke.revoked)
+        self.assertIn("foo", revoke.revoked)
         revoke.revoked.pop_value("foo")
-        self.assertTrue("foo" not in revoke.revoked)
+        self.assertNotIn("foo", revoke.revoked)

+ 3 - 3
celery/tests/test_worker_scheduler.py

@@ -1,5 +1,6 @@
 from __future__ import generators
-import unittest
+
+import unittest2 as unittest
 from Queue import Queue, Empty
 from datetime import datetime, timedelta
 
@@ -50,5 +51,4 @@ class TestScheduler(unittest.TestCase):
     def test_empty_queue_yields_None(self):
         ready_queue = Queue()
         sched = Scheduler(ready_queue)
-
-        self.assertTrue(iter(sched).next() is None)
+        self.assertIsNone(iter(sched).next())

+ 1 - 0
contrib/requirements/test.txt

@@ -7,3 +7,4 @@ coverage>=3.0
 pytyrant
 redis
 pymongo
+git+git://github.com/exogen/nose-achievements.git