Kaynağa Gözat

[tests] Use moar Mock.assert_called

Ask Solem 9 yıl önce
ebeveyn
işleme
26f69df6da
50 değiştirilmiş dosya ile 220 ekleme ve 235 silme
  1. 1 1
      celery/tests/app/test_amqp.py
  2. 5 5
      celery/tests/app/test_app.py
  3. 4 4
      celery/tests/app/test_beat.py
  4. 6 6
      celery/tests/app/test_builtins.py
  5. 3 7
      celery/tests/app/test_loaders.py
  6. 3 3
      celery/tests/app/test_log.py
  7. 5 5
      celery/tests/backends/test_base.py
  8. 4 3
      celery/tests/backends/test_cache.py
  9. 1 1
      celery/tests/backends/test_couchbase.py
  10. 1 1
      celery/tests/backends/test_couchdb.py
  11. 3 3
      celery/tests/backends/test_mongodb.py
  12. 1 1
      celery/tests/backends/test_redis.py
  13. 1 1
      celery/tests/backends/test_riak.py
  14. 1 1
      celery/tests/bin/test_amqp.py
  15. 2 2
      celery/tests/bin/test_base.py
  16. 3 3
      celery/tests/bin/test_beat.py
  17. 11 11
      celery/tests/bin/test_celery.py
  18. 2 2
      celery/tests/bin/test_celeryd_detach.py
  19. 2 2
      celery/tests/bin/test_celeryevdump.py
  20. 2 2
      celery/tests/bin/test_events.py
  21. 5 5
      celery/tests/bin/test_multi.py
  22. 6 6
      celery/tests/bin/test_worker.py
  23. 2 2
      celery/tests/concurrency/test_concurrency.py
  24. 1 1
      celery/tests/concurrency/test_eventlet.py
  25. 2 2
      celery/tests/concurrency/test_gevent.py
  26. 8 11
      celery/tests/concurrency/test_prefork.py
  27. 19 19
      celery/tests/contrib/test_migrate.py
  28. 2 2
      celery/tests/contrib/test_rdb.py
  29. 1 1
      celery/tests/events/test_events.py
  30. 8 19
      celery/tests/events/test_snapshot.py
  31. 1 1
      celery/tests/events/test_state.py
  32. 11 11
      celery/tests/fixups/test_django.py
  33. 10 8
      celery/tests/tasks/test_canvas.py
  34. 7 7
      celery/tests/tasks/test_chord.py
  35. 6 6
      celery/tests/tasks/test_result.py
  36. 6 6
      celery/tests/tasks/test_trace.py
  37. 2 2
      celery/tests/utils/test_imports.py
  38. 1 1
      celery/tests/utils/test_mail.py
  39. 6 6
      celery/tests/utils/test_platforms.py
  40. 1 1
      celery/tests/utils/test_timer2.py
  41. 1 1
      celery/tests/utils/test_utils.py
  42. 2 2
      celery/tests/worker/test_autoreload.py
  43. 2 2
      celery/tests/worker/test_autoscale.py
  44. 6 6
      celery/tests/worker/test_consumer.py
  45. 13 13
      celery/tests/worker/test_control.py
  46. 12 12
      celery/tests/worker/test_loops.py
  47. 10 10
      celery/tests/worker/test_request.py
  48. 1 1
      celery/tests/worker/test_strategy.py
  49. 6 6
      celery/tests/worker/test_worker.py
  50. 1 1
      requirements/test.txt

+ 1 - 1
celery/tests/app/test_amqp.py

@@ -231,7 +231,7 @@ class test_AMQP(AppCase):
             exchange='xyz', routing_key='xyb',
             event_dispatcher=evd,
         )
-        self.assertTrue(evd.publish.called)
+        evd.publish.assert_called()
         event = evd.publish.call_args[0][1]
         self.assertEqual(event['routing_key'], 'xyb')
         self.assertEqual(event['exchange'], 'xyz')

+ 5 - 5
celery/tests/app/test_app.py

@@ -214,7 +214,7 @@ class test_App(AppCase):
             def lazy_list():
                 return [1, 2, 3]
             self.app.autodiscover_tasks(lazy_list)
-            self.assertTrue(import_modules.connect.called)
+            import_modules.connect.assert_called()
             prom = import_modules.connect.call_args[0][0]
             self.assertIsInstance(prom, promise)
             self.assertEqual(prom.fun, self.app._autodiscover_tasks)
@@ -378,7 +378,7 @@ class test_App(AppCase):
             @self.app.task(shared=False)
             def foo():
                 pass
-            self.assertFalse(sh.called)
+            sh.assert_not_called()
 
     def test_task_compat_with_filter(self):
         with self.Celery() as app:
@@ -469,7 +469,7 @@ class test_App(AppCase):
                 aawsX.apply_async((4, 5))
                 args = create.call_args[0][2]
                 self.assertEqual(args, ('hello', 4, 5))
-                self.assertTrue(send.called)
+                send.assert_called()
 
     def test_apply_async_adds_children(self):
         from celery._state import _task_stack
@@ -707,7 +707,7 @@ class test_App(AppCase):
     @patch('celery.bin.celery.CeleryCommand.execute_from_commandline')
     def test_start(self, execute):
         self.app.start()
-        self.assertTrue(execute.called)
+        execute.assert_called()
 
     def test_mail_admins(self):
 
@@ -915,7 +915,7 @@ class test_App(AppCase):
         x.should_send = Mock()
         x.should_send.return_value = False
         x.send(Mock(), Mock())
-        self.assertFalse(task.app.mail_admins.called)
+        task.app.mail_admins.assert_not_called()
 
     def test_select_queues(self):
         self.app.amqp = Mock(name='amqp')

+ 4 - 4
celery/tests/app/test_beat.py

@@ -174,7 +174,7 @@ class test_Scheduler(AppCase):
 
         scheduler = mScheduler(app=self.app)
         scheduler.apply_async(scheduler.Entry(task=foo.name, app=self.app))
-        self.assertTrue(foo.apply_async.called)
+        foo.apply_async.assert_called()
 
     def test_should_sync(self):
 
@@ -193,7 +193,7 @@ class test_Scheduler(AppCase):
         s._do_sync = Mock()
         s.should_sync.return_value = False
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
-        self.assertFalse(s._do_sync.called)
+        s._do_sync.assert_not_called()
 
     def test_should_sync_increments_sync_every_counter(self):
         self.app.conf.beat_sync_every = 2
@@ -257,7 +257,7 @@ class test_Scheduler(AppCase):
     def test_ensure_connection_error_handler(self, ensure):
         s = mScheduler(app=self.app)
         self.assertTrue(s._ensure_connected())
-        self.assertTrue(ensure.called)
+        ensure.assert_called()
         callback = ensure.call_args[0][0]
 
         callback(KeyError(), 5)
@@ -295,7 +295,7 @@ class test_Scheduler(AppCase):
         scheduler.add(name='test_due_tick_SchedulingError',
                       schedule=always_due)
         self.assertEqual(scheduler.tick(), 0)
-        self.assertTrue(error.called)
+        error.assert_called()
 
     def test_pending_tick(self):
         scheduler = mScheduler(app=self.app)

+ 6 - 6
celery/tests/app/test_builtins.py

@@ -29,7 +29,7 @@ class test_backend_cleanup(BuiltinsCase):
         self.app.backend.cleanup.__name__ = 'cleanup'
         cleanup_task = builtins.add_backend_cleanup_task(self.app)
         cleanup_task()
-        self.assertTrue(self.app.backend.cleanup.called)
+        self.app.backend.cleanup.assert_called()
 
 
 class test_accumulate(BuiltinsCase):
@@ -84,7 +84,7 @@ class test_chunks(BuiltinsCase):
         self.app.tasks['celery.chunks'](
             chunks_mul, [(2, 2), (4, 4), (8, 8)], 1,
         )
-        self.assertTrue(apply_chunks.called)
+        apply_chunks.assert_called()
 
 
 class test_group(BuiltinsCase):
@@ -101,7 +101,7 @@ class test_group(BuiltinsCase):
     def test_apply_async_eager(self):
         self.task.apply = Mock(name='apply')
         self.task.apply_async((1, 2, 3, 4, 5))
-        self.assertTrue(self.task.apply.called)
+        self.task.apply.assert_called()
 
     def mock_group(self, *tasks):
         g = group(*tasks, app=self.app)
@@ -125,7 +125,7 @@ class test_group(BuiltinsCase):
     def test_task__disable_add_to_parent(self, current_worker_task):
         g, result = self.mock_group(self.add.s(2, 2), self.add.s(4, 4))
         self.task(g.tasks, result, result.id, None, add_to_parent=False)
-        self.assertFalse(current_worker_task.add_trail.called)
+        current_worker_task.add_trail.assert_not_called()
 
 
 class test_chain(BuiltinsCase):
@@ -159,13 +159,13 @@ class test_chord(BuiltinsCase):
         x = chord([self.add.s(i, i) for i in range(10)], body=body)
         x.run = Mock(name='chord.run(x)')
         x.apply_async(group_id='some_group_id')
-        self.assertTrue(x.run.called)
+        x.run.assert_called()
         resbody = x.run.call_args[0][1]
         self.assertEqual(resbody.options['group_id'], 'some_group_id')
         x2 = chord([self.add.s(i, i) for i in range(10)], body=body)
         x2.run = Mock(name='chord.run(x2)')
         x2.apply_async(chord='some_chord_id')
-        self.assertTrue(x2.run.called)
+        x2.run.assert_called()
         resbody = x2.run.call_args[0][1]
         self.assertEqual(resbody.options['chord'], 'some_chord_id')
 

+ 3 - 7
celery/tests/app/test_loaders.py

@@ -94,13 +94,9 @@ class test_LoaderBase(AppCase):
         )
 
     def test_import_from_cwd_custom_imp(self):
-
-        def imp(module, package=None):
-            imp.called = True
-        imp.called = False
-
+        imp = Mock(name='imp')
         self.loader.import_from_cwd('foo', imp=imp)
-        self.assertTrue(imp.called)
+        imp.assert_called()
 
     @patch('celery.utils.mail.Mailer._send')
     def test_mail_admins_errors(self, send):
@@ -257,7 +253,7 @@ class test_autodiscovery(Case):
             base._RACE_PROTECTION = False
         with patch('celery.loaders.base.find_related_module') as frm:
             base.autodiscover_tasks(['foo'])
-            self.assertTrue(frm.called)
+            frm.assert_called()
 
     def test_find_related_module(self):
         with patch('importlib.import_module') as imp:

+ 3 - 3
celery/tests/app/test_log.py

@@ -99,8 +99,8 @@ class test_ColorFormatter(AppCase):
         value = KeyError()
         fe.return_value = value
         self.assertIs(x.formatException(value), value)
-        self.assertTrue(fe.called)
-        self.assertFalse(safe_str.called)
+        fe.assert_called()
+        safe_str.assert_not_called()
 
     @patch('logging.Formatter.formatException')
     @patch('celery.utils.log.safe_str')
@@ -112,7 +112,7 @@ class test_ColorFormatter(AppCase):
         except Exception:
             self.assertTrue(x.formatException(sys.exc_info()))
         if sys.version_info[0] == 2:
-            self.assertTrue(safe_str.called)
+            safe_str.assert_called()
 
     @patch('logging.Formatter.format')
     def test_format_object(self, _format):

+ 5 - 5
celery/tests/backends/test_base.py

@@ -238,7 +238,7 @@ class test_BaseBackend_dict(AppCase):
             raise KeyError('foo')
         except KeyError as exc:
             self.b.fail_from_current_stack('task_id')
-            self.assertTrue(self.b.mark_as_failure.called)
+            self.b.mark_as_failure.assert_called()
             args = self.b.mark_as_failure.call_args[0]
             self.assertEqual(args[0], 'task_id')
             self.assertIs(args[1], exc)
@@ -466,14 +466,14 @@ class test_KeyValueStoreBackend(AppCase):
     def test_chord_part_return_propagate_set(self):
         with self._chord_part_context(self.b) as (task, deps, _):
             self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
-            self.assertFalse(self.b.expire.called)
+            self.b.expire.assert_not_called()
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(propagate=True, timeout=3.0)
 
     def test_chord_part_return_propagate_default(self):
         with self._chord_part_context(self.b) as (task, deps, _):
             self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
-            self.assertFalse(self.b.expire.called)
+            self.b.expire.assert_not_called()
             deps.delete.assert_called_with()
             deps.join_native.assert_called_with(propagate=True, timeout=3.0)
 
@@ -482,7 +482,7 @@ class test_KeyValueStoreBackend(AppCase):
             deps._failed_join_report = lambda: iter([])
             deps.join_native.side_effect = KeyError('foo')
             self.b.on_chord_part_return(task.request, 'SUCCESS', 10)
-            self.assertTrue(self.b.fail_from_current_stack.called)
+            self.b.fail_from_current_stack.assert_called()
             args = self.b.fail_from_current_stack.call_args
             exc = args[1]['exc']
             self.assertIsInstance(exc, ChordError)
@@ -496,7 +496,7 @@ class test_KeyValueStoreBackend(AppCase):
             ])
             deps.join_native.side_effect = KeyError('foo')
             b.on_chord_part_return(task.request, 'SUCCESS', 10)
-            self.assertTrue(b.fail_from_current_stack.called)
+            b.fail_from_current_stack.assert_called()
             args = b.fail_from_current_stack.call_args
             exc = args[1]['exc']
             self.assertIsInstance(exc, ChordError)

+ 4 - 3
celery/tests/backends/test_cache.py

@@ -15,7 +15,7 @@ from celery.exceptions import ImproperlyConfigured
 from celery.five import items, module_name_t, string, text_t
 from celery.utils import uuid
 
-from celery.tests.case import AppCase, Mock, mock, patch
+from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 PY3 = sys.version_info[0] == 3
 
@@ -89,9 +89,9 @@ class test_CacheBackend(AppCase):
         task.request.group = gid
         tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
 
-        self.assertFalse(deps.join_native.called)
+        deps.join_native.assert_not_called()
         tb.on_chord_part_return(task.request, 'SUCCESS', 10)
-        self.assertFalse(deps.join_native.called)
+        deps.join_native.assert_not_called()
 
         tb.on_chord_part_return(task.request, 'SUCCESS', 10)
         deps.join_native.assert_called_with(propagate=True, timeout=3.0)
@@ -135,6 +135,7 @@ class test_CacheBackend(AppCase):
         self.assertEqual(b.as_uri(), backend)
 
     @mock.stdouts
+    @skip.unless_module('memcached', name='python-memcached')
     def test_regression_worker_startup_info(self, stdout, stderr):
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'

+ 1 - 1
celery/tests/backends/test_couchbase.py

@@ -48,7 +48,7 @@ class test_CouchBaseBackend(AppCase):
             connection = self.backend._get_connection()
 
             self.assertEqual(sentinel._connection, connection)
-            self.assertFalse(mock_Connection.called)
+            mock_Connection.assert_not_called()
 
     def test_get(self):
         self.app.conf.couchbase_backend_settings = {}

+ 1 - 1
celery/tests/backends/test_couchdb.py

@@ -36,7 +36,7 @@ class test_CouchBackend(AppCase):
             connection = self.backend._get_connection()
 
             self.assertEqual(sentinel._connection, connection)
-            self.assertFalse(mock_Connection.called)
+            mock_Connection.assert_not_called()
 
     def test_get(self):
         """test_get

+ 3 - 3
celery/tests/backends/test_mongodb.py

@@ -143,7 +143,7 @@ class test_MongoBackend(AppCase):
             connection = self.backend._get_connection()
 
             self.assertEqual(sentinel._connection, connection)
-            self.assertFalse(mock_Connection.called)
+            mock_Connection.assert_not_called()
 
     def test_get_connection_no_connection_host(self):
         with patch('pymongo.MongoClient') as mock_Connection:
@@ -205,7 +205,7 @@ class test_MongoBackend(AppCase):
         database = self.backend.database
 
         self.assertTrue(database is mock_database)
-        self.assertFalse(mock_database.authenticate.called)
+        mock_database.authenticate.assert_not_called()
         self.assertTrue(self.backend.__dict__['database'] is mock_database)
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
@@ -371,7 +371,7 @@ class test_MongoBackend(AppCase):
         self.backend.cleanup()
 
         mock_get_database.assert_called_once_with()
-        self.assertTrue(mock_collection.remove.called)
+        mock_collection.remove.assert_called()
 
     def test_get_database_authfailure(self):
         x = MongoBackend(app=self.app)

+ 1 - 1
celery/tests/backends/test_redis.py

@@ -327,7 +327,7 @@ class test_RedisBackend(AppCase):
     def test_on_chord_part_return__success(self):
         with self.chord_context(2) as (_, request, callback):
             self.b.on_chord_part_return(request, states.SUCCESS, 10)
-            self.assertFalse(callback.delay.called)
+            callback.delay.assert_not_called()
             self.b.on_chord_part_return(request, states.SUCCESS, 20)
             callback.delay.assert_called_with([10, 20])
 

+ 1 - 1
celery/tests/backends/test_riak.py

@@ -45,7 +45,7 @@ class test_RiakBackend(AppCase):
             mocked_is_alive.return_value.value = True
             client = self.backend._get_client()
             self.assertEquals(sentinel._client, client)
-            self.assertFalse(mock_connection.called)
+            mock_connection.assert_not_called()
 
     def test_get(self):
         self.app.conf.couchbase_backend_settings = {}

+ 1 - 1
celery/tests/bin/test_amqp.py

@@ -60,7 +60,7 @@ class test_AMQShell(AppCase):
         self.shell.say = Mock()
         self.assertFalse(self.shell.needs_reconnect)
         self.shell.onecmd('hello')
-        self.assertTrue(self.shell.say.called)
+        self.shell.say.assert_called()
         self.assertTrue(self.shell.needs_reconnect)
 
     def test_exit(self):

+ 2 - 2
celery/tests/bin/test_base.py

@@ -54,7 +54,7 @@ class test_Extensions(AppCase):
                 symbyname.side_effect = SyntaxError()
                 with patch('warnings.warn') as warn:
                     e.load()
-                    self.assertTrue(warn.called)
+                    warn.assert_called()
 
             with patch('celery.bin.base.symbol_by_name') as symbyname:
                 symbyname.side_effect = KeyError('foo')
@@ -219,7 +219,7 @@ class test_Command(AppCase):
         cmd.respects_app_option = False
         with patch('celery.bin.base.Celery') as cp:
             cmd.setup_app_from_commandline(['--app=x.y:z'])
-            self.assertTrue(cp.called)
+            cp.assert_called()
 
     def test_setup_app_custom_app(self):
         cmd = MockCommand(app=self.app)

+ 3 - 3
celery/tests/bin/test_beat.py

@@ -70,7 +70,7 @@ class test_Beat(AppCase):
         b = beatapp.Beat(app=self.app, no_color=True,
                          redirect_stdouts=False)
         b.setup_logging()
-        self.assertTrue(self.app.log.setup.called)
+        self.app.log.setup.assert_called()
         self.assertEqual(self.app.log.setup.call_args[1]['colorize'], False)
 
     def test_init_loader(self):
@@ -137,7 +137,7 @@ class test_Beat(AppCase):
             app=self.app, redirect_stdouts=False, socket_timeout=None,
         )
         b.start_scheduler()
-        self.assertTrue(logger.critical.called)
+        logger.critical.assert_called()
 
     @patch('celery.platforms.create_pidlock')
     @mock.stdouts
@@ -145,7 +145,7 @@ class test_Beat(AppCase):
         b = MockBeat2(app=self.app, pidfile='pidfilelockfilepid',
                       socket_timeout=None, redirect_stdouts=False)
         b.start_scheduler()
-        self.assertTrue(create_pidlock.called)
+        create_pidlock.assert_called()
 
 
 class MockDaemonContext(object):

+ 11 - 11
celery/tests/bin/test_celery.py

@@ -49,7 +49,7 @@ class test__main__(AppCase):
                 prev, sys.argv = sys.argv, ['foo', 'multi']
                 try:
                     __main__.main()
-                    self.assertFalse(mpc.called)
+                    mpc.assert_not_called()
                     main.assert_called_with()
                 finally:
                     sys.argv = prev
@@ -71,7 +71,7 @@ class test_Command(AppCase):
     def test_error(self):
         self.cmd.out = Mock()
         self.cmd.error('FOO')
-        self.assertTrue(self.cmd.out.called)
+        self.cmd.out.assert_called()
 
     def test_out(self):
         f = Mock()
@@ -147,7 +147,7 @@ class test_call(AppCase):
     def test_run(self, send_task):
         a = call(app=self.app, stderr=WhateverIO(), stdout=WhateverIO())
         a.run(self.add.name)
-        self.assertTrue(send_task.called)
+        send_task.assert_called()
 
         a.run(self.add.name,
               args=dumps([4, 4]),
@@ -240,10 +240,10 @@ class test_migrate(AppCase):
         m = migrate(app=self.app, stdout=out, stderr=WhateverIO())
         with self.assertRaises(TypeError):
             m.run()
-        self.assertFalse(migrate_tasks.called)
+        migrate_tasks.assert_not_called()
 
         m.run('memory://foo', 'memory://bar')
-        self.assertTrue(migrate_tasks.called)
+        migrate_tasks.assert_called()
 
         state = Mock()
         state.count = 10
@@ -403,7 +403,7 @@ class test_CeleryCommand(AppCase):
         x = CeleryCommand(app=self.app)
         x.error = Mock()
         x.on_usage_error(x.UsageError('foo'), command=None)
-        self.assertTrue(x.error.called)
+        x.error.assert_called()
         x.on_usage_error(x.UsageError('foo'), command='dummy')
 
     def test_prepare_prog_name(self):
@@ -458,15 +458,15 @@ class test_inspect(AppCase):
         i.out = Mock()
         i.quiet = True
         i.say_chat('<-', 'hello out')
-        self.assertFalse(i.out.called)
+        i.out.assert_not_called()
 
         i.say_chat('->', 'hello in')
-        self.assertTrue(i.out.called)
+        i.out.assert_called()
 
         i.quiet = False
         i.out.reset_mock()
         i.say_chat('<-', 'hello out', 'body')
-        self.assertTrue(i.out.called)
+        i.out.assert_called()
 
     @patch('celery.app.control.Control.inspect')
     def test_run(self, real):
@@ -480,7 +480,7 @@ class test_inspect(AppCase):
             i.run('xyzzybaz')
 
         i.run('ping')
-        self.assertTrue(real.called)
+        real.assert_called()
         i.run('ping', destination='foo,bar')
         self.assertEqual(real.call_args[1]['destination'], ['foo', 'bar'])
         self.assertEqual(real.call_args[1]['timeout'], 0.2)
@@ -491,7 +491,7 @@ class test_inspect(AppCase):
 
         with patch('celery.bin.celery.json.dumps') as dumps:
             i.run('ping', json=True)
-            self.assertTrue(dumps.called)
+            dumps.assert_called()
 
         instance = real.return_value = Mock()
         instance.ping.return_value = None

+ 2 - 2
celery/tests/bin/test_celeryd_detach.py

@@ -41,7 +41,7 @@ if not IS_WINDOWS:
                 logfile='/var/log', pidfile='/var/pid',
                 hostname='foo@example.com', app=self.app)
             context.__enter__.assert_called_with()
-            self.assertTrue(logger.critical.called)
+            logger.critical.assert_called()
             setup_logs.assert_called_with(
                 'ERROR', '/var/log', hostname='foo@example.com')
             self.assertEqual(r, 1)
@@ -109,7 +109,7 @@ class test_Command(AppCase):
     def test_execute_from_commandline(self, detach, exit):
         x = detached_celeryd(app=self.app)
         x.execute_from_commandline(self.argv)
-        self.assertTrue(exit.called)
+        exit.assert_called()
         detach.assert_called_with(
             path=x.execv_path, uid=None, gid=None,
             umask=None, fake=False, logfile='/var/log', pidfile='celeryd.pid',

+ 2 - 2
celery/tests/bin/test_celeryevdump.py

@@ -63,7 +63,7 @@ class test_Dumper(AppCase):
             conn.channel_errors = ()
 
             evdump(app)
-            self.assertTrue(conn.ensure_connection.called)
+            conn.ensure_connection.assert_called()
             errback = conn.ensure_connection.call_args[0][0]
             errback(KeyError(), 1)
-            self.assertTrue(conn.as_uri.called)
+            conn.as_uri.assert_called()

+ 2 - 2
celery/tests/bin/test_events.py

@@ -56,8 +56,8 @@ class test_events(AppCase):
     def test_run_cam_detached(self, detached, evcam):
         self.ev.prog_name = 'celery events'
         self.ev.run_evcam('myapp.Camera', detach=True)
-        self.assertTrue(detached.called)
-        self.assertTrue(evcam.called)
+        detached.assert_called()
+        evcam.assert_called()
 
     def test_get_options(self):
         self.assertFalse(self.ev.get_options())

+ 5 - 5
celery/tests/bin/test_multi.py

@@ -190,7 +190,7 @@ class test_MultiTool(AppCase):
 
         self.t.carp = Mock()
         self.assertEqual(self.t.error(), 1)
-        self.assertFalse(self.t.carp.called)
+        self.t.carp.assert_not_called()
 
         self.assertEqual(self.t.retcode, 1)
 
@@ -240,7 +240,7 @@ class test_MultiTool(AppCase):
         stop = self.t._stop_nodes = Mock()
         self.t.restart(['jerry', 'george'], 'celery worker')
         waitexec = self.t.waitexec = Mock()
-        self.assertTrue(stop.called)
+        stop.assert_called()
         callback = stop.call_args[1]['callback']
         self.assertTrue(callback)
 
@@ -322,7 +322,7 @@ class test_MultiTool(AppCase):
                     '-n bar@e.com', '')),
         )
         self.assertEqual(node_1[2], 11)
-        self.assertTrue(callback.called)
+        callback.assert_called()
         cargs, _ = callback.call_args
         self.assertEqual(cargs[0], 'baz@e.com')
         self.assertItemsEqual(
@@ -359,7 +359,7 @@ class test_MultiTool(AppCase):
             {tup[0] for tup in sigs},
         )
         self.t.signal_node.return_value = False
-        self.assertTrue(callback.called)
+        callback.assert_called()
         self.t.stop(['foo', 'bar', 'baz'], 'celery worker', callback=None)
 
         def on_node_alive(pid):
@@ -433,7 +433,7 @@ class test_MultiTool(AppCase):
         start = self.t.commands['start'] = Mock()
         self.t.error = Mock()
         self.t.execute_from_commandline(['multi', 'start', 'foo', 'bar'])
-        self.assertFalse(self.t.error.called)
+        self.t.error.assert_not_called()
         start.assert_called_with(['foo', 'bar'], 'celery worker')
 
         self.t.error = Mock()

+ 6 - 6
celery/tests/bin/test_worker.py

@@ -73,16 +73,16 @@ class test_Worker(WorkerAppCase):
             pass
         x.run = run
         x.run_from_argv('celery', [])
-        self.assertTrue(x.maybe_detach.called)
+        x.maybe_detach.assert_called()
 
     def test_maybe_detach(self):
         x = worker(app=self.app)
         with patch('celery.bin.worker.detached_celeryd') as detached:
             x.maybe_detach([])
-            self.assertFalse(detached.called)
+            detached.assert_not_called()
             with self.assertRaises(SystemExit):
                 x.maybe_detach(['--detach'])
-            self.assertTrue(detached.called)
+            detached.assert_called()
 
     @mock.stdouts
     def test_invalid_loglevel_gives_error(self, stdout, stderr):
@@ -283,7 +283,7 @@ class test_Worker(WorkerAppCase):
         worker = self.Worker(app=self.app, redirect_stoutds=True)
         worker._custom_logging = True
         worker.on_start()
-        self.assertFalse(self.app.log.redirect_stdouts.called)
+        self.app.log.redirect_stdouts.assert_not_called()
 
     def test_setup_logging_no_color(self):
         worker = self.Worker(
@@ -581,7 +581,7 @@ class test_signal_handlers(WorkerAppCase):
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
-        self.assertTrue(stderr.write.called)
+        stderr.write.assert_called()
 
     @skip.unless_module('multiprocessing')
     def test_worker_term_handler_only_stop_MainProcess(self):
@@ -620,7 +620,7 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers['SIGHUP']('SIGHUP', object())
             self.assertEqual(state.should_stop, EX_OK)
-            self.assertTrue(register.called)
+            register.assert_called()
             callback = register.call_args[0][0]
             callback()
             self.assertTrue(argv)

+ 2 - 2
celery/tests/concurrency/test_concurrency.py

@@ -77,7 +77,7 @@ class test_BasePool(AppCase):
         callback = Mock(name='callback')
         target.side_effect = BaseException()
         apply_target(target, callback=callback)
-        self.assertTrue(callback.called)
+        callback.assert_called()
 
     @patch('celery.concurrency.base.reraise')
     def test_apply_target__raises_BaseException_raises_else(self, reraise):
@@ -87,7 +87,7 @@ class test_BasePool(AppCase):
         target.side_effect = BaseException()
         with self.assertRaises(KeyError):
             apply_target(target, callback=callback)
-        self.assertFalse(callback.called)
+        callback.assert_not_called()
 
     def test_does_not_debug(self):
         x = BasePool(10)

+ 1 - 1
celery/tests/concurrency/test_eventlet.py

@@ -111,7 +111,7 @@ class test_TaskPool(EventletCase):
     @patch('celery.concurrency.eventlet.base')
     def test_apply_target(self, base):
         apply_target(Mock(), getpid=Mock())
-        self.assertTrue(base.apply_target.called)
+        base.apply_target.assert_called()
 
     def test_grow(self):
         x = TaskPool(10)

+ 2 - 2
celery/tests/concurrency/test_gevent.py

@@ -32,7 +32,7 @@ class test_gevent_patch(GeventCase):
             gevent.version_info = (1, 0, 0)
             from celery import maybe_patch_concurrency
             maybe_patch_concurrency(['x', '-P', 'gevent'])
-            self.assertTrue(patch_all.called)
+            patch_all.assert_called()
 
 
 class test_Timer(GeventCase):
@@ -118,7 +118,7 @@ class test_apply_timeout(AppCase):
                 apply_target=apply_target, Timeout=Timeout,
             )
             self.assertEqual(Timeout.value, 10)
-            self.assertTrue(apply_target.called)
+            apply_target.assert_called()
 
             apply_target.side_effect = Timeout(10)
             apply_timeout(

+ 8 - 11
celery/tests/concurrency/test_prefork.py

@@ -64,10 +64,7 @@ class test_process_initializer(AppCase):
             from celery.concurrency.prefork import (
                 process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
             )
-
-            def on_worker_process_init(**kwargs):
-                on_worker_process_init.called = True
-            on_worker_process_init.called = False
+            on_worker_process_init = Mock()
             signals.worker_process_init.connect(on_worker_process_init)
 
             def Loader(*args, **kwargs):
@@ -82,7 +79,7 @@ class test_process_initializer(AppCase):
                 _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
                 _signals.reset.assert_any_call(*WORKER_SIGRESET)
                 self.assertTrue(app.loader.init_worker.call_count)
-                self.assertTrue(on_worker_process_init.called)
+                on_worker_process_init.assert_called()
                 self.assertIs(_tls.current_app, app)
                 set_mp_process_title.assert_called_with(
                     'celeryd', hostname='awesome.worker.com',
@@ -233,7 +230,7 @@ class test_AsynPool(PoolCase):
             )
             self.assertIn(3, readers)
 
-        with patch('select.poll') as poller:
+        with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll.side_effect = ebadf
             with patch('select.select') as selcheck:
@@ -245,13 +242,13 @@ class test_AsynPool(PoolCase):
                 )
                 self.assertNotIn(3, readers)
 
-        with patch('select.poll') as poller:
+        with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll.side_effect = MemoryError()
             with self.assertRaises(MemoryError):
                 asynpool._select({1}, poll=poll)
 
-        with patch('select.poll') as poller:
+        with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             with patch('select.select') as selcheck:
 
@@ -262,7 +259,7 @@ class test_AsynPool(PoolCase):
                 with self.assertRaises(MemoryError):
                     asynpool._select({3}, poll=poll)
 
-        with patch('select.poll') as poller:
+        with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             with patch('select.select') as selcheck:
 
@@ -274,7 +271,7 @@ class test_AsynPool(PoolCase):
                 with self.assertRaises(socket.error):
                     asynpool._select({3}, poll=poll)
 
-        with patch('select.poll') as poller:
+        with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
 
             poll.side_effect = socket.error()
@@ -373,7 +370,7 @@ class test_TaskPool(PoolCase):
         pool._pool = Mock(name='pool')
         pool._pool._state = mp.CLOSE
         pool.on_close()
-        self.assertFalse(pool._pool.close.called)
+        pool._pool.close.assert_not_called()
 
     def test_apply_async(self):
         pool = TaskPool(10)

+ 19 - 19
celery/tests/contrib/test_migrate.py

@@ -76,7 +76,7 @@ class test_move(AppCase):
                 pred = Mock(name='predicate')
                 move(pred, app=self.app,
                      connection=self.app.connection(), **kwargs)
-                self.assertTrue(start.called)
+                start.assert_called()
                 callback = start.call_args[0][2]
                 yield callback, pred, republish
 
@@ -89,13 +89,13 @@ class test_move(AppCase):
             pred.return_value = None
             body, message = self.msgpair()
             callback(body, message)
-            self.assertFalse(message.ack.called)
-            self.assertFalse(republish.called)
+            message.ack.assert_not_called()
+            republish.assert_not_called()
 
             pred.return_value = 'foo'
             callback(body, message)
             message.ack.assert_called_with()
-            self.assertTrue(republish.called)
+            republish.assert_called()
 
     def test_move_transform(self):
         trans = Mock(name='transform')
@@ -106,8 +106,8 @@ class test_move(AppCase):
             with patch('celery.contrib.migrate.maybe_declare') as maybed:
                 callback(body, message)
                 trans.assert_called_with('foo')
-                self.assertTrue(maybed.called)
-                self.assertTrue(republish.called)
+                maybed.assert_called()
+                republish.assert_called()
 
     def test_limit(self):
         with self.move_context(limit=1) as (callback, pred, republish):
@@ -115,7 +115,7 @@ class test_move(AppCase):
             body, message = self.msgpair()
             with self.assertRaises(StopFiltering):
                 callback(body, message)
-            self.assertTrue(republish.called)
+            republish.assert_called()
 
     def test_callback(self):
         cb = Mock()
@@ -123,8 +123,8 @@ class test_move(AppCase):
             pred.return_value = 'foo'
             body, message = self.msgpair()
             callback(body, message)
-            self.assertTrue(republish.called)
-            self.assertTrue(cb.called)
+            republish.assert_called()
+            cb.assert_called()
 
 
 class test_start_filter(AppCase):
@@ -157,12 +157,12 @@ class test_start_filter(AppCase):
             start_filter(app, conn, filt, tasks='add,mul', callback=cb)
             for callback in consumer.callbacks:
                 callback(body, Message(body))
-            self.assertTrue(cb.called)
+            cb.assert_called()
 
             on_declare_queue = Mock()
             start_filter(app, conn, filt, tasks='add,mul', queues='foo',
                          on_declare_queue=on_declare_queue)
-            self.assertTrue(on_declare_queue.called)
+            on_declare_queue.assert_called()
             start_filter(app, conn, filt, queues=['foo', 'bar'])
             consumer.callbacks[:] = []
             state = State()
@@ -188,7 +188,7 @@ class test_filter_callback(AppCase):
 
         message = Mock()
         filt(t2, message)
-        self.assertFalse(callback.called)
+        callback.assert_not_called()
         filt(t1, message)
         callback.assert_called_with(t1, message)
 
@@ -221,21 +221,21 @@ class test_utils(AppCase):
     def test_move_by_taskmap(self):
         with patch('celery.contrib.migrate.move') as move:
             move_by_taskmap({'add': Queue('foo')})
-            self.assertTrue(move.called)
+            move.assert_called()
             cb = move.call_args[0][0]
             self.assertTrue(cb({'task': 'add'}, Mock()))
 
     def test_move_by_idmap(self):
         with patch('celery.contrib.migrate.move') as move:
             move_by_idmap({'123f': Queue('foo')})
-            self.assertTrue(move.called)
+            move.assert_called()
             cb = move.call_args[0][0]
             self.assertTrue(cb({'id': '123f'}, Mock()))
 
     def test_move_task_by_id(self):
         with patch('celery.contrib.migrate.move') as move:
             move_task_by_id('123f', Queue('foo'))
-            self.assertTrue(move.called)
+            move.assert_called()
             cb = move.call_args[0][0]
             self.assertEqual(
                 cb({'id': '123f'}, Mock()),
@@ -249,7 +249,7 @@ class test_migrate_task(AppCase):
         x = Message('foo', compression='zlib')
         producer = Mock()
         migrate_task(producer, x.body, x)
-        self.assertTrue(producer.publish.called)
+        producer.publish.assert_called()
         args, kwargs = producer.publish.call_args
         self.assertIsInstance(args[0], bytes_t)
         self.assertNotIn('compression', kwargs['headers'])
@@ -289,12 +289,12 @@ class test_migrate_tasks(AppCase):
         callback = Mock()
         migrate_tasks(x, y,
                       callback=callback, accept=['text/plain'], app=self.app)
-        self.assertTrue(callback.called)
+        callback.assert_called()
         migrate = Mock()
         Producer(x).publish('baz', exchange=name, routing_key=name)
         migrate_tasks(x, y, callback=callback,
                       migrate=migrate, accept=['text/plain'], app=self.app)
-        self.assertTrue(migrate.called)
+        migrate.assert_called()
 
         with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
 
@@ -311,4 +311,4 @@ class test_migrate_tasks(AppCase):
         callback = Mock()
         migrate_tasks(x, y,
                       callback=callback, accept=['text/plain'], app=self.app)
-        self.assertFalse(callback.called)
+        callback.assert_not_called()

+ 2 - 2
celery/tests/contrib/test_rdb.py

@@ -29,7 +29,7 @@ class test_Rdb(AppCase):
     def test_set_trace(self, _frame, debugger):
         self.assertTrue(set_trace(Mock()))
         self.assertTrue(set_trace())
-        self.assertTrue(debugger.return_value.set_trace.called)
+        debugger.return_value.set_trace.assert_called()
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
     @skip.if_pypy()
@@ -39,7 +39,7 @@ class test_Rdb(AppCase):
         sock.accept.return_value = (Mock(), ['helu'])
         out = WhateverIO()
         with Rdb(out=out) as rdb:
-            self.assertTrue(get_avail_port.called)
+            get_avail_port.assert_called()
             self.assertIn('helu', out.getvalue())
 
             # set_quit

+ 1 - 1
celery/tests/events/test_events.py

@@ -258,7 +258,7 @@ class test_EventReceiver(AppCase):
             localize=False,
             adjust_timestamp=ts_adjust,
         )
-        self.assertFalse(ts_adjust.called)
+        ts_adjust.assert_not_called()
         r.adjust_clock.assert_called_with(313)
 
     def test_event_from_message_clock_from_client(self):

+ 8 - 19
celery/tests/events/test_snapshot.py

@@ -3,18 +3,7 @@ from __future__ import absolute_import, unicode_literals
 from celery.events import Events
 from celery.events.snapshot import Polaroid, evcam
 
-from celery.tests.case import AppCase, mock, patch
-
-
-class TRef(object):
-    active = True
-    called = False
-
-    def __call__(self):
-        self.called = True
-
-    def cancel(self):
-        self.active = False
+from celery.tests.case import AppCase, Mock, mock, patch
 
 
 class MockTimer(object):
@@ -22,7 +11,7 @@ class MockTimer(object):
 
     def call_repeatedly(self, secs, fun, *args, **kwargs):
         self.installed.append(fun)
-        return TRef()
+        return Mock(name='TRef')
 timer = MockTimer()
 
 
@@ -47,13 +36,13 @@ class test_Polaroid(AppCase):
         x.__enter__()
         self.assertIn(x.capture, MockTimer.installed)
         self.assertIn(x.cleanup, MockTimer.installed)
-        self.assertTrue(x._tref.active)
-        self.assertTrue(x._ctref.active)
+        x._tref.cancel.assert_not_called()
+        x._ctref.cancel.assert_not_called()
         x.__exit__()
-        self.assertFalse(x._tref.active)
-        self.assertFalse(x._ctref.active)
-        self.assertTrue(x._tref.called)
-        self.assertFalse(x._ctref.called)
+        x._tref.cancel.assert_called()
+        x._ctref.cancel.assert_called()
+        x._tref.assert_called()
+        x._ctref.assert_not_called()
 
     def test_cleanup(self):
         x = Polaroid(self.state, app=self.app)

+ 1 - 1
celery/tests/events/test_state.py

@@ -251,7 +251,7 @@ class test_Worker(AppCase):
         worker = Worker(hostname='foo')
         with patch('celery.events.state.warn') as warn:
             worker.event(None, time() + (HEARTBEAT_DRIFT_MAX * 2), time())
-            self.assertTrue(warn.called)
+            warn.assert_called()
             self.assertIn('Substantial drift', warn.call_args[0][0])
 
     def test_updates_heartbeat(self):

+ 11 - 11
celery/tests/fixups/test_django.py

@@ -74,15 +74,15 @@ class test_DjangoFixup(FixupCase):
         with patch('celery.fixups.django.DjangoFixup') as Fixup:
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE=''):
                 fixup(self.app)
-                self.assertFalse(Fixup.called)
+                Fixup.assert_not_called()
             with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
                 with mock.mask_modules('django'):
                     with self.assertWarnsRegex(UserWarning, 'but Django is'):
                         fixup(self.app)
-                    self.assertFalse(Fixup.called)
+                    Fixup.assert_not_called()
                 with mock.module_exists('django'):
                     fixup(self.app)
-                    self.assertTrue(Fixup.called)
+                    Fixup.assert_called()
 
     def test_maybe_close_fd(self):
         with patch('os.close'):
@@ -116,9 +116,9 @@ class test_DjangoFixup(FixupCase):
     def test_now(self):
         with self.fixup_context(self.app) as (f, _, _):
             self.assertTrue(f.now(utc=True))
-            self.assertFalse(f._now.called)
+            f._now.assert_not_called()
             self.assertTrue(f.now(utc=False))
-            self.assertTrue(f._now.called)
+            f._now.assert_called()
 
     def test_mail_admins(self):
         with self.fixup_context(self.app) as (f, _, _):
@@ -204,7 +204,7 @@ class test_DjangoWorkerFixup(FixupCase):
             task.request.is_eager = True
             with patch.object(f, 'close_database'):
                 f.on_task_prerun(task)
-                self.assertFalse(f.close_database.called)
+                f.close_database.assert_not_called()
 
     def test_on_task_postrun(self):
         task = Mock()
@@ -213,16 +213,16 @@ class test_DjangoWorkerFixup(FixupCase):
                 task.request.is_eager = False
                 with patch.object(f, 'close_database'):
                     f.on_task_postrun(task)
-                    self.assertTrue(f.close_database.called)
-                    self.assertTrue(f.close_cache.called)
+                    f.close_database.assert_called()
+                    f.close_cache.assert_called()
 
             # when a task is eager, do not close connections
             with patch.object(f, 'close_cache'):
                 task.request.is_eager = True
                 with patch.object(f, 'close_database'):
                     f.on_task_postrun(task)
-                    self.assertFalse(f.close_database.called)
-                    self.assertFalse(f.close_cache.called)
+                    f.close_database.assert_not_called()
+                    f.close_cache.assert_not_called()
 
     def test_close_database(self):
         with self.fixup_context(self.app) as (f, _, _):
@@ -239,7 +239,7 @@ class test_DjangoWorkerFixup(FixupCase):
                 f.db_reuse_max = 10
                 f._db_recycles = 3
                 f.close_database()
-                self.assertFalse(_close.called)
+                _close.assert_not_called()
                 self.assertEqual(f._db_recycles, 4)
                 _close.reset_mock()
 

+ 10 - 8
celery/tests/tasks/test_canvas.py

@@ -20,11 +20,13 @@ from celery.tests.case import (
     AppCase, ContextMock, MagicMock, Mock, depends_on_current_app,
 )
 
-SIG = Signature({'task': 'TASK',
-                 'args': ('A1',),
-                 'kwargs': {'K1': 'V1'},
-                 'options': {'task_id': 'TASK_ID'},
-                 'subtask_type': ''})
+SIG = Signature({
+    'task': 'TASK',
+    'args': ('A1',),
+    'kwargs': {'K1': 'V1'},
+    'options': {'task_id': 'TASK_ID'},
+    'subtask_type': ''},
+)
 
 
 class test_maybe_unroll_group(AppCase):
@@ -158,7 +160,7 @@ class test_Signature(CanvasCase):
         x.apply_async.return_value.get = Mock()
         x.apply_async.return_value.get.return_value = 4
         self.assertEqual(~x, 4)
-        self.assertTrue(x.apply_async.called)
+        x.apply_async.assert_called()
 
     def test_merge_immutable(self):
         x = self.add.si(2, 2, foo=1)
@@ -180,7 +182,7 @@ class test_Signature(CanvasCase):
         x.freeze('foo')
         x.type.app.control = Mock()
         r = x.election()
-        self.assertTrue(x.type.app.control.election.called)
+        x.type.app.control.election.assert_called()
         self.assertEqual(r.id, 'foo')
 
     def test_AsyncResult_when_not_registered(self):
@@ -444,7 +446,7 @@ class test_chain(CanvasCase):
         self.app.producer_or_acquire = ContextMock()
 
         task.apply_async(**options)
-        self.assertTrue(self.app.amqp.send_task_message.called)
+        self.app.amqp.send_task_message.assert_called()
         message = self.app.amqp.send_task_message.call_args[0][2]
         self.assertEqual(message.headers['parent_id'], pid)
         self.assertEqual(message.headers['root_id'], rid)

+ 7 - 7
celery/tests/tasks/test_chord.py

@@ -98,7 +98,7 @@ class test_unlock_chord_task(ChordCase):
             callback.apply_async.side_effect = IOError()
 
         with self._chord_context(AlwaysReady, setup) as (cb, retry, fail):
-            self.assertTrue(fail.called)
+            fail.assert_called()
             self.assertEqual(
                 fail.call_args[0][0], cb.id,
             )
@@ -113,10 +113,10 @@ class test_unlock_chord_task(ChordCase):
             value = [2, KeyError('foo'), 8, 6]
 
         with self._chord_context(Failed) as (cb, retry, fail_current):
-            self.assertFalse(cb.type.apply_async.called)
+            cb.type.apply_async.assert_not_called()
             # did not retry
             self.assertFalse(retry.call_count)
-            self.assertTrue(fail_current.called)
+            fail_current.assert_called()
             self.assertEqual(
                 fail_current.call_args[0][0], cb.id,
             )
@@ -131,7 +131,7 @@ class test_unlock_chord_task(ChordCase):
             value = [2, KeyError('foo'), 8, 6]
 
         with self._chord_context(Failed) as (cb, retry, fail_current):
-            self.assertTrue(fail_current.called)
+            fail_current.assert_called()
             self.assertEqual(
                 fail_current.call_args[0][0], cb.id,
             )
@@ -182,7 +182,7 @@ class test_unlock_chord_task(ChordCase):
 
         with self._chord_context(NeverReady, interval=10, max_retries=30) \
                 as (cb, retry, _):
-            self.assertFalse(cb.type.apply_async.called)
+            cb.type.apply_async.assert_not_called()
             # did retry
             retry.assert_called_with(countdown=10, max_retries=30)
 
@@ -225,7 +225,7 @@ class test_chord(ChordCase):
             # does not modify original signature
             with self.assertRaises(KeyError):
                 body.options['task_id']
-            self.assertTrue(chord.run.called)
+            chord.run.assert_called()
         finally:
             chord.run = prev
 
@@ -262,7 +262,7 @@ class test_add_to_chord(AppCase):
         self.assertTrue(sig.options['task_id'])
         self.assertEqual(sig.options['group_id'], self.adds.request.group)
         self.assertEqual(sig.options['chord'], self.adds.request.chord)
-        self.assertFalse(sig.delay.called)
+        sig.delay.assert_not_called()
         self.app.backend.add_to_chord.assert_called_with(
             self.adds.request.group, sig.freeze(),
         )

+ 6 - 6
celery/tests/tasks/test_result.py

@@ -105,11 +105,11 @@ class test_AsyncResult(AppCase):
         x.parent = EagerResult(uuid(), KeyError('foo'), states.FAILURE)
         with self.assertRaises(KeyError):
             x.get(propagate=True)
-        self.assertFalse(x.backend.wait_for_pending.called)
+        x.backend.wait_for_pending.assert_not_called()
 
         x.parent = EagerResult(uuid(), 42, states.SUCCESS)
         self.assertEqual(x.get(propagate=True), 84)
-        self.assertTrue(x.backend.wait_for_pending.called)
+        x.backend.wait_for_pending.assert_called()
 
     def test_get_children(self):
         tid = uuid()
@@ -336,10 +336,10 @@ class test_ResultSet(AppCase):
         x.join_native = Mock()
         x.join = Mock()
         x.get()
-        self.assertTrue(x.join.called)
+        x.join.assert_called()
         b.supports_native_join = True
         x.get()
-        self.assertTrue(x.join_native.called)
+        x.join_native.assert_called()
 
     def test_eq_ne(self):
         g1 = self.app.ResultSet([
@@ -369,7 +369,7 @@ class test_ResultSet(AppCase):
         self.assertIsNone(x.supports_native_join)
         x.join = Mock(name='join')
         x.get()
-        self.assertTrue(x.join.called)
+        x.join.assert_called()
 
     def test_add(self):
         x = self.app.ResultSet([self.app.AsyncResult(1)])
@@ -417,7 +417,7 @@ class test_ResultSet(AppCase):
                         ready.return_value = False
                         ready.side_effect = se
                         list(x.iterate())
-                    self.assertFalse(_time.sleep.called)
+                    _time.sleep.assert_not_called()
 
     def test_times_out(self):
         r1 = self.app.AsyncResult(uuid)

+ 6 - 6
celery/tests/tasks/test_trace.py

@@ -68,7 +68,7 @@ class test_trace(TraceCase):
             return x + y
 
         self.trace(add_with_success, (2, 2), {})
-        self.assertTrue(add_with_success.on_success.called)
+        add_with_success.on_success.assert_called()
 
     def test_get_log_policy(self):
         einfo = Mock(name='einfo')
@@ -104,14 +104,14 @@ class test_trace(TraceCase):
             return x + y
 
         self.trace(add_with_after_return, (2, 2), {})
-        self.assertTrue(add_with_after_return.after_return.called)
+        add_with_after_return.after_return.assert_called()
 
     def test_with_prerun_receivers(self):
         on_prerun = Mock()
         signals.task_prerun.connect(on_prerun)
         try:
             self.trace(self.add, (2, 2), {})
-            self.assertTrue(on_prerun.called)
+            on_prerun.assert_called()
         finally:
             signals.task_prerun.receivers[:] = []
 
@@ -120,7 +120,7 @@ class test_trace(TraceCase):
         signals.task_postrun.connect(on_postrun)
         try:
             self.trace(self.add, (2, 2), {})
-            self.assertTrue(on_postrun.called)
+            on_postrun.assert_called()
         finally:
             signals.task_postrun.receivers[:] = []
 
@@ -129,7 +129,7 @@ class test_trace(TraceCase):
         signals.task_success.connect(on_success)
         try:
             self.trace(self.add, (2, 2), {})
-            self.assertTrue(on_success.called)
+            on_success.assert_called()
         finally:
             signals.task_success.receivers[:] = []
 
@@ -142,7 +142,7 @@ class test_trace(TraceCase):
 
         request = {'chord': uuid()}
         self.trace(add, (2, 2), {}, request=request)
-        self.assertTrue(add.backend.mark_as_done.called)
+        add.backend.mark_as_done.assert_called()
         args, kwargs = add.backend.mark_as_done.call_args
         self.assertEqual(args[0], 'id-1')
         self.assertEqual(args[1], 4)

+ 2 - 2
celery/tests/utils/test_imports.py

@@ -33,12 +33,12 @@ class test_import_utils(Case):
     @patch('celery.utils.imports.reload')
     def test_reload_from_cwd(self, reload):
         reload_from_cwd('foo')
-        self.assertTrue(reload.called)
+        reload.assert_called()
 
     def test_reload_from_cwd_custom_reloader(self):
         reload = Mock()
         reload_from_cwd('foo', reload)
-        self.assertTrue(reload.called)
+        reload.assert_called()
 
     def test_module_file(self):
         m1 = Mock()

+ 1 - 1
celery/tests/utils/test_mail.py

@@ -31,7 +31,7 @@ class test_Mailer(Case):
         mailer = Mailer(use_ssl=True, use_tls=True)
         client = SMTP_SSL.return_value = Mock()
         mailer._send(msg)
-        self.assertTrue(client.starttls.called)
+        client.starttls.assert_called()
         self.assertEqual(client.ehlo.call_count, 2)
         client.quit.assert_called_with()
         client.sendmail.assert_called_with(msg.sender, msg.to, str(msg))

+ 6 - 6
celery/tests/utils/test_platforms.py

@@ -147,7 +147,7 @@ class test_Signals(Case):
         if hasattr(signal, 'setitimer'):
             with patch('signal.setitimer', create=True) as seti:
                 signals.arm_alarm(30)
-                self.assertTrue(seti.called)
+                seti.assert_called()
 
     def test_signum(self):
         self.assertEqual(signals.signum(13), 13)
@@ -315,7 +315,7 @@ class test_maybe_drop_privileges(Case):
         maybe_drop_privileges(gid='group')
         parse_gid.assert_called_with('group')
         setgid.assert_called_with(50001)
-        self.assertFalse(setuid.called)
+        setuid.assert_not_called()
 
 
 @skip.if_win32()
@@ -476,11 +476,11 @@ class test_DaemonContext(Case):
                 pass
         self.assertEqual(fork.call_count, 2)
         setsid.assert_called_with()
-        self.assertFalse(_exit.called)
+        _exit.assert_not_called()
 
         chdir.assert_called_with(x.workdir)
         umask.assert_called_with(0o22)
-        self.assertTrue(dup2.called)
+        dup2.assert_called()
 
         fork.reset_mock()
         fork.return_value = 1
@@ -496,7 +496,7 @@ class test_DaemonContext(Case):
         x._detach = Mock()
         with x:
             pass
-        self.assertFalse(x._detach.called)
+        x._detach.assert_not_called()
 
         x.after_chdir = Mock()
         with x:
@@ -700,7 +700,7 @@ class test_Pidfile(Case):
         p.write_pid()
         w.seek(0)
         self.assertEqual(w.readline(), '1816\n')
-        self.assertTrue(w.close.called)
+        w.close.assert_called()
         getpid.assert_called_with()
         osopen.assert_called_with(
             p.path, platforms.PIDFILE_FLAGS, platforms.PIDFILE_MODE,

+ 1 - 1
celery/tests/utils/test_timer2.py

@@ -39,7 +39,7 @@ class test_Timer(Case):
         t.running = True
         t.start = Mock()
         t.ensure_started()
-        self.assertFalse(t.start.called)
+        t.start.assert_not_called()
         t.running = False
         t.on_start = Mock()
         t.ensure_started()

+ 1 - 1
celery/tests/utils/test_utils.py

@@ -179,4 +179,4 @@ class test_utils(Case):
     @patch('warnings.warn')
     def test_warn_deprecated(self, warn):
         warn_deprecated('Foo')
-        self.assertTrue(warn.called)
+        warn.assert_called()

+ 2 - 2
celery/tests/worker/test_autoreload.py

@@ -229,7 +229,7 @@ class test_InotifyMonitor(Case):
         x.process_(Mock())
         x._on_change = Mock()
         x.process_(Mock())
-        self.assertTrue(x._on_change.called)
+        x._on_change.assert_called()
 
         x.create_notifier = Mock()
         x._wm = Mock()
@@ -337,7 +337,7 @@ class test_Autoreloader(AppCase):
         x._reload = Mock()
         x.file_to_module[__name__] = __name__
         x.on_change([__name__])
-        self.assertTrue(x._reload.called)
+        x._reload.assert_called()
         mm.return_value = False
         x.on_change([__name__])
 

+ 2 - 2
celery/tests/worker/test_autoscale.py

@@ -100,14 +100,14 @@ class test_Autoscaler(AppCase):
         x.body()
         x.body()
         self.assertEqual(x.pool.num_processes, 10)
-        self.assertTrue(worker.consumer._update_prefetch_count.called)
+        worker.consumer._update_prefetch_count.assert_called()
         state.reserved_requests.clear()
         x.body()
         self.assertEqual(x.pool.num_processes, 10)
         x._last_scale_up = monotonic() - 10000
         x.body()
         self.assertEqual(x.pool.num_processes, 3)
-        self.assertTrue(worker.consumer._update_prefetch_count.called)
+        worker.consumer._update_prefetch_count.assert_called()
 
     def test_run(self):
 

+ 6 - 6
celery/tests/worker/test_consumer.py

@@ -126,7 +126,7 @@ class test_Consumer(AppCase):
                 priority=c._limit_order,
             )
             bucket.expected_time.assert_called_with(4)
-            self.assertFalse(reserv.called)
+            reserv.assert_not_called()
 
     def test_start_blueprint_raises_EMFILE(self):
         c = self.get_consumer()
@@ -192,7 +192,7 @@ class test_Consumer(AppCase):
         conn = self.app._connection.return_value
         c = self.get_consumer()
         self.assertTrue(c.connect())
-        self.assertTrue(conn.ensure_connection.called)
+        conn.ensure_connection.assert_called()
         errback = conn.ensure_connection.call_args[0][0]
         conn.alt = [(1, 2, 3)]
         errback(Mock(), 0)
@@ -375,7 +375,7 @@ class test_Gossip(AppCase):
         signature.return_value.apply_async.side_effect = MemoryError()
         with patch('celery.worker.consumer.gossip.error') as error:
             g.call_task(task)
-            self.assertTrue(error.called)
+            error.assert_called()
 
     def Event(self, id='id', clock=312,
               hostname='foo@example.com', pid=4312,
@@ -405,7 +405,7 @@ class test_Gossip(AppCase):
         event.pop('clock')
         with patch('celery.worker.consumer.gossip.error') as error:
             g.on_elect(event)
-            self.assertTrue(error.called)
+            error.assert_called()
 
     def Consumer(self, hostname='foo@x.com', pid=4312):
         c = Mock()
@@ -456,7 +456,7 @@ class test_Gossip(AppCase):
         g = Gossip(c)
         handler = g.election_handlers['topic'] = Mock()
         self.setup_election(g, c)
-        self.assertFalse(handler.called)
+        handler.assert_not_called()
 
     def test_on_elect_ack_win_but_no_action(self):
         c = self.Consumer(hostname='foo@x.com')  # I will win
@@ -465,7 +465,7 @@ class test_Gossip(AppCase):
         g.election_handlers = {}
         with patch('celery.worker.consumer.gossip.error') as error:
             self.setup_election(g, c)
-            self.assertTrue(error.called)
+            error.assert_called()
 
     def test_on_node_join(self):
         c = self.Consumer()

+ 13 - 13
celery/tests/worker/test_control.py

@@ -339,9 +339,9 @@ class test_ControlPanel(AppCase):
         panel.state.consumer.controller = Mock()
         sc = panel.state.consumer.controller.autoscaler = Mock()
         panel.handle('pool_grow')
-        self.assertTrue(sc.force_scale_up.called)
+        sc.force_scale_up.assert_called()
         panel.handle('pool_shrink')
-        self.assertTrue(sc.force_scale_down.called)
+        sc.force_scale_down.assert_called()
 
     def test_add__cancel_consumer(self):
 
@@ -565,11 +565,11 @@ class test_ControlPanel(AppCase):
 
         self.app.conf.worker_pool_restarts = True
         panel.handle('pool_restart', {'reloader': _reload})
-        self.assertTrue(consumer.controller.pool.restart.called)
+        consumer.controller.pool.restart.assert_called()
         consumer.reset_rate_limits.assert_called_with()
         consumer.update_strategies.assert_called_with()
-        self.assertFalse(_reload.called)
-        self.assertFalse(_import.called)
+        _reload.assert_not_called()
+        _import.assert_not_called()
         consumer.controller.pool.restart.side_effect = NotImplementedError()
         panel.handle('pool_restart', {'reloader': _reload})
         consumer.controller.consumer = None
@@ -592,10 +592,10 @@ class test_ControlPanel(AppCase):
         panel.handle('pool_restart', {'modules': ['foo', 'bar'],
                                       'reloader': _reload})
 
-        self.assertTrue(consumer.controller.pool.restart.called)
+        consumer.controller.pool.restart.assert_called()
         consumer.reset_rate_limits.assert_called_with()
         consumer.update_strategies.assert_called_with()
-        self.assertFalse(_reload.called)
+        _reload.assert_not_called()
         self.assertItemsEqual(
             [call('bar'), call('foo')],
             _import.call_args_list,
@@ -619,9 +619,9 @@ class test_ControlPanel(AppCase):
                                           'reload': False,
                                           'reloader': _reload})
 
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertFalse(_reload.called)
-            self.assertFalse(_import.called)
+            consumer.controller.pool.restart.assert_called()
+            _reload.assert_not_called()
+            _import.assert_not_called()
 
             _import.reset_mock()
             _reload.reset_mock()
@@ -631,9 +631,9 @@ class test_ControlPanel(AppCase):
                                           'reload': True,
                                           'reloader': _reload})
 
-            self.assertTrue(consumer.controller.pool.restart.called)
-            self.assertTrue(_reload.called)
-            self.assertFalse(_import.called)
+            consumer.controller.pool.restart.assert_called()
+            _reload.assert_called()
+            _import.assert_not_called()
 
     def test_query_task(self):
         consumer = Consumer(self.app)

+ 12 - 12
celery/tests/worker/test_loops.py

@@ -250,7 +250,7 @@ class test_asynloop(AppCase):
         x.hub.on_tick.add(x.closer(mod=2))
         x.hub.timer._queue = [1]
         asynloop(*x.args)
-        self.assertFalse(x.qos.update.called)
+        x.qos.update.assert_not_called()
 
         x = X(self.app)
         x.qos.prev = 1
@@ -282,7 +282,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         reader.assert_called_with(6)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_readable_raises_Empty(self):
         x = X(self.app)
@@ -295,7 +295,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         reader.assert_called_with(6)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_writable(self):
         x = X(self.app)
@@ -307,7 +307,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         writer.assert_called_with(6)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_writable_none_registered(self):
         x = X(self.app)
@@ -318,7 +318,7 @@ class test_asynloop(AppCase):
         poller.poll.return_value = [(7, WRITE)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_unknown_event(self):
         x = X(self.app)
@@ -329,7 +329,7 @@ class test_asynloop(AppCase):
         poller.poll.return_value = [(6, 0)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_keep_draining_disabled(self):
         x = X(self.app)
@@ -344,7 +344,7 @@ class test_asynloop(AppCase):
         poll.return_value = [(6, 0)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_err_writable(self):
         x = X(self.app)
@@ -356,7 +356,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         writer.assert_called_with(6, 48)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_write_generator(self):
         x = X(self.app)
@@ -373,7 +373,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         self.assertTrue(gen.gi_frame.f_lasti != -1)
-        self.assertFalse(x.hub.remove.called)
+        x.hub.remove.assert_not_called()
 
     def test_poll_write_generator_stopped(self):
         x = X(self.app)
@@ -416,7 +416,7 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         reader.assert_called_with(6, 24)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
     def test_poll_raises_ValueError(self):
         x = X(self.app)
@@ -424,7 +424,7 @@ class test_asynloop(AppCase):
         poller = x.hub.poller
         x.close_then_error(poller.poll, exc=ValueError)
         asynloop(*x.args)
-        self.assertTrue(poller.poll.called)
+        poller.poll.assert_called()
 
 
 class test_synloop(AppCase):
@@ -443,7 +443,7 @@ class test_synloop(AppCase):
         x.timeout_then_error(x.connection.drain_events)
         with self.assertRaises(socket.error):
             synloop(*x.args)
-        self.assertFalse(x.qos.update.called)
+        x.qos.update.assert_not_called()
 
         x.qos.value = 4
         x.timeout_then_error(x.connection.drain_events)

+ 10 - 10
celery/tests/worker/test_request.py

@@ -177,7 +177,7 @@ class test_trace_task(RequestCase):
         tid = uuid()
         ret = jail(self.app, tid, self.mytask.name, [2], {})
         self.assertEqual(ret, 4)
-        self.assertTrue(self.mytask.backend.mark_as_done.called)
+        self.mytask.backend.mark_as_done.assert_called()
         self.assertIn('Process cleanup failed', _logger.error.call_args[0][0])
 
     def test_process_cleanup_BaseException(self):
@@ -270,7 +270,7 @@ class test_Request(RequestCase):
         with patch('celery.worker.request.maybe_make_aware') as mma:
             self.get_request(self.add.s(2, 2).set(expires=10),
                              maybe_make_aware=mma)
-            self.assertTrue(mma.called)
+            mma.assert_called()
 
     def test_maybe_expire_when_expires_is_None(self):
         req = self.get_request(self.add.s(2, 2))
@@ -455,7 +455,7 @@ class test_Request(RequestCase):
         job = self.get_request(self.mytask.s(1, f='x'))
         job.time_start = None
         job.terminate(pool, signal='TERM')
-        self.assertFalse(pool.terminate_job.called)
+        pool.terminate_job.assert_not_called()
         self.assertTupleEqual(job._terminate_on_ack, (pool, 15))
         job.terminate(pool, signal='TERM')
 
@@ -587,7 +587,7 @@ class test_Request(RequestCase):
         job.eventer = Mock()
         job.eventer.send = Mock()
         job.on_success((0, 42, 0.001))
-        self.assertTrue(job.eventer.send.called)
+        job.eventer.send.assert_called()
 
     def test_on_success_when_failure(self):
         job = self.xRequest()
@@ -597,7 +597,7 @@ class test_Request(RequestCase):
             raise KeyError('foo')
         except Exception:
             job.on_success((1, ExceptionInfo(), 0.001))
-            self.assertTrue(job.on_failure.called)
+            job.on_failure.assert_called()
 
     def test_on_success_acks_late(self):
         job = self.xRequest()
@@ -673,7 +673,7 @@ class test_Request(RequestCase):
         job.acknowledge = Mock(name='ack')
         job.task.acks_late = False
         job.on_timeout(soft=True, timeout=1335)
-        self.assertFalse(job.acknowledge.called)
+        job.acknowledge.assert_not_called()
 
     def test_fast_trace_task(self):
         from celery.app import trace
@@ -917,7 +917,7 @@ class test_Request(RequestCase):
         except type(exception):
             exc_info = ExceptionInfo()
             job.on_failure(exc_info, **kwargs)
-            self.assertTrue(job.send_event.called)
+            job.send_event.assert_called()
         return job
 
     def test_on_failure(self):
@@ -951,7 +951,7 @@ class test_Request(RequestCase):
         self.assertTrue(job.acknowledged)
         job.on_reject.reset_mock()
         job.reject(requeue=True)
-        self.assertFalse(job.on_reject.called)
+        job.on_reject.assert_not_called()
 
     def test_group(self):
         gid = uuid()
@@ -1013,14 +1013,14 @@ class test_create_request_class(RequestCase):
         job = self.zRequest(id=uuid())
         job.acknowledge = Mock(name='ack')
         job.on_success((False, 'foo', 1.0))
-        self.assertFalse(job.acknowledge.called)
+        job.acknowledge.assert_not_called()
 
     def test_on_success__no_events(self):
         self.eventer = None
         job = self.zRequest(id=uuid())
         job.send_event = Mock(name='send_event')
         job.on_success((False, 'foo', 1.0))
-        self.assertFalse(job.send_event.called)
+        job.send_event.assert_not_called()
 
     def test_on_success__with_events(self):
         job = self.zRequest(id=uuid())

+ 1 - 1
celery/tests/worker/test_strategy.py

@@ -139,7 +139,7 @@ class test_default_strategy_proto2(AppCase):
             logger.isEnabledFor.return_value = False
             with self._context(self.add.s(2, 2)) as C:
                 C()
-                self.assertFalse(logger.info.called)
+                logger.info.assert_not_called()
 
     def test_task_strategy(self):
         with self._context(self.add.s(2, 2)) as C:

+ 6 - 6
celery/tests/worker/test_worker.py

@@ -199,7 +199,7 @@ class test_Consumer(AppCase):
         step = find_step(l, consumer.Connection)
         conn = l.connection = Mock()
         step.shutdown(l)
-        self.assertTrue(conn.close.called)
+        conn.close.assert_called()
         self.assertIsNone(l.connection)
 
         l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
@@ -268,7 +268,7 @@ class test_Consumer(AppCase):
 
         callback = self._get_on_message(l)
         callback(m)
-        self.assertTrue(error.called)
+        error.assert_called()
         self.assertIn('Received invalid task message', error.call_args[0][0])
 
     @patch('celery.worker.consumer.consumer.crit')
@@ -497,7 +497,7 @@ class test_Consumer(AppCase):
         con.node.handle_message.side_effect = ValueError('foo')
         con.on_message('foo', 'bar')
         con.node.handle_message.assert_called_with('foo', 'bar')
-        self.assertTrue(con.reset.called)
+        con.reset.assert_called()
 
     def test_revoke(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
@@ -705,7 +705,7 @@ class test_Consumer(AppCase):
         controller = find_step(l, consumer.Control)
         controller.box.loop(l)
 
-        self.assertTrue(controller.box.node.listen.called)
+        controller.box.node.listen.assert_called()
         self.assertTrue(controller.box.consumer)
         controller.box.consumer.consume.assert_called_with()
 
@@ -883,9 +883,9 @@ class test_WorkController(AppCase):
         worker = self.create_worker(pidfile='pidfilelockfilepid')
         worker.steps = []
         worker.start()
-        self.assertTrue(create_pidlock.called)
+        create_pidlock.assert_called()
         worker.stop()
-        self.assertTrue(worker.pidlock.release.called)
+        worker.pidlock.release.assert_called()
 
     def test_attrs(self):
         worker = self.worker

+ 1 - 1
requirements/test.txt

@@ -1 +1 @@
-case
+case>=1.1