Jelajahi Sumber

Use py.test for everything :-)

Ask Solem 8 tahun lalu
induk
melakukan
29df527147
100 mengubah file dengan 3420 tambahan dan 5078 penghapusan
  1. 4 0
      .gitignore
  2. 12 10
      CONTRIBUTING.rst
  3. 5 4
      Makefile
  4. 2 2
      celery/app/trace.py
  5. 0 96
      celery/tests/__init__.py
  6. 0 301
      celery/tests/app/test_amqp.py
  7. 0 19
      celery/tests/app/test_celery.py
  8. 0 65
      celery/tests/app/test_defaults.py
  9. 0 78
      celery/tests/app/test_registry.py
  10. 0 41
      celery/tests/backends/test_backends.py
  11. 0 349
      celery/tests/case.py
  12. 0 38
      celery/tests/compat_modules/test_decorators.py
  13. 0 129
      celery/tests/concurrency/test_gevent.py
  14. 0 54
      celery/tests/tasks/test_states.py
  15. 0 98
      celery/tests/utils/test_debug.py
  16. 0 20
      celery/tests/utils/test_encoding.py
  17. 0 210
      celery/tests/utils/test_functional.py
  18. 0 58
      celery/tests/utils/test_imports.py
  19. 0 372
      celery/tests/utils/test_local.py
  20. 0 77
      celery/tests/utils/test_serialization.py
  21. 0 27
      celery/tests/utils/test_sysinfo.py
  22. 0 87
      celery/tests/utils/test_term.py
  23. 0 94
      celery/tests/utils/test_text.py
  24. 0 253
      celery/tests/utils/test_timeutils.py
  25. 0 44
      celery/tests/utils/test_utils.py
  26. 7 11
      docs/contributing.rst
  27. 2 2
      docs/internals/guide.rst
  28. 1 1
      docs/whatsnew-4.0.rst
  29. 5 2
      funtests/suite/test_leak.py
  30. 1 0
      requirements/test-ci-base.txt
  31. 2 1
      requirements/test.txt
  32. 3 2
      setup.cfg
  33. 39 35
      setup.py
  34. 0 0
      t/__init__.py
  35. 418 0
      t/conftest.py
  36. 0 0
      t/unit/__init__.py
  37. 0 0
      t/unit/app/__init__.py
  38. 269 0
      t/unit/app/test_amqp.py
  39. 12 14
      t/unit/app/test_annotations.py
  40. 222 246
      t/unit/app/test_app.py
  41. 79 80
      t/unit/app/test_beat.py
  42. 19 17
      t/unit/app/test_builtins.py
  43. 18 0
      t/unit/app/test_celery.py
  44. 50 47
      t/unit/app/test_control.py
  45. 65 0
      t/unit/app/test_defaults.py
  46. 8 10
      t/unit/app/test_exceptions.py
  47. 32 36
      t/unit/app/test_loaders.py
  48. 45 60
      t/unit/app/test_log.py
  49. 72 0
      t/unit/app/test_registry.py
  50. 30 37
      t/unit/app/test_routes.py
  51. 276 328
      t/unit/app/test_schedules.py
  52. 16 20
      t/unit/app/test_utils.py
  53. 0 0
      t/unit/apps/__init__.py
  54. 96 124
      t/unit/apps/test_multi.py
  55. 0 0
      t/unit/backends/__init__.py
  56. 40 47
      t/unit/backends/test_amqp.py
  57. 41 0
      t/unit/backends/test_backends.py
  58. 106 111
      t/unit/backends/test_base.py
  59. 35 37
      t/unit/backends/test_cache.py
  60. 18 15
      t/unit/backends/test_cassandra.py
  61. 6 5
      t/unit/backends/test_consul.py
  62. 24 22
      t/unit/backends/test_couchbase.py
  63. 20 20
      t/unit/backends/test_couchdb.py
  64. 47 49
      t/unit/backends/test_database.py
  65. 16 14
      t/unit/backends/test_elasticsearch.py
  66. 13 16
      t/unit/backends/test_filesystem.py
  67. 80 92
      t/unit/backends/test_mongodb.py
  68. 50 61
      t/unit/backends/test_redis.py
  69. 21 20
      t/unit/backends/test_riak.py
  70. 20 22
      t/unit/backends/test_rpc.py
  71. 0 0
      t/unit/bin/__init__.py
  72. 0 0
      t/unit/bin/celery.py
  73. 0 0
      t/unit/bin/proj/__init__.py
  74. 0 0
      t/unit/bin/proj/app.py
  75. 29 32
      t/unit/bin/test_amqp.py
  76. 130 143
      t/unit/bin/test_base.py
  77. 20 20
      t/unit/bin/test_beat.py
  78. 104 116
      t/unit/bin/test_celery.py
  79. 27 21
      t/unit/bin/test_celeryd_detach.py
  80. 7 7
      t/unit/bin/test_celeryevdump.py
  81. 35 14
      t/unit/bin/test_events.py
  82. 63 84
      t/unit/bin/test_multi.py
  83. 235 239
      t/unit/bin/test_worker.py
  84. 0 0
      t/unit/compat_modules/__init__.py
  85. 16 14
      t/unit/compat_modules/test_compat.py
  86. 13 14
      t/unit/compat_modules/test_compat_utils.py
  87. 37 0
      t/unit/compat_modules/test_decorators.py
  88. 4 3
      t/unit/compat_modules/test_messaging.py
  89. 0 0
      t/unit/concurrency/__init__.py
  90. 32 33
      t/unit/concurrency/test_concurrency.py
  91. 36 38
      t/unit/concurrency/test_eventlet.py
  92. 128 0
      t/unit/concurrency/test_gevent.py
  93. 15 20
      t/unit/concurrency/test_pool.py
  94. 42 52
      t/unit/concurrency/test_prefork.py
  95. 2 3
      t/unit/concurrency/test_solo.py
  96. 0 0
      t/unit/contrib/__init__.py
  97. 6 9
      t/unit/contrib/test_abortable.py
  98. 80 76
      t/unit/contrib/test_migrate.py
  99. 12 10
      t/unit/contrib/test_rdb.py
  100. 0 0
      t/unit/events/__init__.py

+ 4 - 0
.gitignore

@@ -25,3 +25,7 @@ celery/tests/cover/
 .ve*
 cover/
 .vagrant/
+.cache/
+htmlcov/
+coverage.xml
+test.db

+ 12 - 10
CONTRIBUTING.rst

@@ -464,12 +464,12 @@ dependencies, so install these next:
     $ pip install -U -r requirements/default.txt
 
 After installing the dependencies required, you can now execute
-the test suite by calling ``nosetests <nose>``:
+the test suite by calling ``py.test <pytest``:
 ::
 
-    $ nosetests
+    $ py.test
 
-Some useful options to ``nosetests`` are:
+Some useful options to ``py.test`` are:
 
 * ``-x``
 
@@ -479,10 +479,6 @@ Some useful options to ``nosetests`` are:
 
     Don't capture output
 
-* ``-nologcapture``
-
-    Don't capture log output.
-
 * ``-v``
 
     Run with verbose output.
@@ -491,7 +487,7 @@ If you want to run the tests for a single test file only
 you can do so like this:
 ::
 
-    $ nosetests celery.tests.test_worker.test_worker_job
+    $ py.test t/unit/worker/test_worker_job.py
 
 .. _contributing-pull-requests:
 
@@ -525,7 +521,7 @@ Installing the ``coverage`` module:
 Code coverage in HTML:
 ::
 
-    $ nosetests --with-coverage --cover-html
+    $ py.test --cov=celery --cov-report=html
 
 The coverage output will then be located at
 ``celery/tests/cover/index.html``.
@@ -533,7 +529,7 @@ The coverage output will then be located at
 Code coverage in XML (Cobertura-style):
 ::
 
-    $ nosetests --with-coverage --cover-xml --cover-xml-file=coverage.xml
+    $ py.test --cov=celery --cov-report=xml
 
 The coverage XML output will then be located at ``coverage.xml``
 
@@ -857,6 +853,12 @@ Ask Solem
 :github: https://github.com/ask
 :twitter: http://twitter.com/#!/asksol
 
+Asif Saif Uddin
+~~~~~~~~~~~~~~~
+
+:github: https://github.com/auvipy
+:twitter: https://twitter.com/#!/auvipy
+
 Dmitry Malinovsky
 ~~~~~~~~~~~~~~~~~
 

+ 5 - 4
Makefile

@@ -9,6 +9,8 @@ FLAKE8=flake8
 FLAKEPLUS=flakeplus
 SPHINX2RST=sphinx2rst
 
+TESTDIR=t
+
 SPHINX_DIR=docs/
 SPHINX_BUILDDIR="${SPHINX_DIR}/_build"
 README=README.rst
@@ -84,13 +86,13 @@ configcheck:
 
 flakecheck:
 	# the only way to enable all-1 errors is to ignore one of them.
-	$(FLAKE8) --ignore=X999 "$(PROJ)"
+	$(FLAKE8) --ignore=X999 "$(PROJ)" "$(TESTDIR)"
 
 flakediag:
 	-$(MAKE) flakecheck
 
 flakepluscheck:
-	$(FLAKEPLUS) --$(FLAKEPLUSTARGET) "$(PROJ)"
+	$(FLAKEPLUS) --$(FLAKEPLUSTARGET) "$(PROJ)" "$(TESTDIR)"
 
 flakeplusdiag:
 	-$(MAKE) flakepluscheck
@@ -138,7 +140,7 @@ test:
 	$(PYTHON) setup.py test
 
 cov:
-	$(NOSETESTS) -xv --with-coverage --cover-html --cover-branch
+	py.test -x --cov="$(PROJ)" --cov-report=html
 
 build:
 	$(PYTHON) setup.py sdist bdist_wheel
@@ -158,4 +160,3 @@ graph: clean-graph $(WORKER_GRAPH)
 
 authorcheck:
 	git shortlog -se | cut -f2 | extra/release/attribution.py
-

+ 2 - 2
celery/app/trace.py

@@ -418,13 +418,13 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     except EncodeError as exc:
                         I, R, state, retval = on_error(task_request, exc, uuid)
                     else:
+                        Rstr = saferepr(R, resultrepr_maxsize)
+                        T = monotonic() - time_start
                         if task_on_success:
                             task_on_success(retval, uuid, args, kwargs)
                         if success_receivers:
                             send_success(sender=task, result=retval)
                         if _does_info:
-                            T = monotonic() - time_start
-                            Rstr = saferepr(R, resultrepr_maxsize)
                             info(LOG_SUCCESS, {
                                 'id': uuid, 'name': name,
                                 'return_value': Rstr, 'runtime': T,

+ 0 - 96
celery/tests/__init__.py

@@ -1,96 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import logging
-import os
-import sys
-import warnings
-
-from importlib import import_module
-
-PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
-
-try:
-    WindowsError = WindowsError  # noqa
-except NameError:
-
-    class WindowsError(Exception):
-        pass
-
-
-def setup():
-    using_coverage = (
-        os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
-    )
-    os.environ.update(
-        # warn if config module not found
-        C_WNOCONF='yes',
-        KOMBU_DISABLE_LIMIT_PROTECTION='yes',
-    )
-
-    if using_coverage and not PYPY3:
-        from warnings import catch_warnings
-        with catch_warnings(record=True):
-            import_all_modules()
-        warnings.resetwarnings()
-    from celery.tests.case import Trap
-    from celery._state import set_default_app
-    set_default_app(Trap())
-
-
-def teardown():
-    # Don't want SUBDEBUG log messages at finalization.
-    try:
-        from multiprocessing.util import get_logger
-    except ImportError:
-        pass
-    else:
-        get_logger().setLevel(logging.WARNING)
-
-    # Make sure test database is removed.
-    import os
-    if os.path.exists('test.db'):
-        try:
-            os.remove('test.db')
-        except WindowsError:
-            pass
-
-    # Make sure there are no remaining threads at shutdown.
-    import threading
-    remaining_threads = [thread for thread in threading.enumerate()
-                         if thread.getName() != 'MainThread']
-    if remaining_threads:
-        sys.stderr.write(
-            '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
-                remaining_threads))
-
-
-def find_distribution_modules(name=__name__, file=__file__):
-    current_dist_depth = len(name.split('.')) - 1
-    current_dist = os.path.join(os.path.dirname(file),
-                                *([os.pardir] * current_dist_depth))
-    abs = os.path.abspath(current_dist)
-    dist_name = os.path.basename(abs)
-
-    for dirpath, dirnames, filenames in os.walk(abs):
-        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
-        if '__init__.py' in filenames:
-            yield package
-            for filename in filenames:
-                if filename.endswith('.py') and filename != '__init__.py':
-                    yield '.'.join([package, filename])[:-3]
-
-
-def import_all_modules(name=__name__, file=__file__,
-                       skip=('celery.decorators',
-                             'celery.task')):
-    for module in find_distribution_modules(name, file):
-        if not module.startswith(skip):
-            try:
-                import_module(module)
-            except ImportError:
-                pass
-            except OSError as exc:
-                warnings.warn(UserWarning(
-                    'Ignored error importing module {0}: {1!r}'.format(
-                        module, exc,
-                    )))

+ 0 - 301
celery/tests/app/test_amqp.py

@@ -1,301 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from datetime import datetime, timedelta
-
-from kombu import Exchange, Queue
-
-from celery import uuid
-from celery.app.amqp import Queues, utf8dict
-from celery.five import keys
-from celery.utils.time import to_utc
-
-from celery.tests.case import AppCase, Mock
-
-
-class test_TaskConsumer(AppCase):
-
-    def test_accept_content(self):
-        with self.app.pool.acquire(block=True) as conn:
-            self.app.conf.accept_content = ['application/json']
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn).accept,
-                {'application/json'},
-            )
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn, accept=['json']).accept,
-                {'application/json'},
-            )
-
-
-class test_ProducerPool(AppCase):
-
-    def test_setup_nolimit(self):
-        self.app.conf.broker_pool_limit = None
-        try:
-            delattr(self.app, '_pool')
-        except AttributeError:
-            pass
-        self.app.amqp._producer_pool = None
-        pool = self.app.amqp.producer_pool
-        self.assertEqual(pool.limit, self.app.pool.limit)
-        self.assertFalse(pool._resource.queue)
-
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-        r1.release()
-        r2.release()
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-
-    def test_setup(self):
-        self.app.conf.broker_pool_limit = 2
-        try:
-            delattr(self.app, '_pool')
-        except AttributeError:
-            pass
-        self.app.amqp._producer_pool = None
-        pool = self.app.amqp.producer_pool
-        self.assertEqual(pool.limit, self.app.pool.limit)
-        self.assertTrue(pool._resource.queue)
-
-        p1 = r1 = pool.acquire()
-        p2 = r2 = pool.acquire()
-        r1.release()
-        r2.release()
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-        self.assertIs(p2, r1)
-        self.assertIs(p1, r2)
-        r1.release()
-        r2.release()
-
-
-class test_Queues(AppCase):
-
-    def test_queues_format(self):
-        self.app.amqp.queues._consume_from = {}
-        self.assertEqual(self.app.amqp.queues.format(), '')
-
-    def test_with_defaults(self):
-        self.assertEqual(Queues(None), {})
-
-    def test_add(self):
-        q = Queues()
-        q.add('foo', exchange='ex', routing_key='rk')
-        self.assertIn('foo', q)
-        self.assertIsInstance(q['foo'], Queue)
-        self.assertEqual(q['foo'].routing_key, 'rk')
-
-    def test_with_ha_policy(self):
-        qn = Queues(ha_policy=None, create_missing=False)
-        qn.add('xyz')
-        self.assertIsNone(qn['xyz'].queue_arguments)
-
-        qn.add('xyx', queue_arguments={'x-foo': 'bar'})
-        self.assertEqual(qn['xyx'].queue_arguments, {'x-foo': 'bar'})
-
-        q = Queues(ha_policy='all', create_missing=False)
-        q.add(Queue('foo'))
-        self.assertEqual(q['foo'].queue_arguments, {'x-ha-policy': 'all'})
-
-        qq = Queue('xyx2', queue_arguments={'x-foo': 'bari'})
-        q.add(qq)
-        self.assertEqual(q['xyx2'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-foo': 'bari',
-        })
-
-        q2 = Queues(ha_policy=['A', 'B', 'C'], create_missing=False)
-        q2.add(Queue('foo'))
-        self.assertEqual(q2['foo'].queue_arguments, {
-            'x-ha-policy': 'nodes',
-            'x-ha-policy-params': ['A', 'B', 'C'],
-        })
-
-    def test_select_add(self):
-        q = Queues()
-        q.select(['foo', 'bar'])
-        q.select_add('baz')
-        self.assertItemsEqual(keys(q._consume_from), ['foo', 'bar', 'baz'])
-
-    def test_deselect(self):
-        q = Queues()
-        q.select(['foo', 'bar'])
-        q.deselect('bar')
-        self.assertItemsEqual(keys(q._consume_from), ['foo'])
-
-    def test_with_ha_policy_compat(self):
-        q = Queues(ha_policy='all')
-        q.add('bar')
-        self.assertEqual(q['bar'].queue_arguments, {'x-ha-policy': 'all'})
-
-    def test_add_default_exchange(self):
-        ex = Exchange('fff', 'fanout')
-        q = Queues(default_exchange=ex)
-        q.add(Queue('foo'))
-        self.assertEqual(q['foo'].exchange.name, '')
-
-    def test_alias(self):
-        q = Queues()
-        q.add(Queue('foo', alias='barfoo'))
-        self.assertIs(q['barfoo'], q['foo'])
-
-    def test_with_max_priority(self):
-        qs1 = Queues(max_priority=10)
-        qs1.add('foo')
-        self.assertEqual(qs1['foo'].queue_arguments, {'x-max-priority': 10})
-
-        q1 = Queue('xyx', queue_arguments={'x-max-priority': 3})
-        qs1.add(q1)
-        self.assertEqual(qs1['xyx'].queue_arguments, {
-            'x-max-priority': 3,
-        })
-
-        q1 = Queue('moo', queue_arguments=None)
-        qs1.add(q1)
-        self.assertEqual(qs1['moo'].queue_arguments, {
-            'x-max-priority': 10,
-        })
-
-        qs2 = Queues(ha_policy='all', max_priority=5)
-        qs2.add('bar')
-        self.assertEqual(qs2['bar'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-max-priority': 5
-        })
-
-        q2 = Queue('xyx2', queue_arguments={'x-max-priority': 2})
-        qs2.add(q2)
-        self.assertEqual(qs2['xyx2'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-max-priority': 2,
-        })
-
-        qs3 = Queues(max_priority=None)
-        qs3.add('foo2')
-        self.assertEqual(qs3['foo2'].queue_arguments, None)
-
-        q3 = Queue('xyx3', queue_arguments={'x-max-priority': 7})
-        qs3.add(q3)
-        self.assertEqual(qs3['xyx3'].queue_arguments, {
-            'x-max-priority': 7,
-        })
-
-
-class test_AMQP(AppCase):
-
-    def setup(self):
-        self.simple_message = self.app.amqp.as_task_v2(
-            uuid(), 'foo', create_sent_event=True,
-        )
-
-    def test_Queues__with_ha_policy(self):
-        x = self.app.amqp.Queues({}, ha_policy='all')
-        self.assertEqual(x.ha_policy, 'all')
-
-    def test_Queues__with_max_priority(self):
-        x = self.app.amqp.Queues({}, max_priority=23)
-        self.assertEqual(x.max_priority, 23)
-
-    def test_send_task_message__no_kwargs(self):
-        self.app.amqp.send_task_message(Mock(), 'foo', self.simple_message)
-
-    def test_send_task_message__properties(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, foo=1, retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['foo'], 1)
-
-    def test_send_task_message__headers(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, headers={'x1x': 'y2x'},
-            retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['headers']['x1x'], 'y2x')
-
-    def test_send_task_message__queue_string(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, queue='foo', retry=False,
-        )
-        kwargs = prod.publish.call_args[1]
-        self.assertEqual(kwargs['routing_key'], 'foo')
-        self.assertEqual(kwargs['exchange'], '')
-
-    def test_send_event_exchange_string(self):
-        evd = Mock(name='evd')
-        self.app.amqp.send_task_message(
-            Mock(), 'foo', self.simple_message, retry=False,
-            exchange='xyz', routing_key='xyb',
-            event_dispatcher=evd,
-        )
-        evd.publish.assert_called()
-        event = evd.publish.call_args[0][1]
-        self.assertEqual(event['routing_key'], 'xyb')
-        self.assertEqual(event['exchange'], 'xyz')
-
-    def test_send_task_message__with_delivery_mode(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, delivery_mode=33, retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['delivery_mode'], 33)
-
-    def test_routes(self):
-        r1 = self.app.amqp.routes
-        r2 = self.app.amqp.routes
-        self.assertIs(r1, r2)
-
-
-class test_as_task_v2(AppCase):
-
-    def test_raises_if_args_is_not_tuple(self):
-        with self.assertRaises(TypeError):
-            self.app.amqp.as_task_v2(uuid(), 'foo', args='123')
-
-    def test_raises_if_kwargs_is_not_mapping(self):
-        with self.assertRaises(TypeError):
-            self.app.amqp.as_task_v2(uuid(), 'foo', kwargs=(1, 2, 3))
-
-    def test_countdown_to_eta(self):
-        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo', countdown=10, now=now,
-        )
-        self.assertEqual(
-            m.headers['eta'],
-            (now + timedelta(seconds=10)).isoformat(),
-        )
-
-    def test_expires_to_datetime(self):
-        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo', expires=30, now=now,
-        )
-        self.assertEqual(
-            m.headers['expires'],
-            (now + timedelta(seconds=30)).isoformat(),
-        )
-
-    def test_callbacks_errbacks_chord(self):
-
-        @self.app.task
-        def t(i):
-            pass
-
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo',
-            callbacks=[t.s(1), t.s(2)],
-            errbacks=[t.s(3), t.s(4)],
-            chord=t.s(5),
-        )
-        _, _, embed = m.body
-        self.assertListEqual(
-            embed['callbacks'], [utf8dict(t.s(1)), utf8dict(t.s(2))],
-        )
-        self.assertListEqual(
-            embed['errbacks'], [utf8dict(t.s(3)), utf8dict(t.s(4))],
-        )
-        self.assertEqual(embed['chord'], utf8dict(t.s(5)))

+ 0 - 19
celery/tests/app/test_celery.py

@@ -1,19 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.tests.case import AppCase
-
-import celery
-
-
-class test_celery_package(AppCase):
-
-    def test_version(self):
-        self.assertTrue(celery.VERSION)
-        self.assertGreaterEqual(len(celery.VERSION), 3)
-        celery.VERSION = (0, 3, 0)
-        self.assertGreaterEqual(celery.__version__.count('.'), 2)
-
-    def test_meta(self):
-        for m in ('__author__', '__contact__', '__homepage__',
-                  '__docformat__'):
-            self.assertTrue(getattr(celery, m, None))

+ 0 - 65
celery/tests/app/test_defaults.py

@@ -1,65 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from importlib import import_module
-
-from celery.app.defaults import (
-    _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY,
-    DEFAULTS, NAMESPACES, SETTING_KEYS
-)
-from celery.five import values
-
-from celery.tests.case import AppCase, mock
-
-
-class test_defaults(AppCase):
-
-    def setup(self):
-        self._prev = sys.modules.pop('celery.app.defaults', None)
-
-    def teardown(self):
-        if self._prev:
-            sys.modules['celery.app.defaults'] = self._prev
-
-    def test_option_repr(self):
-        self.assertTrue(repr(NAMESPACES['broker']['url']))
-
-    def test_any(self):
-        val = object()
-        self.assertIs(self.defaults.Option.typemap['any'](val), val)
-
-    @mock.sys_platform('darwin')
-    @mock.pypy_version((1, 4, 0))
-    def test_default_pool_pypy_14(self):
-        self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
-
-    @mock.sys_platform('darwin')
-    @mock.pypy_version((1, 5, 0))
-    def test_default_pool_pypy_15(self):
-        self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
-
-    def test_compat_indices(self):
-        self.assertFalse(any(key.isupper() for key in DEFAULTS))
-        self.assertFalse(any(key.islower() for key in _OLD_DEFAULTS))
-        self.assertFalse(any(key.isupper() for key in _TO_OLD_KEY))
-        self.assertFalse(any(key.islower() for key in _TO_NEW_KEY))
-        self.assertFalse(any(key.isupper() for key in SETTING_KEYS))
-        self.assertFalse(any(key.islower() for key in _OLD_SETTING_KEYS))
-        self.assertFalse(any(value.isupper() for value in values(_TO_NEW_KEY)))
-        self.assertFalse(any(value.islower() for value in values(_TO_OLD_KEY)))
-
-        for key in _TO_NEW_KEY:
-            self.assertIn(key, _OLD_SETTING_KEYS)
-        for key in _TO_OLD_KEY:
-            self.assertIn(key, SETTING_KEYS)
-
-    def test_find(self):
-        find = self.defaults.find
-
-        self.assertEqual(find('default_queue')[2].default, 'celery')
-        self.assertEqual(find('task_default_exchange')[2], 'celery')
-
-    @property
-    def defaults(self):
-        return import_module('celery.app.defaults')

+ 0 - 78
celery/tests/app/test_registry.py

@@ -1,78 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.app.registry import _unpickle_task, _unpickle_task_v2
-from celery.tests.case import AppCase, depends_on_current_app
-
-
-def returns():
-    return 1
-
-
-class test_unpickle_task(AppCase):
-
-    @depends_on_current_app
-    def test_unpickle_v1(self):
-        self.app.tasks['txfoo'] = 'bar'
-        self.assertEqual(_unpickle_task('txfoo'), 'bar')
-
-    @depends_on_current_app
-    def test_unpickle_v2(self):
-        self.app.tasks['txfoo1'] = 'bar1'
-        self.assertEqual(_unpickle_task_v2('txfoo1'), 'bar1')
-        self.assertEqual(_unpickle_task_v2('txfoo1', module='celery'), 'bar1')
-
-
-class test_TaskRegistry(AppCase):
-
-    def setup(self):
-        self.mytask = self.app.task(name='A', shared=False)(returns)
-        self.myperiodic = self.app.task(
-            name='B', shared=False, type='periodic',
-        )(returns)
-
-    def test_NotRegistered_str(self):
-        self.assertTrue(repr(self.app.tasks.NotRegistered('tasks.add')))
-
-    def assertRegisterUnregisterCls(self, r, task):
-        r.unregister(task)
-        with self.assertRaises(r.NotRegistered):
-            r.unregister(task)
-        r.register(task)
-        self.assertIn(task.name, r)
-
-    def assertRegisterUnregisterFunc(self, r, task, task_name):
-        with self.assertRaises(r.NotRegistered):
-            r.unregister(task_name)
-        r.register(task, task_name)
-        self.assertIn(task_name, r)
-
-    def test_task_registry(self):
-        r = self.app._tasks
-        self.assertIsInstance(r, dict, 'TaskRegistry is mapping')
-
-        self.assertRegisterUnregisterCls(r, self.mytask)
-        self.assertRegisterUnregisterCls(r, self.myperiodic)
-
-        r.register(self.myperiodic)
-        r.unregister(self.myperiodic.name)
-        self.assertNotIn(self.myperiodic, r)
-        r.register(self.myperiodic)
-
-        tasks = dict(r)
-        self.assertIs(tasks.get(self.mytask.name), self.mytask)
-        self.assertIs(tasks.get(self.myperiodic.name), self.myperiodic)
-
-        self.assertIs(r[self.mytask.name], self.mytask)
-        self.assertIs(r[self.myperiodic.name], self.myperiodic)
-
-        r.unregister(self.mytask)
-        self.assertNotIn(self.mytask.name, r)
-        r.unregister(self.myperiodic)
-        self.assertNotIn(self.myperiodic.name, r)
-
-        self.assertTrue(self.mytask.run())
-        self.assertTrue(self.myperiodic.run())
-
-    def test_compat(self):
-        self.assertTrue(self.app.tasks.regular())
-        self.assertTrue(self.app.tasks.periodic())

+ 0 - 41
celery/tests/backends/test_backends.py

@@ -1,41 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery import backends
-from celery.backends.amqp import AMQPBackend
-from celery.backends.cache import CacheBackend
-from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, depends_on_current_app, patch
-
-
-class test_backends(AppCase):
-
-    def test_get_backend_aliases(self):
-        expects = [('amqp://', AMQPBackend),
-                   ('cache+memory://', CacheBackend)]
-
-        for url, expect_cls in expects:
-            backend, url = backends.get_backend_by_url(url, self.app.loader)
-            self.assertIsInstance(
-                backend(app=self.app, url=url),
-                expect_cls,
-            )
-
-    def test_unknown_backend(self):
-        with self.assertRaises(ImportError):
-            backends.get_backend_cls('fasodaopjeqijwqe', self.app.loader)
-
-    @depends_on_current_app
-    def test_default_backend(self):
-        self.assertEqual(backends.default_backend, self.app.backend)
-
-    def test_backend_by_url(self, url='redis://localhost/1'):
-        from celery.backends.redis import RedisBackend
-        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, RedisBackend)
-        self.assertEqual(url_, url)
-
-    def test_sym_raises_ValuError(self):
-        with patch('celery.backends.symbol_by_name') as sbn:
-            sbn.side_effect = ValueError()
-            with self.assertRaises(ImproperlyConfigured):
-                backends.get_backend_cls('xxx.xxx:foo', self.app.loader)

+ 0 - 349
celery/tests/case.py

@@ -1,349 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import importlib
-import inspect
-import logging
-import numbers
-import os
-import sys
-import threading
-
-from copy import deepcopy
-from datetime import datetime, timedelta
-from functools import partial
-
-from kombu import Queue
-from kombu.utils.imports import symbol_by_name
-from vine.utils import wraps
-
-from celery import Celery
-from celery.app import current_app
-from celery.backends.cache import CacheBackend, DummyClient
-from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
-from celery.utils.imports import qualname
-
-from case import (
-    ANY, ContextMock, MagicMock, Mock, call, mock, skip, patch, sentinel,
-)
-from case import Case as _Case
-from case.utils import decorator
-
-__all__ = [
-    'ANY', 'ContextMock', 'MagicMock', 'Mock',
-    'call', 'mock', 'skip', 'patch', 'sentinel',
-
-    'AppCase', 'TaskMessage', 'TaskMessage1',
-    'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
-]
-
-CASE_REDEFINES_SETUP = """\
-{name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
-"""
-CASE_REDEFINES_TEARDOWN = """\
-{name} (subclass of AppCase) redefines private "tearDown", \
-should be: "teardown"\
-"""
-CASE_LOG_REDIRECT_EFFECT = """\
-Test {0} didn't disable LoggingProxy for {1}\
-"""
-CASE_LOG_LEVEL_EFFECT = """\
-Test {0} Modified the level of the root logger\
-"""
-CASE_LOG_HANDLER_EFFECT = """\
-Test {0} Modified handlers for the root logger\
-"""
-
-CELERY_TEST_CONFIG = {
-    #: Don't want log output when running suite.
-    'worker_hijack_root_logger': False,
-    'worker_log_color': False,
-    'task_default_queue': 'testcelery',
-    'task_default_exchange': 'testcelery',
-    'task_default_routing_key': 'testcelery',
-    'task_queues': (
-        Queue('testcelery', routing_key='testcelery'),
-    ),
-    'accept_content': ('json', 'pickle'),
-    'enable_utc': True,
-    'timezone': 'UTC',
-
-    # Mongo results tests (only executed if installed and running)
-    'mongodb_backend_settings': {
-        'host': os.environ.get('MONGO_HOST') or 'localhost',
-        'port': os.environ.get('MONGO_PORT') or 27017,
-        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
-        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
-                                'taskmeta_collection'),
-        'user': os.environ.get('MONGO_USER'),
-        'password': os.environ.get('MONGO_PASSWORD'),
-    }
-}
-
-
-class Case(_Case):
-    DeprecationWarning = CDeprecationWarning
-    PendingDeprecationWarning = CPendingDeprecationWarning
-
-
-class Trap(object):
-
-    def __getattr__(self, name):
-        raise RuntimeError('Test depends on current_app')
-
-
-class UnitLogging(symbol_by_name(Celery.log_cls)):
-
-    def __init__(self, *args, **kwargs):
-        super(UnitLogging, self).__init__(*args, **kwargs)
-        self.already_setup = True
-
-
-def UnitApp(name=None, set_as_current=False, log=UnitLogging,
-            broker='memory://', backend='cache+memory://', **kwargs):
-    app = Celery(name or 'celery.tests',
-                 set_as_current=set_as_current,
-                 log=log, broker=broker, backend=backend,
-                 **kwargs)
-    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
-    return app
-
-
-def alive_threads():
-    return [thread for thread in threading.enumerate() if thread.is_alive()]
-
-
-def depends_on_current_app(fun):
-    if inspect.isclass(fun):
-        fun.contained = False
-    else:
-        @wraps(fun)
-        def __inner(self, *args, **kwargs):
-            self.app.set_current()
-            return fun(self, *args, **kwargs)
-        return __inner
-
-
-class AppCase(Case):
-    contained = True
-    _threads_at_startup = [None]
-
-    def __init__(self, *args, **kwargs):
-        super(AppCase, self).__init__(*args, **kwargs)
-        setUp = self.__class__.__dict__.get('setUp')
-        tearDown = self.__class__.__dict__.get('tearDown')
-        if setUp and not hasattr(setUp, '__wrapped__'):
-            raise RuntimeError(
-                CASE_REDEFINES_SETUP.format(name=qualname(self)),
-            )
-        if tearDown and not hasattr(tearDown, '__wrapped__'):
-            raise RuntimeError(
-                CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
-            )
-
-    def Celery(self, *args, **kwargs):
-        return UnitApp(*args, **kwargs)
-
-    def threads_at_startup(self):
-        if self._threads_at_startup[0] is None:
-            self._threads_at_startup[0] = alive_threads()
-        return self._threads_at_startup[0]
-
-    def setUp(self):
-        self._threads_at_setup = self.threads_at_startup()
-        from celery import _state
-        from celery import result
-        self._prev_res_join_block = result.task_join_will_block
-        self._prev_state_join_block = _state.task_join_will_block
-        result.task_join_will_block = \
-            _state.task_join_will_block = lambda: False
-        self._current_app = current_app()
-        self._default_app = _state.default_app
-        trap = Trap()
-        self._prev_tls = _state._tls
-        _state.set_default_app(trap)
-
-        class NonTLS(object):
-            current_app = trap
-        _state._tls = NonTLS()
-
-        self.app = self.Celery(set_as_current=False)
-        if not self.contained:
-            self.app.set_current()
-        root = logging.getLogger()
-        self.__rootlevel = root.level
-        self.__roothandlers = root.handlers
-        _state._set_task_join_will_block(False)
-        try:
-            self.setup()
-        except:
-            self._teardown_app()
-            raise
-
-    def _teardown_app(self):
-        from celery import _state
-        from celery import result
-        from celery.utils.log import LoggingProxy
-        assert sys.stdout
-        assert sys.stderr
-        assert sys.__stdout__
-        assert sys.__stderr__
-        this = self._get_test_name()
-        result.task_join_will_block = self._prev_res_join_block
-        _state.task_join_will_block = self._prev_state_join_block
-        if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
-                isinstance(sys.__stdout__, (LoggingProxy, Mock)):
-            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
-        if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
-                isinstance(sys.__stderr__, (LoggingProxy, Mock)):
-            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
-        backend = self.app.__dict__.get('backend')
-        if backend is not None:
-            if isinstance(backend, CacheBackend):
-                if isinstance(backend.client, DummyClient):
-                    backend.client.cache.clear()
-                backend._cache.clear()
-        from celery import _state
-        _state._set_task_join_will_block(False)
-
-        _state.set_default_app(self._default_app)
-        _state._tls = self._prev_tls
-        _state._tls.current_app = self._current_app
-        if self.app is not self._current_app:
-            self.app.close()
-        self.app = None
-        self.assertEqual(self._threads_at_setup, alive_threads())
-
-        # Make sure no test left the shutdown flags enabled.
-        from celery.worker import state as worker_state
-        # check for EX_OK
-        self.assertIsNot(worker_state.should_stop, False)
-        self.assertIsNot(worker_state.should_terminate, False)
-        # check for other true values
-        self.assertFalse(worker_state.should_stop)
-        self.assertFalse(worker_state.should_terminate)
-
-    def _get_test_name(self):
-        return '.'.join([self.__class__.__name__, self._testMethodName])
-
-    def tearDown(self):
-        try:
-            self.teardown()
-        finally:
-            self._teardown_app()
-        self.assert_no_logging_side_effect()
-
-    def assert_no_logging_side_effect(self):
-        this = self._get_test_name()
-        root = logging.getLogger()
-        if root.level != self.__rootlevel:
-            raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
-        if root.handlers != self.__roothandlers:
-            raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
-
-    def assert_signal_called(self, signal, **expected):
-        return assert_signal_called(signal, **expected)
-
-    def setup(self):
-        pass
-
-    def teardown(self):
-        pass
-
-
-@decorator
-def assert_signal_called(signal, **expected):
-    handler = Mock()
-    call_handler = partial(handler)
-    signal.connect(call_handler)
-    try:
-        yield handler
-    finally:
-        signal.disconnect(call_handler)
-    handler.assert_called_with(signal=signal, **expected)
-
-
-def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
-                errbacks=None, chain=None, shadow=None, utc=None, **options):
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {
-        'id': id,
-        'task': name,
-        'shadow': shadow,
-    }
-    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
-    message.headers.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        (args, kwargs, embed), serializer='json',
-    )
-    message.payload = (args, kwargs, embed)
-    return message
-
-
-def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
-                 errbacks=None, chain=None, **options):
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {}
-    message.payload = {
-        'task': name,
-        'id': id,
-        'args': args,
-        'kwargs': kwargs,
-        'callbacks': callbacks,
-        'errbacks': errbacks,
-    }
-    message.payload.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        message.payload,
-    )
-    return message
-
-
-def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
-    sig.freeze()
-    callbacks = sig.options.pop('link', None)
-    errbacks = sig.options.pop('link_error', None)
-    countdown = sig.options.pop('countdown', None)
-    if countdown:
-        eta = app.now() + timedelta(seconds=countdown)
-    else:
-        eta = sig.options.pop('eta', None)
-    if eta and isinstance(eta, datetime):
-        eta = eta.isoformat()
-    expires = sig.options.pop('expires', None)
-    if expires and isinstance(expires, numbers.Real):
-        expires = app.now() + timedelta(seconds=expires)
-    if expires and isinstance(expires, datetime):
-        expires = expires.isoformat()
-    return TaskMessage(
-        sig.task, id=sig.id, args=sig.args,
-        kwargs=sig.kwargs,
-        callbacks=[dict(s) for s in callbacks] if callbacks else None,
-        errbacks=[dict(s) for s in errbacks] if errbacks else None,
-        eta=eta,
-        expires=expires,
-        utc=utc,
-        **sig.options
-    )
-
-
-def _old_patch(module, name, mocked):
-    module = importlib.import_module(module)
-
-    def _patch(fun):
-
-        @wraps(fun)
-        def __patched(*args, **kwargs):
-            prev = getattr(module, name)
-            setattr(module, name, mocked)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                setattr(module, name, prev)
-        return __patched
-    return _patch

+ 0 - 38
celery/tests/compat_modules/test_decorators.py

@@ -1,38 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import warnings
-
-from celery.task import base
-
-from celery.tests.case import AppCase, depends_on_current_app
-
-
-def add(x, y):
-    return x + y
-
-
-@depends_on_current_app
-class test_decorators(AppCase):
-
-    def test_task_alias(self):
-        from celery import task
-        self.assertTrue(task.__file__)
-        self.assertTrue(task(add))
-
-    def setup(self):
-        with warnings.catch_warnings(record=True):
-            from celery import decorators
-            self.decorators = decorators
-
-    def assertCompatDecorator(self, decorator, type, **opts):
-        task = decorator(**opts)(add)
-        self.assertEqual(task(8, 8), 16)
-        self.assertIsInstance(task, type)
-
-    def test_task(self):
-        self.assertCompatDecorator(self.decorators.task, base.BaseTask)
-
-    def test_periodic_task(self):
-        self.assertCompatDecorator(self.decorators.periodic_task,
-                                   base.BaseTask,
-                                   run_every=1)

+ 0 - 129
celery/tests/concurrency/test_gevent.py

@@ -1,129 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.concurrency.gevent import (
-    Timer,
-    TaskPool,
-    apply_timeout,
-)
-
-from celery.tests.case import AppCase, Mock, patch, skip
-
-gevent_modules = (
-    'gevent',
-    'gevent.monkey',
-    'gevent.greenlet',
-    'gevent.pool',
-    'greenlet',
-)
-
-
-@skip.if_pypy()
-class GeventCase(AppCase):
-
-    def setup(self):
-        self.mock_modules(*gevent_modules)
-
-
-class test_gevent_patch(GeventCase):
-
-    def test_is_patched(self):
-        with patch('gevent.monkey.patch_all', create=True) as patch_all:
-            import gevent
-            gevent.version_info = (1, 0, 0)
-            from celery import maybe_patch_concurrency
-            maybe_patch_concurrency(['x', '-P', 'gevent'])
-            patch_all.assert_called()
-
-
-class test_Timer(GeventCase):
-
-    def setup(self):
-        GeventCase.setup(self)
-        self.greenlet = self.patch('gevent.greenlet')
-        self.GreenletExit = self.patch('gevent.greenlet.GreenletExit')
-
-    def test_sched(self):
-        self.greenlet.Greenlet = object
-        x = Timer()
-        self.greenlet.Greenlet = Mock()
-        x._Greenlet.spawn_later = Mock()
-        x._GreenletExit = KeyError
-        entry = Mock()
-        g = x._enter(1, 0, entry)
-        self.assertTrue(x.queue)
-
-        x._entry_exit(g)
-        g.kill.assert_called_with()
-        self.assertFalse(x._queue)
-
-        x._queue.add(g)
-        x.clear()
-        x._queue.add(g)
-        g.kill.side_effect = KeyError()
-        x.clear()
-
-        g = x._Greenlet()
-        g.cancel()
-
-
-class test_TaskPool(GeventCase):
-
-    def setup(self):
-        GeventCase.setup(self)
-        self.spawn_raw = self.patch('gevent.spawn_raw')
-        self.Pool = self.patch('gevent.pool.Pool')
-
-    def test_pool(self):
-        x = TaskPool()
-        x.on_start()
-        x.on_stop()
-        x.on_apply(Mock())
-        x._pool = None
-        x.on_stop()
-
-        x._pool = Mock()
-        x._pool._semaphore.counter = 1
-        x._pool.size = 1
-        x.grow()
-        self.assertEqual(x._pool.size, 2)
-        self.assertEqual(x._pool._semaphore.counter, 2)
-        x.shrink()
-        self.assertEqual(x._pool.size, 1)
-        self.assertEqual(x._pool._semaphore.counter, 1)
-
-        x._pool = [4, 5, 6]
-        self.assertEqual(x.num_processes, 3)
-
-
-class test_apply_timeout(AppCase):
-
-    def test_apply_timeout(self):
-
-            class Timeout(Exception):
-                value = None
-
-                def __init__(self, value):
-                    self.__class__.value = value
-
-                def __enter__(self):
-                    return self
-
-                def __exit__(self, *exc_info):
-                    pass
-            timeout_callback = Mock(name='timeout_callback')
-            apply_target = Mock(name='apply_target')
-            apply_timeout(
-                Mock(), timeout=10, callback=Mock(name='callback'),
-                timeout_callback=timeout_callback,
-                apply_target=apply_target, Timeout=Timeout,
-            )
-            self.assertEqual(Timeout.value, 10)
-            apply_target.assert_called()
-
-            apply_target.side_effect = Timeout(10)
-            apply_timeout(
-                Mock(), timeout=10, callback=Mock(),
-                timeout_callback=timeout_callback,
-                apply_target=apply_target, Timeout=Timeout,
-            )
-            timeout_callback.assert_called_with(False, 10)

+ 0 - 54
celery/tests/tasks/test_states.py

@@ -1,54 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery import states
-
-from celery.tests.case import Case
-
-
-class test_state_precedence(Case):
-
-    def test_gt(self):
-        self.assertGreater(
-            states.state(states.SUCCESS), states.state(states.PENDING),
-        )
-        self.assertGreater(
-            states.state(states.FAILURE), states.state(states.RECEIVED),
-        )
-        self.assertGreater(
-            states.state(states.REVOKED), states.state(states.STARTED),
-        )
-        self.assertGreater(
-            states.state(states.SUCCESS), states.state('CRASHED'),
-        )
-        self.assertGreater(
-            states.state(states.FAILURE), states.state('CRASHED'),
-        )
-        self.assertLessEqual(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-
-    def test_lt(self):
-        self.assertLess(
-            states.state(states.PENDING), states.state(states.SUCCESS),
-        )
-        self.assertLess(
-            states.state(states.RECEIVED), states.state(states.FAILURE),
-        )
-        self.assertLess(
-            states.state(states.STARTED), states.state(states.REVOKED),
-        )
-        self.assertLess(
-            states.state('CRASHED'), states.state(states.SUCCESS),
-        )
-        self.assertLess(
-            states.state('CRASHED'), states.state(states.FAILURE),
-        )
-        self.assertLess(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-        self.assertLessEqual(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-        self.assertGreaterEqual(
-            states.state('CRASHED'), states.state(states.REVOKED),
-        )

+ 0 - 98
celery/tests/utils/test_debug.py

@@ -1,98 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import debug
-
-from celery.tests.case import Case, Mock, patch
-
-
-class test_on_blocking(Case):
-
-    @patch('inspect.getframeinfo')
-    def test_on_blocking(self, getframeinfo):
-        frame = Mock(name='frame')
-        with self.assertRaises(RuntimeError):
-            debug._on_blocking(1, frame)
-            getframeinfo.assert_called_with(frame)
-
-
-class test_blockdetection(Case):
-
-    @patch('celery.utils.debug.signals')
-    def test_context(self, signals):
-        with debug.blockdetection(10):
-            signals.arm_alarm.assert_called_with(10)
-            signals.__setitem__.assert_called_with('ALRM', debug._on_blocking)
-        signals.__setitem__.assert_called_with('ALRM', signals['ALRM'])
-        signals.reset_alarm.assert_called_with()
-
-
-class test_sample_mem(Case):
-
-    @patch('celery.utils.debug.mem_rss')
-    def test_sample_mem(self, mem_rss):
-        prev, debug._mem_sample = debug._mem_sample, []
-        try:
-            debug.sample_mem()
-            self.assertIs(debug._mem_sample[0], mem_rss())
-        finally:
-            debug._mem_sample = prev
-
-
-class test_sample(Case):
-
-    def test_sample(self):
-        x = list(range(100))
-        self.assertEqual(
-            list(debug.sample(x, 10)),
-            [0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
-        )
-        x = list(range(91))
-        self.assertEqual(
-            list(debug.sample(x, 10)),
-            [0, 9, 18, 27, 36, 45, 54, 63, 72, 81],
-        )
-
-
-class test_hfloat(Case):
-
-    def test_hfloat(self):
-        self.assertEqual(str(debug.hfloat(10, 5)), '10')
-        self.assertEqual(str(debug.hfloat(10.45645234234, 5)), '10.456')
-
-
-class test_humanbytes(Case):
-
-    def test_humanbytes(self):
-        self.assertEqual(debug.humanbytes(2 ** 20), '1MB')
-        self.assertEqual(debug.humanbytes(4 * 2 ** 20), '4MB')
-        self.assertEqual(debug.humanbytes(2 ** 16), '64KB')
-        self.assertEqual(debug.humanbytes(2 ** 16), '64KB')
-        self.assertEqual(debug.humanbytes(2 ** 8), '256b')
-
-
-class test_mem_rss(Case):
-
-    @patch('celery.utils.debug.ps')
-    @patch('celery.utils.debug.humanbytes')
-    def test_mem_rss(self, humanbytes, ps):
-        ret = debug.mem_rss()
-        ps.assert_called_with()
-        ps().memory_info.assert_called_with()
-        humanbytes.assert_called_with(ps().memory_info().rss)
-        self.assertIs(ret, humanbytes())
-        ps.return_value = None
-        self.assertIsNone(debug.mem_rss())
-
-
-class test_ps(Case):
-
-    @patch('celery.utils.debug.Process')
-    @patch('os.getpid')
-    def test_ps(self, getpid, Process):
-        prev, debug._process = debug._process, None
-        try:
-            debug.ps()
-            Process.assert_called_with(getpid())
-            self.assertIs(debug._process, Process())
-        finally:
-            debug._process = prev

+ 0 - 20
celery/tests/utils/test_encoding.py

@@ -1,20 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import encoding
-from celery.tests.case import Case
-
-
-class test_encoding(Case):
-
-    def test_safe_str(self):
-        self.assertTrue(encoding.safe_str(object()))
-        self.assertTrue(encoding.safe_str('foo'))
-
-    def test_safe_repr(self):
-        self.assertTrue(encoding.safe_repr(object()))
-
-        class foo(object):
-            def __repr__(self):
-                raise ValueError('foo')
-
-        self.assertTrue(encoding.safe_repr(foo()))

+ 0 - 210
celery/tests/utils/test_functional.py

@@ -1,210 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu.utils.functional import lazy
-
-from celery.five import range, nextfun
-from celery.utils.functional import (
-    DummyContext,
-    fun_takes_argument,
-    head_from_fun,
-    firstmethod,
-    first,
-    maybe_list,
-    mlazy,
-    padlist,
-    regen,
-)
-
-from celery.tests.case import Case
-
-
-class test_DummyContext(Case):
-
-    def test_context(self):
-        with DummyContext():
-            pass
-        with self.assertRaises(KeyError):
-            with DummyContext():
-                raise KeyError()
-
-
-class test_utils(Case):
-
-    def test_padlist(self):
-        self.assertListEqual(
-            padlist(['George', 'Costanza', 'NYC'], 3),
-            ['George', 'Costanza', 'NYC'],
-        )
-        self.assertListEqual(
-            padlist(['George', 'Costanza'], 3),
-            ['George', 'Costanza', None],
-        )
-        self.assertListEqual(
-            padlist(['George', 'Costanza', 'NYC'], 4, default='Earth'),
-            ['George', 'Costanza', 'NYC', 'Earth'],
-        )
-
-    def test_firstmethod_AttributeError(self):
-        self.assertIsNone(firstmethod('foo')([object()]))
-
-    def test_firstmethod_handles_lazy(self):
-
-        class A(object):
-
-            def __init__(self, value=None):
-                self.value = value
-
-            def m(self):
-                return self.value
-
-        self.assertEqual('four', firstmethod('m')([
-            A(), A(), A(), A('four'), A('five')]))
-        self.assertEqual('four', firstmethod('m')([
-            A(), A(), A(), lazy(lambda: A('four')), A('five')]))
-
-    def test_first(self):
-        iterations = [0]
-
-        def predicate(value):
-            iterations[0] += 1
-            if value == 5:
-                return True
-            return False
-
-        self.assertEqual(5, first(predicate, range(10)))
-        self.assertEqual(iterations[0], 6)
-
-        iterations[0] = 0
-        self.assertIsNone(first(predicate, range(10, 20)))
-        self.assertEqual(iterations[0], 10)
-
-    def test_maybe_list(self):
-        self.assertEqual(maybe_list(1), [1])
-        self.assertEqual(maybe_list([1]), [1])
-        self.assertIsNone(maybe_list(None))
-
-
-class test_mlazy(Case):
-
-    def test_is_memoized(self):
-
-        it = iter(range(20, 30))
-        p = mlazy(nextfun(it))
-        self.assertEqual(p(), 20)
-        self.assertTrue(p.evaluated)
-        self.assertEqual(p(), 20)
-        self.assertEqual(repr(p), '20')
-
-
-class test_regen(Case):
-
-    def test_regen_list(self):
-        l = [1, 2]
-        r = regen(iter(l))
-        self.assertIs(regen(l), l)
-        self.assertEqual(r, l)
-        self.assertEqual(r, l)
-        self.assertEqual(r.__length_hint__(), 0)
-
-        fun, args = r.__reduce__()
-        self.assertEqual(fun(*args), l)
-
-    def test_regen_gen(self):
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[7], 7)
-        self.assertEqual(g[6], 6)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g[4], 4)
-        self.assertEqual(g[3], 3)
-        self.assertEqual(g[2], 2)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g.data, list(range(10)))
-        self.assertEqual(g[8], 8)
-        self.assertEqual(g[0], 0)
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g.data, list(range(10)))
-        g = regen(iter([1]))
-        self.assertEqual(g[0], 1)
-        with self.assertRaises(IndexError):
-            g[1]
-        self.assertEqual(g.data, [1])
-
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[-1], 9)
-        self.assertEqual(g[-2], 8)
-        self.assertEqual(g[-3], 7)
-        self.assertEqual(g[-4], 6)
-        self.assertEqual(g[-5], 5)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g.data, list(range(10)))
-
-        self.assertListEqual(list(iter(g)), list(range(10)))
-
-
-class test_head_from_fun(Case):
-
-    def test_from_cls(self):
-        class X(object):
-            def __call__(x, y, kwarg=1):
-                pass
-
-        g = head_from_fun(X())
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-    def test_from_fun(self):
-        def f(x, y, kwarg=1):
-            pass
-        g = head_from_fun(f)
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-    def test_from_fun_with_hints(self):
-        local = {}
-        fun = ('def f_hints(x: int, y: int, kwarg: int=1):'
-               '    pass')
-        try:
-            exec(fun, {}, local)
-        except SyntaxError:
-            # py2
-            return
-        f_hints = local['f_hints']
-
-        g = head_from_fun(f_hints)
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-
-class test_fun_takes_argument(Case):
-
-    def test_starkwargs(self):
-        self.assertTrue(fun_takes_argument('foo', lambda **kw: 1))
-
-    def test_named(self):
-        self.assertTrue(fun_takes_argument('foo', lambda a, foo, bar: 1))
-
-        def fun(a, b, c, d):
-            return 1
-
-        self.assertTrue(fun_takes_argument('foo', fun, position=4))
-
-    def test_starargs(self):
-        self.assertTrue(fun_takes_argument('foo', lambda a, *args: 1))
-
-    def test_does_not(self):
-        self.assertFalse(fun_takes_argument('foo', lambda a, bar, baz: 1))
-        self.assertFalse(fun_takes_argument('foo', lambda: 1))
-
-        def fun(a, b, foo):
-            return 1
-
-        self.assertFalse(fun_takes_argument('foo', fun, position=4))

+ 0 - 58
celery/tests/utils/test_imports.py

@@ -1,58 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.five import bytes_if_py2
-
-from celery.utils.imports import (
-    NotAPackage,
-    qualname,
-    gen_task_name,
-    reload_from_cwd,
-    module_file,
-    find_module,
-)
-
-from celery.tests.case import Case, Mock, patch
-
-
-class test_import_utils(Case):
-
-    def test_find_module(self):
-        self.assertTrue(find_module('celery'))
-        imp = Mock()
-        imp.return_value = None
-        with self.assertRaises(NotAPackage):
-            find_module('foo.bar.baz', imp=imp)
-        self.assertTrue(find_module('celery.worker.request'))
-
-    def test_qualname(self):
-        Class = type(bytes_if_py2('Fox'), (object,), {
-            '__module__': 'quick.brown',
-        })
-        self.assertEqual(qualname(Class), 'quick.brown.Fox')
-        self.assertEqual(qualname(Class()), 'quick.brown.Fox')
-
-    @patch('celery.utils.imports.reload')
-    def test_reload_from_cwd(self, reload):
-        reload_from_cwd('foo')
-        reload.assert_called()
-
-    def test_reload_from_cwd_custom_reloader(self):
-        reload = Mock()
-        reload_from_cwd('foo', reload)
-        reload.assert_called()
-
-    def test_module_file(self):
-        m1 = Mock()
-        m1.__file__ = '/opt/foo/xyz.pyc'
-        self.assertEqual(module_file(m1), '/opt/foo/xyz.py')
-        m2 = Mock()
-        m2.__file__ = '/opt/foo/xyz.py'
-        self.assertEqual(module_file(m1), '/opt/foo/xyz.py')
-
-
-class test_gen_task_name(Case):
-
-    def test_no_module(self):
-        app = Mock()
-        app.name == '__main__'
-        self.assertTrue(gen_task_name(app, 'foo', 'axsadaewe'))

+ 0 - 372
celery/tests/utils/test_local.py

@@ -1,372 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from celery.five import python_2_unicode_compatible, string, long_t
-from celery.local import (
-    Proxy,
-    PromiseProxy,
-    maybe_evaluate,
-    try_import,
-)
-from celery.tests.case import Case, Mock, skip
-
-PY3 = sys.version_info[0] == 3
-
-
-class test_try_import(Case):
-
-    def test_imports(self):
-        self.assertTrue(try_import(__name__))
-
-    def test_when_default(self):
-        default = object()
-        self.assertIs(try_import('foobar.awqewqe.asdwqewq', default), default)
-
-
-class test_Proxy(Case):
-
-    def test_std_class_attributes(self):
-        self.assertEqual(Proxy.__name__, 'Proxy')
-        self.assertEqual(Proxy.__module__, 'celery.local')
-        self.assertIsInstance(Proxy.__doc__, str)
-
-    def test_doc(self):
-        def real():
-            pass
-        x = Proxy(real, __doc__='foo')
-        self.assertEqual(x.__doc__, 'foo')
-
-    def test_name(self):
-
-        def real():
-            """real function"""
-            return 'REAL'
-
-        x = Proxy(lambda: real, name='xyz')
-        self.assertEqual(x.__name__, 'xyz')
-
-        y = Proxy(lambda: real)
-        self.assertEqual(y.__name__, 'real')
-
-        self.assertEqual(x.__doc__, 'real function')
-
-        self.assertEqual(x.__class__, type(real))
-        self.assertEqual(x.__dict__, real.__dict__)
-        self.assertEqual(repr(x), repr(real))
-        self.assertTrue(x.__module__)
-
-    def test_get_current_local(self):
-        x = Proxy(lambda: 10)
-        object.__setattr__(x, '_Proxy_local', Mock())
-        self.assertTrue(x._get_current_object())
-
-    def test_bool(self):
-
-        class X(object):
-
-            def __bool__(self):
-                return False
-            __nonzero__ = __bool__
-
-        x = Proxy(lambda: X())
-        self.assertFalse(x)
-
-    def test_slots(self):
-
-        class X(object):
-            __slots__ = ()
-
-        x = Proxy(X)
-        with self.assertRaises(AttributeError):
-            x.__dict__
-
-    @skip.if_python3()
-    def test_unicode(self):
-
-        @python_2_unicode_compatible
-        class X(object):
-
-            def __unicode__(self):
-                return 'UNICODE'
-            __str__ = __unicode__
-
-            def __repr__(self):
-                return 'REPR'
-
-        x = Proxy(lambda: X())
-        self.assertEqual(string(x), 'UNICODE')
-        del(X.__unicode__)
-        del(X.__str__)
-        self.assertEqual(string(x), 'REPR')
-
-    def test_dir(self):
-
-        class X(object):
-
-            def __dir__(self):
-                return ['a', 'b', 'c']
-
-        x = Proxy(lambda: X())
-        self.assertListEqual(dir(x), ['a', 'b', 'c'])
-
-        class Y(object):
-
-            def __dir__(self):
-                raise RuntimeError()
-        y = Proxy(lambda: Y())
-        self.assertListEqual(dir(y), [])
-
-    def test_getsetdel_attr(self):
-
-        class X(object):
-            a = 1
-            b = 2
-            c = 3
-
-            def __dir__(self):
-                return ['a', 'b', 'c']
-
-        v = X()
-
-        x = Proxy(lambda: v)
-        self.assertListEqual(x.__members__, ['a', 'b', 'c'])
-        self.assertEqual(x.a, 1)
-        self.assertEqual(x.b, 2)
-        self.assertEqual(x.c, 3)
-
-        setattr(x, 'a', 10)
-        self.assertEqual(x.a, 10)
-
-        del(x.a)
-        self.assertEqual(x.a, 1)
-
-    def test_dictproxy(self):
-        v = {}
-        x = Proxy(lambda: v)
-        x['foo'] = 42
-        self.assertEqual(x['foo'], 42)
-        self.assertEqual(len(x), 1)
-        self.assertIn('foo', x)
-        del(x['foo'])
-        with self.assertRaises(KeyError):
-            x['foo']
-        self.assertTrue(iter(x))
-
-    def test_listproxy(self):
-        v = []
-        x = Proxy(lambda: v)
-        x.append(1)
-        x.extend([2, 3, 4])
-        self.assertEqual(x[0], 1)
-        self.assertEqual(x[:-1], [1, 2, 3])
-        del(x[-1])
-        self.assertEqual(x[:-1], [1, 2])
-        x[0] = 10
-        self.assertEqual(x[0], 10)
-        self.assertIn(10, x)
-        self.assertEqual(len(x), 3)
-        self.assertTrue(iter(x))
-        x[0:2] = [1, 2]
-        del(x[0:2])
-        self.assertTrue(str(x))
-        if sys.version_info[0] < 3:
-            self.assertEqual(x.__cmp__(object()), -1)
-
-    def test_complex_cast(self):
-
-        class O(object):
-
-            def __complex__(self):
-                return complex(10.333)
-
-        o = Proxy(O)
-        self.assertEqual(o.__complex__(), complex(10.333))
-
-    def test_index(self):
-
-        class O(object):
-
-            def __index__(self):
-                return 1
-
-        o = Proxy(O)
-        self.assertEqual(o.__index__(), 1)
-
-    def test_coerce(self):
-
-        class O(object):
-
-            def __coerce__(self, other):
-                return self, other
-
-        o = Proxy(O)
-        self.assertTrue(o.__coerce__(3))
-
-    def test_int(self):
-        self.assertEqual(Proxy(lambda: 10) + 1, Proxy(lambda: 11))
-        self.assertEqual(Proxy(lambda: 10) - 1, Proxy(lambda: 9))
-        self.assertEqual(Proxy(lambda: 10) * 2, Proxy(lambda: 20))
-        self.assertEqual(Proxy(lambda: 10) ** 2, Proxy(lambda: 100))
-        self.assertEqual(Proxy(lambda: 20) / 2, Proxy(lambda: 10))
-        self.assertEqual(Proxy(lambda: 20) // 2, Proxy(lambda: 10))
-        self.assertEqual(Proxy(lambda: 11) % 2, Proxy(lambda: 1))
-        self.assertEqual(Proxy(lambda: 10) << 2, Proxy(lambda: 40))
-        self.assertEqual(Proxy(lambda: 10) >> 2, Proxy(lambda: 2))
-        self.assertEqual(Proxy(lambda: 10) ^ 7, Proxy(lambda: 13))
-        self.assertEqual(Proxy(lambda: 10) | 40, Proxy(lambda: 42))
-        self.assertEqual(~Proxy(lambda: 10), Proxy(lambda: -11))
-        self.assertEqual(-Proxy(lambda: 10), Proxy(lambda: -10))
-        self.assertEqual(+Proxy(lambda: -10), Proxy(lambda: -10))
-        self.assertTrue(Proxy(lambda: 10) < Proxy(lambda: 20))
-        self.assertTrue(Proxy(lambda: 20) > Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) >= Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) <= Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) == Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 20) != Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 100).__divmod__(30))
-        self.assertTrue(Proxy(lambda: 100).__truediv__(30))
-        self.assertTrue(abs(Proxy(lambda: -100)))
-
-        x = Proxy(lambda: 10)
-        x -= 1
-        self.assertEqual(x, 9)
-        x = Proxy(lambda: 9)
-        x += 1
-        self.assertEqual(x, 10)
-        x = Proxy(lambda: 10)
-        x *= 2
-        self.assertEqual(x, 20)
-        x = Proxy(lambda: 20)
-        x /= 2
-        self.assertEqual(x, 10)
-        x = Proxy(lambda: 10)
-        x %= 2
-        self.assertEqual(x, 0)
-        x = Proxy(lambda: 10)
-        x <<= 3
-        self.assertEqual(x, 80)
-        x = Proxy(lambda: 80)
-        x >>= 4
-        self.assertEqual(x, 5)
-        x = Proxy(lambda: 5)
-        x ^= 1
-        self.assertEqual(x, 4)
-        x = Proxy(lambda: 4)
-        x **= 4
-        self.assertEqual(x, 256)
-        x = Proxy(lambda: 256)
-        x //= 2
-        self.assertEqual(x, 128)
-        x = Proxy(lambda: 128)
-        x |= 2
-        self.assertEqual(x, 130)
-        x = Proxy(lambda: 130)
-        x &= 10
-        self.assertEqual(x, 2)
-
-        x = Proxy(lambda: 10)
-        self.assertEqual(type(x.__float__()), float)
-        self.assertEqual(type(x.__int__()), int)
-        if not PY3:
-            self.assertEqual(type(x.__long__()), long_t)
-        self.assertTrue(hex(x))
-        self.assertTrue(oct(x))
-
-    def test_hash(self):
-
-        class X(object):
-
-            def __hash__(self):
-                return 1234
-
-        self.assertEqual(hash(Proxy(lambda: X())), 1234)
-
-    def test_call(self):
-
-        class X(object):
-
-            def __call__(self):
-                return 1234
-
-        self.assertEqual(Proxy(lambda: X())(), 1234)
-
-    def test_context(self):
-
-        class X(object):
-            entered = exited = False
-
-            def __enter__(self):
-                self.entered = True
-                return 1234
-
-            def __exit__(self, *exc_info):
-                self.exited = True
-
-        v = X()
-        x = Proxy(lambda: v)
-        with x as val:
-            self.assertEqual(val, 1234)
-        self.assertTrue(x.entered)
-        self.assertTrue(x.exited)
-
-    def test_reduce(self):
-
-        class X(object):
-
-            def __reduce__(self):
-                return 123
-
-        x = Proxy(lambda: X())
-        self.assertEqual(x.__reduce__(), 123)
-
-
-class test_PromiseProxy(Case):
-
-    def test_only_evaluated_once(self):
-
-        class X(object):
-            attr = 123
-            evals = 0
-
-            def __init__(self):
-                self.__class__.evals += 1
-
-        p = PromiseProxy(X)
-        self.assertEqual(p.attr, 123)
-        self.assertEqual(p.attr, 123)
-        self.assertEqual(X.evals, 1)
-
-    def test_callbacks(self):
-        source = Mock(name='source')
-        p = PromiseProxy(source)
-        cbA = Mock(name='cbA')
-        cbB = Mock(name='cbB')
-        cbC = Mock(name='cbC')
-        p.__then__(cbA, p)
-        p.__then__(cbB, p)
-        self.assertFalse(p.__evaluated__())
-        self.assertTrue(object.__getattribute__(p, '__pending__'))
-
-        self.assertTrue(repr(p))
-        self.assertTrue(p.__evaluated__())
-        with self.assertRaises(AttributeError):
-            object.__getattribute__(p, '__pending__')
-        cbA.assert_called_with(p)
-        cbB.assert_called_with(p)
-
-        self.assertTrue(p.__evaluated__())
-        p.__then__(cbC, p)
-        cbC.assert_called_with(p)
-
-        with self.assertRaises(AttributeError):
-            object.__getattribute__(p, '__pending__')
-
-    def test_maybe_evaluate(self):
-        x = PromiseProxy(lambda: 30)
-        self.assertFalse(x.__evaluated__())
-        self.assertEqual(maybe_evaluate(x), 30)
-        self.assertEqual(maybe_evaluate(x), 30)
-
-        self.assertEqual(maybe_evaluate(30), 30)
-        self.assertTrue(x.__evaluated__())

+ 0 - 77
celery/tests/utils/test_serialization.py

@@ -1,77 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pytz
-import sys
-
-from datetime import datetime, date, time, timedelta
-
-from kombu import Queue
-
-from celery.utils.serialization import (
-    UnpickleableExceptionWrapper,
-    get_pickleable_etype,
-    jsonify,
-)
-
-from celery.tests.case import Case, Mock, mock
-
-
-class test_AAPickle(Case):
-
-    def test_no_cpickle(self):
-        prev = sys.modules.pop('celery.utils.serialization', None)
-        try:
-            with mock.mask_modules('cPickle'):
-                from celery.utils.serialization import pickle
-                import pickle as orig_pickle
-                self.assertIs(pickle.dumps, orig_pickle.dumps)
-        finally:
-            sys.modules['celery.utils.serialization'] = prev
-
-
-class test_UnpickleExceptionWrapper(Case):
-
-    def test_init(self):
-        x = UnpickleableExceptionWrapper('foo', 'Bar', [10, lambda x: x])
-        self.assertTrue(x.exc_args)
-        self.assertEqual(len(x.exc_args), 2)
-
-
-class test_get_pickleable_etype(Case):
-
-    def test_get_pickleable_etype(self):
-
-        class Unpickleable(Exception):
-            def __reduce__(self):
-                raise ValueError('foo')
-
-        self.assertIs(get_pickleable_etype(Unpickleable), Exception)
-
-
-class test_jsonify(Case):
-
-    def test_simple(self):
-        self.assertTrue(jsonify(Queue('foo')))
-        self.assertTrue(jsonify(['foo', 'bar', 'baz']))
-        self.assertTrue(jsonify({'foo': 'bar'}))
-        self.assertTrue(jsonify(datetime.utcnow()))
-        self.assertTrue(jsonify(datetime.utcnow().replace(tzinfo=pytz.utc)))
-        self.assertTrue(jsonify(datetime.utcnow().replace(microsecond=0)))
-        self.assertTrue(jsonify(date(2012, 1, 1)))
-        self.assertTrue(jsonify(time(hour=1, minute=30)))
-        self.assertTrue(jsonify(time(hour=1, minute=30, microsecond=3)))
-        self.assertTrue(jsonify(timedelta(seconds=30)))
-        self.assertTrue(jsonify(10))
-        self.assertTrue(jsonify(10.3))
-        self.assertTrue(jsonify('hello'))
-
-        unknown_type_filter = Mock()
-        obj = object()
-        self.assertIs(
-            jsonify(obj, unknown_type_filter=unknown_type_filter),
-            unknown_type_filter.return_value,
-        )
-        unknown_type_filter.assert_called_with(obj)
-
-        with self.assertRaises(ValueError):
-            jsonify(obj)

+ 0 - 27
celery/tests/utils/test_sysinfo.py

@@ -1,27 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils.sysinfo import load_average, df
-
-from celery.tests.case import Case, patch, skip
-
-
-@skip.unless_symbol('os.getloadavg')
-class test_load_average(Case):
-
-    def test_avg(self):
-        with patch('os.getloadavg') as getloadavg:
-            getloadavg.return_value = 0.54736328125, 0.6357421875, 0.69921875
-            l = load_average()
-            self.assertTrue(l)
-            self.assertEqual(l, (0.55, 0.64, 0.7))
-
-
-@skip.unless_symbol('posix.statvfs_result')
-class test_df(Case):
-
-    def test_df(self):
-        x = df('/')
-        self.assertTrue(x.total_blocks)
-        self.assertTrue(x.available)
-        self.assertTrue(x.capacity)
-        self.assertTrue(x.stat)

+ 0 - 87
celery/tests/utils/test_term.py

@@ -1,87 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from celery.utils import term
-from celery.utils.term import colored, fg
-from celery.five import text_t
-
-from celery.tests.case import Case, skip
-
-
-@skip.if_win32()
-class test_colored(Case):
-
-    def setUp(self):
-        self._prev_encoding = sys.getdefaultencoding
-
-        def getdefaultencoding():
-            return 'utf-8'
-
-        sys.getdefaultencoding = getdefaultencoding
-
-    def tearDown(self):
-        sys.getdefaultencoding = self._prev_encoding
-
-    def test_colors(self):
-        colors = (
-            ('black', term.BLACK),
-            ('red', term.RED),
-            ('green', term.GREEN),
-            ('yellow', term.YELLOW),
-            ('blue', term.BLUE),
-            ('magenta', term.MAGENTA),
-            ('cyan', term.CYAN),
-            ('white', term.WHITE),
-        )
-
-        for name, key in colors:
-            self.assertIn(fg(30 + key), str(colored().names[name]('foo')))
-
-        self.assertTrue(str(colored().bold('f')))
-        self.assertTrue(str(colored().underline('f')))
-        self.assertTrue(str(colored().blink('f')))
-        self.assertTrue(str(colored().reverse('f')))
-        self.assertTrue(str(colored().bright('f')))
-        self.assertTrue(str(colored().ired('f')))
-        self.assertTrue(str(colored().igreen('f')))
-        self.assertTrue(str(colored().iyellow('f')))
-        self.assertTrue(str(colored().iblue('f')))
-        self.assertTrue(str(colored().imagenta('f')))
-        self.assertTrue(str(colored().icyan('f')))
-        self.assertTrue(str(colored().iwhite('f')))
-        self.assertTrue(str(colored().reset('f')))
-
-        self.assertTrue(text_t(colored().green('∂bar')))
-
-        self.assertTrue(
-            colored().red('éefoo') + colored().green('∂bar'))
-
-        self.assertEqual(
-            colored().red('foo').no_color(), 'foo')
-
-        self.assertTrue(
-            repr(colored().blue('åfoo')))
-
-        self.assertIn("''", repr(colored()))
-
-        c = colored()
-        s = c.red('foo', c.blue('bar'), c.green('baz'))
-        self.assertTrue(s.no_color())
-
-        c._fold_no_color(s, 'øfoo')
-        c._fold_no_color('fooå', s)
-
-        c = colored().red('åfoo')
-        self.assertEqual(
-            c._add(c, 'baræ'),
-            '\x1b[1;31m\xe5foo\x1b[0mbar\xe6',
-        )
-
-        c2 = colored().blue('ƒƒz')
-        c3 = c._add(c, c2)
-        self.assertEqual(
-            c3,
-            '\x1b[1;31m\xe5foo\x1b[0m\x1b[1;34m\u0192\u0192z\x1b[0m',
-        )

+ 0 - 94
celery/tests/utils/test_text.py

@@ -1,94 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils.text import (
-    abbr,
-    abbrtask,
-    ensure_newlines,
-    indent,
-    pretty,
-    truncate,
-    truncate_bytes,
-)
-
-from celery.tests.case import AppCase, Case
-
-RANDTEXT = """\
-The quick brown
-fox jumps
-over the
-lazy dog\
-"""
-
-RANDTEXT_RES = """\
-    The quick brown
-    fox jumps
-    over the
-    lazy dog\
-"""
-
-QUEUES = {
-    'queue1': {
-        'exchange': 'exchange1',
-        'exchange_type': 'type1',
-        'routing_key': 'bind1',
-    },
-    'queue2': {
-        'exchange': 'exchange2',
-        'exchange_type': 'type2',
-        'routing_key': 'bind2',
-    },
-}
-
-
-QUEUE_FORMAT1 = '.> queue1           exchange=exchange1(type1) key=bind1'
-QUEUE_FORMAT2 = '.> queue2           exchange=exchange2(type2) key=bind2'
-
-
-class test_Info(AppCase):
-
-    def test_textindent(self):
-        self.assertEqual(indent(RANDTEXT, 4), RANDTEXT_RES)
-
-    def test_format_queues(self):
-        self.app.amqp.queues = self.app.amqp.Queues(QUEUES)
-        self.assertEqual(sorted(self.app.amqp.queues.format().split('\n')),
-                         sorted([QUEUE_FORMAT1, QUEUE_FORMAT2]))
-
-    def test_ensure_newlines(self):
-        self.assertEqual(
-            len(ensure_newlines('foo\nbar\nbaz\n').splitlines()), 3,
-        )
-        self.assertEqual(
-            len(ensure_newlines('foo\nbar').splitlines()), 2,
-        )
-
-
-class test_utils(Case):
-
-    def test_truncate_text(self):
-        self.assertEqual(truncate('ABCDEFGHI', 3), 'ABC...')
-        self.assertEqual(truncate('ABCDEFGHI', 10), 'ABCDEFGHI')
-
-    def test_truncate_bytes(self):
-        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 3), b'ABC...')
-        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 10), b'ABCDEFGHI')
-
-    def test_abbr(self):
-        self.assertEqual(abbr(None, 3), '???')
-        self.assertEqual(abbr('ABCDEFGHI', 6), 'ABC...')
-        self.assertEqual(abbr('ABCDEFGHI', 20), 'ABCDEFGHI')
-        self.assertEqual(abbr('ABCDEFGHI', 6, None), 'ABCDEF')
-
-    def test_abbrtask(self):
-        self.assertEqual(abbrtask(None, 3), '???')
-        self.assertEqual(
-            abbrtask('feeds.tasks.refresh', 10),
-            '[.]refresh',
-        )
-        self.assertEqual(
-            abbrtask('feeds.tasks.refresh', 30),
-            'feeds.tasks.refresh',
-        )
-
-    def test_pretty(self):
-        self.assertTrue(pretty(('a', 'b', 'c')))

+ 0 - 253
celery/tests/utils/test_timeutils.py

@@ -1,253 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pytz
-
-from datetime import datetime, timedelta, tzinfo
-from pytz import AmbiguousTimeError
-
-from celery.utils.time import (
-    delta_resolution,
-    humanize_seconds,
-    maybe_iso8601,
-    maybe_timedelta,
-    timezone,
-    rate,
-    remaining,
-    make_aware,
-    maybe_make_aware,
-    localize,
-    LocalTimezone,
-    ffwd,
-    utcoffset,
-)
-from celery.utils.iso8601 import parse_iso8601
-from celery.tests.case import Case, Mock, patch
-
-
-class test_LocalTimezone(Case):
-
-    def test_daylight(self):
-        with patch('celery.utils.time._time') as time:
-            time.timezone = 3600
-            time.daylight = False
-            x = LocalTimezone()
-            self.assertEqual(x.STDOFFSET, timedelta(seconds=-3600))
-            self.assertEqual(x.DSTOFFSET, x.STDOFFSET)
-            time.daylight = True
-            time.altzone = 3600
-            y = LocalTimezone()
-            self.assertEqual(y.STDOFFSET, timedelta(seconds=-3600))
-            self.assertEqual(y.DSTOFFSET, timedelta(seconds=-3600))
-
-            self.assertTrue(repr(y))
-
-            y._isdst = Mock()
-            y._isdst.return_value = True
-            self.assertTrue(y.utcoffset(datetime.now()))
-            self.assertFalse(y.dst(datetime.now()))
-            y._isdst.return_value = False
-            self.assertTrue(y.utcoffset(datetime.now()))
-            self.assertFalse(y.dst(datetime.now()))
-
-            self.assertTrue(y.tzname(datetime.now()))
-
-
-class test_iso8601(Case):
-
-    def test_parse_with_timezone(self):
-        d = datetime.utcnow().replace(tzinfo=pytz.utc)
-        self.assertEqual(parse_iso8601(d.isoformat()), d)
-        # 2013-06-07T20:12:51.775877+00:00
-        iso = d.isoformat()
-        iso1 = iso.replace('+00:00', '-01:00')
-        d1 = parse_iso8601(iso1)
-        self.assertEqual(d1.tzinfo._minutes, -60)
-        iso2 = iso.replace('+00:00', '+01:00')
-        d2 = parse_iso8601(iso2)
-        self.assertEqual(d2.tzinfo._minutes, +60)
-        iso3 = iso.replace('+00:00', 'Z')
-        d3 = parse_iso8601(iso3)
-        self.assertEqual(d3.tzinfo, pytz.UTC)
-
-
-class test_time_utils(Case):
-
-    def test_delta_resolution(self):
-        D = delta_resolution
-        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
-        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
-                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
-                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
-                    (timedelta(seconds=2), dt))
-        for delta, shoulda in deltamap:
-            self.assertEqual(D(dt, delta), shoulda)
-
-    def test_humanize_seconds(self):
-        t = ((4 * 60 * 60 * 24, '4.00 days'),
-             (1 * 60 * 60 * 24, '1.00 day'),
-             (4 * 60 * 60, '4.00 hours'),
-             (1 * 60 * 60, '1.00 hour'),
-             (4 * 60, '4.00 minutes'),
-             (1 * 60, '1.00 minute'),
-             (4, '4.00 seconds'),
-             (1, '1.00 second'),
-             (4.3567631221, '4.36 seconds'),
-             (0, 'now'))
-
-        for seconds, human in t:
-            self.assertEqual(humanize_seconds(seconds), human)
-
-        self.assertEqual(humanize_seconds(4, prefix='about '),
-                         'about 4.00 seconds')
-
-    def test_maybe_iso8601_datetime(self):
-        now = datetime.now()
-        self.assertIs(maybe_iso8601(now), now)
-
-    def test_maybe_timedelta(self):
-        D = maybe_timedelta
-
-        for i in (30, 30.6):
-            self.assertEqual(D(i), timedelta(seconds=i))
-
-        self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
-
-    def test_remaining_relative(self):
-        remaining(datetime.utcnow(), timedelta(hours=1), relative=True)
-
-
-class test_timezone(Case):
-
-    def test_get_timezone_with_pytz(self):
-        self.assertTrue(timezone.get_timezone('UTC'))
-
-    def test_tz_or_local(self):
-        self.assertEqual(timezone.tz_or_local(), timezone.local)
-        self.assertTrue(timezone.tz_or_local(timezone.utc))
-
-    def test_to_local(self):
-        self.assertTrue(
-            timezone.to_local(make_aware(datetime.utcnow(), timezone.utc)),
-        )
-        self.assertTrue(
-            timezone.to_local(datetime.utcnow())
-        )
-
-    def test_to_local_fallback(self):
-        self.assertTrue(
-            timezone.to_local_fallback(
-                make_aware(datetime.utcnow(), timezone.utc)),
-        )
-        self.assertTrue(
-            timezone.to_local_fallback(datetime.utcnow())
-        )
-
-
-class test_make_aware(Case):
-
-    def test_tz_without_localize(self):
-        tz = tzinfo()
-        self.assertFalse(hasattr(tz, 'localize'))
-        wtz = make_aware(datetime.utcnow(), tz)
-        self.assertEqual(wtz.tzinfo, tz)
-
-    def test_when_has_localize(self):
-
-        class tzz(tzinfo):
-            raises = False
-
-            def localize(self, dt, is_dst=None):
-                self.localized = True
-                if self.raises and is_dst is None:
-                    self.raised = True
-                    raise AmbiguousTimeError()
-                return 1  # needed by min() in Python 3 (None not hashable)
-
-        tz = tzz()
-        make_aware(datetime.utcnow(), tz)
-        self.assertTrue(tz.localized)
-
-        tz2 = tzz()
-        tz2.raises = True
-        make_aware(datetime.utcnow(), tz2)
-        self.assertTrue(tz2.localized)
-        self.assertTrue(tz2.raised)
-
-    def test_maybe_make_aware(self):
-        aware = datetime.utcnow().replace(tzinfo=timezone.utc)
-        self.assertTrue(maybe_make_aware(aware), timezone.utc)
-        naive = datetime.utcnow()
-        self.assertTrue(maybe_make_aware(naive))
-
-
-class test_localize(Case):
-
-    def test_tz_without_normalize(self):
-        tz = tzinfo()
-        self.assertFalse(hasattr(tz, 'normalize'))
-        self.assertTrue(localize(make_aware(datetime.utcnow(), tz), tz))
-
-    def test_when_has_normalize(self):
-
-        class tzz(tzinfo):
-            raises = None
-
-            def normalize(self, dt, **kwargs):
-                self.normalized = True
-                if self.raises and kwargs and kwargs.get('is_dst') is None:
-                    self.raised = True
-                    raise self.raises
-                return 1  # needed by min() in Python 3 (None not hashable)
-
-        tz = tzz()
-        localize(make_aware(datetime.utcnow(), tz), tz)
-        self.assertTrue(tz.normalized)
-
-        tz2 = tzz()
-        tz2.raises = AmbiguousTimeError()
-        localize(make_aware(datetime.utcnow(), tz2), tz2)
-        self.assertTrue(tz2.normalized)
-        self.assertTrue(tz2.raised)
-
-        tz3 = tzz()
-        tz3.raises = TypeError()
-        localize(make_aware(datetime.utcnow(), tz3), tz3)
-        self.assertTrue(tz3.normalized)
-        self.assertTrue(tz3.raised)
-
-
-class test_rate_limit_string(Case):
-
-    def test_conversion(self):
-        self.assertEqual(rate(999), 999)
-        self.assertEqual(rate(7.5), 7.5)
-        self.assertEqual(rate('2.5/s'), 2.5)
-        self.assertEqual(rate('1456/s'), 1456)
-        self.assertEqual(rate('100/m'),
-                         100 / 60.0)
-        self.assertEqual(rate('10/h'),
-                         10 / 60.0 / 60.0)
-
-        for zero in (0, None, '0', '0/m', '0/h', '0/s', '0.0/s'):
-            self.assertEqual(rate(zero), 0)
-
-
-class test_ffwd(Case):
-
-    def test_repr(self):
-        x = ffwd(year=2012)
-        self.assertTrue(repr(x))
-
-    def test_radd_with_unknown_gives_NotImplemented(self):
-        x = ffwd(year=2012)
-        self.assertEqual(x.__radd__(object()), NotImplemented)
-
-
-class test_utcoffset(Case):
-
-    def test_utcoffset(self):
-        with patch('celery.utils.time._time') as _time:
-            _time.daylight = True
-            self.assertIsNotNone(utcoffset(time=_time))
-            _time.daylight = False
-            self.assertIsNotNone(utcoffset(time=_time))

+ 0 - 44
celery/tests/utils/test_utils.py

@@ -1,44 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import chunks, cached_property
-
-from celery.tests.case import Case
-
-
-class test_chunks(Case):
-
-    def test_chunks(self):
-
-        # n == 2
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertListEqual(
-            list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]],
-        )
-
-        # n == 3
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertListEqual(
-            list(x),
-            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]],
-        )
-
-        # n == 2 (exact)
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertListEqual(
-            list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
-        )
-
-
-class test_utils(Case):
-
-    def test_cached_property(self):
-
-        def fun(obj):
-            return fun.value
-
-        x = cached_property(fun)
-        self.assertIs(x.__get__(None), x)
-        self.assertIs(x.__set__(None, None), x)
-        self.assertIs(x.__delete__(None), x)

+ 7 - 11
docs/contributing.rst

@@ -471,13 +471,13 @@ dependencies, so install these next:
     $ pip install -U -r requirements/default.txt
 
 After installing the dependencies required, you can now execute
-the test suite by calling :pypi:`nosetests <nose>`:
+the test suite by calling :pypi:`py.test <pytest`:
 
 .. code-block:: console
 
-    $ nosetests
+    $ py.test
 
-Some useful options to :command:`nosetests` are:
+Some useful options to :command:`py.test` are:
 
 * ``-x``
 
@@ -487,10 +487,6 @@ Some useful options to :command:`nosetests` are:
 
     Don't capture output
 
-* ``-nologcapture``
-
-    Don't capture log output.
-
 * ``-v``
 
     Run with verbose output.
@@ -500,7 +496,7 @@ you can do so like this:
 
 .. code-block:: console
 
-    $ nosetests celery.tests.test_worker.test_worker_job
+    $ py.test t/unit/worker/test_worker_job.py
 
 .. _contributing-pull-requests:
 
@@ -536,16 +532,16 @@ Code coverage in HTML:
 
 .. code-block:: console
 
-    $ nosetests --with-coverage --cover-html
+    $ py.test --cov=celery --cov-report=html
 
 The coverage output will then be located at
-:file:`celery/tests/cover/index.html`.
+:file:`cover/index.html`.
 
 Code coverage in XML (Cobertura-style):
 
 .. code-block:: console
 
-    $ nosetests --with-coverage --cover-xml --cover-xml-file=coverage.xml
+    $ py.test --cov=celery --cov-report=xml
 
 The coverage XML output will then be located at :file:`coverage.xml`
 

+ 2 - 2
docs/internals/guide.rst

@@ -292,9 +292,9 @@ Module Overview
 
     single-mode interface to creating tasks, and controlling workers.
 
-- celery.tests
+- t.unit (int distribution)
 
-    The unittest suite.
+    The unit test suite.
 
 - celery.utils
 

+ 1 - 1
docs/whatsnew-4.0.rst

@@ -1004,7 +1004,7 @@ Instead of using router classes you can now simply define a function:
             return {'queue': 'hipri'}
 
 If you don't need the arguments you can use start arguments, just make
-sure you always also accept star arguments so that we've the ability
+sure you always also accept star arguments so that we have the ability
 to add more features in the future:
 
 .. code-block:: python

+ 5 - 2
funtests/suite/test_leak.py

@@ -5,10 +5,13 @@ import os
 import sys
 import shlex
 import subprocess
+import unittest
+
+from case import Case
+from case.skip import SkipTest
 
 from celery import current_app
 from celery.five import range
-from celery.tests.case import SkipTest, unittest
 
 import suite  # noqa
 
@@ -25,7 +28,7 @@ class Sizes(list):
         return sum(self) / len(self)
 
 
-class LeakFunCase(unittest.TestCase):
+class LeakFunCase(Case):
 
     def setUp(self):
         self.app = current_app

+ 1 - 0
requirements/test-ci-base.txt

@@ -1,4 +1,5 @@
 coverage>=3.0
+pytest-cov
 codecov
 -r extras/redis.txt
 -r extras/sqlalchemy.txt

+ 2 - 1
requirements/test.txt

@@ -1 +1,2 @@
-case>=1.2.2
+case>=1.3.0
+pytest

+ 3 - 2
setup.cfg

@@ -1,5 +1,6 @@
-[nosetests]
-where = celery/tests
+[tool:pytest]
+testpaths = t/
+python_classes = test_*
 
 [build_sphinx]
 source-dir = docs/

+ 39 - 35
setup.py

@@ -1,12 +1,14 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-from setuptools import setup, find_packages
-
+import codecs
 import os
 import re
 import sys
-import codecs
+
+import setuptools
+import setuptools.command.test
+
 
 try:
     import platform
@@ -15,6 +17,8 @@ except (AttributeError, ImportError):
     def _pyimp():
         return 'Python'
 
+NAME = 'celery'
+
 E_UNSUPPORTED_PYTHON = """
 ----------------------------------------
  Celery 4.0 requires %s %s or later
@@ -80,10 +84,6 @@ except:
 finally:
     sys.path[:] = orig_path
 
-NAME = 'celery'
-entrypoints = {}
-extra = {}
-
 # -*- Classifiers -*-
 
 classes = """
@@ -154,6 +154,10 @@ def _reqs(*f):
 def reqs(*f):
     return [req for subreq in _reqs(*f) for req in subreq]
 
+
+def extras(*p):
+    return reqs('extras', *p)
+
 install_requires = reqs('default.txt')
 if JYTHON:
     install_requires.extend(reqs('jython.txt'))
@@ -165,46 +169,46 @@ if os.path.exists('README.rst'):
 else:
     long_description = 'See http://pypi.python.org/pypi/celery'
 
-# -*- Entry Points -*- #
-
-console_scripts = entrypoints['console_scripts'] = [
-    'celery = celery.__main__:main',
-]
+# -*- %%% -*-
 
-# -*- Extras -*-
 
+class pytest(setuptools.command.test.test):
+    user_options = [('pytest-args=', 'a', 'Arguments to pass to py.test')]
 
-def extras(*p):
-    return reqs('extras', *p)
+    def initialize_options(self):
+        setuptools.command.test.test.initialize_options(self)
+        self.pytest_args = []
 
-# Celery specific
-features = set([
-    'auth', 'cassandra', 'elasticsearch', 'memcache', 'pymemcache',
-    'couchbase', 'eventlet', 'gevent', 'msgpack', 'yaml',
-    'redis', 'sqs', 'couchdb', 'riak', 'zookeeper', 'solar',
-    'sqlalchemy', 'librabbitmq', 'pyro', 'slmq', 'tblib', 'consul'
-])
-extras_require = dict((x, extras(x + '.txt')) for x in features)
-extra['extras_require'] = extras_require
+    def run_tests(self):
+        import pytest
+        sys.exit(pytest.main(self.pytest_args))
 
-# -*- %%% -*-
-
-setup(
+setuptools.setup(
     name=NAME,
+    packages=setuptools.find_packages(exclude=['t', 't.*']),
     version=meta['version'],
     description=meta['doc'],
+    long_description=long_description,
     author=meta['author'],
     author_email=meta['contact'],
-    url=meta['homepage'],
     platforms=['any'],
     license='BSD',
-    packages=find_packages(),
-    include_package_data=True,
-    zip_safe=False,
+    url=meta['homepage'],
     install_requires=install_requires,
     tests_require=reqs('test.txt'),
-    test_suite='nose.collector',
+    extras_require=dict((x, extras(x + '.txt')) for x in set([
+        'auth', 'cassandra', 'elasticsearch', 'memcache', 'pymemcache',
+        'couchbase', 'eventlet', 'gevent', 'msgpack', 'yaml',
+        'redis', 'sqs', 'couchdb', 'riak', 'zookeeper', 'solar',
+        'sqlalchemy', 'librabbitmq', 'pyro', 'slmq', 'tblib', 'consul'
+    ])),
     classifiers=classifiers,
-    entry_points=entrypoints,
-    long_description=long_description,
-    **extra)
+    entry_points={
+        'console_scripts': [
+            'celery = celery.__main__:main',
+        ]
+    },
+    cmdclass={'test': pytest},
+    include_package_data=True,
+    zip_safe=False,
+)

+ 0 - 0
celery/tests/app/__init__.py → t/__init__.py


+ 418 - 0
t/conftest.py

@@ -0,0 +1,418 @@
+from __future__ import absolute_import, unicode_literals
+
+import logging
+import numbers
+import os
+import pytest
+import sys
+import threading
+import warnings
+import weakref
+
+from copy import deepcopy
+from datetime import datetime, timedelta
+from functools import partial
+from importlib import import_module
+
+from case import Mock
+from case.utils import decorator
+from kombu import Queue
+from kombu.utils.imports import symbol_by_name
+
+from celery import Celery
+from celery.app import current_app
+from celery.backends.cache import CacheBackend, DummyClient
+
+try:
+    WindowsError = WindowsError  # noqa
+except NameError:
+
+    class WindowsError(Exception):
+        pass
+
+PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
+
+CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
+CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
+CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
+
+CELERY_TEST_CONFIG = {
+    #: Don't want log output when running suite.
+    'worker_hijack_root_logger': False,
+    'worker_log_color': False,
+    'task_default_queue': 'testcelery',
+    'task_default_exchange': 'testcelery',
+    'task_default_routing_key': 'testcelery',
+    'task_queues': (
+        Queue('testcelery', routing_key='testcelery'),
+    ),
+    'accept_content': ('json', 'pickle'),
+    'enable_utc': True,
+    'timezone': 'UTC',
+
+    # Mongo results tests (only executed if installed and running)
+    'mongodb_backend_settings': {
+        'host': os.environ.get('MONGO_HOST') or 'localhost',
+        'port': os.environ.get('MONGO_PORT') or 27017,
+        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
+        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
+                                'taskmeta_collection'),
+        'user': os.environ.get('MONGO_USER'),
+        'password': os.environ.get('MONGO_PASSWORD'),
+    }
+}
+
+
+class Trap(object):
+
+    def __getattr__(self, name):
+        raise RuntimeError('Test depends on current_app')
+
+
+class UnitLogging(symbol_by_name(Celery.log_cls)):
+
+    def __init__(self, *args, **kwargs):
+        super(UnitLogging, self).__init__(*args, **kwargs)
+        self.already_setup = True
+
+
+def TestApp(name=None, set_as_current=False, log=UnitLogging,
+            broker='memory://', backend='cache+memory://', **kwargs):
+    app = Celery(name or 'celery.tests',
+                 set_as_current=set_as_current,
+                 log=log, broker=broker, backend=backend,
+                 **kwargs)
+    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
+    return app
+
+
+def alive_threads():
+    return [thread for thread in threading.enumerate() if thread.is_alive()]
+
+
+@pytest.fixture(autouse=True)
+def task_join_will_not_block(request):
+    from celery import _state
+    from celery import result
+    prev_res_join_block = result.task_join_will_block
+    _state.orig_task_join_will_block = _state.task_join_will_block
+    prev_state_join_block = _state.task_join_will_block
+    result.task_join_will_block = \
+        _state.task_join_will_block = lambda: False
+    _state._set_task_join_will_block(False)
+
+    def fin():
+        result.task_join_will_block = prev_res_join_block
+        _state.task_join_will_block = prev_state_join_block
+        _state._set_task_join_will_block(False)
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(scope='session', autouse=True)
+def record_threads_at_startup(request):
+    try:
+        request.session._threads_at_startup
+    except AttributeError:
+        request.session._threads_at_startup = alive_threads()
+
+
+@pytest.fixture(autouse=True)
+def threads_not_lingering(request):
+    def fin():
+        assert request.session._threads_at_startup == alive_threads()
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def app(request):
+    from celery import _state
+    prev_current_app = current_app()
+    prev_default_app = _state.default_app
+    prev_finalizers = set(_state._on_app_finalizers)
+    prev_apps = weakref.WeakSet(_state._apps)
+    trap = Trap()
+    prev_tls = _state._tls
+    _state.set_default_app(trap)
+
+    class NonTLS(object):
+        current_app = trap
+    _state._tls = NonTLS()
+
+    app = TestApp(set_as_current=False)
+    is_not_contained = any([
+        not getattr(request.module, 'app_contained', True),
+        not getattr(request.cls, 'app_contained', True),
+        not getattr(request.function, 'app_contained', True)
+    ])
+    if is_not_contained:
+        app.set_current()
+
+    def fin():
+        _state.set_default_app(prev_default_app)
+        _state._tls = prev_tls
+        _state._tls.current_app = prev_current_app
+        if app is not prev_current_app:
+            app.close()
+        _state._on_app_finalizers = prev_finalizers
+        _state._apps = prev_apps
+    request.addfinalizer(fin)
+    return app
+
+
+@pytest.fixture()
+def depends_on_current_app(app):
+    app.set_current()
+
+
+@pytest.fixture(autouse=True)
+def test_cases_shortcuts(request, app, patching):
+    if request.instance:
+        @app.task
+        def add(x, y):
+            return x + y
+
+        # IMPORTANT: We set an .app attribute for every test case class.
+        request.instance.app = app
+        request.instance.Celery = TestApp
+        request.instance.assert_signal_called = assert_signal_called
+        request.instance.task_message_from_sig = task_message_from_sig
+        request.instance.TaskMessage = TaskMessage
+        request.instance.TaskMessage1 = TaskMessage1
+        request.instance.CELERY_TEST_CONFIG = dict(CELERY_TEST_CONFIG)
+        request.instance.add = add
+        request.instance.patching = patching
+
+        def fin():
+            request.instance.app = None
+        request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def zzzz_test_cases_calls_setup_teardown(request):
+    if request.instance:
+        # we set the .patching attribute for every test class.
+        setup = getattr(request.instance, 'setup', None)
+        # we also call .setup() and .teardown() after every test method.
+        teardown = getattr(request.instance, 'teardown', None)
+        setup and setup()
+        teardown and request.addfinalizer(teardown)
+
+
+@pytest.fixture(autouse=True)
+def sanity_no_shutdown_flags_set(request):
+    def fin():
+        # Make sure no test left the shutdown flags enabled.
+        from celery.worker import state as worker_state
+        # check for EX_OK
+        assert worker_state.should_stop is not False
+        assert worker_state.should_terminate is not False
+        # check for other true values
+        assert not worker_state.should_stop
+        assert not worker_state.should_terminate
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def reset_cache_backend_state(request, app):
+    def fin():
+        backend = app.__dict__.get('backend')
+        if backend is not None:
+            if isinstance(backend, CacheBackend):
+                if isinstance(backend.client, DummyClient):
+                    backend.client.cache.clear()
+                backend._cache.clear()
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def sanity_stdouts(request):
+    def fin():
+        from celery.utils.log import LoggingProxy
+        assert sys.stdout
+        assert sys.stderr
+        assert sys.__stdout__
+        assert sys.__stderr__
+        this = request.node.name
+        if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stdout__, (LoggingProxy, Mock)):
+            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
+        if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stderr__, (LoggingProxy, Mock)):
+            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def sanity_logging_side_effects(request):
+    root = logging.getLogger()
+    rootlevel = root.level
+    roothandlers = root.handlers
+
+    def fin():
+        this = request.node.name
+        root_now = logging.getLogger()
+        if root_now.level != rootlevel:
+            raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
+        if root_now.handlers != roothandlers:
+            raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
+    request.addfinalizer(fin)
+
+
+def setup_session(scope='session'):
+    using_coverage = (
+        os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
+    )
+    os.environ.update(
+        # warn if config module not found
+        C_WNOCONF='yes',
+        KOMBU_DISABLE_LIMIT_PROTECTION='yes',
+    )
+
+    if using_coverage and not PYPY3:
+        from warnings import catch_warnings
+        with catch_warnings(record=True):
+            import_all_modules()
+        warnings.resetwarnings()
+    from celery._state import set_default_app
+    set_default_app(Trap())
+
+
+def teardown():
+    # Don't want SUBDEBUG log messages at finalization.
+    try:
+        from multiprocessing.util import get_logger
+    except ImportError:
+        pass
+    else:
+        get_logger().setLevel(logging.WARNING)
+
+    # Make sure test database is removed.
+    import os
+    if os.path.exists('test.db'):
+        try:
+            os.remove('test.db')
+        except WindowsError:
+            pass
+
+    # Make sure there are no remaining threads at shutdown.
+    import threading
+    remaining_threads = [thread for thread in threading.enumerate()
+                         if thread.getName() != 'MainThread']
+    if remaining_threads:
+        sys.stderr.write(
+            '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
+                remaining_threads))
+
+
+def find_distribution_modules(name=__name__, file=__file__):
+    current_dist_depth = len(name.split('.')) - 1
+    current_dist = os.path.join(os.path.dirname(file),
+                                *([os.pardir] * current_dist_depth))
+    abs = os.path.abspath(current_dist)
+    dist_name = os.path.basename(abs)
+
+    for dirpath, dirnames, filenames in os.walk(abs):
+        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
+        if '__init__.py' in filenames:
+            yield package
+            for filename in filenames:
+                if filename.endswith('.py') and filename != '__init__.py':
+                    yield '.'.join([package, filename])[:-3]
+
+
+def import_all_modules(name=__name__, file=__file__,
+                       skip=('celery.decorators',
+                             'celery.task')):
+    for module in find_distribution_modules(name, file):
+        if not module.startswith(skip):
+            try:
+                import_module(module)
+            except ImportError:
+                pass
+            except OSError as exc:
+                warnings.warn(UserWarning(
+                    'Ignored error importing module {0}: {1!r}'.format(
+                        module, exc,
+                    )))
+
+
+@decorator
+def assert_signal_called(signal, **expected):
+    handler = Mock()
+    call_handler = partial(handler)
+    signal.connect(call_handler)
+    try:
+        yield handler
+    finally:
+        signal.disconnect(call_handler)
+    handler.assert_called_with(signal=signal, **expected)
+
+
+def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
+                errbacks=None, chain=None, shadow=None, utc=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {
+        'id': id,
+        'task': name,
+        'shadow': shadow,
+    }
+    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
+    message.headers.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        (args, kwargs, embed), serializer='json',
+    )
+    message.payload = (args, kwargs, embed)
+    return message
+
+
+def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
+                 errbacks=None, chain=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {}
+    message.payload = {
+        'task': name,
+        'id': id,
+        'args': args,
+        'kwargs': kwargs,
+        'callbacks': callbacks,
+        'errbacks': errbacks,
+    }
+    message.payload.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        message.payload,
+    )
+    return message
+
+
+def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
+    sig.freeze()
+    callbacks = sig.options.pop('link', None)
+    errbacks = sig.options.pop('link_error', None)
+    countdown = sig.options.pop('countdown', None)
+    if countdown:
+        eta = app.now() + timedelta(seconds=countdown)
+    else:
+        eta = sig.options.pop('eta', None)
+    if eta and isinstance(eta, datetime):
+        eta = eta.isoformat()
+    expires = sig.options.pop('expires', None)
+    if expires and isinstance(expires, numbers.Real):
+        expires = app.now() + timedelta(seconds=expires)
+    if expires and isinstance(expires, datetime):
+        expires = expires.isoformat()
+    return TaskMessage(
+        sig.task, id=sig.id, args=sig.args,
+        kwargs=sig.kwargs,
+        callbacks=[dict(s) for s in callbacks] if callbacks else None,
+        errbacks=[dict(s) for s in errbacks] if errbacks else None,
+        eta=eta,
+        expires=expires,
+        utc=utc,
+        **sig.options
+    )

+ 0 - 0
celery/tests/apps/__init__.py → t/unit/__init__.py


+ 0 - 0
celery/tests/backends/__init__.py → t/unit/app/__init__.py


+ 269 - 0
t/unit/app/test_amqp.py

@@ -0,0 +1,269 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from datetime import datetime, timedelta
+
+from case import Mock
+from kombu import Exchange, Queue
+
+from celery import uuid
+from celery.app.amqp import Queues, utf8dict
+from celery.five import keys
+from celery.utils.time import to_utc
+
+
+class test_TaskConsumer:
+
+    def test_accept_content(self, app):
+        with app.pool.acquire(block=True) as con:
+            app.conf.accept_content = ['application/json']
+            assert app.amqp.TaskConsumer(con).accept == {
+                'application/json',
+            }
+            assert app.amqp.TaskConsumer(con, accept=['json']).accept == {
+                'application/json',
+            }
+
+
+class test_ProducerPool:
+
+    def test_setup_nolimit(self, app):
+        app.conf.broker_pool_limit = None
+        try:
+            delattr(app, '_pool')
+        except AttributeError:
+            pass
+        app.amqp._producer_pool = None
+        pool = app.amqp.producer_pool
+        assert pool.limit == app.pool.limit
+        assert not pool._resource.queue
+
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+        r1.release()
+        r2.release()
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+
+    def test_setup(self, app):
+        app.conf.broker_pool_limit = 2
+        try:
+            delattr(app, '_pool')
+        except AttributeError:
+            pass
+        app.amqp._producer_pool = None
+        pool = app.amqp.producer_pool
+        assert pool.limit == app.pool.limit
+        assert pool._resource.queue
+
+        p1 = r1 = pool.acquire()
+        p2 = r2 = pool.acquire()
+        r1.release()
+        r2.release()
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+        assert p2 is r1
+        assert p1 is r2
+        r1.release()
+        r2.release()
+
+
+class test_Queues:
+
+    def test_queues_format(self):
+        self.app.amqp.queues._consume_from = {}
+        assert self.app.amqp.queues.format() == ''
+
+    def test_with_defaults(self):
+        assert Queues(None) == {}
+
+    def test_add(self):
+        q = Queues()
+        q.add('foo', exchange='ex', routing_key='rk')
+        assert 'foo' in q
+        assert isinstance(q['foo'], Queue)
+        assert q['foo'].routing_key == 'rk'
+
+    @pytest.mark.parametrize('ha_policy,qname,q,qargs,expected', [
+        (None, 'xyz', 'xyz', None, None),
+        (None, 'xyz', 'xyz', {'x-foo': 'bar'}, {'x-foo': 'bar'}),
+        ('all', 'foo', Queue('foo'), None, {'x-ha-policy': 'all'}),
+        ('all', 'xyx2',
+         Queue('xyx2', queue_arguments={'x-foo': 'bari'}),
+         None,
+         {'x-ha-policy': 'all', 'x-foo': 'bari'}),
+        (['A', 'B', 'C'], 'foo', Queue('foo'), None, {
+            'x-ha-policy': 'nodes',
+            'x-ha-policy-params': ['A', 'B', 'C']}),
+    ])
+    def test_with_ha_policy(self, ha_policy, qname, q, qargs, expected):
+        queues = Queues(ha_policy=ha_policy, create_missing=False)
+        queues.add(q, queue_arguments=qargs)
+        assert queues[qname].queue_arguments == expected
+
+    def test_select_add(self):
+        q = Queues()
+        q.select(['foo', 'bar'])
+        q.select_add('baz')
+        assert sorted(keys(q._consume_from)) == ['bar', 'baz', 'foo']
+
+    def test_deselect(self):
+        q = Queues()
+        q.select(['foo', 'bar'])
+        q.deselect('bar')
+        assert sorted(keys(q._consume_from)) == ['foo']
+
+    def test_with_ha_policy_compat(self):
+        q = Queues(ha_policy='all')
+        q.add('bar')
+        assert q['bar'].queue_arguments == {'x-ha-policy': 'all'}
+
+    def test_add_default_exchange(self):
+        ex = Exchange('fff', 'fanout')
+        q = Queues(default_exchange=ex)
+        q.add(Queue('foo'))
+        assert q['foo'].exchange.name == ''
+
+    def test_alias(self):
+        q = Queues()
+        q.add(Queue('foo', alias='barfoo'))
+        assert q['barfoo'] is q['foo']
+
+    @pytest.mark.parametrize('queues_kwargs,qname,q,expected', [
+        (dict(max_priority=10),
+         'foo', 'foo', {'x-max-priority': 10}),
+        (dict(max_priority=10),
+         'xyz', Queue('xyz', queue_arguments={'x-max-priority': 3}),
+         {'x-max-priority': 3}),
+        (dict(max_priority=10),
+         'moo', Queue('moo', queue_arguments=None),
+         {'x-max-priority': 10}),
+        (dict(ha_policy='all', max_priority=5),
+         'bar', 'bar',
+         {'x-ha-policy': 'all', 'x-max-priority': 5}),
+        (dict(ha_policy='all', max_priority=5),
+         'xyx2', Queue('xyx2', queue_arguments={'x-max-priority': 2}),
+         {'x-ha-policy': 'all', 'x-max-priority': 2}),
+        (dict(max_priority=None),
+         'foo2', 'foo2',
+         None),
+        (dict(max_priority=None),
+         'xyx3', Queue('xyx3', queue_arguments={'x-max-priority': 7}),
+         {'x-max-priority': 7}),
+
+    ])
+    def test_with_max_priority(self, queues_kwargs, qname, q, expected):
+        queues = Queues(**queues_kwargs)
+        queues.add(q)
+        assert queues[qname].queue_arguments == expected
+
+
+class test_AMQP:
+
+    def setup(self):
+        self.simple_message = self.app.amqp.as_task_v2(
+            uuid(), 'foo', create_sent_event=True,
+        )
+
+    def test_Queues__with_ha_policy(self):
+        x = self.app.amqp.Queues({}, ha_policy='all')
+        assert x.ha_policy == 'all'
+
+    def test_Queues__with_max_priority(self):
+        x = self.app.amqp.Queues({}, max_priority=23)
+        assert x.max_priority == 23
+
+    def test_send_task_message__no_kwargs(self):
+        self.app.amqp.send_task_message(Mock(), 'foo', self.simple_message)
+
+    def test_send_task_message__properties(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, foo=1, retry=False,
+        )
+        assert prod.publish.call_args[1]['foo'] == 1
+
+    def test_send_task_message__headers(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, headers={'x1x': 'y2x'},
+            retry=False,
+        )
+        assert prod.publish.call_args[1]['headers']['x1x'] == 'y2x'
+
+    def test_send_task_message__queue_string(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, queue='foo', retry=False,
+        )
+        kwargs = prod.publish.call_args[1]
+        assert kwargs['routing_key'] == 'foo'
+        assert kwargs['exchange'] == ''
+
+    def test_send_event_exchange_string(self):
+        evd = Mock(name='evd')
+        self.app.amqp.send_task_message(
+            Mock(), 'foo', self.simple_message, retry=False,
+            exchange='xyz', routing_key='xyb',
+            event_dispatcher=evd,
+        )
+        evd.publish.assert_called()
+        event = evd.publish.call_args[0][1]
+        assert event['routing_key'] == 'xyb'
+        assert event['exchange'] == 'xyz'
+
+    def test_send_task_message__with_delivery_mode(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, delivery_mode=33, retry=False,
+        )
+        assert prod.publish.call_args[1]['delivery_mode'] == 33
+
+    def test_routes(self):
+        r1 = self.app.amqp.routes
+        r2 = self.app.amqp.routes
+        assert r1 is r2
+
+
+class test_as_task_v2:
+
+    def test_raises_if_args_is_not_tuple(self):
+        with pytest.raises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', args='123')
+
+    def test_raises_if_kwargs_is_not_mapping(self):
+        with pytest.raises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', kwargs=(1, 2, 3))
+
+    def test_countdown_to_eta(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', countdown=10, now=now,
+        )
+        assert m.headers['eta'] == (now + timedelta(seconds=10)).isoformat()
+
+    def test_expires_to_datetime(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', expires=30, now=now,
+        )
+        assert m.headers['expires'] == (
+            now + timedelta(seconds=30)).isoformat()
+
+    def test_callbacks_errbacks_chord(self):
+
+        @self.app.task
+        def t(i):
+            pass
+
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo',
+            callbacks=[t.s(1), t.s(2)],
+            errbacks=[t.s(3), t.s(4)],
+            chord=t.s(5),
+        )
+        _, _, embed = m.body
+        assert embed['callbacks'] == [utf8dict(t.s(1)), utf8dict(t.s(2))]
+        assert embed['errbacks'] == [utf8dict(t.s(3)), utf8dict(t.s(4))]
+        assert embed['chord'] == utf8dict(t.s(5))

+ 12 - 14
celery/tests/app/test_annotations.py → t/unit/app/test_annotations.py

@@ -3,14 +3,12 @@ from __future__ import absolute_import, unicode_literals
 from celery.app.annotations import MapAnnotation, prepare
 from celery.utils.imports import qualname
 
-from celery.tests.case import AppCase
-
 
 class MyAnnotation(object):
     foo = 65
 
 
-class AnnotationCase(AppCase):
+class AnnotationCase:
 
     def setup(self):
         @self.app.task(shared=False)
@@ -28,29 +26,29 @@ class test_MapAnnotation(AnnotationCase):
 
     def test_annotate(self):
         x = MapAnnotation({self.add.name: {'foo': 1}})
-        self.assertDictEqual(x.annotate(self.add), {'foo': 1})
-        self.assertIsNone(x.annotate(self.mul))
+        assert x.annotate(self.add) == {'foo': 1}
+        assert x.annotate(self.mul) is None
 
     def test_annotate_any(self):
         x = MapAnnotation({'*': {'foo': 2}})
-        self.assertDictEqual(x.annotate_any(), {'foo': 2})
+        assert x.annotate_any() == {'foo': 2}
 
         x = MapAnnotation()
-        self.assertIsNone(x.annotate_any())
+        assert x.annotate_any() is None
 
 
 class test_prepare(AnnotationCase):
 
     def test_dict_to_MapAnnotation(self):
         x = prepare({self.add.name: {'foo': 3}})
-        self.assertIsInstance(x[0], MapAnnotation)
+        assert isinstance(x[0], MapAnnotation)
 
     def test_returns_list(self):
-        self.assertListEqual(prepare(1), [1])
-        self.assertListEqual(prepare([1]), [1])
-        self.assertListEqual(prepare((1,)), [1])
-        self.assertEqual(prepare(None), ())
+        assert prepare(1) == [1]
+        assert prepare([1]) == [1]
+        assert prepare((1,)) == [1]
+        assert prepare(None) == ()
 
     def test_evalutes_qualnames(self):
-        self.assertEqual(prepare(qualname(MyAnnotation))[0]().foo, 65)
-        self.assertEqual(prepare([qualname(MyAnnotation)])[0]().foo, 65)
+        assert prepare(qualname(MyAnnotation))[0]().foo == 65
+        assert prepare([qualname(MyAnnotation)])[0]().foo == 65

+ 222 - 246
celery/tests/app/test_app.py → t/unit/app/test_app.py

@@ -1,12 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 
 import gc
-import os
 import itertools
+import os
+import pytest
 
 from copy import deepcopy
 from pickle import loads, dumps
 
+from case import ContextMock, Mock, mock, patch
 from vine import promise
 
 from celery import Celery
@@ -16,7 +18,7 @@ from celery import _state
 from celery.app import base as _appbase
 from celery.app import defaults
 from celery.exceptions import ImproperlyConfigured
-from celery.five import keys
+from celery.five import items, keys
 from celery.loaders.base import unconfigured
 from celery.platforms import pyimplementation
 from celery.utils.collections import DictAttribute
@@ -24,17 +26,6 @@ from celery.utils.serialization import pickle
 from celery.utils.time import timezone
 from celery.utils.objects import Bunch
 
-from celery.tests.case import (
-    CELERY_TEST_CONFIG,
-    AppCase,
-    Mock,
-    Case,
-    ContextMock,
-    depends_on_current_app,
-    mock,
-    patch,
-)
-
 THIS_IS_A_KEY = 'this is a value'
 
 
@@ -54,37 +45,32 @@ class ObjectConfig2(object):
     UNDERSTAND_ME = True
 
 
-def _get_test_config():
-    return deepcopy(CELERY_TEST_CONFIG)
-test_config = _get_test_config()
-
-
-class test_module(AppCase):
+class test_module:
 
     def test_default_app(self):
-        self.assertEqual(_app.default_app, _state.default_app)
+        assert _app.default_app == _state.default_app
 
-    def test_bugreport(self):
-        self.assertTrue(_app.bugreport(app=self.app))
+    def test_bugreport(self, app):
+        assert _app.bugreport(app=app)
 
 
-class test_task_join_will_block(Case):
+class test_task_join_will_block:
 
-    def test_task_join_will_block(self):
-        prev, _state._task_join_will_block = _state._task_join_will_block, 0
-        try:
-            self.assertEqual(_state._task_join_will_block, 0)
-            _state._set_task_join_will_block(True)
-            print(_state.task_join_will_block)
-            self.assertTrue(_state.task_join_will_block())
-        finally:
-            _state._task_join_will_block = prev
+    def test_task_join_will_block(self, patching):
+        patching('celery._state._task_join_will_block', 0)
+        assert _state._task_join_will_block == 0
+        _state._set_task_join_will_block(True)
+        assert _state._task_join_will_block is True
+        # fixture 'app' sets this, so need to use orig_ function
+        # set there by that fixture.
+        res = _state.orig_task_join_will_block()
+        assert res is True
 
 
-class test_App(AppCase):
+class test_App:
 
     def setup(self):
-        self.app.add_defaults(test_config)
+        self.app.add_defaults(deepcopy(self.CELERY_TEST_CONFIG))
 
     def test_task_autofinalize_disabled(self):
         with self.Celery('xyzibari', autofinalize=False) as app:
@@ -92,7 +78,7 @@ class test_App(AppCase):
             def ttafd():
                 return 42
 
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 ttafd()
 
         with self.Celery('xyzibari', autofinalize=False) as app:
@@ -101,14 +87,14 @@ class test_App(AppCase):
                 return 42
 
             app.finalize()
-            self.assertEqual(ttafd2(), 42)
+            assert ttafd2() == 42
 
     def test_registry_autofinalize_disabled(self):
         with self.Celery('xyzibari', autofinalize=False) as app:
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 app.tasks['celery.chain']
             app.finalize()
-            self.assertTrue(app.tasks['celery.chain'])
+            assert app.tasks['celery.chain']
 
     def test_task(self):
         with self.Celery('foozibari') as app:
@@ -118,20 +104,20 @@ class test_App(AppCase):
 
             fun.__module__ = '__main__'
             task = app.task(fun)
-            self.assertEqual(task.name, app.main + '.fun')
+            assert task.name == app.main + '.fun'
 
     def test_task_too_many_args(self):
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             self.app.task(Mock(name='fun'), True)
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             self.app.task(Mock(name='fun'), True, 1, 2)
 
     def test_with_config_source(self):
         with self.Celery(config_source=ObjectConfig) as app:
-            self.assertEqual(app.conf.FOO, 1)
-            self.assertEqual(app.conf.BAR, 2)
+            assert app.conf.FOO == 1
+            assert app.conf.BAR == 2
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_task_windows_execv(self):
         prev, _appbase.USING_EXECV = _appbase.USING_EXECV, True
         try:
@@ -139,55 +125,55 @@ class test_App(AppCase):
             def foo():
                 pass
 
-            self.assertTrue(foo._get_current_object())  # is proxy
+            assert foo._get_current_object()  # is proxy
 
         finally:
             _appbase.USING_EXECV = prev
         assert not _appbase.USING_EXECV
 
     def test_task_takes_no_args(self):
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             @self.app.task(1)
             def foo():
                 pass
 
     def test_add_defaults(self):
-        self.assertFalse(self.app.configured)
+        assert not self.app.configured
         _conf = {'FOO': 300}
 
         def conf():
             return _conf
 
         self.app.add_defaults(conf)
-        self.assertIn(conf, self.app._pending_defaults)
-        self.assertFalse(self.app.configured)
-        self.assertEqual(self.app.conf.FOO, 300)
-        self.assertTrue(self.app.configured)
-        self.assertFalse(self.app._pending_defaults)
+        assert conf in self.app._pending_defaults
+        assert not self.app.configured
+        assert self.app.conf.FOO == 300
+        assert self.app.configured
+        assert not self.app._pending_defaults
 
         # defaults not pickled
         appr = loads(dumps(self.app))
-        with self.assertRaises(AttributeError):
+        with pytest.raises(AttributeError):
             appr.conf.FOO
 
         # add more defaults after configured
         conf2 = {'FOO': 'BAR'}
         self.app.add_defaults(conf2)
-        self.assertEqual(self.app.conf.FOO, 'BAR')
+        assert self.app.conf.FOO == 'BAR'
 
-        self.assertIn(_conf, self.app.conf.defaults)
-        self.assertIn(conf2, self.app.conf.defaults)
+        assert _conf in self.app.conf.defaults
+        assert conf2 in self.app.conf.defaults
 
     def test_connection_or_acquire(self):
         with self.app.connection_or_acquire(block=True):
-            self.assertTrue(self.app.pool._dirty)
+            assert self.app.pool._dirty
 
         with self.app.connection_or_acquire(pool=False):
-            self.assertFalse(self.app.pool._dirty)
+            assert not self.app.pool._dirty
 
     def test_using_v1_reduce(self):
         self.app._using_v1_reduce = True
-        self.assertTrue(loads(dumps(self.app)))
+        assert loads(dumps(self.app))
 
     def test_autodiscover_tasks_force(self):
         self.app.loader.autodiscover_tasks = Mock()
@@ -215,9 +201,9 @@ class test_App(AppCase):
             self.app.autodiscover_tasks(lazy_list)
             import_modules.connect.assert_called()
             prom = import_modules.connect.call_args[0][0]
-            self.assertIsInstance(prom, promise)
-            self.assertEqual(prom.fun, self.app._autodiscover_tasks)
-            self.assertEqual(prom.args[0](), [1, 2, 3])
+            assert isinstance(prom, promise)
+            assert prom.fun == self.app._autodiscover_tasks
+            assert prom.args[0](), [1, 2 == 3]
 
     def test_autodiscover_tasks__no_packages(self):
         fixup1 = Mock(name='fixup')
@@ -231,28 +217,28 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
 
-    @mock.environ('CELERY_BROKER_URL', '')
-    def test_with_broker(self):
+    def test_with_broker(self, patching):
+        patching.setenv('CELERY_BROKER_URL', '')
         with self.Celery(broker='foo://baribaz') as app:
-            self.assertEqual(app.conf.broker_url, 'foo://baribaz')
+            assert app.conf.broker_url == 'foo://baribaz'
 
     def test_pending_configuration__setattr(self):
         with self.Celery(broker='foo://bar') as app:
             app.conf.task_default_delivery_mode = 44
             app.conf.worker_agent = 'foo:Bar'
-            self.assertFalse(app.configured)
-            self.assertEqual(app.conf.worker_agent, 'foo:Bar')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app._preconf['worker_agent'], 'foo:Bar')
+            assert not app.configured
+            assert app.conf.worker_agent == 'foo:Bar'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app._preconf['worker_agent'] == 'foo:Bar'
 
-            self.assertTrue(app.configured)
+            assert app.configured
             reapp = pickle.loads(pickle.dumps(app))
-            self.assertEqual(reapp._preconf['worker_agent'], 'foo:Bar')
-            self.assertFalse(reapp.configured)
-            self.assertEqual(reapp.conf.worker_agent, 'foo:Bar')
-            self.assertTrue(reapp.configured)
-            self.assertEqual(reapp.conf.broker_url, 'foo://bar')
-            self.assertEqual(reapp._preconf['worker_agent'], 'foo:Bar')
+            assert reapp._preconf['worker_agent'] == 'foo:Bar'
+            assert not reapp.configured
+            assert reapp.conf.worker_agent == 'foo:Bar'
+            assert reapp.configured
+            assert reapp.conf.broker_url == 'foo://bar'
+            assert reapp._preconf['worker_agent'] == 'foo:Bar'
 
     def test_pending_configuration__update(self):
         with self.Celery(broker='foo://bar') as app:
@@ -260,10 +246,10 @@ class test_App(AppCase):
                 task_default_delivery_mode=44,
                 worker_agent='foo:Bar',
             )
-            self.assertFalse(app.configured)
-            self.assertEqual(app.conf.worker_agent, 'foo:Bar')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app._preconf['worker_agent'], 'foo:Bar')
+            assert not app.configured
+            assert app.conf.worker_agent == 'foo:Bar'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app._preconf['worker_agent'] == 'foo:Bar'
 
     def test_pending_configuration__compat_settings(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -272,11 +258,11 @@ class test_App(AppCase):
                 CELERY_DEFAULT_DELIVERY_MODE=63,
                 CELERYD_AGENT='foo:Barz',
             )
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.task_default_delivery_mode, 63)
-            self.assertEqual(app.conf.worker_agent, 'foo:Barz')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app.conf.result_backend, 'foo')
+            assert app.conf.task_always_eager == 4
+            assert app.conf.task_default_delivery_mode == 63
+            assert app.conf.worker_agent == 'foo:Barz'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app.conf.result_backend == 'foo'
 
     def test_pending_configuration__compat_settings_mixing(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -286,8 +272,8 @@ class test_App(AppCase):
                 CELERYD_AGENT='foo:Barz',
                 worker_consumer='foo:Fooz',
             )
-            with self.assertRaises(ImproperlyConfigured):
-                self.assertEqual(app.conf.task_always_eager, 4)
+            with pytest.raises(ImproperlyConfigured):
+                assert app.conf.task_always_eager == 4
 
     def test_pending_configuration__django_settings(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -297,13 +283,13 @@ class test_App(AppCase):
                 CELERY_WORKER_AGENT='foo:Barz',
                 CELERY_RESULT_SERIALIZER='pickle',
             )), namespace='CELERY')
-            self.assertEqual(app.conf.result_serializer, 'pickle')
-            self.assertEqual(app.conf.CELERY_RESULT_SERIALIZER, 'pickle')
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.task_default_delivery_mode, 63)
-            self.assertEqual(app.conf.worker_agent, 'foo:Barz')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app.conf.result_backend, 'foo')
+            assert app.conf.result_serializer == 'pickle'
+            assert app.conf.CELERY_RESULT_SERIALIZER == 'pickle'
+            assert app.conf.task_always_eager == 4
+            assert app.conf.task_default_delivery_mode == 63
+            assert app.conf.worker_agent == 'foo:Barz'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app.conf.result_backend == 'foo'
 
     def test_pending_configuration__compat_settings_mixing_new(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -314,8 +300,8 @@ class test_App(AppCase):
                 CELERYD_CONSUMER='foo:Fooz',
                 CELERYD_POOL='foo:Xuzzy',
             )
-            with self.assertRaises(ImproperlyConfigured):
-                self.assertEqual(app.conf.worker_consumer, 'foo:Fooz')
+            with pytest.raises(ImproperlyConfigured):
+                assert app.conf.worker_consumer == 'foo:Fooz'
 
     def test_pending_configuration__compat_settings_mixing_alt(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -328,52 +314,52 @@ class test_App(AppCase):
                 CELERYD_POOL='foo:Xuzzy',
                 worker_pool='foo:Xuzzy'
             )
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.worker_pool, 'foo:Xuzzy')
+            assert app.conf.task_always_eager == 4
+            assert app.conf.worker_pool == 'foo:Xuzzy'
 
     def test_pending_configuration__setdefault(self):
         with self.Celery(broker='foo://bar') as app:
             app.conf.setdefault('worker_agent', 'foo:Bar')
-            self.assertFalse(app.configured)
+            assert not app.configured
 
     def test_pending_configuration__iter(self):
         with self.Celery(broker='foo://bar') as app:
             app.conf.worker_agent = 'foo:Bar'
-            self.assertFalse(app.configured)
-            self.assertTrue(list(keys(app.conf)))
-            self.assertFalse(app.configured)
-            self.assertIn('worker_agent', app.conf)
-            self.assertFalse(app.configured)
-            self.assertTrue(dict(app.conf))
-            self.assertTrue(app.configured)
+            assert not app.configured
+            assert list(keys(app.conf))
+            assert not app.configured
+            assert 'worker_agent' in app.conf
+            assert not app.configured
+            assert dict(app.conf)
+            assert app.configured
 
     def test_pending_configuration__raises_ImproperlyConfigured(self):
         with self.Celery(set_as_current=False) as app:
             app.conf.worker_agent = 'foo://bar'
             app.conf.task_default_delivery_mode = 44
             app.conf.CELERY_ALWAYS_EAGER = 5
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 app.finalize()
 
         with self.Celery() as app:
-            self.assertFalse(self.app.conf.task_always_eager)
+            assert not self.app.conf.task_always_eager
 
     def test_repr(self):
-        self.assertTrue(repr(self.app))
+        assert repr(self.app)
 
     def test_custom_task_registry(self):
         with self.Celery(tasks=self.app.tasks) as app2:
-            self.assertIs(app2.tasks, self.app.tasks)
+            assert app2.tasks is self.app.tasks
 
     def test_include_argument(self):
         with self.Celery(include=('foo', 'bar.foo')) as app:
-            self.assertEqual(app.conf.include, ('foo', 'bar.foo'))
+            assert app.conf.include, ('foo' == 'bar.foo')
 
     def test_set_as_current(self):
         current = _state._tls.current_app
         try:
             app = self.Celery(set_as_current=True)
-            self.assertIs(_state._tls.current_app, app)
+            assert _state._tls.current_app is app
         finally:
             _state._tls.current_app = current
 
@@ -384,7 +370,7 @@ class test_App(AppCase):
 
         _state._task_stack.push(foo)
         try:
-            self.assertEqual(self.app.current_task.name, foo.name)
+            assert self.app.current_task.name == foo.name
         finally:
             _state._task_stack.pop()
 
@@ -433,7 +419,7 @@ class test_App(AppCase):
                 def foo():
                     pass
 
-                self.assertEqual(foo.name, 'xuzzy.foo')
+                assert foo.name == 'xuzzy.foo'
         finally:
             _imports.MP_MAIN_FILE = None
 
@@ -458,7 +444,7 @@ class test_App(AppCase):
             adX.name: {'@__call__': deco}
         }
         adX.bind(self.app)
-        self.assertIs(adX.app, self.app)
+        assert adX.app is self.app
 
         i = adX()
         i(2, 4, x=3)
@@ -472,9 +458,9 @@ class test_App(AppCase):
         def aawsX(x, y):
             pass
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             aawsX.apply_async(())
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             aawsX.apply_async((2,))
 
         with patch('celery.app.amqp.AMQP.create_task_message') as create:
@@ -482,7 +468,7 @@ class test_App(AppCase):
                 create.return_value = Mock(), Mock(), Mock(), Mock()
                 aawsX.apply_async((4, 5))
                 args = create.call_args[0][2]
-                self.assertEqual(args, ('hello', 4, 5))
+                assert args, ('hello', 4 == 5)
                 send.assert_called()
 
     def test_apply_async_adds_children(self):
@@ -501,7 +487,7 @@ class test_App(AppCase):
             a3cX1.push_request(called_directly=False)
             try:
                 res = a3cX2.apply_async(add_to_parent=True)
-                self.assertIn(res, a3cX1.request.children)
+                assert res in a3cX1.request.children
             finally:
                 a3cX1.pop_request()
         finally:
@@ -512,9 +498,10 @@ class test_App(AppCase):
                        THE_MII_MAR='jars')
         self.app.conf.update(changes)
         saved = pickle.dumps(self.app)
-        self.assertLess(len(saved), 2048)
+        assert len(saved) < 2048
         restored = pickle.loads(saved)
-        self.assertDictContainsSubset(changes, restored.conf)
+        for key, value in items(changes):
+            assert restored.conf[key] == value
 
     def test_worker_main(self):
         from celery.bin import worker as worker_bin
@@ -527,33 +514,33 @@ class test_App(AppCase):
         prev, worker_bin.worker = worker_bin.worker, worker
         try:
             ret = self.app.worker_main(argv=['--version'])
-            self.assertListEqual(ret, ['--version'])
+            assert ret == ['--version']
         finally:
             worker_bin.worker = prev
 
     def test_config_from_envvar(self):
-        os.environ['CELERYTEST_CONFIG_OBJECT'] = 'celery.tests.app.test_app'
+        os.environ['CELERYTEST_CONFIG_OBJECT'] = 't.unit.app.test_app'
         self.app.config_from_envvar('CELERYTEST_CONFIG_OBJECT')
-        self.assertEqual(self.app.conf.THIS_IS_A_KEY, 'this is a value')
+        assert self.app.conf.THIS_IS_A_KEY == 'this is a value'
 
     def assert_config2(self):
-        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
-        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
-        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
-        self.assertFalse(self.app.conf.WANT_ME_TO)
-        self.assertTrue(self.app.conf.UNDERSTAND_ME)
+        assert self.app.conf.LEAVE_FOR_WORK
+        assert self.app.conf.MOMENT_TO_STOP
+        assert self.app.conf.CALL_ME_BACK == 123456789
+        assert not self.app.conf.WANT_ME_TO
+        assert self.app.conf.UNDERSTAND_ME
 
     def test_config_from_object__lazy(self):
         conf = ObjectConfig2()
         self.app.config_from_object(conf)
-        self.assertIs(self.app.loader._conf, unconfigured)
-        self.assertIs(self.app._config_source, conf)
+        assert self.app.loader._conf is unconfigured
+        assert self.app._config_source is conf
 
         self.assert_config2()
 
     def test_config_from_object__force(self):
         self.app.config_from_object(ObjectConfig2(), force=True)
-        self.assertTrue(self.app.loader._conf)
+        assert self.app.loader._conf
 
         self.assert_config2()
 
@@ -565,10 +552,10 @@ class test_App(AppCase):
             CELERY_TASK_PUBLISH_RETRY = False
 
         self.app.config_from_object(Config)
-        self.assertEqual(self.app.conf.task_always_eager, 44)
-        self.assertEqual(self.app.conf.CELERY_ALWAYS_EAGER, 44)
-        self.assertFalse(self.app.conf.task_publish_retry)
-        self.assertEqual(self.app.conf.task_default_routing_key, 'testcelery')
+        assert self.app.conf.task_always_eager == 44
+        assert self.app.conf.CELERY_ALWAYS_EAGER == 44
+        assert not self.app.conf.task_publish_retry
+        assert self.app.conf.task_default_routing_key == 'testcelery'
 
     def test_config_from_object__supports_old_names(self):
 
@@ -577,11 +564,11 @@ class test_App(AppCase):
             task_default_delivery_mode = 301
 
         self.app.config_from_object(Config())
-        self.assertEqual(self.app.conf.CELERY_ALWAYS_EAGER, 45)
-        self.assertEqual(self.app.conf.task_always_eager, 45)
-        self.assertEqual(self.app.conf.CELERY_DEFAULT_DELIVERY_MODE, 301)
-        self.assertEqual(self.app.conf.task_default_delivery_mode, 301)
-        self.assertEqual(self.app.conf.task_default_routing_key, 'testcelery')
+        assert self.app.conf.CELERY_ALWAYS_EAGER == 45
+        assert self.app.conf.task_always_eager == 45
+        assert self.app.conf.CELERY_DEFAULT_DELIVERY_MODE == 301
+        assert self.app.conf.task_default_delivery_mode == 301
+        assert self.app.conf.task_default_routing_key == 'testcelery'
 
     def test_config_from_object__namespace_uppercase(self):
 
@@ -590,7 +577,7 @@ class test_App(AppCase):
             CELERY_TASK_DEFAULT_DELIVERY_MODE = 301
 
         self.app.config_from_object(Config(), namespace='CELERY')
-        self.assertEqual(self.app.conf.task_always_eager, 44)
+        assert self.app.conf.task_always_eager == 44
 
     def test_config_from_object__namespace_lowercase(self):
 
@@ -599,7 +586,7 @@ class test_App(AppCase):
             celery_task_default_delivery_mode = 301
 
         self.app.config_from_object(Config(), namespace='celery')
-        self.assertEqual(self.app.conf.task_always_eager, 44)
+        assert self.app.conf.task_always_eager == 44
 
     def test_config_from_object__mixing_new_and_old(self):
 
@@ -610,11 +597,10 @@ class test_App(AppCase):
             beat_schedule = '/foo/schedule'
             CELERY_DEFAULT_DELIVERY_MODE = 301
 
-        with self.assertRaises(ImproperlyConfigured) as exc:
+        with pytest.raises(ImproperlyConfigured) as exc:
             self.app.config_from_object(Config(), force=True)
-            self.assertTrue(
-                exc.args[0].startswith('CELERY_DEFAULT_DELIVERY_MODE'))
-            self.assertIn('task_default_delivery_mode', exc.args[0])
+            assert exc.args[0].startswith('CELERY_DEFAULT_DELIVERY_MODE')
+            assert 'task_default_delivery_mode' in exc.args[0]
 
     def test_config_from_object__mixing_old_and_new(self):
 
@@ -625,11 +611,10 @@ class test_App(AppCase):
             CELERYBEAT_SCHEDULE = '/foo/schedule'
             task_default_delivery_mode = 301
 
-        with self.assertRaises(ImproperlyConfigured) as exc:
+        with pytest.raises(ImproperlyConfigured) as exc:
             self.app.config_from_object(Config(), force=True)
-            self.assertTrue(
-                exc.args[0].startswith('task_default_delivery_mode'))
-            self.assertIn('CELERY_DEFAULT_DELIVERY_MODE', exc.args[0])
+            assert exc.args[0].startswith('task_default_delivery_mode')
+            assert 'CELERY_DEFAULT_DELIVERY_MODE' in exc.args[0]
 
     def test_config_from_cmdline(self):
         cmdline = ['task_always_eager=no',
@@ -639,139 +624,130 @@ class test_App(AppCase):
                    '.foobarint=(int)300',
                    'sqlalchemy_engine_options=(dict){"foo": "bar"}']
         self.app.config_from_cmdline(cmdline, namespace='worker')
-        self.assertFalse(self.app.conf.task_always_eager)
-        self.assertEqual(self.app.conf.result_backend, '/dev/null')
-        self.assertEqual(self.app.conf.worker_prefetch_multiplier, 368)
-        self.assertEqual(self.app.conf.worker_foobarstring, '300')
-        self.assertEqual(self.app.conf.worker_foobarint, 300)
-        self.assertDictEqual(self.app.conf.sqlalchemy_engine_options,
-                             {'foo': 'bar'})
+        assert not self.app.conf.task_always_eager
+        assert self.app.conf.result_backend == '/dev/null'
+        assert self.app.conf.worker_prefetch_multiplier == 368
+        assert self.app.conf.worker_foobarstring == '300'
+        assert self.app.conf.worker_foobarint == 300
+        assert self.app.conf.sqlalchemy_engine_options == {'foo': 'bar'}
 
     def test_setting__broker_transport_options(self):
 
         _args = {'foo': 'bar', 'spam': 'baz'}
 
         self.app.config_from_object(Bunch())
-        self.assertEqual(self.app.conf.broker_transport_options, {})
+        assert self.app.conf.broker_transport_options == {}
 
         self.app.config_from_object(Bunch(broker_transport_options=_args))
-        self.assertEqual(self.app.conf.broker_transport_options, _args)
+        assert self.app.conf.broker_transport_options == _args
 
     def test_Windows_log_color_disabled(self):
         self.app.IS_WINDOWS = True
-        self.assertFalse(self.app.log.supports_color(True))
+        assert not self.app.log.supports_color(True)
 
     def test_WorkController(self):
         x = self.app.WorkController
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
 
     def test_Worker(self):
         x = self.app.Worker
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_AsyncResult(self):
         x = self.app.AsyncResult('1')
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
         r = loads(dumps(x))
         # not set as current, so ends up as default app after reduce
-        self.assertIs(r.app, current_app._get_current_object())
+        assert r.app is current_app._get_current_object()
 
     def test_get_active_apps(self):
-        self.assertTrue(list(_state._get_active_apps()))
+        assert list(_state._get_active_apps())
 
         app1 = self.Celery()
         appid = id(app1)
-        self.assertIn(app1, _state._get_active_apps())
+        assert app1 in _state._get_active_apps()
         app1.close()
         del(app1)
 
         gc.collect()
 
         # weakref removed from list when app goes out of scope.
-        with self.assertRaises(StopIteration):
+        with pytest.raises(StopIteration):
             next(app for app in _state._get_active_apps() if id(app) == appid)
 
     def test_config_from_envvar_more(self, key='CELERY_HARNESS_CFG1'):
-        self.assertFalse(
-            self.app.config_from_envvar(
-                'HDSAJIHWIQHEWQU', force=True, silent=True),
-        )
-        with self.assertRaises(ImproperlyConfigured):
+        assert not self.app.config_from_envvar(
+            'HDSAJIHWIQHEWQU', force=True, silent=True)
+        with pytest.raises(ImproperlyConfigured):
             self.app.config_from_envvar(
                 'HDSAJIHWIQHEWQU', force=True, silent=False,
             )
         os.environ[key] = __name__ + '.object_config'
-        self.assertTrue(self.app.config_from_envvar(key, force=True))
-        self.assertEqual(self.app.conf['FOO'], 1)
-        self.assertEqual(self.app.conf['BAR'], 2)
+        assert self.app.config_from_envvar(key, force=True)
+        assert self.app.conf['FOO'] == 1
+        assert self.app.conf['BAR'] == 2
 
         os.environ[key] = 'unknown_asdwqe.asdwqewqe'
-        with self.assertRaises(ImportError):
+        with pytest.raises(ImportError):
             self.app.config_from_envvar(key, silent=False)
-        self.assertFalse(
-            self.app.config_from_envvar(key, force=True, silent=True),
-        )
+        assert not self.app.config_from_envvar(key, force=True, silent=True)
 
         os.environ[key] = __name__ + '.dict_config'
-        self.assertTrue(self.app.config_from_envvar(key, force=True))
-        self.assertEqual(self.app.conf['FOO'], 10)
-        self.assertEqual(self.app.conf['BAR'], 20)
+        assert self.app.config_from_envvar(key, force=True)
+        assert self.app.conf['FOO'] == 10
+        assert self.app.conf['BAR'] == 20
 
     @patch('celery.bin.celery.CeleryCommand.execute_from_commandline')
     def test_start(self, execute):
         self.app.start()
         execute.assert_called()
 
-    def test_amqp_get_broker_info(self):
-        self.assertDictContainsSubset(
-            {'hostname': 'localhost',
-             'userid': 'guest',
-             'password': 'guest',
-             'virtual_host': '/'},
-            self.app.connection('pyamqp://').info(),
-        )
-        self.app.conf.broker_port = 1978
-        self.app.conf.broker_vhost = 'foo'
-        self.assertDictContainsSubset(
-            {'port': 1978, 'virtual_host': 'foo'},
-            self.app.connection('pyamqp://:1978/foo').info(),
-        )
-        conn = self.app.connection('pyamqp:////value')
-        self.assertDictContainsSubset({'virtual_host': '/value'},
-                                      conn.info())
+    @pytest.mark.parametrize('url,expected_fields', [
+        ('pyamqp://', {
+            'hostname': 'localhost',
+            'userid': 'guest',
+            'password': 'guest',
+            'virtual_host': '/',
+        }),
+        ('pyamqp://:1978/foo', {
+            'port': 1978,
+            'virtual_host': 'foo',
+        }),
+        ('pyamqp:////value', {
+            'virtual_host': '/value',
+        })
+    ])
+    def test_amqp_get_broker_info(self, url, expected_fields):
+        info = self.app.connection(url).info()
+        for key, expected_value in items(expected_fields):
+            assert info[key] == expected_value
 
     def test_amqp_failover_strategy_selection(self):
         # Test passing in a string and make sure the string
         # gets there untouched
         self.app.conf.broker_failover_strategy = 'foo-bar'
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            'foo-bar',
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == 'foo-bar'
 
         # Try passing in None
         self.app.conf.broker_failover_strategy = None
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            itertools.cycle,
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == itertools.cycle
 
         # Test passing in a method
         def my_failover_strategy(it):
             yield True
 
         self.app.conf.broker_failover_strategy = my_failover_strategy
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            my_failover_strategy,
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == my_failover_strategy
 
     def test_after_fork(self):
         self.app._pool = Mock()
         self.app.on_after_fork = Mock(name='on_after_fork')
         self.app._after_fork()
-        self.assertIsNone(self.app._pool)
+        assert self.app._pool is None
         self.app.on_after_fork.send.assert_called_with(sender=self.app)
         self.app._after_fork()
 
@@ -794,21 +770,21 @@ class test_App(AppCase):
         try:
             self.app._after_fork_registered = False
             self.app._ensure_after_fork()
-            self.assertTrue(self.app._after_fork_registered)
+            assert self.app._after_fork_registered
         finally:
             _appbase.register_after_fork = prev
 
     def test_canvas(self):
-        self.assertTrue(self.app.canvas.Signature)
+        assert self.app.canvas.Signature
 
     def test_signature(self):
         sig = self.app.signature('foo', (1, 2))
-        self.assertIs(sig.app, self.app)
+        assert sig.app is self.app
 
     def test_timezone__none_set(self):
         self.app.conf.timezone = None
         tz = self.app.timezone
-        self.assertEqual(tz, timezone.get_timezone('UTC'))
+        assert tz == timezone.get_timezone('UTC')
 
     def test_compat_on_configure(self):
         _on_configure = Mock(name='on_configure')
@@ -837,22 +813,22 @@ class test_App(AppCase):
             10, self.app.signature('add', (2, 2)),
             name='add1', expires=3,
         )
-        self.assertTrue(self.app._pending_periodic_tasks)
+        assert self.app._pending_periodic_tasks
         assert not self.app.configured
 
         sig2 = add.s(4, 4)
-        self.assertTrue(self.app.configured)
+        assert self.app.configured
         self.app.add_periodic_task(20, sig2, name='add2', expires=4)
-        self.assertIn('add1', self.app.conf.beat_schedule)
-        self.assertIn('add2', self.app.conf.beat_schedule)
+        assert 'add1' in self.app.conf.beat_schedule
+        assert 'add2' in self.app.conf.beat_schedule
 
     def test_pool_no_multiprocessing(self):
         with mock.mask_modules('multiprocessing.util'):
             pool = self.app.pool
-            self.assertIs(pool, self.app._pool)
+            assert pool is self.app._pool
 
     def test_bugreport(self):
-        self.assertTrue(self.app.bugreport())
+        assert self.app.bugreport()
 
     def test_send_task__connection_provided(self):
         connection = Mock(name='connection')
@@ -896,8 +872,8 @@ class test_App(AppCase):
             exchange='moo_exchange', routing_key='moo_exchange',
             event_dispatcher=dispatcher,
         )
-        self.assertTrue(dispatcher.sent)
-        self.assertEqual(dispatcher.sent[0][0], 'task-sent')
+        assert dispatcher.sent
+        assert dispatcher.sent[0][0] == 'task-sent'
         self.app.amqp.send_task_message(
             prod, 'footask', message, event_dispatcher=dispatcher,
             exchange='bar_exchange', routing_key='bar_exchange',
@@ -909,56 +885,56 @@ class test_App(AppCase):
         self.app.amqp.queues.select.assert_called_with({'foo', 'bar'})
 
 
-class test_defaults(AppCase):
+class test_defaults:
 
     def test_strtobool(self):
         for s in ('false', 'no', '0'):
-            self.assertFalse(defaults.strtobool(s))
+            assert not defaults.strtobool(s)
         for s in ('true', 'yes', '1'):
-            self.assertTrue(defaults.strtobool(s))
-        with self.assertRaises(TypeError):
+            assert defaults.strtobool(s)
+        with pytest.raises(TypeError):
             defaults.strtobool('unsure')
 
 
-class test_debugging_utils(AppCase):
+class test_debugging_utils:
 
     def test_enable_disable_trace(self):
         try:
             _app.enable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
+            assert _app.app_or_default == _app._app_or_default_trace
             _app.disable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default)
+            assert _app.app_or_default == _app._app_or_default
         finally:
             _app.disable_trace()
 
 
-class test_pyimplementation(AppCase):
+class test_pyimplementation:
 
     def test_platform_python_implementation(self):
         with mock.platform_pyimp(lambda: 'Xython'):
-            self.assertEqual(pyimplementation(), 'Xython')
+            assert pyimplementation() == 'Xython'
 
     def test_platform_jython(self):
         with mock.platform_pyimp():
             with mock.sys_platform('java 1.6.51'):
-                self.assertIn('Jython', pyimplementation())
+                assert 'Jython' in pyimplementation()
 
     def test_platform_pypy(self):
         with mock.platform_pyimp():
             with mock.sys_platform('darwin'):
                 with mock.pypy_version((1, 4, 3)):
-                    self.assertIn('PyPy', pyimplementation())
+                    assert 'PyPy' in pyimplementation()
                 with mock.pypy_version((1, 4, 3, 'a4')):
-                    self.assertIn('PyPy', pyimplementation())
+                    assert 'PyPy' in pyimplementation()
 
     def test_platform_fallback(self):
         with mock.platform_pyimp():
             with mock.sys_platform('darwin'):
                 with mock.pypy_version():
-                    self.assertEqual('CPython', pyimplementation())
+                    assert 'CPython' == pyimplementation()
 
 
-class test_shared_task(AppCase):
+class test_shared_task:
 
     def test_registers_to_all_apps(self):
         with self.Celery('xproj', set_as_current=True) as xproj:
@@ -972,16 +948,16 @@ class test_shared_task(AppCase):
             def bar():
                 return 84
 
-            self.assertIs(foo.app, xproj)
-            self.assertIs(bar.app, xproj)
-            self.assertTrue(foo._get_current_object())
+            assert foo.app is xproj
+            assert bar.app is xproj
+            assert foo._get_current_object()
 
             with self.Celery('yproj', set_as_current=True) as yproj:
-                self.assertIs(foo.app, yproj)
-                self.assertIs(bar.app, yproj)
+                assert foo.app is yproj
+                assert bar.app is yproj
 
                 @shared_task()
                 def baz():
                     return 168
 
-                self.assertIs(baz.app, yproj)
+                assert baz.app is yproj

+ 79 - 80
celery/tests/app/test_beat.py → t/unit/app/test_beat.py

@@ -1,18 +1,19 @@
 from __future__ import absolute_import, unicode_literals
 
 import errno
+import pytest
 
 from datetime import datetime, timedelta
 from pickle import dumps, loads
 
+from case import Mock, call, patch, skip
+
 from celery import beat
 from celery import uuid
 from celery.five import keys, string_t
 from celery.schedules import schedule
 from celery.utils.objects import Bunch
 
-from celery.tests.case import AppCase, Mock, call, patch, skip
-
 
 class MockShelve(dict):
     closed = False
@@ -39,7 +40,7 @@ class MockService(object):
         self.stopped = True
 
 
-class test_ScheduleEntry(AppCase):
+class test_ScheduleEntry:
     Entry = beat.ScheduleEntry
 
     def create_entry(self, **kwargs):
@@ -54,38 +55,38 @@ class test_ScheduleEntry(AppCase):
 
     def test_next(self):
         entry = self.create_entry(schedule=10)
-        self.assertTrue(entry.last_run_at)
-        self.assertIsInstance(entry.last_run_at, datetime)
-        self.assertEqual(entry.total_run_count, 0)
+        assert entry.last_run_at
+        assert isinstance(entry.last_run_at, datetime)
+        assert entry.total_run_count == 0
 
         next_run_at = entry.last_run_at + timedelta(seconds=10)
         next_entry = entry.next(next_run_at)
-        self.assertGreaterEqual(next_entry.last_run_at, next_run_at)
-        self.assertEqual(next_entry.total_run_count, 1)
+        assert next_entry.last_run_at >= next_run_at
+        assert next_entry.total_run_count == 1
 
     def test_is_due(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
-        self.assertIs(entry.app, self.app)
-        self.assertIs(entry.schedule.app, self.app)
+        assert entry.app is self.app
+        assert entry.schedule.app is self.app
         due1, next_time_to_run1 = entry.is_due()
-        self.assertFalse(due1)
-        self.assertGreater(next_time_to_run1, 9)
+        assert not due1
+        assert next_time_to_run1 > 9
 
         next_run_at = entry.last_run_at - timedelta(seconds=10)
         next_entry = entry.next(next_run_at)
         due2, next_time_to_run2 = next_entry.is_due()
-        self.assertTrue(due2)
-        self.assertGreater(next_time_to_run2, 9)
+        assert due2
+        assert next_time_to_run2 > 9
 
     def test_repr(self):
         entry = self.create_entry()
-        self.assertIn('<ScheduleEntry:', repr(entry))
+        assert '<ScheduleEntry:' in repr(entry)
 
     def test_reduce(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
         fun, args = entry.__reduce__()
         res = fun(*args)
-        self.assertEqual(res.schedule, entry.schedule)
+        assert res.schedule == entry.schedule
 
     def test_lt(self):
         e1 = self.create_entry(schedule=timedelta(seconds=10))
@@ -99,20 +100,20 @@ class test_ScheduleEntry(AppCase):
 
     def test_update(self):
         entry = self.create_entry()
-        self.assertEqual(entry.schedule, timedelta(seconds=10))
-        self.assertTupleEqual(entry.args, (2, 2))
-        self.assertDictEqual(entry.kwargs, {})
-        self.assertDictEqual(entry.options, {'routing_key': 'cpu'})
+        assert entry.schedule == timedelta(seconds=10)
+        assert entry.args == (2, 2)
+        assert entry.kwargs == {}
+        assert entry.options == {'routing_key': 'cpu'}
 
         entry2 = self.create_entry(schedule=timedelta(minutes=20),
                                    args=(16, 16),
                                    kwargs={'callback': 'foo.bar.baz'},
                                    options={'routing_key': 'urgent'})
         entry.update(entry2)
-        self.assertEqual(entry.schedule, schedule(timedelta(minutes=20)))
-        self.assertTupleEqual(entry.args, (16, 16))
-        self.assertDictEqual(entry.kwargs, {'callback': 'foo.bar.baz'})
-        self.assertDictEqual(entry.options, {'routing_key': 'urgent'})
+        assert entry.schedule == schedule(timedelta(minutes=20))
+        assert entry.args == (16, 16)
+        assert entry.kwargs == {'callback': 'foo.bar.baz'}
+        assert entry.options == {'routing_key': 'urgent'}
 
 
 class mScheduler(beat.Scheduler):
@@ -157,12 +158,12 @@ always_due = mocked_schedule(True, 1)
 always_pending = mocked_schedule(False, 1)
 
 
-class test_Scheduler(AppCase):
+class test_Scheduler:
 
     def test_custom_schedule_dict(self):
         custom = {'foo': 'bar'}
         scheduler = mScheduler(app=self.app, schedule=custom, lazy=True)
-        self.assertIs(scheduler.data, custom)
+        assert scheduler.data is custom
 
     def test_apply_async_uses_registered_task_instances(self):
 
@@ -204,11 +205,11 @@ class test_Scheduler(AppCase):
         not_sync.apply_async = Mock()
 
         s = mScheduler(app=self.app)
-        self.assertEqual(s.sync_every_tasks, 2)
+        assert s.sync_every_tasks == 2
         s._do_sync = Mock()
 
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
-        self.assertEqual(s._tasks_since_sync, 1)
+        assert s._tasks_since_sync == 1
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s._do_sync.assert_called_with()
 
@@ -223,10 +224,10 @@ class test_Scheduler(AppCase):
         not_sync.apply_async = Mock()
 
         s = mScheduler(app=self.app)
-        self.assertEqual(s.sync_every_tasks, 1)
+        assert s.sync_every_tasks == 1
 
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
-        self.assertEqual(s._tasks_since_sync, 0)
+        assert s._tasks_since_sync == 0
 
         self.app.conf.beat_sync_every = 0
 
@@ -238,25 +239,23 @@ class test_Scheduler(AppCase):
 
     def test_info(self):
         scheduler = mScheduler(app=self.app)
-        self.assertIsInstance(scheduler.info, string_t)
+        assert isinstance(scheduler.info, string_t)
 
     def test_maybe_entry(self):
         s = mScheduler(app=self.app)
         entry = s.Entry(name='add every', task='tasks.add', app=self.app)
-        self.assertIs(s._maybe_entry(entry.name, entry), entry)
-        self.assertTrue(s._maybe_entry('add every', {
-            'task': 'tasks.add',
-        }))
+        assert s._maybe_entry(entry.name, entry) is entry
+        assert s._maybe_entry('add every', {'task': 'tasks.add'})
 
     def test_set_schedule(self):
         s = mScheduler(app=self.app)
         s.schedule = {'foo': 'bar'}
-        self.assertEqual(s.data, {'foo': 'bar'})
+        assert s.data == {'foo': 'bar'}
 
     @patch('kombu.connection.Connection.ensure_connection')
     def test_ensure_connection_error_handler(self, ensure):
         s = mScheduler(app=self.app)
-        self.assertTrue(s._ensure_connected())
+        assert s._ensure_connected()
         ensure.assert_called()
         callback = ensure.call_args[0][0]
 
@@ -267,19 +266,19 @@ class test_Scheduler(AppCase):
         self.app.conf.beat_schedule = {}
         s = mScheduler(app=self.app)
         s.install_default_entries({})
-        self.assertNotIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' not in s.data
         self.app.backend.supports_autoexpire = False
 
         self.app.conf.result_expires = 30
         s = mScheduler(app=self.app)
         s.install_default_entries({})
-        self.assertIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' in s.data
 
         self.app.backend.supports_autoexpire = True
         self.app.conf.result_expires = 31
         s = mScheduler(app=self.app)
         s.install_default_entries({})
-        self.assertNotIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' not in s.data
 
     def test_due_tick(self):
         scheduler = mScheduler(app=self.app)
@@ -287,28 +286,28 @@ class test_Scheduler(AppCase):
                       schedule=always_due,
                       args=(1, 2),
                       kwargs={'foo': 'bar'})
-        self.assertEqual(scheduler.tick(), 0)
+        assert scheduler.tick() == 0
 
     @patch('celery.beat.error')
     def test_due_tick_SchedulingError(self, error):
         scheduler = mSchedulerSchedulingError(app=self.app)
         scheduler.add(name='test_due_tick_SchedulingError',
                       schedule=always_due)
-        self.assertEqual(scheduler.tick(), 0)
+        assert scheduler.tick() == 0
         error.assert_called()
 
     def test_pending_tick(self):
         scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_pending_tick',
                       schedule=always_pending)
-        self.assertEqual(scheduler.tick(), 1 - 0.010)
+        assert scheduler.tick() == 1 - 0.010
 
     def test_honors_max_interval(self):
         scheduler = mScheduler(app=self.app)
         maxi = scheduler.max_interval
         scheduler.add(name='test_honors_max_interval',
                       schedule=mocked_schedule(False, maxi * 4))
-        self.assertEqual(scheduler.tick(), maxi)
+        assert scheduler.tick() == maxi
 
     def test_ticks(self):
         scheduler = mScheduler(app=self.app)
@@ -317,13 +316,13 @@ class test_Scheduler(AppCase):
                  {'schedule': mocked_schedule(False, j)})
                  for i, j in enumerate(nums))
         scheduler.update_from_dict(s)
-        self.assertEqual(scheduler.tick(), min(nums) - 0.010)
+        assert scheduler.tick() == min(nums) - 0.010
 
     def test_schedule_no_remain(self):
         scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_schedule_no_remain',
                       schedule=mocked_schedule(False, None))
-        self.assertEqual(scheduler.tick(), scheduler.max_interval)
+        assert scheduler.tick() == scheduler.max_interval
 
     def test_interface(self):
         scheduler = mScheduler(app=self.app)
@@ -340,9 +339,9 @@ class test_Scheduler(AppCase):
                             'baz': {'schedule': mocked_schedule(True, 10)}})
         a.merge_inplace(b.schedule)
 
-        self.assertNotIn('foo', a.schedule)
-        self.assertIn('baz', a.schedule)
-        self.assertEqual(a.schedule['bar'].schedule._next_run_at, 40)
+        assert 'foo' not in a.schedule
+        assert 'baz' in a.schedule
+        assert a.schedule['bar'].schedule._next_run_at == 40
 
 
 def create_persistent_scheduler(shelv=None):
@@ -367,7 +366,7 @@ def create_persistent_scheduler(shelv=None):
     return MockPersistentScheduler, shelv
 
 
-class test_PersistentScheduler(AppCase):
+class test_PersistentScheduler:
 
     @patch('os.remove')
     def test_remove_db(self, remove):
@@ -382,7 +381,7 @@ class test_PersistentScheduler(AppCase):
         remove.side_effect = err
         s._remove_db()
         err.errno = errno.EPERM
-        with self.assertRaises(OSError):
+        with pytest.raises(OSError):
             s._remove_db()
 
     def test_setup_schedule(self):
@@ -420,11 +419,11 @@ class test_PersistentScheduler(AppCase):
         )
         s._store = {'entries': {}}
         s.schedule = {'foo': 'bar'}
-        self.assertDictEqual(s.schedule, {'foo': 'bar'})
-        self.assertDictEqual(s._store['entries'], s.schedule)
+        assert s.schedule == {'foo': 'bar'}
+        assert s._store['entries'] == s.schedule
 
 
-class test_Service(AppCase):
+class test_Service:
 
     def get_service(self):
         Scheduler, mock_shelve = create_persistent_scheduler()
@@ -432,26 +431,26 @@ class test_Service(AppCase):
 
     def test_pickleable(self):
         s = beat.Service(app=self.app, scheduler_cls=Mock)
-        self.assertTrue(loads(dumps(s)))
+        assert loads(dumps(s))
 
     def test_start(self):
         s, sh = self.get_service()
         schedule = s.scheduler.schedule
-        self.assertIsInstance(schedule, dict)
-        self.assertIsInstance(s.scheduler, beat.Scheduler)
+        assert isinstance(schedule, dict)
+        assert isinstance(s.scheduler, beat.Scheduler)
         scheduled = list(schedule.keys())
         for task_name in keys(sh['entries']):
-            self.assertIn(task_name, scheduled)
+            assert task_name in scheduled
 
         s.sync()
-        self.assertTrue(sh.closed)
-        self.assertTrue(sh.synced)
-        self.assertTrue(s._is_stopped.isSet())
+        assert sh.closed
+        assert sh.synced
+        assert s._is_stopped.isSet()
         s.sync()
         s.stop(wait=False)
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
         s.stop(wait=True)
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
         p = s.scheduler._store
         s.scheduler._store = None
@@ -474,24 +473,24 @@ class test_Service(AppCase):
         s, sh = self.get_service()
         s.scheduler.tick_raises_exit = True
         s.start()
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
     def test_start_manages_one_tick_before_shutdown(self):
         s, sh = self.get_service()
         s.scheduler.shutdown_service = s
         s.start()
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
 
-class test_EmbeddedService(AppCase):
+class test_EmbeddedService:
 
     @skip.unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
         from billiard.process import Process
 
         s = beat.EmbeddedService(self.app)
-        self.assertIsInstance(s, Process)
-        self.assertIsInstance(s.service, beat.Service)
+        assert isinstance(s, Process)
+        assert isinstance(s.service, beat.Service)
         s.service = MockService()
 
         class _Popen(object):
@@ -502,43 +501,43 @@ class test_EmbeddedService(AppCase):
 
         with patch('celery.platforms.close_open_fds'):
             s.run()
-        self.assertTrue(s.service.started)
+        assert s.service.started
 
         s._popen = _Popen()
         s.stop()
-        self.assertTrue(s.service.stopped)
-        self.assertTrue(s._popen.terminated)
+        assert s.service.stopped
+        assert s._popen.terminated
 
     def test_start_stop_threaded(self):
         s = beat.EmbeddedService(self.app, thread=True)
         from threading import Thread
-        self.assertIsInstance(s, Thread)
-        self.assertIsInstance(s.service, beat.Service)
+        assert isinstance(s, Thread)
+        assert isinstance(s.service, beat.Service)
         s.service = MockService()
 
         s.run()
-        self.assertTrue(s.service.started)
+        assert s.service.started
 
         s.stop()
-        self.assertTrue(s.service.stopped)
+        assert s.service.stopped
 
 
-class test_schedule(AppCase):
+class test_schedule:
 
     def test_maybe_make_aware(self):
         x = schedule(10, app=self.app)
         x.utc_enabled = True
         d = x.maybe_make_aware(datetime.utcnow())
-        self.assertTrue(d.tzinfo)
+        assert d.tzinfo
         x.utc_enabled = False
         d2 = x.maybe_make_aware(datetime.utcnow())
-        self.assertTrue(d2.tzinfo)
+        assert d2.tzinfo
 
     def test_to_local(self):
         x = schedule(10, app=self.app)
         x.utc_enabled = True
         d = x.to_local(datetime.utcnow())
-        self.assertIsNone(d.tzinfo)
+        assert d.tzinfo is None
         x.utc_enabled = False
         d = x.to_local(datetime.utcnow())
-        self.assertTrue(d.tzinfo)
+        assert d.tzinfo

+ 19 - 17
celery/tests/app/test_builtins.py → t/unit/app/test_builtins.py

@@ -1,14 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import ContextMock, Mock, patch
+
 from celery import group, chord
 from celery.app import builtins
 from celery.five import range
 from celery.utils.functional import pass1
 
-from celery.tests.case import AppCase, ContextMock, Mock, patch
-
 
-class BuiltinsCase(AppCase):
+class BuiltinsCase:
 
     def setup(self):
         @self.app.task(shared=False)
@@ -38,10 +40,10 @@ class test_accumulate(BuiltinsCase):
         self.accumulate = self.app.tasks['celery.accumulate']
 
     def test_with_index(self):
-        self.assertEqual(self.accumulate(1, 2, 3, 4, index=0), 1)
+        assert self.accumulate(1, 2, 3, 4, index=0) == 1
 
     def test_no_index(self):
-        self.assertEqual(self.accumulate(1, 2, 3, 4), (1, 2, 3, 4))
+        assert self.accumulate(1, 2, 3, 4), (1, 2, 3 == 4)
 
 
 class test_map(BuiltinsCase):
@@ -55,7 +57,7 @@ class test_map(BuiltinsCase):
         res = self.app.tasks['celery.map'](
             map_mul, [(2, 2), (4, 4), (8, 8)],
         )
-        self.assertEqual(res, [4, 16, 64])
+        assert res, [4, 16 == 64]
 
 
 class test_starmap(BuiltinsCase):
@@ -69,7 +71,7 @@ class test_starmap(BuiltinsCase):
         res = self.app.tasks['celery.starmap'](
             smap_mul, [(2, 2), (4, 4), (8, 8)],
         )
-        self.assertEqual(res, [4, 16, 64])
+        assert res, [4, 16 == 64]
 
 
 class test_chunks(BuiltinsCase):
@@ -90,13 +92,13 @@ class test_chunks(BuiltinsCase):
 class test_group(BuiltinsCase):
 
     def setup(self):
-        self.maybe_signature = self.patch('celery.canvas.maybe_signature')
+        self.maybe_signature = self.patching('celery.canvas.maybe_signature')
         self.maybe_signature.side_effect = pass1
         self.app.producer_or_acquire = Mock()
         self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
         self.app.conf.task_always_eager = True
         self.task = builtins.add_group_task(self.app)
-        super(test_group, self).setup()
+        BuiltinsCase.setup(self)
 
     def test_apply_async_eager(self):
         self.task.apply = Mock(name='apply')
@@ -135,7 +137,7 @@ class test_chain(BuiltinsCase):
         self.task = builtins.add_chain_task(self.app)
 
     def test_not_implemented(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.task()
 
 
@@ -143,13 +145,13 @@ class test_chord(BuiltinsCase):
 
     def setup(self):
         self.task = builtins.add_chord_task(self.app)
-        super(test_chord, self).setup()
+        BuiltinsCase.setup(self)
 
     def test_apply_async(self):
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async()
-        self.assertTrue(r)
-        self.assertTrue(r.parent)
+        assert r
+        assert r.parent
 
     def test_run_header_not_group(self):
         self.task([self.add.s(i, i) for i in range(10)], self.xsum.s())
@@ -161,22 +163,22 @@ class test_chord(BuiltinsCase):
         x.apply_async(group_id='some_group_id')
         x.run.assert_called()
         resbody = x.run.call_args[0][1]
-        self.assertEqual(resbody.options['group_id'], 'some_group_id')
+        assert resbody.options['group_id'] == 'some_group_id'
         x2 = chord([self.add.s(i, i) for i in range(10)], body=body)
         x2.run = Mock(name='chord.run(x2)')
         x2.apply_async(chord='some_chord_id')
         x2.run.assert_called()
         resbody = x2.run.call_args[0][1]
-        self.assertEqual(resbody.options['chord'], 'some_chord_id')
+        assert resbody.options['chord'] == 'some_chord_id'
 
     def test_apply_eager(self):
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async()
-        self.assertEqual(r.get(), 90)
+        assert r.get() == 90
 
     def test_apply_eager_with_arguments(self):
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async([1])
-        self.assertEqual(r.get(), 55)
+        assert r.get() == 55

+ 18 - 0
t/unit/app/test_celery.py

@@ -0,0 +1,18 @@
+from __future__ import absolute_import, unicode_literals
+
+import celery
+import pytest
+
+
+def test_version():
+    assert celery.VERSION
+    assert len(celery.VERSION) >= 3
+    celery.VERSION = (0, 3, 0)
+    assert celery.__version__.count('.') >= 2
+
+
+@pytest.mark.parametrize('attr', [
+    '__author__', '__contact__', '__homepage__', '__docformat__',
+])
+def test_meta(attr):
+    assert getattr(celery, attr, None)

+ 50 - 47
celery/tests/app/test_control.py → t/unit/app/test_control.py

@@ -1,12 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from kombu.pidbox import Mailbox
 from vine.utils import wraps
 
 from celery import uuid
 from celery.app import control
 from celery.exceptions import DuplicateNodenameWarning
-from celery.tests.case import AppCase
 
 
 class MockMailbox(Mailbox):
@@ -38,7 +39,7 @@ def with_mock_broadcast(fun):
     return _resets
 
 
-class test_flatten_reply(AppCase):
+class test_flatten_reply:
 
     def test_flatten_reply(self):
         reply = [
@@ -46,18 +47,16 @@ class test_flatten_reply(AppCase):
             {'foo@example.com': {'hello': 20}},
             {'bar@example.com': {'hello': 30}}
         ]
-        with self.assertWarns(DuplicateNodenameWarning) as w:
+        with pytest.warns(DuplicateNodenameWarning) as w:
             nodes = control.flatten_reply(reply)
 
-        self.assertIn(
-            'Received multiple replies from node name: foo@example.com.',
-            str(w.warning)
-        )
-        self.assertIn('foo@example.com', nodes)
-        self.assertIn('bar@example.com', nodes)
+        assert 'Received multiple replies from node name: {0}.'.format(
+            next(iter(reply[0]))) in str(w[0].message.args[0])
+        assert 'foo@example.com' in nodes
+        assert 'bar@example.com' in nodes
 
 
-class test_inspect(AppCase):
+class test_inspect:
 
     def setup(self):
         self.c = Control(app=self.app)
@@ -65,91 +64,95 @@ class test_inspect(AppCase):
         self.i = self.c.inspect()
 
     def test_prepare_reply(self):
-        self.assertDictEqual(self.i._prepare([{'w1': {'ok': 1}},
-                                              {'w2': {'ok': 1}}]),
-                             {'w1': {'ok': 1}, 'w2': {'ok': 1}})
+        reply = self.i._prepare([
+            {'w1': {'ok': 1}},
+            {'w2': {'ok': 1}},
+        ])
+        assert reply == {
+            'w1': {'ok': 1},
+            'w2': {'ok': 1},
+        }
 
         i = self.c.inspect(destination='w1')
-        self.assertEqual(i._prepare([{'w1': {'ok': 1}}]),
-                         {'ok': 1})
+        assert i._prepare([{'w1': {'ok': 1}}]) == {'ok': 1}
 
     @with_mock_broadcast
     def test_active(self):
         self.i.active()
-        self.assertIn('active', MockMailbox.sent)
+        assert 'active' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_clock(self):
         self.i.clock()
-        self.assertIn('clock', MockMailbox.sent)
+        assert 'clock' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_conf(self):
         self.i.conf()
-        self.assertIn('conf', MockMailbox.sent)
+        assert 'conf' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_hello(self):
         self.i.hello('george@vandelay.com')
-        self.assertIn('hello', MockMailbox.sent)
+        assert 'hello' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_memsample(self):
         self.i.memsample()
-        self.assertIn('memsample', MockMailbox.sent)
+        assert 'memsample' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_memdump(self):
         self.i.memdump()
-        self.assertIn('memdump', MockMailbox.sent)
+        assert 'memdump' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_objgraph(self):
         self.i.objgraph()
-        self.assertIn('objgraph', MockMailbox.sent)
+        assert 'objgraph' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_scheduled(self):
         self.i.scheduled()
-        self.assertIn('scheduled', MockMailbox.sent)
+        assert 'scheduled' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_reserved(self):
         self.i.reserved()
-        self.assertIn('reserved', MockMailbox.sent)
+        assert 'reserved' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_stats(self):
         self.i.stats()
-        self.assertIn('stats', MockMailbox.sent)
+        assert 'stats' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_revoked(self):
         self.i.revoked()
-        self.assertIn('revoked', MockMailbox.sent)
+        assert 'revoked' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_tasks(self):
         self.i.registered()
-        self.assertIn('registered', MockMailbox.sent)
+        assert 'registered' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_ping(self):
         self.i.ping()
-        self.assertIn('ping', MockMailbox.sent)
+        assert 'ping' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_active_queues(self):
         self.i.active_queues()
-        self.assertIn('active_queues', MockMailbox.sent)
+        assert 'active_queues' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_report(self):
         self.i.report()
-        self.assertIn('report', MockMailbox.sent)
+        assert 'report' in MockMailbox.sent
 
 
-class test_Broadcast(AppCase):
+class test_Broadcast:
 
     def setup(self):
         self.control = Control(app=self.app)
@@ -166,80 +169,80 @@ class test_Broadcast(AppCase):
     @with_mock_broadcast
     def test_broadcast(self):
         self.control.broadcast('foobarbaz', arguments=[])
-        self.assertIn('foobarbaz', MockMailbox.sent)
+        assert 'foobarbaz' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_broadcast_limit(self):
         self.control.broadcast(
             'foobarbaz1', arguments=[], limit=None, destination=[1, 2, 3],
         )
-        self.assertIn('foobarbaz1', MockMailbox.sent)
+        assert 'foobarbaz1' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_broadcast_validate(self):
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             self.control.broadcast('foobarbaz2',
                                    destination='foo')
 
     @with_mock_broadcast
     def test_rate_limit(self):
         self.control.rate_limit(self.mytask.name, '100/m')
-        self.assertIn('rate_limit', MockMailbox.sent)
+        assert 'rate_limit' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_time_limit(self):
         self.control.time_limit(self.mytask.name, soft=10, hard=20)
-        self.assertIn('time_limit', MockMailbox.sent)
+        assert 'time_limit' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_add_consumer(self):
         self.control.add_consumer('foo')
-        self.assertIn('add_consumer', MockMailbox.sent)
+        assert 'add_consumer' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_cancel_consumer(self):
         self.control.cancel_consumer('foo')
-        self.assertIn('cancel_consumer', MockMailbox.sent)
+        assert 'cancel_consumer' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_enable_events(self):
         self.control.enable_events()
-        self.assertIn('enable_events', MockMailbox.sent)
+        assert 'enable_events' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_disable_events(self):
         self.control.disable_events()
-        self.assertIn('disable_events', MockMailbox.sent)
+        assert 'disable_events' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_revoke(self):
         self.control.revoke('foozbaaz')
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_ping(self):
         self.control.ping()
-        self.assertIn('ping', MockMailbox.sent)
+        assert 'ping' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_election(self):
         self.control.election('some_id', 'topic', 'action')
-        self.assertIn('election', MockMailbox.sent)
+        assert 'election' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_pool_grow(self):
         self.control.pool_grow(2)
-        self.assertIn('pool_grow', MockMailbox.sent)
+        assert 'pool_grow' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_pool_shrink(self):
         self.control.pool_shrink(2)
-        self.assertIn('pool_shrink', MockMailbox.sent)
+        assert 'pool_shrink' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_revoke_from_result(self):
         self.app.AsyncResult('foozbazzbar').revoke()
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent
 
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
@@ -247,4 +250,4 @@ class test_Broadcast(AppCase):
                                  [self.app.AsyncResult(x)
                                   for x in [uuid() for i in range(10)]])
         r.revoke()
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent

+ 65 - 0
t/unit/app/test_defaults.py

@@ -0,0 +1,65 @@
+from __future__ import absolute_import, unicode_literals
+
+import sys
+
+from importlib import import_module
+
+from case import mock
+
+from celery.app.defaults import (
+    _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY,
+    DEFAULTS, NAMESPACES, SETTING_KEYS
+)
+from celery.five import values
+
+
+class test_defaults:
+
+    def setup(self):
+        self._prev = sys.modules.pop('celery.app.defaults', None)
+
+    def teardown(self):
+        if self._prev:
+            sys.modules['celery.app.defaults'] = self._prev
+
+    def test_option_repr(self):
+        assert repr(NAMESPACES['broker']['url'])
+
+    def test_any(self):
+        val = object()
+        assert self.defaults.Option.typemap['any'](val) is val
+
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 4, 0))
+    def test_default_pool_pypy_14(self):
+        assert self.defaults.DEFAULT_POOL == 'solo'
+
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 5, 0))
+    def test_default_pool_pypy_15(self):
+        assert self.defaults.DEFAULT_POOL == 'prefork'
+
+    def test_compat_indices(self):
+        assert not any(key.isupper() for key in DEFAULTS)
+        assert not any(key.islower() for key in _OLD_DEFAULTS)
+        assert not any(key.isupper() for key in _TO_OLD_KEY)
+        assert not any(key.islower() for key in _TO_NEW_KEY)
+        assert not any(key.isupper() for key in SETTING_KEYS)
+        assert not any(key.islower() for key in _OLD_SETTING_KEYS)
+        assert not any(value.isupper() for value in values(_TO_NEW_KEY))
+        assert not any(value.islower() for value in values(_TO_OLD_KEY))
+
+        for key in _TO_NEW_KEY:
+            assert key in _OLD_SETTING_KEYS
+        for key in _TO_OLD_KEY:
+            assert key in SETTING_KEYS
+
+    def test_find(self):
+        find = self.defaults.find
+
+        assert find('default_queue')[2].default == 'celery'
+        assert find('task_default_exchange')[2] == 'celery'
+
+    @property
+    def defaults(self):
+        return import_module('celery.app.defaults')

+ 8 - 10
celery/tests/app/test_exceptions.py → t/unit/app/test_exceptions.py

@@ -6,30 +6,28 @@ from datetime import datetime
 
 from celery.exceptions import Reject, Retry
 
-from celery.tests.case import AppCase
 
-
-class test_Retry(AppCase):
+class test_Retry:
 
     def test_when_datetime(self):
         x = Retry('foo', KeyError(), when=datetime.utcnow())
-        self.assertTrue(x.humanize())
+        assert x.humanize()
 
     def test_pickleable(self):
         x = Retry('foo', KeyError(), when=datetime.utcnow())
-        self.assertTrue(pickle.loads(pickle.dumps(x)))
+        assert pickle.loads(pickle.dumps(x))
 
 
-class test_Reject(AppCase):
+class test_Reject:
 
     def test_attrs(self):
         x = Reject('foo', requeue=True)
-        self.assertEqual(x.reason, 'foo')
-        self.assertTrue(x.requeue)
+        assert x.reason == 'foo'
+        assert x.requeue
 
     def test_repr(self):
-        self.assertTrue(repr(Reject('foo', True)))
+        assert repr(Reject('foo', True))
 
     def test_pickleable(self):
         x = Retry('foo', True)
-        self.assertTrue(pickle.loads(pickle.dumps(x)))
+        assert pickle.loads(pickle.dumps(x))

+ 32 - 36
celery/tests/app/test_loaders.py → t/unit/app/test_loaders.py

@@ -1,9 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 
 import os
+import pytest
 import sys
 import warnings
 
+from case import Mock, mock, patch
+
 from celery import loaders
 from celery.exceptions import NotConfigured
 from celery.five import bytes_if_py2
@@ -12,8 +15,6 @@ from celery.loaders import default
 from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 
-from celery.tests.case import AppCase, Case, Mock, mock, patch
-
 
 class DummyLoader(base.BaseLoader):
 
@@ -21,14 +22,13 @@ class DummyLoader(base.BaseLoader):
         return {'foo': 'bar', 'imports': ('os', 'sys')}
 
 
-class test_loaders(AppCase):
+class test_loaders:
 
     def test_get_loader_cls(self):
-        self.assertEqual(loaders.get_loader_cls('default'),
-                         default.Loader)
+        assert loaders.get_loader_cls('default') is default.Loader
 
 
-class test_LoaderBase(AppCase):
+class test_LoaderBase:
     message_options = {'subject': 'Subject',
                        'body': 'Body',
                        'sender': 'x@x.com',
@@ -47,25 +47,23 @@ class test_LoaderBase(AppCase):
         self.loader.on_worker_init()
 
     def test_now(self):
-        self.assertTrue(self.loader.now(utc=True))
-        self.assertTrue(self.loader.now(utc=False))
+        assert self.loader.now(utc=True)
+        assert self.loader.now(utc=False)
 
     def test_read_configuration_no_env(self):
-        self.assertIsNone(
-            base.BaseLoader(app=self.app).read_configuration(
-                'FOO_X_S_WE_WQ_Q_WE'),
-        )
+        assert base.BaseLoader(app=self.app).read_configuration(
+            'FOO_X_S_WE_WQ_Q_WE') is None
 
     def test_autodiscovery(self):
         with patch('celery.loaders.base.autodiscover_tasks') as auto:
             auto.return_value = [Mock()]
             auto.return_value[0].__name__ = 'moo'
             self.loader.autodiscover_tasks(['A', 'B'])
-            self.assertIn('moo', self.loader.task_modules)
+            assert 'moo' in self.loader.task_modules
             self.loader.task_modules.discard('moo')
 
     def test_import_task_module(self):
-        self.assertEqual(sys, self.loader.import_task_module('sys'))
+        assert sys == self.loader.import_task_module('sys')
 
     def test_init_worker_process(self):
         self.loader.on_worker_process_init()
@@ -79,18 +77,16 @@ class test_LoaderBase(AppCase):
         self.loader.import_from_cwd.assert_called_with('module_name')
 
     def test_conf_property(self):
-        self.assertEqual(self.loader.conf['foo'], 'bar')
-        self.assertEqual(self.loader._conf['foo'], 'bar')
-        self.assertEqual(self.loader.conf['foo'], 'bar')
+        assert self.loader.conf['foo'] == 'bar'
+        assert self.loader._conf['foo'] == 'bar'
+        assert self.loader.conf['foo'] == 'bar'
 
     def test_import_default_modules(self):
         def modnames(l):
             return [m.__name__ for m in l]
         self.app.conf.imports = ('os', 'sys')
-        self.assertEqual(
-            sorted(modnames(self.loader.import_default_modules())),
-            sorted(modnames([os, sys])),
-        )
+        assert (sorted(modnames(self.loader.import_default_modules())) ==
+                sorted(modnames([os, sys])))
 
     def test_import_from_cwd_custom_imp(self):
         imp = Mock(name='imp')
@@ -98,17 +94,17 @@ class test_LoaderBase(AppCase):
         imp.assert_called()
 
     def test_cmdline_config_ValueError(self):
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             self.loader.cmdline_config_parser(['broker.port=foobar'])
 
 
-class test_DefaultLoader(AppCase):
+class test_DefaultLoader:
 
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_not_a_package(self, find_module):
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
-        with self.assertRaises(NotAPackage):
+        with pytest.raises(NotAPackage):
             l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
@@ -116,7 +112,7 @@ class test_DefaultLoader(AppCase):
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
-        with self.assertRaises(NotAPackage):
+        with pytest.raises(NotAPackage):
             l.read_configuration(fail_silently=False)
 
     @patch('celery.loaders.base.find_module')
@@ -124,7 +120,7 @@ class test_DefaultLoader(AppCase):
         default.C_WNOCONF = True
         find_module.side_effect = ImportError()
         l = default.Loader(app=self.app)
-        with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
+        with pytest.warns(NotConfigured):
             l.read_configuration(fail_silently=True)
         default.C_WNOCONF = False
         l.read_configuration(fail_silently=True)
@@ -145,9 +141,9 @@ class test_DefaultLoader(AppCase):
             l = default.Loader(app=self.app)
             l.find_module = Mock(name='find_module')
             settings = l.read_configuration(fail_silently=False)
-            self.assertTupleEqual(settings.imports, ('os', 'sys'))
+            assert settings.imports == ('os', 'sys')
             settings = l.read_configuration(fail_silently=False)
-            self.assertTupleEqual(settings.imports, ('os', 'sys'))
+            assert settings.imports == ('os', 'sys')
             l.on_worker_init()
         finally:
             if prevconfig:
@@ -160,7 +156,7 @@ class test_DefaultLoader(AppCase):
         )
         try:
             l = default.Loader(app=self.app)
-            with self.assertRaises(ImportError):
+            with pytest.raises(ImportError):
                 l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=True)
         finally:
@@ -179,11 +175,11 @@ class test_DefaultLoader(AppCase):
         celery = sys.modules.pop('celery', None)
         sys.modules.pop('celery.five', None)
         try:
-            self.assertTrue(l.import_from_cwd('celery'))
+            assert l.import_from_cwd('celery')
             sys.modules.pop('celery', None)
             sys.modules.pop('celery.five', None)
             sys.path.insert(0, os.getcwd())
-            self.assertTrue(l.import_from_cwd('celery'))
+            assert l.import_from_cwd('celery')
         finally:
             sys.path = old_path
             sys.modules['celery'] = celery
@@ -198,12 +194,12 @@ class test_DefaultLoader(AppCase):
 
         with warnings.catch_warnings(record=True):
             l = _Loader(app=self.app)
-            self.assertFalse(l.configured)
+            assert not l.configured
             context_executed[0] = True
-        self.assertTrue(context_executed[0])
+        assert context_executed[0]
 
 
-class test_AppLoader(AppCase):
+class test_AppLoader:
 
     def setup(self):
         self.loader = AppLoader(app=self.app)
@@ -212,10 +208,10 @@ class test_AppLoader(AppCase):
         self.app.conf.imports = ('subprocess',)
         sys.modules.pop('subprocess', None)
         self.loader.init_worker()
-        self.assertIn('subprocess', sys.modules)
+        assert 'subprocess' in sys.modules
 
 
-class test_autodiscovery(Case):
+class test_autodiscovery:
 
     def test_autodiscover_tasks(self):
         base._RACE_PROTECTION = True

+ 45 - 60
celery/tests/app/test_log.py → t/unit/app/test_log.py

@@ -1,12 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
-import sys
 import logging
+import pytest
+import sys
 
 from collections import defaultdict
 from io import StringIO
 from tempfile import mktemp
 
+from case import Mock, mock, patch, skip
+from case.utils import get_logger_handlers
+
 from celery import signals
 from celery import uuid
 from celery.app.log import TaskFormatter
@@ -22,11 +26,8 @@ from celery.utils.log import (
     logger_isa,
 )
 
-from case.utils import get_logger_handlers
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 
-class test_TaskFormatter(AppCase):
+class test_TaskFormatter:
 
     def test_no_task(self):
         class Record(object):
@@ -40,39 +41,39 @@ class test_TaskFormatter(AppCase):
         record = Record()
         x = TaskFormatter()
         x.format(record)
-        self.assertEqual(record.task_name, '???')
-        self.assertEqual(record.task_id, '???')
+        assert record.task_name == '???'
+        assert record.task_id == '???'
 
 
-class test_logger_isa(AppCase):
+class test_logger_isa:
 
     def test_isa(self):
         x = get_task_logger('Z1george')
-        self.assertTrue(logger_isa(x, task_logger))
+        assert logger_isa(x, task_logger)
         prev_x, x.parent = x.parent, None
         try:
-            self.assertFalse(logger_isa(x, task_logger))
+            assert not logger_isa(x, task_logger)
         finally:
             x.parent = prev_x
 
         y = get_task_logger('Z1elaine')
         y.parent = x
-        self.assertTrue(logger_isa(y, task_logger))
-        self.assertTrue(logger_isa(y, x))
-        self.assertTrue(logger_isa(y, y))
+        assert logger_isa(y, task_logger)
+        assert logger_isa(y, x)
+        assert logger_isa(y, y)
 
         z = get_task_logger('Z1jerry')
         z.parent = y
-        self.assertTrue(logger_isa(z, task_logger))
-        self.assertTrue(logger_isa(z, y))
-        self.assertTrue(logger_isa(z, x))
-        self.assertTrue(logger_isa(z, z))
+        assert logger_isa(z, task_logger)
+        assert logger_isa(z, y)
+        assert logger_isa(z, x)
+        assert logger_isa(z, z)
 
     def test_recursive(self):
         x = get_task_logger('X1foo')
         prev, x.parent = x.parent, x
         try:
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 logger_isa(x, task_logger)
         finally:
             x.parent = prev
@@ -83,7 +84,7 @@ class test_logger_isa(AppCase):
         try:
             prev_z, z.parent = z.parent, y
             try:
-                with self.assertRaises(RuntimeError):
+                with pytest.raises(RuntimeError):
                     logger_isa(y, task_logger)
             finally:
                 z.parent = prev_z
@@ -91,7 +92,7 @@ class test_logger_isa(AppCase):
             y.parent = prev_y
 
 
-class test_ColorFormatter(AppCase):
+class test_ColorFormatter:
 
     @patch('celery.utils.log.safe_str')
     @patch('logging.Formatter.formatException')
@@ -99,7 +100,7 @@ class test_ColorFormatter(AppCase):
         x = ColorFormatter()
         value = KeyError()
         fe.return_value = value
-        self.assertIs(x.formatException(value), value)
+        assert x.formatException(value) is value
         fe.assert_called()
         safe_str.assert_not_called()
 
@@ -111,7 +112,7 @@ class test_ColorFormatter(AppCase):
         try:
             raise Exception()
         except Exception:
-            self.assertTrue(x.formatException(sys.exc_info()))
+            assert x.formatException(sys.exc_info())
         if sys.version_info[0] == 2:
             safe_str.assert_called()
 
@@ -122,7 +123,7 @@ class test_ColorFormatter(AppCase):
         record = Mock()
         record.levelname = 'ERROR'
         record.msg = object()
-        self.assertTrue(x.format(record))
+        assert x.format(record)
 
     @patch('celery.utils.log.safe_str')
     def test_format_raises(self, safe_str):
@@ -153,8 +154,8 @@ class test_ColorFormatter(AppCase):
         safe_str.return_value = record
 
         msg = x.format(record)
-        self.assertIn('<Unrepresentable', msg)
-        self.assertEqual(safe_str.call_count, 1)
+        assert '<Unrepresentable' in msg
+        assert safe_str.call_count == 1
 
     @skip.if_python3()
     @patch('celery.utils.log.safe_str')
@@ -165,10 +166,10 @@ class test_ColorFormatter(AppCase):
         record.msg = 'HELLO'
         record.exc_text = 'error text'
         x.format(record)
-        self.assertEqual(safe_str.call_count, 1)
+        assert safe_str.call_count == 1
 
 
-class test_default_logger(AppCase):
+class test_default_logger:
 
     def setup(self):
         self.setup_logger = self.app.log.setup_logger
@@ -178,11 +179,11 @@ class test_default_logger(AppCase):
 
     def test_get_logger_sets_parent(self):
         logger = get_logger('celery.test_get_logger')
-        self.assertEqual(logger.parent.name, base_logger.name)
+        assert logger.parent.name == base_logger.name
 
     def test_get_logger_root(self):
         logger = get_logger(base_logger.name)
-        self.assertIs(logger.parent, logging.root)
+        assert logger.parent is logging.root
 
     @mock.restore_logging()
     def test_setup_logging_subsystem_misc(self):
@@ -194,7 +195,7 @@ class test_default_logger(AppCase):
         self.app.log.setup_logging_subsystem()
 
     def test_get_default_logger(self):
-        self.assertTrue(self.app.log.get_default_logger())
+        assert self.app.log.get_default_logger()
 
     def test_configure_logger(self):
         logger = self.app.log.get_default_logger()
@@ -212,20 +213,6 @@ class test_default_logger(AppCase):
         with mock.mask_modules('billiard.util'):
             self.app.log.setup_logging_subsystem()
 
-    def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
-
-        with mock.wrap_logger(logger, loglevel=loglevel) as sio:
-            logger.log(loglevel, logmsg)
-            return sio.getvalue().strip()
-
-    def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
-        val = self._assertLog(logger, logmsg, loglevel=loglevel)
-        return self.assertEqual(val, logmsg, reason)
-
-    def assertDidLogFalse(self, logger, logmsg, reason, loglevel=None):
-        val = self._assertLog(logger, logmsg, loglevel=loglevel)
-        return self.assertFalse(val, reason)
-
     @mock.restore_logging()
     def test_setup_logger(self):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
@@ -234,10 +221,9 @@ class test_default_logger(AppCase):
         self.app.log.already_setup = False
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False, colorize=None)
-        self.assertIs(
-            get_logger_handlers(logger)[0].stream, sys.__stderr__,
-            'setup_logger logs to stderr without logfile argument.',
-        )
+        # setup_logger logs to stderr without logfile argument.
+        assert (get_logger_handlers(logger)[0].stream is
+                sys.__stderr__)
 
     @mock.restore_logging()
     def test_setup_logger_no_handlers_stream(self):
@@ -249,7 +235,7 @@ class test_default_logger(AppCase):
             l = self.setup_logger(logfile=sys.stderr,
                                   loglevel=logging.INFO, root=False)
             l.info('The quick brown fox...')
-            self.assertIn('The quick brown fox...', stderr.getvalue())
+            assert 'The quick brown fox...' in stderr.getvalue()
 
     @patch('os.fstat')
     def test_setup_logger_no_handlers_file(self, *args):
@@ -272,10 +258,9 @@ class test_default_logger(AppCase):
                 l = self.setup_logger(
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                 )
-                self.assertIsInstance(
-                    get_logger_handlers(l)[0], logging.FileHandler,
-                )
-                self.assertIn(tempfile, files)
+                assert isinstance(get_logger_handlers(l)[0],
+                                  logging.FileHandler)
+                assert tempfile in files
 
     @mock.restore_logging()
     def test_redirect_stdouts(self):
@@ -287,7 +272,7 @@ class test_default_logger(AppCase):
                     logger, loglevel=logging.ERROR,
                 )
                 logger.error('foo')
-                self.assertIn('foo', sio.getvalue())
+                assert 'foo' in sio.getvalue()
                 self.app.log.redirect_stdouts_to_logger(
                     logger, stdout=False, stderr=False,
                 )
@@ -303,22 +288,22 @@ class test_default_logger(AppCase):
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p.close()
             p.write('foo')
-            self.assertNotIn('foo', sio.getvalue())
+            assert 'foo' not in sio.getvalue()
             p.closed = False
             p.write('foo')
-            self.assertIn('foo', sio.getvalue())
+            assert 'foo' in sio.getvalue()
             lines = ['baz', 'xuzzy']
             p.writelines(lines)
             for line in lines:
-                self.assertIn(line, sio.getvalue())
+                assert line in sio.getvalue()
             p.flush()
             p.close()
-            self.assertFalse(p.isatty())
+            assert not p.isatty()
 
             with mock.stdouts() as (stdout, stderr):
                 with in_sighandler():
                     p.write('foo')
-                    self.assertTrue(stderr.getvalue())
+                    assert stderr.getvalue()
 
     @mock.restore_logging()
     def test_logging_proxy_recurse_protection(self):
@@ -327,7 +312,7 @@ class test_default_logger(AppCase):
         p = LoggingProxy(logger, loglevel=logging.ERROR)
         p._thread.recurse_protection = True
         try:
-            self.assertIsNone(p.write('FOOFO'))
+            assert p.write('FOOFO') is None
         finally:
             p._thread.recurse_protection = False
 

+ 72 - 0
t/unit/app/test_registry.py

@@ -0,0 +1,72 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from celery.app.registry import _unpickle_task, _unpickle_task_v2
+
+
+def returns():
+    return 1
+
+
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_unpickle_task:
+
+    def test_unpickle_v1(self, app):
+        app.tasks['txfoo'] = 'bar'
+        assert _unpickle_task('txfoo') == 'bar'
+
+    def test_unpickle_v2(self, app):
+        app.tasks['txfoo1'] = 'bar1'
+        assert _unpickle_task_v2('txfoo1') == 'bar1'
+        assert _unpickle_task_v2('txfoo1', module='celery') == 'bar1'
+
+
+class test_TaskRegistry:
+
+    def setup(self):
+        self.mytask = self.app.task(name='A', shared=False)(returns)
+        self.myperiodic = self.app.task(
+            name='B', shared=False, type='periodic',
+        )(returns)
+
+    def test_NotRegistered_str(self):
+        assert repr(self.app.tasks.NotRegistered('tasks.add'))
+
+    def assert_register_unregister_cls(self, r, task):
+        r.unregister(task)
+        with pytest.raises(r.NotRegistered):
+            r.unregister(task)
+        r.register(task)
+        assert task.name in r
+
+    def test_task_registry(self):
+        r = self.app._tasks
+        assert isinstance(r, dict)
+
+        self.assert_register_unregister_cls(r, self.mytask)
+        self.assert_register_unregister_cls(r, self.myperiodic)
+
+        r.register(self.myperiodic)
+        r.unregister(self.myperiodic.name)
+        assert self.myperiodic not in r
+        r.register(self.myperiodic)
+
+        tasks = dict(r)
+        assert tasks.get(self.mytask.name) is self.mytask
+        assert tasks.get(self.myperiodic.name) is self.myperiodic
+
+        assert r[self.mytask.name] is self.mytask
+        assert r[self.myperiodic.name] is self.myperiodic
+
+        r.unregister(self.mytask)
+        assert self.mytask.name not in r
+        r.unregister(self.myperiodic)
+        assert self.myperiodic.name not in r
+
+        assert self.mytask.run()
+        assert self.myperiodic.run()
+
+    def test_compat(self):
+        assert self.app.tasks.regular()
+        assert self.app.tasks.periodic()

+ 30 - 37
celery/tests/app/test_routes.py → t/unit/app/test_routes.py

@@ -1,14 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import ANY, Mock
 from kombu import Exchange, Queue
 from kombu.utils.functional import maybe_evaluate
 
 from celery.app import routes
 from celery.exceptions import QueueNotFound
+from celery.five import items
 from celery.utils.imports import qualname
 
-from celery.tests.case import ANY, AppCase, Mock
-
 
 def Router(app, *args, **kwargs):
     return routes.Router(*args, app=app, **kwargs)
@@ -25,7 +27,7 @@ def set_queues(app, **queues):
     app.amqp.queues = app.amqp.Queues(queues)
 
 
-class RouteCase(AppCase):
+class RouteCase:
 
     def setup(self):
         self.a_queue = {
@@ -51,10 +53,7 @@ class RouteCase(AppCase):
 
     def assert_routes_to_queue(self, queue, router, name,
                                args=[], kwargs={}, options={}):
-        self.assertEqual(
-            router.route(options, name, args, kwargs)['queue'].name,
-            queue,
-        )
+        assert router.route(options, name, args, kwargs)['queue'].name == queue
 
     def assert_routes_to_default_queue(self, router, name, *args, **kwargs):
         self.assert_routes_to_queue(
@@ -67,38 +66,32 @@ class test_MapRoute(RouteCase):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         expand = E(self.app, self.app.amqp.queues)
         route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
-        self.assertEqual(
-            expand(route(self.mytask.name))['queue'].name,
-            'foo',
-        )
-        self.assertIsNone(route('celery.awesome'))
+        assert expand(route(self.mytask.name))['queue'].name == 'foo'
+        assert route('celery.awesome') is None
 
     def test_route_for_task(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         expand = E(self.app, self.app.amqp.queues)
         route = routes.MapRoute({self.mytask.name: self.b_queue})
-        self.assertDictContainsSubset(
-            self.b_queue,
-            expand(route(self.mytask.name)),
-        )
-        self.assertIsNone(route('celery.awesome'))
+        eroute = expand(route(self.mytask.name))
+        for key, value in items(self.b_queue):
+            assert eroute[key] == value
+        assert route('celery.awesome') is None
 
     def test_route_for_task__glob(self):
         route = routes.MapRoute([
             ('proj.tasks.*', 'routeA'),
             ('demoapp.tasks.bar.*', {'exchange': 'routeB'}),
         ])
-        self.assertDictEqual(route('proj.tasks.foo'), {'queue': 'routeA'})
-        self.assertDictEqual(route('demoapp.tasks.bar.moo'), {
-            'exchange': 'routeB',
-        })
-        self.assertIsNone(route('demoapp.foo.bar.moo'))
+        assert route('proj.tasks.foo') == {'queue': 'routeA'}
+        assert route('demoapp.tasks.bar.moo') == {'exchange': 'routeB'}
+        assert route('demoapp.foo.bar.moo') is None
 
     def test_expand_route_not_found(self):
         expand = E(self.app, self.app.amqp.Queues(
                    self.app.conf.task_queues, False))
         route = routes.MapRoute({'a': {'queue': 'x'}})
-        with self.assertRaises(QueueNotFound):
+        with pytest.raises(QueueNotFound):
             expand(route('a'))
 
 
@@ -106,7 +99,7 @@ class test_lookup_route(RouteCase):
 
     def test_init_queues(self):
         router = Router(self.app, queues=None)
-        self.assertDictEqual(router.queues, {})
+        assert router.queues == {}
 
     def test_lookup_takes_first(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
@@ -131,22 +124,22 @@ class test_lookup_route(RouteCase):
             self.mytask.name,
             args=[1, 2], kwargs={},
         )
-        self.assertEqual(route['queue'].name, 'testq')
-        self.assertEqual(route['queue'].exchange, Exchange('testq'))
-        self.assertEqual(route['queue'].routing_key, 'testq')
-        self.assertEqual(route['immediate'], False)
+        assert route['queue'].name == 'testq'
+        assert route['queue'].exchange == Exchange('testq')
+        assert route['queue'].routing_key == 'testq'
+        assert route['immediate'] is False
 
     def test_expand_destination_string(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         x = Router(self.app, {}, self.app.amqp.queues)
         dest = x.expand_destination('foo')
-        self.assertEqual(dest['queue'].name, 'foo')
+        assert dest['queue'].name == 'foo'
 
     def test_expand_destination__Queue(self):
         queue = Queue('foo')
         x = Router(self.app, {}, self.app.amqp.queues)
         dest = x.expand_destination({'queue': queue})
-        self.assertIs(dest['queue'], queue)
+        assert dest['queue'] is queue
 
     def test_lookup_paths_traversed(self):
         self.simple_queue_setup()
@@ -179,7 +172,7 @@ class test_lookup_route(RouteCase):
             task=self.mytask,
         )
         options = step.call_args[0][3]
-        self.assertEqual(options['priority'], 3)
+        assert options['priority'] == 3
 
     def test_compat_router_classes__called_with(self):
         self.simple_queue_setup()
@@ -205,7 +198,7 @@ class TestRouter(object):
             return 'bar'
 
 
-class test_prepare(AppCase):
+class test_prepare:
 
     def test_prepare(self):
         o = object()
@@ -215,13 +208,13 @@ class test_prepare(AppCase):
             o,
         ]
         p = routes.prepare(R)
-        self.assertIsInstance(p[0], routes.MapRoute)
-        self.assertIsInstance(maybe_evaluate(p[1]), TestRouter)
-        self.assertIs(p[2], o)
+        assert isinstance(p[0], routes.MapRoute)
+        assert isinstance(maybe_evaluate(p[1]), TestRouter)
+        assert p[2] is o
 
-        self.assertEqual(routes.prepare(o), [o])
+        assert routes.prepare(o) == [o]
 
     def test_prepare_item_is_dict(self):
         R = {'foo': 'bar'}
         p = routes.prepare(R)
-        self.assertIsInstance(p[0], routes.MapRoute)
+        assert isinstance(p[0], routes.MapRoute)

File diff ditekan karena terlalu besar
+ 276 - 328
t/unit/app/test_schedules.py


+ 16 - 20
celery/tests/app/test_utils.py → t/unit/app/test_utils.py

@@ -2,45 +2,41 @@ from __future__ import absolute_import, unicode_literals
 
 from collections import Mapping, MutableMapping
 
-from celery.app.utils import Settings, filter_hidden_settings, bugreport
+from case import Mock
 
-from celery.tests.case import AppCase, Mock
+from celery.app.utils import Settings, filter_hidden_settings, bugreport
 
 
-class test_Settings(AppCase):
+class test_Settings:
 
     def test_is_mapping(self):
         """Settings should be a collections.Mapping"""
-        self.assertTrue(issubclass(Settings, Mapping))
+        assert issubclass(Settings, Mapping)
 
     def test_is_mutable_mapping(self):
         """Settings should be a collections.MutableMapping"""
-        self.assertTrue(issubclass(Settings, MutableMapping))
+        assert issubclass(Settings, MutableMapping)
 
     def test_find(self):
-        self.assertTrue(self.app.conf.find_option('always_eager'))
+        assert self.app.conf.find_option('always_eager')
 
     def test_get_by_parts(self):
         self.app.conf.task_do_this_and_that = 303
-        self.assertEqual(
-            self.app.conf.get_by_parts('task', 'do', 'this', 'and', 'that'),
-            303,
-        )
+        assert self.app.conf.get_by_parts(
+            'task', 'do', 'this', 'and', 'that') == 303
 
     def test_find_value_for_key(self):
-        self.assertEqual(
-            self.app.conf.find_value_for_key('always_eager'),
-            False,
-        )
+        assert self.app.conf.find_value_for_key(
+            'always_eager') is False
 
     def test_table(self):
-        self.assertTrue(self.app.conf.table(with_defaults=True))
-        self.assertTrue(self.app.conf.table(with_defaults=False))
-        self.assertTrue(self.app.conf.table(censored=False))
-        self.assertTrue(self.app.conf.table(censored=True))
+        assert self.app.conf.table(with_defaults=True)
+        assert self.app.conf.table(with_defaults=False)
+        assert self.app.conf.table(censored=False)
+        assert self.app.conf.table(censored=True)
 
 
-class test_filter_hidden_settings(AppCase):
+class test_filter_hidden_settings:
 
     def test_handles_non_string_keys(self):
         """filter_hidden_settings shouldn't raise an exception when handling
@@ -56,7 +52,7 @@ class test_filter_hidden_settings(AppCase):
         filter_hidden_settings(conf)
 
 
-class test_bugreport(AppCase):
+class test_bugreport:
 
     def test_no_conn_driver_info(self):
         self.app.connection = Mock()

+ 0 - 0
celery/tests/bin/__init__.py → t/unit/apps/__init__.py


+ 96 - 124
celery/tests/apps/test_multi.py → t/unit/apps/test_multi.py

@@ -1,59 +1,61 @@
 from __future__ import absolute_import, unicode_literals
 
 import errno
+import pytest
 import signal
 import sys
 
+from case import Mock, call, patch, skip
+
 from celery.apps.multi import (
     Cluster, MultiParser, NamespacedOptionParser, Node, format_opt,
 )
 
-from celery.tests.case import AppCase, Mock, call, patch, skip
-
 
-class test_functions(AppCase):
+class test_functions:
 
     def test_parse_ns_range(self):
         m = MultiParser()
-        self.assertEqual(m._parse_ns_range('1-3', True), ['1', '2', '3'])
-        self.assertEqual(m._parse_ns_range('1-3', False), ['1-3'])
-        self.assertEqual(m._parse_ns_range(
-            '1-3,10,11,20', True),
-            ['1', '2', '3', '10', '11', '20'],
-        )
+        assert m._parse_ns_range('1-3', True), ['1', '2' == '3']
+        assert m._parse_ns_range('1-3', False) == ['1-3']
+        assert m._parse_ns_range('1-3,10,11,20', True) == [
+            '1', '2', '3', '10', '11', '20',
+        ]
 
     def test_format_opt(self):
-        self.assertEqual(format_opt('--foo', None), '--foo')
-        self.assertEqual(format_opt('-c', 1), '-c 1')
-        self.assertEqual(format_opt('--log', 'foo'), '--log=foo')
+        assert format_opt('--foo', None) == '--foo'
+        assert format_opt('-c', 1) == '-c 1'
+        assert format_opt('--log', 'foo') == '--log=foo'
 
 
-class test_NamespacedOptionParser(AppCase):
+class test_NamespacedOptionParser:
 
     def test_parse(self):
         x = NamespacedOptionParser(['-c:1,3', '4'])
         x.parse()
-        self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
+        assert x.namespaces.get('1,3') == {'-c': '4'}
         x = NamespacedOptionParser(['-c:jerry,elaine', '5',
                                     '--loglevel:kramer=DEBUG',
                                     '--flag',
                                     '--logfile=foo', '-Q', 'bar', 'a', 'b',
                                     '--', '.disable_rate_limits=1'])
         x.parse()
-        self.assertEqual(x.options, {'--logfile': 'foo',
-                                     '-Q': 'bar',
-                                     '--flag': None})
-        self.assertEqual(x.values, ['a', 'b'])
-        self.assertEqual(x.namespaces.get('jerry,elaine'), {'-c': '5'})
-        self.assertEqual(x.namespaces.get('kramer'), {'--loglevel': 'DEBUG'})
-        self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
+        assert x.options == {
+            '--logfile': 'foo',
+            '-Q': 'bar',
+            '--flag': None,
+        }
+        assert x.values, ['a' == 'b']
+        assert x.namespaces.get('jerry,elaine') == {'-c': '5'}
+        assert x.namespaces.get('kramer') == {'--loglevel': 'DEBUG'}
+        assert x.passthrough == '-- .disable_rate_limits=1'
 
 
 def multi_args(p, *args, **kwargs):
     return MultiParser(*args, **kwargs).parse(p)
 
 
-class test_multi_args(AppCase):
+class test_multi_args:
 
     @patch('celery.apps.multi.gethostname')
     def test_parse(self, gethostname):
@@ -72,14 +74,14 @@ class test_multi_args(AppCase):
         nodes = list(it)
 
         def assert_line_in(name, args):
-            self.assertIn(name, {n.name for n in nodes})
+            assert name in {n.name for n in nodes}
             argv = None
             for node in nodes:
                 if node.name == name:
                     argv = node.argv
-            self.assertTrue(argv)
+            assert argv
             for arg in args:
-                self.assertIn(arg, argv)
+                assert arg in argv
 
         assert_line_in(
             '*P*jerry@*S*',
@@ -100,11 +102,11 @@ class test_multi_args(AppCase):
              '-- .disable_rate_limits=1', '*AP*'],
         )
         expand = nodes[0].expander
-        self.assertEqual(expand('%h'), '*P*jerry@*S*')
-        self.assertEqual(expand('%n'), '*P*jerry')
+        assert expand('%h') == '*P*jerry@*S*'
+        assert expand('%n') == '*P*jerry'
         nodes2 = list(multi_args(p, cmd='COMMAND', append='',
                       prefix='*P*', suffix='*S*'))
-        self.assertEqual(nodes2[0].argv[-1], '-- .disable_rate_limits=1')
+        assert nodes2[0].argv[-1] == '-- .disable_rate_limits=1'
 
         p2 = NamespacedOptionParser(['10', '-c:1', '5'])
         p2.parse()
@@ -118,58 +120,47 @@ class test_multi_args(AppCase):
                 '',
             )
 
-        self.assertEqual(len(nodes3), 10)
-        self.assertEqual(nodes3[0].name, 'celery1@example.com')
-        self.assertTupleEqual(
-            nodes3[0].argv,
-            ('COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1'),
-        )
+        assert len(nodes3) == 10
+        assert nodes3[0].name == 'celery1@example.com'
+        assert nodes3[0].argv == (
+            'COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1')
         for i, worker in enumerate(nodes3[1:]):
-            self.assertEqual(worker.name, 'celery%s@example.com' % (i + 2))
-            self.assertTupleEqual(
-                worker.argv,
-                (('COMMAND', '-n celery%s@example.com' % (i + 2)) +
-                 _args('celery%s' % (i + 2))),
-            )
+            assert worker.name == 'celery%s@example.com' % (i + 2)
+            node_i = 'celery%s' % (i + 2,)
+            assert worker.argv == (
+                'COMMAND',
+                '-n %s@example.com' % (node_i,)) + _args(node_i)
 
         nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
-        self.assertEqual(len(nodes4), 10)
-        self.assertEqual(nodes4[0].name, 'celery1@')
-        self.assertTupleEqual(
-            nodes4[0].argv,
-            ('COMMAND', '-c 5', '-n celery1@') + _args('celery1'),
-        )
+        assert len(nodes4) == 10
+        assert nodes4[0].name == 'celery1@'
+        assert nodes4[0].argv == (
+            'COMMAND', '-c 5', '-n celery1@') + _args('celery1')
 
         p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
         p3.parse()
         nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes5[0].name, 'foo@')
-        self.assertTupleEqual(
-            nodes5[0].argv,
-            ('COMMAND', '-c 5', '-n foo@') + _args('foo'),
-        )
+        assert nodes5[0].name == 'foo@'
+        assert nodes5[0].argv == (
+            'COMMAND', '-c 5', '-n foo@') + _args('foo')
 
         p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
         p4.parse()
         nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes6[0].name, 'foo@')
-        self.assertTupleEqual(
-            nodes6[0].argv,
-            ('COMMAND', '-Q test', '-n foo@') + _args('foo'),
-        )
+        assert nodes6[0].name == 'foo@'
+        assert nodes6[0].argv == (
+            'COMMAND', '-Q test', '-n foo@') + _args('foo')
 
         p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
         p5.parse()
         nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes7[0].name, 'foo@bar')
-        self.assertTupleEqual(
-            nodes7[0].argv,
-            ('COMMAND', '-Q test', '-n foo@bar') + _args('foo'),
-        )
+        assert nodes7[0].name == 'foo@bar'
+        assert nodes7[0].argv == (
+            'COMMAND', '-Q test', '-n foo@bar') + _args('foo')
 
         p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
         p6.parse()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             list(multi_args(p6))
 
     def test_optmerge(self):
@@ -177,10 +168,10 @@ class test_multi_args(AppCase):
         p.parse()
         p.options = {'x': 'y'}
         r = p.optmerge('foo')
-        self.assertEqual(r['x'], 'y')
+        assert r['x'] == 'y'
 
 
-class test_Node(AppCase):
+class test_Node:
 
     def setup(self):
         self.p = Mock(name='p')
@@ -198,7 +189,7 @@ class test_Node(AppCase):
             'foo@bar.com',
             max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
         )
-        self.assertItemsEqual(n.argv, (
+        assert sorted(n.argv) == sorted([
             '-m celery worker --detach',
             '-A foo',
             '--executable={0}'.format(n.executable),
@@ -209,31 +200,31 @@ class test_Node(AppCase):
             '--max-tasks-per-child=30',
             '--pidfile=foo.pid',
             '',
-        ))
+        ])
 
     @patch('os.kill')
     def test_send(self, kill):
-        self.assertTrue(self.node.send(9))
+        assert self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
 
     @patch('os.kill')
     def test_send__ESRCH(self, kill):
         kill.side_effect = OSError()
         kill.side_effect.errno = errno.ESRCH
-        self.assertFalse(self.node.send(9))
+        assert not self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
 
     @patch('os.kill')
     def test_send__error(self, kill):
         kill.side_effect = OSError()
         kill.side_effect.errno = errno.ENOENT
-        with self.assertRaises(OSError):
+        with pytest.raises(OSError):
             self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
 
     def test_alive(self):
         self.node.send = Mock(name='send')
-        self.assertIs(self.node.alive(), self.node.send.return_value)
+        assert self.node.alive() is self.node.send.return_value
         self.node.send.assert_called_with(0)
 
     def test_start(self):
@@ -270,40 +261,32 @@ class test_Node(AppCase):
         )
 
     def test_handle_process_exit(self):
-        self.assertEqual(
-            self.node.handle_process_exit(0),
-            0,
-        )
+        assert self.node.handle_process_exit(0) == 0
 
     def test_handle_process_exit__failure(self):
         on_failure = Mock(name='on_failure')
-        self.assertEqual(
-            self.node.handle_process_exit(9, on_failure=on_failure),
-            9,
-        )
+        assert self.node.handle_process_exit(9, on_failure=on_failure) == 9
         on_failure.assert_called_with(self.node, 9)
 
     def test_handle_process_exit__signalled(self):
         on_signalled = Mock(name='on_signalled')
-        self.assertEqual(
-            self.node.handle_process_exit(-9, on_signalled=on_signalled),
-            9,
-        )
+        assert self.node.handle_process_exit(
+            -9, on_signalled=on_signalled) == 9
         on_signalled.assert_called_with(self.node, 9)
 
     def test_logfile(self):
-        self.assertEqual(self.node.logfile, self.expander.return_value)
+        assert self.node.logfile == self.expander.return_value
         self.expander.assert_called_with('%n%I.log')
 
 
-class test_Cluster(AppCase):
+class test_Cluster:
 
     def setup(self):
-        self.Popen = self.patch('celery.apps.multi.Popen')
-        self.kill = self.patch('os.kill')
-        self.gethostname = self.patch('celery.apps.multi.gethostname')
+        self.Popen = self.patching('celery.apps.multi.Popen')
+        self.kill = self.patching('os.kill')
+        self.gethostname = self.patching('celery.apps.multi.gethostname')
         self.gethostname.return_value = 'example.com'
-        self.Pidfile = self.patch('celery.apps.multi.Pidfile')
+        self.Pidfile = self.patching('celery.apps.multi.Pidfile')
         self.cluster = Cluster(
             [Node('foo@example.com'),
              Node('bar@example.com'),
@@ -326,10 +309,10 @@ class test_Cluster(AppCase):
         )
 
     def test_len(self):
-        self.assertEqual(len(self.cluster), 3)
+        assert len(self.cluster) == 3
 
     def test_getitem(self):
-        self.assertEqual(self.cluster[0].name, 'foo@example.com')
+        assert self.cluster[0].name == 'foo@example.com'
 
     def test_start(self):
         self.cluster.start_node = Mock(name='start_node')
@@ -341,10 +324,8 @@ class test_Cluster(AppCase):
     def test_start_node(self):
         self.cluster._start_node = Mock(name='_start_node')
         node = self.cluster[0]
-        self.assertIs(
-            self.cluster.start_node(node),
-            self.cluster._start_node.return_value,
-        )
+        assert (self.cluster.start_node(node) is
+                self.cluster._start_node.return_value)
         self.cluster.on_node_start.assert_called_with(node)
         self.cluster._start_node.assert_called_with(node)
         self.cluster.on_node_status.assert_called_with(
@@ -354,10 +335,7 @@ class test_Cluster(AppCase):
     def test__start_node(self):
         node = self.cluster[0]
         node.start = Mock(name='node.start')
-        self.assertIs(
-            self.cluster._start_node(node),
-            node.start.return_value,
-        )
+        assert self.cluster._start_node(node) is node.start.return_value
         node.start.assert_called_with(
             self.cluster.env,
             on_spawn=self.cluster.on_child_spawn,
@@ -394,33 +372,27 @@ class test_Cluster(AppCase):
         ])
         nodes = p.getpids(on_down=callback)
         node_0, node_1 = nodes
-        self.assertEqual(node_0.name, 'foo@e.com')
-        self.assertEqual(
-            sorted(node_0.argv),
-            sorted([
-                '',
-                '--executable={0}'.format(node_0.executable),
-                '--logfile=foo%I.log',
-                '--pidfile=foo.pid',
-                '-m celery worker --detach',
-                '-n foo@e.com',
-            ]),
-        )
-        self.assertEqual(node_0.pid, 10)
+        assert node_0.name == 'foo@e.com'
+        assert sorted(node_0.argv) == sorted([
+            '',
+            '--executable={0}'.format(node_0.executable),
+            '--logfile=foo%I.log',
+            '--pidfile=foo.pid',
+            '-m celery worker --detach',
+            '-n foo@e.com',
+        ])
+        assert node_0.pid == 10
 
-        self.assertEqual(node_1.name, 'bar@e.com')
-        self.assertEqual(
-            sorted(node_1.argv),
-            sorted([
-                '',
-                '--executable={0}'.format(node_1.executable),
-                '--logfile=bar%I.log',
-                '--pidfile=bar.pid',
-                '-m celery worker --detach',
-                '-n bar@e.com',
-            ]),
-        )
-        self.assertEqual(node_1.pid, 11)
+        assert node_1.name == 'bar@e.com'
+        assert sorted(node_1.argv) == sorted([
+            '',
+            '--executable={0}'.format(node_1.executable),
+            '--logfile=bar%I.log',
+            '--pidfile=bar.pid',
+            '-m celery worker --detach',
+            '-n bar@e.com',
+        ])
+        assert node_1.pid == 11
 
         # without callback, should work
         nodes = p.getpids('celery worker')

+ 0 - 0
celery/tests/compat_modules/__init__.py → t/unit/backends/__init__.py


+ 40 - 47
celery/tests/backends/test_amqp.py → t/unit/backends/test_amqp.py

@@ -1,11 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 
 import pickle
+import pytest
 
 from contextlib import contextmanager
 from datetime import timedelta
 from pickle import dumps, loads
 
+from case import Mock, mock
 from billiard.einfo import ExceptionInfo
 
 from celery import states
@@ -14,8 +16,6 @@ from celery.backends.amqp import AMQPBackend
 from celery.five import Empty, Queue, range
 from celery.result import AsyncResult
 
-from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
-
 
 class SomeClass(object):
 
@@ -23,7 +23,7 @@ class SomeClass(object):
         self.data = data
 
 
-class test_AMQPBackend(AppCase):
+class test_AMQPBackend:
 
     def setup(self):
         self.app.conf.result_cache_max = 100
@@ -35,9 +35,8 @@ class test_AMQPBackend(AppCase):
     def test_destination_for(self):
         b = self.create_backend()
         request = Mock()
-        self.assertTupleEqual(
-            b.destination_for('id', request),
-            (b.rkey('id'), request.correlation_id),
+        assert b.destination_for('id', request) == (
+            b.rkey('id'), request.correlation_id,
         )
 
     def test_store_result__no_routing_key(self):
@@ -53,14 +52,14 @@ class test_AMQPBackend(AppCase):
         tid = uuid()
 
         tb1.mark_as_done(tid, 42)
-        self.assertEqual(tb2.get_state(tid), states.SUCCESS)
-        self.assertEqual(tb2.get_result(tid), 42)
-        self.assertTrue(tb2._cache.get(tid))
-        self.assertTrue(tb2.get_result(tid), 42)
+        assert tb2.get_state(tid) == states.SUCCESS
+        assert tb2.get_result(tid) == 42
+        assert tb2._cache.get(tid)
+        assert tb2.get_result(tid), 42
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_pickleable(self):
-        self.assertTrue(loads(dumps(self.create_backend())))
+        assert loads(dumps(self.create_backend()))
 
     def test_revive(self):
         tb = self.create_backend()
@@ -75,8 +74,8 @@ class test_AMQPBackend(AppCase):
         tb1.mark_as_done(tid2, result)
         # is serialized properly.
         rindb = tb2.get_result(tid2)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
     def test_mark_as_failure(self):
         tb1 = self.create_backend()
@@ -88,27 +87,27 @@ class test_AMQPBackend(AppCase):
         except KeyError as exception:
             einfo = ExceptionInfo()
             tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
-            self.assertEqual(tb2.get_state(tid3), states.FAILURE)
-            self.assertIsInstance(tb2.get_result(tid3), KeyError)
-            self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
+            assert tb2.get_state(tid3) == states.FAILURE
+            assert isinstance(tb2.get_result(tid3), KeyError)
+            assert tb2.get_traceback(tid3) == einfo.traceback
 
     def test_repair_uuid(self):
         from celery.backends.amqp import repair_uuid
         for i in range(10):
             tid = uuid()
-            self.assertEqual(repair_uuid(tid.replace('-', '')), tid)
+            assert repair_uuid(tid.replace('-', '')) == tid
 
     def test_expires_is_int(self):
         b = self.create_backend(expires=48)
-        self.assertEqual(b.queue_arguments.get('x-expires'), 48 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 48 * 1000.0
 
     def test_expires_is_float(self):
         b = self.create_backend(expires=48.3)
-        self.assertEqual(b.queue_arguments.get('x-expires'), 48.3 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 48.3 * 1000.0
 
     def test_expires_is_timedelta(self):
         b = self.create_backend(expires=timedelta(minutes=1))
-        self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 60 * 1000.0
 
     @mock.sleepdeprived()
     def test_store_result_retries(self):
@@ -125,22 +124,19 @@ class test_AMQPBackend(AppCase):
         from celery.app.amqp import Producer
         prod, Producer.publish = Producer.publish, publish
         try:
-            with self.assertRaises(KeyError):
+            with pytest.raises(KeyError):
                 backend.retry_policy['max_retries'] = None
                 backend.store_result('foo', 'bar', 'STARTED')
 
-            with self.assertRaises(KeyError):
+            with pytest.raises(KeyError):
                 backend.retry_policy['max_retries'] = 10
                 backend.store_result('foo', 'bar', 'STARTED')
         finally:
             Producer.publish = prod
 
-    def assertState(self, retval, state):
-        self.assertEqual(retval['status'], state)
-
     def test_poll_no_messages(self):
         b = self.create_backend()
-        self.assertState(b.get_task_meta(uuid()), states.PENDING)
+        assert b.get_task_meta(uuid())['status'] == states.PENDING
 
     @contextmanager
     def _result_context(self):
@@ -199,7 +195,7 @@ class test_AMQPBackend(AppCase):
         with self._result_context() as (results, backend, Message):
             for i in range(1001):
                 results.put(Message(task_id='id', status=states.RECEIVED))
-            with self.assertRaises(backend.BacklogLimitExceeded):
+            with pytest.raises(backend.BacklogLimitExceeded):
                 backend.get_task_meta('id')
 
     def test_poll_result(self):
@@ -214,27 +210,24 @@ class test_AMQPBackend(AppCase):
             for state_message in state_messages:
                 results.put(state_message)
             r1 = backend.get_task_meta(tid)
-            self.assertDictContainsSubset(
-                {'status': states.FAILURE, 'seq': 3}, r1,
-                'FFWDs to the last state',
-            )
+            # FFWDs to the last state.
+            assert r1['status'] == states.FAILURE
+            assert r1['seq'] == 3
 
             # Caches last known state.
             tid = uuid()
             results.put(Message(task_id=tid))
             backend.get_task_meta(tid)
-            self.assertIn(tid, backend._cache, 'Caches last known state')
+            assert tid, backend._cache in 'Caches last known state'
 
-            self.assertTrue(state_messages[-1].requeued)
+            assert state_messages[-1].requeued
 
             # Returns cache if no new states.
             results.queue.clear()
             assert not results.qsize()
             backend._cache[tid] = 'hello'
-            self.assertEqual(
-                backend.get_task_meta(tid), 'hello',
-                'Returns cache if no new states',
-            )
+            # returns cache if no new states.
+            assert backend.get_task_meta(tid) == 'hello'
 
     def test_drain_events_decodes_exceptions_in_meta(self):
         tid = uuid()
@@ -242,39 +235,39 @@ class test_AMQPBackend(AppCase):
         b.store_result(tid, RuntimeError('aap'), states.FAILURE)
         result = AsyncResult(tid, backend=b)
 
-        with self.assertRaises(Exception) as cm:
+        with pytest.raises(Exception) as excinfo:
             result.get()
 
-        self.assertEqual(cm.exception.__class__.__name__, 'RuntimeError')
-        self.assertEqual(str(cm.exception), 'aap')
+        assert excinfo.value.__class__.__name__ == 'RuntimeError'
+        assert str(excinfo.value) == 'aap'
 
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         app = self.app
         app.conf.result_expires = None
         b = self.create_backend(expires=None)
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             b.queue_arguments['x-expires']
 
     def test_process_cleanup(self):
         self.create_backend().process_cleanup()
 
     def test_reload_task_result(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().reload_task_result('x')
 
     def test_reload_group_result(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().reload_group_result('x')
 
     def test_save_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().save_group('x', 'x')
 
     def test_restore_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().restore_group('x')
 
     def test_delete_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().delete_group('x')

+ 41 - 0
t/unit/backends/test_backends.py

@@ -0,0 +1,41 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from case import patch
+
+from celery import backends
+from celery.backends.amqp import AMQPBackend
+from celery.backends.cache import CacheBackend
+from celery.exceptions import ImproperlyConfigured
+
+
+class test_backends:
+
+    @pytest.mark.parametrize('url,expect_cls', [
+        ('amqp://', AMQPBackend),
+        ('cache+memory://', CacheBackend),
+    ])
+    def test_get_backend_aliases(self, url, expect_cls, app):
+        backend, url = backends.get_backend_by_url(url, app.loader)
+        assert isinstance(backend(app=app, url=url), expect_cls)
+
+    def test_unknown_backend(self, app):
+        with pytest.raises(ImportError):
+            backends.get_backend_cls('fasodaopjeqijwqe', app.loader)
+
+    @pytest.mark.usefixtures('depends_on_current_app')
+    def test_default_backend(self, app):
+        assert backends.default_backend == app.backend
+
+    def test_backend_by_url(self, app, url='redis://localhost/1'):
+        from celery.backends.redis import RedisBackend
+        backend, url_ = backends.get_backend_by_url(url, app.loader)
+        assert backend is RedisBackend
+        assert url_ == url
+
+    def test_sym_raises_ValuError(self, app):
+        with patch('celery.backends.symbol_by_name') as sbn:
+            sbn.side_effect = ValueError()
+            with pytest.raises(ImproperlyConfigured):
+                backends.get_backend_cls('xxx.xxx:foo', app.loader)

+ 106 - 111
celery/tests/backends/test_base.py → t/unit/backends/test_base.py

@@ -1,17 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
 import sys
 import types
 
 from contextlib import contextmanager
 
-from celery.exceptions import ChordError, TimeoutError
-from celery.five import items, bytes_if_py2, range
-from celery.utils import serialization
-from celery.utils.serialization import subclass_exception
-from celery.utils.serialization import find_pickleable_exception as fnpe
-from celery.utils.serialization import UnpickleableExceptionWrapper
-from celery.utils.serialization import get_pickleable_exception as gpe
+from case import ANY, Mock, call, patch, skip
 
 from celery import states
 from celery import group, uuid
@@ -21,10 +16,15 @@ from celery.backends.base import (
     DisabledBackend,
     _nulldict,
 )
+from celery.exceptions import ChordError, TimeoutError
+from celery.five import items, bytes_if_py2, range
 from celery.result import result_from_tuple
+from celery.utils import serialization
 from celery.utils.functional import pass1
-
-from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
+from celery.utils.serialization import subclass_exception
+from celery.utils.serialization import find_pickleable_exception as fnpe
+from celery.utils.serialization import UnpickleableExceptionWrapper
+from celery.utils.serialization import get_pickleable_exception as gpe
 
 
 class wrapobject(object):
@@ -47,7 +47,7 @@ Lookalike = subclass_exception(
 )
 
 
-class test_nulldict(Case):
+class test_nulldict:
 
     def test_nulldict(self):
         x = _nulldict()
@@ -56,25 +56,24 @@ class test_nulldict(Case):
         x.setdefault('foo', 3)
 
 
-class test_serialization(AppCase):
+class test_serialization:
 
     def test_create_exception_cls(self):
-        self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
-        self.assertTrue(serialization.create_exception_cls('FooError', 'm',
-                                                           KeyError))
+        assert serialization.create_exception_cls('FooError', 'm')
+        assert serialization.create_exception_cls('FooError', 'm', KeyError)
 
 
-class test_BaseBackend_interface(AppCase):
+class test_BaseBackend_interface:
 
     def setup(self):
         self.b = BaseBackend(self.app)
 
     def test__forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.b._forget('SOMExx-N0Nex1stant-IDxx-')
 
     def test_forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.b.forget('SOMExx-N0nex1stant-IDxx-')
 
     def test_on_chord_part_return(self):
@@ -86,29 +85,29 @@ class test_BaseBackend_interface(AppCase):
             group(app=self.app), (), 'dakj221', None,
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
-        self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
+        assert self.app.tasks[unlock].apply_async.call_count
 
 
-class test_exception_pickle(AppCase):
+class test_exception_pickle:
 
     @skip.if_python3(reason='does not support old style classes')
     @skip.if_pypy()
     def test_oldstyle(self):
-        self.assertTrue(fnpe(Oldstyle()))
+        assert fnpe(Oldstyle())
 
     def test_BaseException(self):
-        self.assertIsNone(fnpe(Exception()))
+        assert fnpe(Exception()) is None
 
     def test_get_pickleable_exception(self):
         exc = Exception('foo')
-        self.assertEqual(gpe(exc), exc)
+        assert gpe(exc) == exc
 
     def test_unpickleable(self):
-        self.assertIsInstance(fnpe(Unpickleable()), KeyError)
-        self.assertIsNone(fnpe(Impossible()))
+        assert isinstance(fnpe(Unpickleable()), KeyError)
+        assert fnpe(Impossible()) is None
 
 
-class test_prepare_exception(AppCase):
+class test_prepare_exception:
 
     def setup(self):
         self.b = BaseBackend(self.app)
@@ -116,28 +115,28 @@ class test_prepare_exception(AppCase):
     def test_unpickleable(self):
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
-        self.assertIsInstance(x, KeyError)
+        assert isinstance(x, KeyError)
         y = self.b.exception_to_python(x)
-        self.assertIsInstance(y, KeyError)
+        assert isinstance(y, KeyError)
 
     def test_impossible(self):
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(Impossible())
-        self.assertIsInstance(x, UnpickleableExceptionWrapper)
-        self.assertTrue(str(x))
+        assert isinstance(x, UnpickleableExceptionWrapper)
+        assert str(x)
         y = self.b.exception_to_python(x)
-        self.assertEqual(y.__class__.__name__, 'Impossible')
+        assert y.__class__.__name__ == 'Impossible'
         if sys.version_info < (2, 5):
-            self.assertTrue(y.__class__.__module__)
+            assert y.__class__.__module__
         else:
-            self.assertEqual(y.__class__.__module__, 'foo.module')
+            assert y.__class__.__module__ == 'foo.module'
 
     def test_regular(self):
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(KeyError('baz'))
-        self.assertIsInstance(x, KeyError)
+        assert isinstance(x, KeyError)
         y = self.b.exception_to_python(x)
-        self.assertIsInstance(y, KeyError)
+        assert isinstance(y, KeyError)
 
 
 class KVBackend(KeyValueStoreBackend):
@@ -181,22 +180,22 @@ class DictBackend(BaseBackend):
         self._data.pop(group_id, None)
 
 
-class test_BaseBackend_dict(AppCase):
+class test_BaseBackend_dict:
 
     def setup(self):
         self.b = DictBackend(app=self.app)
 
     def test_delete_group(self):
         self.b.delete_group('can-delete')
-        self.assertNotIn('can-delete', self.b._data)
+        assert 'can-delete' not in self.b._data
 
     def test_prepare_exception_json(self):
         x = DictBackend(self.app, serializer='json')
         e = x.prepare_exception(KeyError('foo'))
-        self.assertIn('exc_type', e)
+        assert 'exc_type' in e
         e = x.exception_to_python(e)
-        self.assertEqual(e.__class__.__name__, 'KeyError')
-        self.assertEqual(str(e).strip('u'), "'foo'")
+        assert e.__class__.__name__ == 'KeyError'
+        assert str(e).strip('u') == "'foo'"
 
     def test_save_group(self):
         b = BaseBackend(self.app)
@@ -206,20 +205,20 @@ class test_BaseBackend_dict(AppCase):
 
     def test_add_to_chord_interface(self):
         b = BaseBackend(self.app)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             b.add_to_chord('group_id', 'sig')
 
     def test_forget_interface(self):
         b = BaseBackend(self.app)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             b.forget('foo')
 
     def test_restore_group(self):
-        self.assertIsNone(self.b.restore_group('missing'))
-        self.assertIsNone(self.b.restore_group('missing'))
-        self.assertEqual(self.b.restore_group('exists'), 'group')
-        self.assertEqual(self.b.restore_group('exists'), 'group')
-        self.assertEqual(self.b.restore_group('exists', cache=False), 'group')
+        assert self.b.restore_group('missing') is None
+        assert self.b.restore_group('missing') is None
+        assert self.b.restore_group('exists') == 'group'
+        assert self.b.restore_group('exists') == 'group'
+        assert self.b.restore_group('exists', cache=False) == 'group'
 
     def test_reload_group_result(self):
         self.b._cache = {}
@@ -239,29 +238,29 @@ class test_BaseBackend_dict(AppCase):
             self.b.fail_from_current_stack('task_id')
             self.b.mark_as_failure.assert_called()
             args = self.b.mark_as_failure.call_args[0]
-            self.assertEqual(args[0], 'task_id')
-            self.assertIs(args[1], exc)
-            self.assertTrue(args[2])
+            assert args[0] == 'task_id'
+            assert args[1] is exc
+            assert args[2]
 
     def test_prepare_value_serializes_group_result(self):
         self.b.serializer = 'json'
         g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
         v = self.b.prepare_value(g)
-        self.assertIsInstance(v, (list, tuple))
-        self.assertEqual(result_from_tuple(v, app=self.app), g)
+        assert isinstance(v, (list, tuple))
+        assert result_from_tuple(v, app=self.app) == g
 
         v2 = self.b.prepare_value(g[0])
-        self.assertIsInstance(v2, (list, tuple))
-        self.assertEqual(result_from_tuple(v2, app=self.app), g[0])
+        assert isinstance(v2, (list, tuple))
+        assert result_from_tuple(v2, app=self.app) == g[0]
 
         self.b.serializer = 'pickle'
-        self.assertIsInstance(self.b.prepare_value(g), self.app.GroupResult)
+        assert isinstance(self.b.prepare_value(g), self.app.GroupResult)
 
     def test_is_cached(self):
         b = BaseBackend(app=self.app, max_cached_results=1)
         b._cache['foo'] = 1
-        self.assertTrue(b.is_cached('foo'))
-        self.assertFalse(b.is_cached('false'))
+        assert b.is_cached('foo')
+        assert not b.is_cached('false')
 
     def test_mark_as_done__chord(self):
         b = BaseBackend(app=self.app)
@@ -297,7 +296,7 @@ class test_BaseBackend_dict(AppCase):
         callback.options = {'link_error': []}
         task = self.app.tasks[callback.task] = Mock()
         b.fail_from_current_stack = Mock()
-        group = self.patch('celery.group')
+        group = self.patching('celery.group')
         group.side_effect = exc
         b.chord_error_from_stack(callback, exc=ValueError())
         task.backend.fail_from_current_stack.assert_called_with(
@@ -305,15 +304,15 @@ class test_BaseBackend_dict(AppCase):
 
     def test_exception_to_python_when_None(self):
         b = BaseBackend(app=self.app)
-        self.assertIsNone(b.exception_to_python(None))
+        assert b.exception_to_python(None) is None
 
     def test_wait_for__on_interval(self):
-        self.patch('time.sleep')
+        self.patching('time.sleep')
         b = BaseBackend(app=self.app)
         b._get_task_meta_for = Mock()
         b._get_task_meta_for.return_value = {'status': states.PENDING}
         callback = Mock(name='callback')
-        with self.assertRaises(TimeoutError):
+        with pytest.raises(TimeoutError):
             b.wait_for(task_id='1', on_interval=callback, timeout=1)
         callback.assert_called_with()
 
@@ -324,12 +323,12 @@ class test_BaseBackend_dict(AppCase):
         b = BaseBackend(app=self.app)
         b._get_task_meta_for = Mock()
         b._get_task_meta_for.return_value = {}
-        self.assertIsNone(b.get_children('id'))
+        assert b.get_children('id') is None
         b._get_task_meta_for.return_value = {'children': 3}
-        self.assertEqual(b.get_children('id'), 3)
+        assert b.get_children('id') == 3
 
 
-class test_KeyValueStoreBackend(AppCase):
+class test_KeyValueStoreBackend:
 
     def setup(self):
         self.b = KVBackend(app=self.app)
@@ -341,15 +340,15 @@ class test_KeyValueStoreBackend(AppCase):
     def test_get_store_delete_result(self):
         tid = uuid()
         self.b.mark_as_done(tid, 'Hello world')
-        self.assertEqual(self.b.get_result(tid), 'Hello world')
-        self.assertEqual(self.b.get_state(tid), states.SUCCESS)
+        assert self.b.get_result(tid) == 'Hello world'
+        assert self.b.get_state(tid) == states.SUCCESS
         self.b.forget(tid)
-        self.assertEqual(self.b.get_state(tid), states.PENDING)
+        assert self.b.get_state(tid) == states.PENDING
 
     def test_strip_prefix(self):
         x = self.b.get_key_for_task('x1b34')
-        self.assertEqual(self.b._strip_prefix(x), 'x1b34')
-        self.assertEqual(self.b._strip_prefix('x1b34'), 'x1b34')
+        assert self.b._strip_prefix(x) == 'x1b34'
+        assert self.b._strip_prefix('x1b34') == 'x1b34'
 
     def test_get_many(self):
         for is_dict in True, False:
@@ -359,17 +358,17 @@ class test_KeyValueStoreBackend(AppCase):
                 self.b.mark_as_done(id, i)
             it = self.b.get_many(list(ids))
             for i, (got_id, got_state) in enumerate(it):
-                self.assertEqual(got_state['result'], ids[got_id])
-            self.assertEqual(i, 9)
-            self.assertTrue(list(self.b.get_many(list(ids))))
+                assert got_state['result'] == ids[got_id]
+            assert i == 9
+            assert list(self.b.get_many(list(ids)))
 
             self.b._cache.clear()
             callback = Mock(name='callback')
             it = self.b.get_many(list(ids), on_message=callback)
             for i, (got_id, got_state) in enumerate(it):
-                self.assertEqual(got_state['result'], ids[got_id])
-            self.assertEqual(i, 9)
-            self.assertTrue(list(self.b.get_many(list(ids))))
+                assert got_state['result'] == ids[got_id]
+            assert i == 9
+            assert list(self.b.get_many(list(ids)))
             callback.assert_has_calls([
                 call(ANY) for id in ids
             ])
@@ -377,7 +376,7 @@ class test_KeyValueStoreBackend(AppCase):
     def test_get_many_times_out(self):
         tasks = [uuid() for _ in range(4)]
         self.b._cache[tasks[1]] = {'status': 'PENDING'}
-        with self.assertRaises(self.b.TimeoutError):
+        with pytest.raises(self.b.TimeoutError):
             list(self.b.get_many(tasks, timeout=0.01, interval=0.01))
 
     def test_chord_part_return_no_gid(self):
@@ -390,9 +389,8 @@ class test_KeyValueStoreBackend(AppCase):
         self.b.get_key_for_chord.side_effect = AssertionError(
             'should not get here',
         )
-        self.assertIsNone(
-            self.b.on_chord_part_return(task.request, state, result),
-        )
+        assert self.b.on_chord_part_return(
+            task.request, state, result) is None
 
     @patch('celery.backends.base.GroupResult')
     @patch('celery.backends.base.maybe_signature')
@@ -429,14 +427,11 @@ class test_KeyValueStoreBackend(AppCase):
     def test_filter_ready(self):
         self.b.decode_result = Mock()
         self.b.decode_result.side_effect = pass1
-        self.assertEqual(
-            len(list(self.b._filter_ready([
-                (1, {'status': states.RETRY}),
-                (2, {'status': states.FAILURE}),
-                (3, {'status': states.SUCCESS}),
-            ]))),
-            2,
-        )
+        assert len(list(self.b._filter_ready([
+            (1, {'status': states.RETRY}),
+            (2, {'status': states.FAILURE}),
+            (3, {'status': states.SUCCESS}),
+        ]))) == 2
 
     @contextmanager
     def _chord_part_context(self, b):
@@ -484,8 +479,8 @@ class test_KeyValueStoreBackend(AppCase):
             self.b.fail_from_current_stack.assert_called()
             args = self.b.fail_from_current_stack.call_args
             exc = args[1]['exc']
-            self.assertIsInstance(exc, ChordError)
-            self.assertIn('foo', str(exc))
+            assert isinstance(exc, ChordError)
+            assert 'foo' in str(exc)
 
     def test_chord_part_return_join_raises_task(self):
         b = KVBackend(serializer='pickle', app=self.app)
@@ -498,8 +493,8 @@ class test_KeyValueStoreBackend(AppCase):
             b.fail_from_current_stack.assert_called()
             args = b.fail_from_current_stack.call_args
             exc = args[1]['exc']
-            self.assertIsInstance(exc, ChordError)
-            self.assertIn('Dependency culprit raised', str(exc))
+            assert isinstance(exc, ChordError)
+            assert 'Dependency culprit raised' in str(exc)
 
     def test_restore_group_from_json(self):
         b = KVBackend(serializer='json', app=self.app)
@@ -509,7 +504,7 @@ class test_KeyValueStoreBackend(AppCase):
         )
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
-        self.assertEqual(g2, g)
+        assert g2 == g
 
     def test_restore_group_from_pickle(self):
         b = KVBackend(serializer='pickle', app=self.app)
@@ -519,7 +514,7 @@ class test_KeyValueStoreBackend(AppCase):
         )
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
-        self.assertEqual(g2, g)
+        assert g2 == g
 
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
@@ -533,8 +528,8 @@ class test_KeyValueStoreBackend(AppCase):
         )
 
     def test_get_missing_meta(self):
-        self.assertIsNone(self.b.get_result('xxx-missing'))
-        self.assertEqual(self.b.get_state('xxx-missing'), states.PENDING)
+        assert self.b.get_result('xxx-missing') is None
+        assert self.b.get_state('xxx-missing') == states.PENDING
 
     def test_save_restore_delete_group(self):
         tid = uuid()
@@ -543,58 +538,58 @@ class test_KeyValueStoreBackend(AppCase):
         )
         self.b.save_group(tid, tsr)
         self.b.restore_group(tid)
-        self.assertEqual(self.b.restore_group(tid), tsr)
+        assert self.b.restore_group(tid) == tsr
         self.b.delete_group(tid)
-        self.assertIsNone(self.b.restore_group(tid))
+        assert self.b.restore_group(tid) is None
 
     def test_restore_missing_group(self):
-        self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
+        assert self.b.restore_group('xxx-nonexistant') is None
 
 
-class test_KeyValueStoreBackend_interface(AppCase):
+class test_KeyValueStoreBackend_interface:
 
     def test_get(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).get('a')
 
     def test_set(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).set('a', 1)
 
     def test_incr(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).incr('a')
 
     def test_cleanup(self):
-        self.assertFalse(KeyValueStoreBackend(self.app).cleanup())
+        assert not KeyValueStoreBackend(self.app).cleanup()
 
     def test_delete(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).delete('a')
 
     def test_mget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).mget(['a'])
 
     def test_forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).forget('a')
 
 
-class test_DisabledBackend(AppCase):
+class test_DisabledBackend:
 
     def test_store_result(self):
         DisabledBackend(self.app).store_result()
 
     def test_is_disabled(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             DisabledBackend(self.app).get_state('foo')
 
     def test_as_uri(self):
-        self.assertEqual(DisabledBackend(self.app).as_uri(), 'disabled://')
+        assert DisabledBackend(self.app).as_uri() == 'disabled://'
 
 
-class test_as_uri(AppCase):
+class test_as_uri:
 
     def setup(self):
         self.b = BaseBackend(
@@ -603,7 +598,7 @@ class test_as_uri(AppCase):
         )
 
     def test_as_uri_include_password(self):
-        self.assertEqual(self.b.as_uri(True), self.b.url)
+        assert self.b.as_uri(True) == self.b.url
 
     def test_as_uri_exclude_password(self):
-        self.assertEqual(self.b.as_uri(), 'sch://uuuu:**@hostname.dom/')
+        assert self.b.as_uri() == 'sch://uuuu:**@hostname.dom/'

+ 35 - 37
celery/tests/backends/test_cache.py → t/unit/backends/test_cache.py

@@ -1,10 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
 import sys
 import types
 
 from contextlib import contextmanager
 
+from case import Mock, mock, patch, skip
 from kombu.utils.encoding import str_to_bytes, ensure_bytes
 
 from celery import states
@@ -13,8 +15,6 @@ from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, bytes_if_py2, string, text_t
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 PY3 = sys.version_info[0] == 3
 
 
@@ -24,7 +24,7 @@ class SomeClass(object):
         self.data = data
 
 
-class test_CacheBackend(AppCase):
+class test_CacheBackend:
 
     def setup(self):
         self.app.conf.result_serializer = 'pickle'
@@ -38,32 +38,32 @@ class test_CacheBackend(AppCase):
 
     def test_no_backend(self):
         self.app.conf.cache_backend = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CacheBackend(backend=None, app=self.app)
 
     def test_mark_as_done(self):
-        self.assertEqual(self.tb.get_state(self.tid), states.PENDING)
-        self.assertIsNone(self.tb.get_result(self.tid))
+        assert self.tb.get_state(self.tid) == states.PENDING
+        assert self.tb.get_result(self.tid) is None
 
         self.tb.mark_as_done(self.tid, 42)
-        self.assertEqual(self.tb.get_state(self.tid), states.SUCCESS)
-        self.assertEqual(self.tb.get_result(self.tid), 42)
+        assert self.tb.get_state(self.tid) == states.SUCCESS
+        assert self.tb.get_result(self.tid) == 42
 
     def test_is_pickled(self):
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
         self.tb.mark_as_done(self.tid, result)
         # is serialized properly.
         rindb = self.tb.get_result(self.tid)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
     def test_mark_as_failure(self):
         try:
             raise KeyError('foo')
         except KeyError as exception:
             self.tb.mark_as_failure(self.tid, exception)
-            self.assertEqual(self.tb.get_state(self.tid), states.FAILURE)
-            self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
+            assert self.tb.get_state(self.tid) == states.FAILURE
+            assert isinstance(self.tb.get_result(self.tid), KeyError)
 
     def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
@@ -99,48 +99,47 @@ class test_CacheBackend(AppCase):
         self.tb.set('foo', 1)
         self.tb.set('bar', 2)
 
-        self.assertDictEqual(self.tb.mget(['foo', 'bar']),
-                             {'foo': 1, 'bar': 2})
+        assert self.tb.mget(['foo', 'bar']) == {'foo': 1, 'bar': 2}
 
     def test_forget(self):
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
         x = self.app.AsyncResult(self.tid, backend=self.tb)
         x.forget()
-        self.assertIsNone(x.result)
+        assert x.result is None
 
     def test_process_cleanup(self):
         self.tb.process_cleanup()
 
     def test_expires_as_int(self):
         tb = CacheBackend(backend='memory://', expires=10, app=self.app)
-        self.assertEqual(tb.expires, 10)
+        assert tb.expires == 10
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CacheBackend(backend='unknown://', app=self.app)
 
     def test_as_uri_no_servers(self):
-        self.assertEqual(self.tb.as_uri(), 'memory:///')
+        assert self.tb.as_uri() == 'memory:///'
 
     def test_as_uri_one_server(self):
         backend = 'memcache://127.0.0.1:11211/'
         b = CacheBackend(backend=backend, app=self.app)
-        self.assertEqual(b.as_uri(), backend)
+        assert b.as_uri() == backend
 
     def test_as_uri_multiple_servers(self):
         backend = 'memcache://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         b = CacheBackend(backend=backend, app=self.app)
-        self.assertEqual(b.as_uri(), backend)
+        assert b.as_uri() == backend
 
-    @mock.stdouts
     @skip.unless_module('memcached', name='python-memcached')
-    def test_regression_worker_startup_info(self, stdout, stderr):
+    def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         )
         worker = self.app.Worker()
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
+        with mock.stdouts():
+            worker.on_start()
+            assert worker.startup_info()
 
 
 class MyMemcachedStringEncodingError(Exception):
@@ -190,15 +189,14 @@ class MockCacheMixin(object):
                 sys.modules['pylibmc'] = prev
 
 
-class test_get_best_memcache(AppCase, MockCacheMixin):
+class test_get_best_memcache(MockCacheMixin):
 
     def test_pylibmc(self):
         with self.mock_pylibmc():
             with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 cache._imp = [None]
-                self.assertEqual(cache.get_best_memcache()[0].__module__,
-                                 'pylibmc')
+                assert cache.get_best_memcache()[0].__module__ == 'pylibmc'
 
     def test_memcache(self):
         with self.mock_memcache():
@@ -206,15 +204,15 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
                 with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     cache._imp = [None]
-                    self.assertEqual(cache.get_best_memcache()[0]().__module__,
-                                     'memcache')
+                    assert (cache.get_best_memcache()[0]().__module__ ==
+                            'memcache')
 
     def test_no_implementations(self):
         with mock.mask_modules('pylibmc', 'memcache'):
             with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 cache._imp = [None]
-                with self.assertRaises(ImproperlyConfigured):
+                with pytest.raises(ImproperlyConfigured):
                     cache.get_best_memcache()
 
     def test_cached(self):
@@ -223,17 +221,17 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 cache._imp = [None]
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
-                self.assertTrue(cache._imp[0])
+                assert cache._imp[0]
                 cache.get_best_memcache()[0]()
 
     def test_backends(self):
         from celery.backends.cache import backends
         with self.mock_memcache():
             for name, fun in items(backends):
-                self.assertTrue(fun())
+                assert fun()
 
 
-class test_memcache_key(AppCase, MockCacheMixin):
+class test_memcache_key(MockCacheMixin):
 
     def test_memcache_unicode_key(self):
         with self.mock_memcache():
@@ -244,7 +242,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     task_id, result = string(uuid()), 42
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, state=states.SUCCESS)
-                    self.assertEqual(b.get_result(task_id), result)
+                    assert b.get_result(task_id) == result
 
     def test_memcache_bytes_key(self):
         with self.mock_memcache():
@@ -255,7 +253,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     task_id, result = str_to_bytes(uuid()), 42
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, state=states.SUCCESS)
-                    self.assertEqual(b.get_result(task_id), result)
+                    assert b.get_result(task_id) == result
 
     def test_pylibmc_unicode_key(self):
         with mock.reset_modules('celery.backends.cache'):
@@ -265,7 +263,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 task_id, result = string(uuid()), 42
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, state=states.SUCCESS)
-                self.assertEqual(b.get_result(task_id), result)
+                assert b.get_result(task_id) == result
 
     def test_pylibmc_bytes_key(self):
         with mock.reset_modules('celery.backends.cache'):
@@ -275,4 +273,4 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 task_id, result = str_to_bytes(uuid()), 42
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, state=states.SUCCESS)
-                self.assertEqual(b.get_result(task_id), result)
+                assert b.get_result(task_id) == result

+ 18 - 15
celery/tests/backends/test_cassandra.py → t/unit/backends/test_cassandra.py

@@ -1,18 +1,21 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from pickle import loads, dumps
 from datetime import datetime
 
+from case import Mock, mock
+
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.objects import Bunch
-from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 
 
 @mock.module(*CASSANDRA_MODULES)
-class test_CassandraBackend(AppCase):
+class test_CassandraBackend:
 
     def setup(self):
         self.app.conf.update(
@@ -27,7 +30,7 @@ class test_CassandraBackend(AppCase):
         from celery.backends import cassandra as mod
         prev, mod.cassandra = mod.cassandra, None
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 mod.CassandraBackend(app=self.app)
         finally:
             mod.cassandra = prev
@@ -48,16 +51,16 @@ class test_CassandraBackend(AppCase):
         mod.CassandraBackend(app=self.app)
 
         # no servers raises ImproperlyConfigured
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             self.app.conf.cassandra_servers = None
             mod.CassandraBackend(
                 app=self.app, keyspace='b', column_family='c',
             )
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self, *modules):
         from celery.backends.cassandra import CassandraBackend
-        self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+        assert loads(dumps(CassandraBackend(app=self.app)))
 
     def test_get_task_meta_for(self, *modules):
         from celery.backends import cassandra as mod
@@ -72,11 +75,11 @@ class test_CassandraBackend(AppCase):
         ]
         x.decode = Mock()
         meta = x._get_task_meta_for('task_id')
-        self.assertEqual(meta['status'], states.SUCCESS)
+        assert meta['status'] == states.SUCCESS
 
         x._session.execute.return_value = []
         meta = x._get_task_meta_for('task_id')
-        self.assertEqual(meta['status'], states.PENDING)
+        assert meta['status'] == states.PENDING
 
     def test_store_result(self, *modules):
         from celery.backends import cassandra as mod
@@ -93,8 +96,8 @@ class test_CassandraBackend(AppCase):
         x = mod.CassandraBackend(app=self.app)
         x.process_cleanup()
 
-        self.assertIsNone(x._connection)
-        self.assertIsNone(x._session)
+        assert x._connection is None
+        assert x._session is None
 
     def test_timeouting_cluster(self):
         # Tests behavior when Cluster.connect raises
@@ -121,10 +124,10 @@ class test_CassandraBackend(AppCase):
 
         x = mod.CassandraBackend(app=self.app)
 
-        with self.assertRaises(OTOExc):
+        with pytest.raises(OTOExc):
             x._store_result('task_id', 'result', states.SUCCESS)
-        self.assertIsNone(x._connection)
-        self.assertIsNone(x._session)
+        assert x._connection is None
+        assert x._session is None
 
         x.process_cleanup()  # shouldn't raise
 
@@ -156,7 +159,7 @@ class test_CassandraBackend(AppCase):
             x._store_result('task_id', 'result', states.SUCCESS)
             x.process_cleanup()
 
-        self.assertEquals(RAMHoggingCluster.objects_alive, 0)
+        assert RAMHoggingCluster.objects_alive == 0
 
     def test_auth_provider(self):
         # Ensure valid auth_provider works properly, and invalid one raises
@@ -181,5 +184,5 @@ class test_CassandraBackend(AppCase):
         self.app.conf.cassandra_auth_kwargs = {
             'username': 'Jack'
         }
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             mod.CassandraBackend(app=self.app)

+ 6 - 5
celery/tests/backends/test_consul.py → t/unit/backends/test_consul.py

@@ -1,25 +1,26 @@
 from __future__ import absolute_import, unicode_literals
 
-from celery.tests.case import AppCase, Mock, skip
+from case import Mock, skip
+
 from celery.backends.consul import ConsulBackend
 
 
 @skip.unless_module('consul')
-class test_ConsulBackend(AppCase):
+class test_ConsulBackend:
 
     def setup(self):
         self.backend = ConsulBackend(
             app=self.app, url='consul://localhost:800')
 
     def test_supports_autoexpire(self):
-        self.assertTrue(self.backend.supports_autoexpire)
+        assert self.backend.supports_autoexpire
 
     def test_consul_consistency(self):
-        self.assertEqual('consistent', self.backend.consistency)
+        assert self.backend.consistency == 'consistent'
 
     def test_get(self):
         index = 100
         data = {'Key': 'test-consul-1', 'Value': 'mypayload'}
         self.backend.client = Mock(name='c.client')
         self.backend.client.kv.get.return_value = (index, data)
-        self.assertEqual(self.backend.get(data['Key']), 'mypayload')
+        assert self.backend.get(data['Key']) == 'mypayload'

+ 24 - 22
celery/tests/backends/test_couchbase.py → t/unit/backends/test_couchbase.py

@@ -1,14 +1,16 @@
 """Tests for the CouchbaseBackend."""
-
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from kombu.utils.encoding import str_t
 
+from case import MagicMock, Mock, patch, sentinel, skip
+
 from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchbaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
-from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 try:
     import couchbase
@@ -19,7 +21,7 @@ COUCHBASE_BUCKET = 'celery_bucket'
 
 
 @skip.unless_module('couchbase')
-class test_CouchbaseBackend(AppCase):
+class test_CouchbaseBackend:
 
     def setup(self):
         self.backend = CouchbaseBackend(app=self.app)
@@ -27,14 +29,14 @@ class test_CouchbaseBackend(AppCase):
     def test_init_no_couchbase(self):
         prev, module.Couchbase = module.Couchbase, None
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 CouchbaseBackend(app=self.app)
         finally:
             module.Couchbase = prev
 
     def test_init_no_settings(self):
         self.app.conf.couchbase_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CouchbaseBackend(app=self.app)
 
     def test_init_settings_is_None(self):
@@ -47,7 +49,7 @@ class test_CouchbaseBackend(AppCase):
 
             connection = self.backend._get_connection()
 
-            self.assertEqual(sentinel._connection, connection)
+            assert sentinel._connection == connection
             mock_Connection.assert_not_called()
 
     def test_get(self):
@@ -57,7 +59,7 @@ class test_CouchbaseBackend(AppCase):
         mocked_get = x._connection.get = Mock()
         mocked_get.return_value.value = sentinel.retval
         # should return None
-        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        assert x.get('1f3fab') == sentinel.retval
         x._connection.get.assert_called_once_with('1f3fab')
 
     def test_set(self):
@@ -66,7 +68,7 @@ class test_CouchbaseBackend(AppCase):
         x._connection = MagicMock()
         x._connection.set = MagicMock()
         # should return None
-        self.assertIsNone(x.set(sentinel.key, sentinel.value))
+        assert x.set(sentinel.key, sentinel.value) is None
 
     def test_delete(self):
         self.app.conf.couchbase_backend_settings = {}
@@ -75,7 +77,7 @@ class test_CouchbaseBackend(AppCase):
         mocked_delete = x._connection.delete = Mock()
         mocked_delete.return_value = None
         # should return None
-        self.assertIsNone(x.delete('1f3fab'))
+        assert x.delete('1f3fab') is None
         x._connection.delete.assert_called_once_with('1f3fab')
 
     def test_config_params(self):
@@ -87,27 +89,27 @@ class test_CouchbaseBackend(AppCase):
             'port': '1234',
         }
         x = CouchbaseBackend(app=self.app)
-        self.assertEqual(x.bucket, 'mycoolbucket')
-        self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
-        self.assertEqual(x.username, 'johndoe',)
-        self.assertEqual(x.password, 'mysecret')
-        self.assertEqual(x.port, 1234)
+        assert x.bucket == 'mycoolbucket'
+        assert x.host == ['here.host.com', 'there.host.com']
+        assert x.username == 'johndoe'
+        assert x.password == 'mysecret'
+        assert x.port == 1234
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
         from celery.backends.couchbase import CouchbaseBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, CouchbaseBackend)
-        self.assertEqual(url_, url)
+        assert backend is CouchbaseBackend
+        assert url_ == url
 
     def test_backend_params_by_url(self):
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         with self.Celery(backend=url) as app:
             x = app.backend
-            self.assertEqual(x.bucket, 'mycoolbucket')
-            self.assertEqual(x.host, 'myhost')
-            self.assertEqual(x.username, 'johndoe')
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 123)
+            assert x.bucket == 'mycoolbucket'
+            assert x.host == 'myhost'
+            assert x.username == 'johndoe'
+            assert x.password == 'mysecret'
+            assert x.port == 123
 
     def test_correct_key_types(self):
         keys = [
@@ -119,4 +121,4 @@ class test_CouchbaseBackend(AppCase):
             self.backend.get_key_for_group('group_id', 'key'),
         ]
         for key in keys:
-            self.assertIsInstance(key, str_t)
+            assert isinstance(key, str_t)

+ 20 - 20
celery/tests/backends/test_couchdb.py → t/unit/backends/test_couchdb.py

@@ -1,10 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import MagicMock, Mock, sentinel, skip
+
 from celery.backends import couchdb as module
 from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
-from celery.tests.case import AppCase, Mock, patch, sentinel, skip
 
 try:
     import pycouchdb
@@ -15,28 +18,26 @@ COUCHDB_CONTAINER = 'celery_container'
 
 
 @skip.unless_module('pycouchdb')
-class test_CouchBackend(AppCase):
+class test_CouchBackend:
 
     def setup(self):
+        self.Server = self.patching('pycouchdb.Server')
         self.backend = CouchBackend(app=self.app)
 
     def test_init_no_pycouchdb(self):
         """test init no pycouchdb raises"""
         prev, module.pycouchdb = module.pycouchdb, None
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 CouchBackend(app=self.app)
         finally:
             module.pycouchdb = prev
 
     def test_get_container_exists(self):
-        with patch('pycouchdb.client.Database') as mock_Connection:
             self.backend._connection = sentinel._connection
-
-            connection = self.backend._get_connection()
-
-            self.assertEqual(sentinel._connection, connection)
-            mock_Connection.assert_not_called()
+            connection = self.backend.connection
+            assert connection is sentinel._connection
+            self.Server.assert_not_called()
 
     def test_get(self):
         """test_get
@@ -48,10 +49,9 @@ class test_CouchBackend(AppCase):
         """
         x = CouchBackend(app=self.app)
         x._connection = Mock()
-        mocked_get = x._connection.get = Mock()
-        mocked_get.return_value = sentinel.retval
+        get = x._connection.get = MagicMock()
         # should return None
-        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        assert x.get('1f3fab') == get.return_value['value']
         x._connection.get.assert_called_once_with('1f3fab')
 
     def test_delete(self):
@@ -67,21 +67,21 @@ class test_CouchBackend(AppCase):
         mocked_delete = x._connection.delete = Mock()
         mocked_delete.return_value = None
         # should return None
-        self.assertIsNone(x.delete('1f3fab'))
+        assert x.delete('1f3fab') is None
         x._connection.delete.assert_called_once_with('1f3fab')
 
     def test_backend_by_url(self, url='couchdb://myhost/mycoolcontainer'):
         from celery.backends.couchdb import CouchBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, CouchBackend)
-        self.assertEqual(url_, url)
+        assert backend is CouchBackend
+        assert url_ == url
 
     def test_backend_params_by_url(self):
         url = 'couchdb://johndoe:mysecret@myhost:123/mycoolcontainer'
         with self.Celery(backend=url) as app:
             x = app.backend
-            self.assertEqual(x.container, 'mycoolcontainer')
-            self.assertEqual(x.host, 'myhost')
-            self.assertEqual(x.username, 'johndoe')
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 123)
+            assert x.container == 'mycoolcontainer'
+            assert x.host == 'myhost'
+            assert x.username == 'johndoe'
+            assert x.password == 'mysecret'
+            assert x.port == 123

+ 47 - 49
celery/tests/backends/test_database.py → t/unit/backends/test_database.py

@@ -1,17 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
-from datetime import datetime
+import pytest
 
+from datetime import datetime
 from pickle import loads, dumps
 
+from case import Mock, patch, skip
+
 from celery import states
 from celery import uuid
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, patch, skip,
-)
-
 try:
     import sqlalchemy  # noqa
 except ImportError:
@@ -33,7 +32,7 @@ class SomeClass(object):
 
 
 @skip.unless_module('sqlalchemy')
-class test_session_cleanup(AppCase):
+class test_session_cleanup:
 
     def test_context(self):
         session = Mock(name='session')
@@ -43,7 +42,7 @@ class test_session_cleanup(AppCase):
 
     def test_context_raises(self):
         session = Mock(name='session')
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             with session_cleanup(session):
                 raise KeyError()
         session.rollback.assert_called_with()
@@ -53,7 +52,7 @@ class test_session_cleanup(AppCase):
 @skip.unless_module('sqlalchemy')
 @skip.if_pypy()
 @skip.if_jython()
-class test_DatabaseBackend(AppCase):
+class test_DatabaseBackend:
 
     def setup(self):
         self.uri = 'sqlite:///test.db'
@@ -69,39 +68,38 @@ class test_DatabaseBackend(AppCase):
             calls[0] += 1
             raise DatabaseError(1, 2, 3)
 
-        with self.assertRaises(DatabaseError):
+        with pytest.raises(DatabaseError):
             raises(max_retries=5)
-        self.assertEqual(calls[0], 5)
+        assert calls[0] == 5
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
         self.app.conf.sqlalchemy_dburi = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             DatabaseBackend(app=self.app)
 
     def test_missing_task_id_is_PENDING(self):
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertEqual(tb.get_state('xxx-does-not-exist'), states.PENDING)
+        assert tb.get_state('xxx-does-not-exist') == states.PENDING
 
     def test_missing_task_meta_is_dict_with_pending(self):
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertDictContainsSubset({
-            'status': states.PENDING,
-            'task_id': 'xxx-does-not-exist-at-all',
-            'result': None,
-            'traceback': None
-        }, tb.get_task_meta('xxx-does-not-exist-at-all'))
+        meta = tb.get_task_meta('xxx-does-not-exist-at-all')
+        assert meta['status'] == states.PENDING
+        assert meta['task_id'] == 'xxx-does-not-exist-at-all'
+        assert meta['result'] is None
+        assert meta['traceback'] is None
 
     def test_mark_as_done(self):
         tb = DatabaseBackend(self.uri, app=self.app)
 
         tid = uuid()
 
-        self.assertEqual(tb.get_state(tid), states.PENDING)
-        self.assertIsNone(tb.get_result(tid))
+        assert tb.get_state(tid) == states.PENDING
+        assert tb.get_result(tid) is None
 
         tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_state(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
+        assert tb.get_state(tid) == states.SUCCESS
+        assert tb.get_result(tid) == 42
 
     def test_is_pickled(self):
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -111,20 +109,20 @@ class test_DatabaseBackend(AppCase):
         tb.mark_as_done(tid2, result)
         # is serialized properly.
         rindb = tb.get_result(tid2)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
     def test_mark_as_started(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tb.mark_as_started(tid)
-        self.assertEqual(tb.get_state(tid), states.STARTED)
+        assert tb.get_state(tid) == states.STARTED
 
     def test_mark_as_revoked(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tb.mark_as_revoked(tid)
-        self.assertEqual(tb.get_state(tid), states.REVOKED)
+        assert tb.get_state(tid) == states.REVOKED
 
     def test_mark_as_retry(self):
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -135,9 +133,9 @@ class test_DatabaseBackend(AppCase):
             import traceback
             trace = '\n'.join(traceback.format_stack())
             tb.mark_as_retry(tid, exception, traceback=trace)
-            self.assertEqual(tb.get_state(tid), states.RETRY)
-            self.assertIsInstance(tb.get_result(tid), KeyError)
-            self.assertEqual(tb.get_traceback(tid), trace)
+            assert tb.get_state(tid) == states.RETRY
+            assert isinstance(tb.get_result(tid), KeyError)
+            assert tb.get_traceback(tid) == trace
 
     def test_mark_as_failure(self):
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -149,9 +147,9 @@ class test_DatabaseBackend(AppCase):
             import traceback
             trace = '\n'.join(traceback.format_stack())
             tb.mark_as_failure(tid3, exception, traceback=trace)
-            self.assertEqual(tb.get_state(tid3), states.FAILURE)
-            self.assertIsInstance(tb.get_result(tid3), KeyError)
-            self.assertEqual(tb.get_traceback(tid3), trace)
+            assert tb.get_state(tid3) == states.FAILURE
+            assert isinstance(tb.get_result(tid3), KeyError)
+            assert tb.get_traceback(tid3) == trace
 
     def test_forget(self):
         tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
@@ -160,31 +158,31 @@ class test_DatabaseBackend(AppCase):
         tb.mark_as_done(tid, {'foo': 'bar'})
         x = self.app.AsyncResult(tid, backend=tb)
         x.forget()
-        self.assertIsNone(x.result)
+        assert x.result is None
 
     def test_process_cleanup(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb.process_cleanup()
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self):
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertTrue(loads(dumps(tb)))
+        assert loads(dumps(tb))
 
     def test_save__restore__delete_group(self):
         tb = DatabaseBackend(self.uri, app=self.app)
 
         tid = uuid()
         res = {'something': 'special'}
-        self.assertEqual(tb.save_group(tid, res), res)
+        assert tb.save_group(tid, res) == res
 
         res2 = tb.restore_group(tid)
-        self.assertEqual(res2, res)
+        assert res2 == res
 
         tb.delete_group(tid)
-        self.assertIsNone(tb.restore_group(tid))
+        assert tb.restore_group(tid) is None
 
-        self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
+        assert tb.restore_group('xxx-nonexisting-id') is None
 
     def test_cleanup(self):
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -202,20 +200,20 @@ class test_DatabaseBackend(AppCase):
         tb.cleanup()
 
     def test_Task__repr__(self):
-        self.assertIn('foo', repr(Task('foo')))
+        assert 'foo' in repr(Task('foo'))
 
     def test_TaskSet__repr__(self):
-        self.assertIn('foo', repr(TaskSet('foo', None)))
+        assert 'foo', repr(TaskSet('foo' in None))
 
 
 @skip.unless_module('sqlalchemy')
-class test_SessionManager(AppCase):
+class test_SessionManager:
 
     def test_after_fork(self):
         s = SessionManager()
-        self.assertFalse(s.forked)
+        assert not s.forked
         s._after_fork()
-        self.assertTrue(s.forked)
+        assert s.forked
 
     @patch('celery.backends.database.session.create_engine')
     def test_get_engine_forked(self, create_engine):
@@ -223,9 +221,9 @@ class test_SessionManager(AppCase):
         s._after_fork()
         engine = s.get_engine('dburi', foo=1)
         create_engine.assert_called_with('dburi', foo=1)
-        self.assertIs(engine, create_engine())
+        assert engine is create_engine()
         engine2 = s.get_engine('dburi', foo=1)
-        self.assertIs(engine2, engine)
+        assert engine2 is engine
 
     @patch('celery.backends.database.session.sessionmaker')
     def test_create_session_forked(self, sessionmaker):
@@ -234,16 +232,16 @@ class test_SessionManager(AppCase):
         s._after_fork()
         engine, session = s.create_session('dburi', short_lived_sessions=True)
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIs(session, sessionmaker())
+        assert session is sessionmaker()
         sessionmaker.return_value = Mock(name='new')
         engine, session2 = s.create_session('dburi', short_lived_sessions=True)
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIsNot(session2, session)
+        assert session2 is not session
         sessionmaker.return_value = Mock(name='new2')
         engine, session3 = s.create_session(
             'dburi', short_lived_sessions=False)
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIs(session3, session2)
+        assert session3 is session2
 
     def test_coverage_madness(self):
         prev, session.register_after_fork = (

+ 16 - 14
celery/tests/backends/test_elasticsearch.py → t/unit/backends/test_elasticsearch.py

@@ -1,15 +1,17 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import Mock, sentinel, skip
+
 from celery import backends
 from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.case import AppCase, Mock, sentinel, skip
-
 
 @skip.unless_module('elasticsearch')
-class test_ElasticsearchBackend(AppCase):
+class test_ElasticsearchBackend:
 
     def setup(self):
         self.backend = ElasticsearchBackend(app=self.app)
@@ -17,7 +19,7 @@ class test_ElasticsearchBackend(AppCase):
     def test_init_no_elasticsearch(self):
         prev, module.elasticsearch = module.elasticsearch, None
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 ElasticsearchBackend(app=self.app)
         finally:
             module.elasticsearch = prev
@@ -31,7 +33,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.get.return_value = r
         dict_result = x.get(sentinel.task_id)
 
-        self.assertEqual(dict_result, sentinel.result)
+        assert dict_result == sentinel.result
         x._server.get.assert_called_once_with(
             doc_type=x.doc_type,
             id=sentinel.task_id,
@@ -45,7 +47,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.get.return_value = sentinel.result
         none_result = x.get(sentinel.task_id)
 
-        self.assertEqual(none_result, None)
+        assert none_result is None
         x._server.get.assert_called_once_with(
             doc_type=x.doc_type,
             id=sentinel.task_id,
@@ -58,7 +60,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.delete = Mock()
         x._server.delete.return_value = sentinel.result
 
-        self.assertIsNone(x.delete(sentinel.task_id), sentinel.result)
+        assert x.delete(sentinel.task_id) is None
         x._server.delete.assert_called_once_with(
             doc_type=x.doc_type,
             id=sentinel.task_id,
@@ -68,16 +70,16 @@ class test_ElasticsearchBackend(AppCase):
     def test_backend_by_url(self, url='elasticsearch://localhost:9200/index'):
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
 
-        self.assertIs(backend, ElasticsearchBackend)
-        self.assertEqual(url_, url)
+        assert backend is ElasticsearchBackend
+        assert url_ == url
 
     def test_backend_params_by_url(self):
         url = 'elasticsearch://localhost:9200/index/doc_type'
         with self.Celery(backend=url) as app:
             x = app.backend
 
-            self.assertEqual(x.index, 'index')
-            self.assertEqual(x.doc_type, 'doc_type')
-            self.assertEqual(x.scheme, 'elasticsearch')
-            self.assertEqual(x.host, 'localhost')
-            self.assertEqual(x.port, 9200)
+            assert x.index == 'index'
+            assert x.doc_type == 'doc_type'
+            assert x.scheme == 'elasticsearch'
+            assert x.host == 'localhost'
+            assert x.port == 9200

+ 13 - 16
celery/tests/backends/test_filesystem.py → t/unit/backends/test_filesystem.py

@@ -2,54 +2,51 @@
 from __future__ import absolute_import, unicode_literals
 
 import os
-import shutil
+import pytest
 import tempfile
 
+from case import skip
+
 from celery import uuid
 from celery import states
 from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.case import AppCase, skip
-
 
 @skip.if_win32()
-class test_FilesystemBackend(AppCase):
+class test_FilesystemBackend:
 
     def setup(self):
         self.directory = tempfile.mkdtemp()
         self.url = 'file://' + self.directory
         self.path = self.directory.encode('ascii')
 
-    def teardown(self):
-        shutil.rmtree(self.directory)
-
     def test_a_path_is_required(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             FilesystemBackend(app=self.app)
 
     def test_a_path_in_url(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
-        self.assertEqual(tb.path, self.path)
+        assert tb.path == self.path
 
     def test_path_is_incorrect(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             FilesystemBackend(app=self.app, url=self.url + '-incorrect')
 
     def test_missing_task_is_PENDING(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
-        self.assertEqual(tb.get_state('xxx-does-not-exist'), states.PENDING)
+        assert tb.get_state('xxx-does-not-exist') == states.PENDING
 
     def test_mark_as_done_writes_file(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb.mark_as_done(uuid(), 42)
-        self.assertEqual(len(os.listdir(self.directory)), 1)
+        assert len(os.listdir(self.directory)) == 1
 
     def test_done_task_is_SUCCESS(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_state(tid), states.SUCCESS)
+        assert tb.get_state(tid) == states.SUCCESS
 
     def test_correct_result(self):
         data = {'foo': 'bar'}
@@ -57,7 +54,7 @@ class test_FilesystemBackend(AppCase):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tb.mark_as_done(tid, data)
-        self.assertEqual(tb.get_result(tid), data)
+        assert tb.get_result(tid) == data
 
     def test_get_many(self):
         data = {uuid(): 'foo', uuid(): 'bar', uuid(): 'baz'}
@@ -67,11 +64,11 @@ class test_FilesystemBackend(AppCase):
             tb.mark_as_done(key, value)
 
         for key, result in tb.get_many(data.keys()):
-            self.assertEqual(result['result'], data[key])
+            assert result['result'] == data[key]
 
     def test_forget_deletes_file(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tb.mark_as_done(tid, 42)
         tb.forget(tid)
-        self.assertEqual(len(os.listdir(self.directory)), 0)
+        assert len(os.listdir(self.directory)) == 0

+ 80 - 92
celery/tests/backends/test_mongodb.py → t/unit/backends/test_mongodb.py

@@ -1,20 +1,18 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 import datetime
 
 from pickle import loads, dumps
 
+from case import ANY, MagicMock, Mock, mock, patch, sentinel, skip
 from kombu.exceptions import EncodeError
 
 from celery import uuid
 from celery import states
-from celery.backends import mongodb as module
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import (
-    ANY, AppCase, MagicMock, Mock,
-    mock, depends_on_current_app, patch, sentinel, skip,
-)
 
 COLLECTION = 'taskmeta_celery'
 TASK_ID = uuid()
@@ -28,7 +26,7 @@ MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
 @skip.unless_module('pymongo')
-class test_MongoBackend(AppCase):
+class test_MongoBackend:
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
     replica_set_url = (
@@ -42,31 +40,20 @@ class test_MongoBackend(AppCase):
     )
 
     def setup(self):
-        R = self._reset = {}
-        R['encode'], MongoBackend.encode = MongoBackend.encode, Mock()
-        R['decode'], MongoBackend.decode = MongoBackend.decode, Mock()
-        R['Binary'], module.Binary = module.Binary, Mock()
-        R['datetime'], datetime.datetime = datetime.datetime, Mock()
-
+        self.patching('celery.backends.mongodb.MongoBackend.encode')
+        self.patching('celery.backends.mongodb.MongoBackend.decode')
+        self.patching('celery.backends.mongodb.Binary')
+        self.patching('datetime.datetime')
         self.backend = MongoBackend(app=self.app, url=self.default_url)
 
-    def teardown(self):
-        MongoBackend.encode = self._reset['encode']
-        MongoBackend.decode = self._reset['decode']
-        module.Binary = self._reset['Binary']
-        datetime.datetime = self._reset['datetime']
-
-    def test_init_no_mongodb(self):
-        prev, module.pymongo = module.pymongo, None
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                MongoBackend(app=self.app)
-        finally:
-            module.pymongo = prev
+    def test_init_no_mongodb(self, patching):
+        patching('celery.backends.mongodb.pymongo', None)
+        with pytest.raises(ImproperlyConfigured):
+            MongoBackend(app=self.app)
 
     def test_init_no_settings(self):
         self.app.conf.mongodb_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             MongoBackend(app=self.app)
 
     def test_init_settings_is_None(self):
@@ -81,14 +68,14 @@ class test_MongoBackend(AppCase):
         # uri
         uri = 'mongodb://localhost:27017'
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['localhost:27017'])
-        self.assertEqual(mb.options, mb._prepare_client_options())
-        self.assertEqual(mb.database_name, 'celery')
+        assert mb.mongo_host == ['localhost:27017']
+        assert mb.options == mb._prepare_client_options()
+        assert mb.database_name == 'celery'
 
         # uri with database name
         uri = 'mongodb://localhost:27017/celerydb'
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.database_name, 'celerydb')
+        assert mb.database_name == 'celerydb'
 
         # uri with user, password, database name, replica set
         uri = ('mongodb://'
@@ -98,15 +85,18 @@ class test_MongoBackend(AppCase):
                'mongo3.example.com:27017/'
                'celerydatabase?replicaSet=rs0')
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['mongo1.example.com:27017',
-                                         'mongo2.example.com:27017',
-                                         'mongo3.example.com:27017'])
-        self.assertEqual(
-            mb.options, dict(mb._prepare_client_options(), replicaset='rs0'),
+        assert mb.mongo_host == [
+            'mongo1.example.com:27017',
+            'mongo2.example.com:27017',
+            'mongo3.example.com:27017',
+        ]
+        assert mb.options == dict(
+            mb._prepare_client_options(),
+            replicaset='rs0',
         )
-        self.assertEqual(mb.user, 'celeryuser')
-        self.assertEqual(mb.password, 'celerypassword')
-        self.assertEqual(mb.database_name, 'celerydatabase')
+        assert mb.user == 'celeryuser'
+        assert mb.password == 'celerypassword'
+        assert mb.database_name == 'celerydatabase'
 
         # same uri, change some parameters in backend settings
         self.app.conf.mongodb_backend_settings = {
@@ -118,23 +108,26 @@ class test_MongoBackend(AppCase):
             },
         }
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['mongo1.example.com:27017',
-                                         'mongo2.example.com:27017',
-                                         'mongo3.example.com:27017'])
-        self.assertEqual(
-            mb.options, dict(mb._prepare_client_options(),
-                             replicaset='rs1', socketKeepAlive=True),
+        assert mb.mongo_host == [
+            'mongo1.example.com:27017',
+            'mongo2.example.com:27017',
+            'mongo3.example.com:27017',
+        ]
+        assert mb.options == dict(
+            mb._prepare_client_options(),
+            replicaset='rs1',
+            socketKeepAlive=True,
         )
-        self.assertEqual(mb.user, 'backenduser')
-        self.assertEqual(mb.password, 'celerypassword')
-        self.assertEqual(mb.database_name, 'another_db')
+        assert mb.user == 'backenduser'
+        assert mb.password == 'celerypassword'
+        assert mb.database_name == 'another_db'
 
         mb = MongoBackend(app=self.app, url='mongodb://')
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self):
         x = MongoBackend(app=self.app)
-        self.assertTrue(loads(dumps(x)))
+        assert loads(dumps(x))
 
     def test_get_connection_connection_exists(self):
         with patch('pymongo.MongoClient') as mock_Connection:
@@ -142,7 +135,7 @@ class test_MongoBackend(AppCase):
 
             connection = self.backend._get_connection()
 
-            self.assertEqual(sentinel._connection, connection)
+            assert sentinel._connection == connection
             mock_Connection.assert_not_called()
 
     def test_get_connection_no_connection_host(self):
@@ -157,7 +150,7 @@ class test_MongoBackend(AppCase):
                 host='mongodb://localhost:27017',
                 **self.backend._prepare_client_options()
             )
-            self.assertEqual(sentinel.connection, connection)
+            assert sentinel.connection == connection
 
     def test_get_connection_no_connection_mongodb_uri(self):
         with patch('pymongo.MongoClient') as mock_Connection:
@@ -171,7 +164,7 @@ class test_MongoBackend(AppCase):
             mock_Connection.assert_called_once_with(
                 host=mongodb_uri, **self.backend._prepare_client_options()
             )
-            self.assertEqual(sentinel.connection, connection)
+            assert sentinel.connection == connection
 
     @patch('celery.backends.mongodb.MongoBackend._get_connection')
     def test_get_database_no_existing(self, mock_get_connection):
@@ -186,8 +179,8 @@ class test_MongoBackend(AppCase):
 
         database = self.backend.database
 
-        self.assertTrue(database is mock_database)
-        self.assertTrue(self.backend.__dict__['database'] is mock_database)
+        assert database is mock_database
+        assert self.backend.__dict__['database'] is mock_database
         mock_database.authenticate.assert_called_once_with(
             MONGODB_USER, MONGODB_PASSWORD)
 
@@ -204,9 +197,9 @@ class test_MongoBackend(AppCase):
 
         database = self.backend.database
 
-        self.assertTrue(database is mock_database)
+        assert database is mock_database
         mock_database.authenticate.assert_not_called()
-        self.assertTrue(self.backend.__dict__['database'] is mock_database)
+        assert self.backend.__dict__['database'] is mock_database
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_store_result(self, mock_get_database):
@@ -224,16 +217,15 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         mock_collection.save.assert_called_once_with(ANY)
-        self.assertEqual(sentinel.result, ret_val)
+        assert sentinel.result == ret_val
 
         mock_collection.save.side_effect = InvalidDocument()
-        with self.assertRaises(EncodeError):
+        with pytest.raises(EncodeError):
             self.backend._store_result(
                 sentinel.task_id, sentinel.result, sentinel.status)
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_get_task_meta_for(self, mock_get_database):
-        datetime.datetime = self._reset['datetime']
         self.backend.taskmeta_collection = MONGODB_COLLECTION
 
         mock_database = MagicMock(spec=['__getitem__', '__setitem__'])
@@ -247,11 +239,10 @@ class test_MongoBackend(AppCase):
 
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
-        self.assertEqual(
-            list(sorted(['status', 'task_id', 'date_done', 'traceback',
-                         'result', 'children'])),
-            list(sorted(ret_val.keys())),
-        )
+        assert list(sorted([
+            'status', 'task_id', 'date_done',
+            'traceback', 'result', 'children',
+        ])) == list(sorted(ret_val.keys()))
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_get_task_meta_for_no_result(self, mock_get_database):
@@ -268,7 +259,7 @@ class test_MongoBackend(AppCase):
 
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
-        self.assertEqual({'status': states.PENDING, 'result': None}, ret_val)
+        assert {'status': states.PENDING, 'result': None} == ret_val
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_save_group(self, mock_get_database):
@@ -288,7 +279,7 @@ class test_MongoBackend(AppCase):
             MONGODB_GROUP_COLLECTION,
         )
         mock_collection.save.assert_called_once_with(ANY)
-        self.assertEqual(res, ret_val)
+        assert res == ret_val
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_restore_group(self, mock_get_database):
@@ -311,10 +302,8 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_collection.find_one.assert_called_once_with(
             {'_id': sentinel.taskset_id})
-        self.assertItemsEqual(
-            ['date_done', 'result', 'task_id'],
-            list(ret_val.keys()),
-        )
+        assert (sorted(['date_done', 'result', 'task_id']) ==
+                sorted(list(ret_val.keys())))
 
         mock_collection.find_one.return_value = None
         self.backend._restore_group(sentinel.taskset_id)
@@ -355,7 +344,6 @@ class test_MongoBackend(AppCase):
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_cleanup(self, mock_get_database):
-        datetime.datetime = self._reset['datetime']
         self.backend.taskmeta_collection = MONGODB_COLLECTION
         self.backend.groupmeta_collection = MONGODB_GROUP_COLLECTION
 
@@ -381,56 +369,56 @@ class test_MongoBackend(AppCase):
         db.authenticate.return_value = False
         x.user = 'jerry'
         x.password = 'cere4l'
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             x._get_database()
         db.authenticate.assert_called_with('jerry', 'cere4l')
 
     def test_prepare_client_options(self):
         with patch('pymongo.version_tuple', new=(3, 0, 3)):
             options = self.backend._prepare_client_options()
-            self.assertDictEqual(options, {
+            assert options == {
                 'maxPoolSize': self.backend.max_pool_size
-            })
+            }
 
     def test_as_uri_include_password(self):
-        self.assertEqual(self.backend.as_uri(True), self.default_url)
+        assert self.backend.as_uri(True) == self.default_url
 
     def test_as_uri_exclude_password(self):
-        self.assertEqual(self.backend.as_uri(), self.sanitized_default_url)
+        assert self.backend.as_uri() == self.sanitized_default_url
 
     def test_as_uri_include_password_replica_set(self):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
-        self.assertEqual(backend.as_uri(True), self.replica_set_url)
+        assert backend.as_uri(True) == self.replica_set_url
 
     def test_as_uri_exclude_password_replica_set(self):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
-        self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
+        assert backend.as_uri() == self.sanitized_replica_set_url
 
-    @mock.stdouts
-    def test_regression_worker_startup_info(self, stdout, stderr):
+    def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             '/work4us?replicaSet=rs&ssl=true'
         )
         worker = self.app.Worker()
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
+        with mock.stdouts():
+            worker.on_start()
+            assert worker.startup_info()
 
 
 @skip.unless_module('pymongo')
-class test_MongoBackend_no_mock(AppCase):
+class test_MongoBackend_no_mock:
 
-    def test_encode_decode(self):
-        backend = MongoBackend(app=self.app)
+    def test_encode_decode(self, app):
+        backend = MongoBackend(app=app)
         data = {'foo': 1}
-        self.assertTrue(backend.decode(backend.encode(data)))
+        assert backend.decode(backend.encode(data))
         backend.serializer = 'bson'
-        self.assertEquals(backend.encode(data), data)
-        self.assertEquals(backend.decode(data), data)
+        assert backend.encode(data) == data
+        assert backend.decode(data) == data
 
-    def test_de(self):
-        backend = MongoBackend(app=self.app)
+    def test_de(self, app):
+        backend = MongoBackend(app=app)
         data = {'foo': 1}
-        self.assertTrue(backend.encode(data))
+        assert backend.encode(data)
         backend.serializer = 'bson'
-        self.assertEquals(backend.encode(data), data)
+        assert backend.encode(data) == data

+ 50 - 61
celery/tests/backends/test_redis.py → t/unit/backends/test_redis.py

@@ -1,21 +1,22 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from datetime import timedelta
 
 from contextlib import contextmanager
 from pickle import loads, dumps
 
+from case import ANY, ContextMock, Mock, mock, call, patch, skip
+
 from celery import signature
 from celery import states
 from celery import uuid
 from celery.canvas import Signature
-from celery.exceptions import ChordError, ImproperlyConfigured
-from celery.utils.collections import AttributeDict
-
-from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, mock,
-    call, depends_on_current_app, patch, skip,
+from celery.exceptions import (
+    ChordError, CPendingDeprecationWarning, ImproperlyConfigured,
 )
+from celery.utils.collections import AttributeDict
 
 
 def raise_on_second_call(mock, exc, *retval):
@@ -122,7 +123,7 @@ class redis(object):
             pass
 
 
-class test_RedisBackend(AppCase):
+class test_RedisBackend:
 
     def get_backend(self):
         from celery.backends.redis import RedisBackend
@@ -141,54 +142,52 @@ class test_RedisBackend(AppCase):
         self.E_LOST = self.get_E_LOST()
         self.b = self.Backend(app=self.app)
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     @skip.unless_module('redis')
     def test_reduce(self):
         from celery.backends.redis import RedisBackend
         x = RedisBackend(app=self.app)
-        self.assertTrue(loads(dumps(x)))
+        assert loads(dumps(x))
 
     def test_no_redis(self):
         self.Backend.redis = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             self.Backend(app=self.app)
 
     def test_url(self):
         x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
         )
-        self.assertTrue(x.connparams)
-        self.assertEqual(x.connparams['host'], 'vandelay.com')
-        self.assertEqual(x.connparams['db'], 1)
-        self.assertEqual(x.connparams['port'], 123)
-        self.assertEqual(x.connparams['password'], 'bosco')
+        assert x.connparams
+        assert x.connparams['host'] == 'vandelay.com'
+        assert x.connparams['db'] == 1
+        assert x.connparams['port'] == 123
+        assert x.connparams['password'] == 'bosco'
 
     def test_socket_url(self):
         x = self.Backend(
             'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
         )
-        self.assertTrue(x.connparams)
-        self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
-        self.assertIs(
-            x.connparams['connection_class'],
-            redis.UnixDomainSocketConnection,
-        )
-        self.assertNotIn('host', x.connparams)
-        self.assertNotIn('port', x.connparams)
-        self.assertEqual(x.connparams['db'], 3)
+        assert x.connparams
+        assert x.connparams['path'] == '/tmp/redis.sock'
+        assert (x.connparams['connection_class'] is
+                redis.UnixDomainSocketConnection)
+        assert 'host' not in x.connparams
+        assert 'port' not in x.connparams
+        assert x.connparams['db'] == 3
 
     def test_compat_propertie(self):
         x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
         )
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.host, 'vandelay.com')
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.db, 1)
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.port, 123)
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.password, 'bosco')
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.host == 'vandelay.com'
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.db == 1
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.port == 123
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.password == 'bosco'
 
     def test_conf_raises_KeyError(self):
         self.app.conf = AttributeDict({
@@ -203,17 +202,11 @@ class test_RedisBackend(AppCase):
     def test_on_connection_error(self, error):
         intervals = iter([10, 20, 30])
         exc = KeyError()
-        self.assertEqual(
-            self.b.on_connection_error(None, exc, intervals, 1), 10,
-        )
+        assert self.b.on_connection_error(None, exc, intervals, 1) == 10
         error.assert_called_with(self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
-        self.assertEqual(
-            self.b.on_connection_error(10, exc, intervals, 2), 20,
-        )
+        assert self.b.on_connection_error(10, exc, intervals, 2) == 20
         error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
-        self.assertEqual(
-            self.b.on_connection_error(10, exc, intervals, 3), 30,
-        )
+        assert self.b.on_connection_error(10, exc, intervals, 3) == 30
         error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
 
     def test_incr(self):
@@ -229,7 +222,6 @@ class test_RedisBackend(AppCase):
     def test_apply_chord(self):
         header = Mock(name='header')
         header.results = [Mock(name='t1'), Mock(name='t2')]
-        print(self.b.apply_chord,)
         self.b.apply_chord(
             header, (1, 2), 'gid', None,
             options={'max_retries': 10},
@@ -241,7 +233,7 @@ class test_RedisBackend(AppCase):
         decode = Mock(name='decode')
         exc = KeyError()
         tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
-        with self.assertRaises(ChordError):
+        with pytest.raises(ChordError):
             self.b._unpack_chord_result(tup, decode)
         decode.assert_called_with(tup)
         self.b.exception_to_python.assert_called_with(exc)
@@ -250,27 +242,27 @@ class test_RedisBackend(AppCase):
         tup = decode.return_value = (2, 'id2', states.RETRY, exc)
         ret = self.b._unpack_chord_result(tup, decode)
         self.b.exception_to_python.assert_called_with(exc)
-        self.assertIs(ret, self.b.exception_to_python())
+        assert ret is self.b.exception_to_python()
 
     def test_on_chord_part_return_no_gid_or_tid(self):
         request = Mock(name='request')
         request.id = request.group = None
-        self.assertIsNone(self.b.on_chord_part_return(request, 'SUCCESS', 10))
+        assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
 
     def test_ConnectionPool(self):
         self.b.redis = Mock(name='redis')
-        self.assertIsNone(self.b._ConnectionPool)
-        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
-        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
+        assert self.b._ConnectionPool is None
+        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
+        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
 
     def test_expires_defaults_to_config(self):
         self.app.conf.result_expires = 10
         b = self.Backend(expires=None, app=self.app)
-        self.assertEqual(b.expires, 10)
+        assert b.expires == 10
 
     def test_expires_is_int(self):
         b = self.Backend(expires=48, app=self.app)
-        self.assertEqual(b.expires, 48)
+        assert b.expires == 48
 
     def test_add_to_chord(self):
         b = self.Backend('redis://', app=self.app)
@@ -280,17 +272,14 @@ class test_RedisBackend(AppCase):
 
     def test_expires_is_None(self):
         b = self.Backend(expires=None, app=self.app)
-        self.assertEqual(
-            b.expires,
-            self.app.conf.result_expires.total_seconds(),
-        )
+        assert b.expires == self.app.conf.result_expires.total_seconds()
 
     def test_expires_is_timedelta(self):
         b = self.Backend(expires=timedelta(minutes=1), app=self.app)
-        self.assertEqual(b.expires, 60)
+        assert b.expires == 60
 
     def test_mget(self):
-        self.assertTrue(self.b.mget(['a', 'b', 'c']))
+        assert self.b.mget(['a', 'b', 'c'])
         self.b.client.mget.assert_called_with(['a', 'b', 'c'])
 
     def test_set_no_expire(self):
@@ -314,9 +303,9 @@ class test_RedisBackend(AppCase):
 
         for i in range(10):
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
-            self.assertTrue(self.b.client.rpush.call_count)
+            assert self.b.client.rpush.call_count
             self.b.client.rpush.reset_mock()
-        self.assertTrue(self.b.client.lrange.call_count)
+        assert self.b.client.lrange.call_count
         jkey = self.b.get_key_for_group('group_id', '.j')
         tkey = self.b.get_key_for_group('group_id', '.t')
         self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
@@ -383,10 +372,10 @@ class test_RedisBackend(AppCase):
     def test_get_set_forget(self):
         tid = uuid()
         self.b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(self.b.get_state(tid), states.SUCCESS)
-        self.assertEqual(self.b.get_result(tid), 42)
+        assert self.b.get_state(tid) == states.SUCCESS
+        assert self.b.get_result(tid) == 42
         self.b.forget(tid)
-        self.assertEqual(self.b.get_state(tid), states.PENDING)
+        assert self.b.get_state(tid) == states.PENDING
 
     def test_set_expires(self):
         self.b = self.Backend(expires=512, app=self.app)

+ 21 - 20
celery/tests/backends/test_riak.py → t/unit/backends/test_riak.py

@@ -1,17 +1,19 @@
 # -*- coding: utf-8 -*-
-
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import MagicMock, Mock, patch, sentinel, skip
+
 from celery.backends import riak as module
 from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 RIAK_BUCKET = 'riak_bucket'
 
 
 @skip.unless_module('riak')
-class test_RiakBackend(AppCase):
+class test_RiakBackend:
 
     def setup(self):
         self.app.conf.result_backend = 'riak://'
@@ -23,28 +25,27 @@ class test_RiakBackend(AppCase):
     def test_init_no_riak(self):
         prev, module.riak = module.riak, None
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 RiakBackend(app=self.app)
         finally:
             module.riak = prev
 
     def test_init_no_settings(self):
         self.app.conf.riak_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             RiakBackend(app=self.app)
 
     def test_init_settings_is_None(self):
         self.app.conf.riak_backend_settings = None
-        self.assertTrue(self.app.backend)
+        assert self.app.backend
 
     def test_get_client_client_exists(self):
         with patch('riak.client.RiakClient') as mock_connection:
             self.backend._client = sentinel._client
-
             mocked_is_alive = self.backend._client.is_alive = Mock()
             mocked_is_alive.return_value.value = True
             client = self.backend._get_client()
-            self.assertEquals(sentinel._client, client)
+            assert sentinel._client == client
             mock_connection.assert_not_called()
 
     def test_get(self):
@@ -54,7 +55,7 @@ class test_RiakBackend(AppCase):
         mocked_get = self.backend._bucket.get = Mock(name='bucket.get')
         mocked_get.return_value.data = sentinel.retval
         # should return None
-        self.assertEqual(self.backend.get('1f3fab'), sentinel.retval)
+        assert self.backend.get('1f3fab') == sentinel.retval
         self.backend._bucket.get.assert_called_once_with('1f3fab')
 
     def test_set(self):
@@ -63,7 +64,7 @@ class test_RiakBackend(AppCase):
         self.backend._bucket = MagicMock()
         self.backend._bucket.set = MagicMock()
         # should return None
-        self.assertIsNone(self.backend.set(sentinel.key, sentinel.value))
+        assert self.backend.set(sentinel.key, sentinel.value) is None
 
     def test_delete(self):
         self.app.conf.couchbase_backend_settings = {}
@@ -73,7 +74,7 @@ class test_RiakBackend(AppCase):
         mocked_delete = self.backend._client.delete = Mock('client.delete')
         mocked_delete.return_value = None
         # should return None
-        self.assertIsNone(self.backend.delete('1f3fab'))
+        assert self.backend.delete('1f3fab') is None
         self.backend._bucket.delete.assert_called_once_with('1f3fab')
 
     def test_config_params(self):
@@ -82,22 +83,22 @@ class test_RiakBackend(AppCase):
             'host': 'there.host.com',
             'port': '1234',
         }
-        self.assertEqual(self.backend.bucket_name, 'mycoolbucket')
-        self.assertEqual(self.backend.host, 'there.host.com')
-        self.assertEqual(self.backend.port, 1234)
+        assert self.backend.bucket_name == 'mycoolbucket'
+        assert self.backend.host == 'there.host.com'
+        assert self.backend.port == 1234
 
     def test_backend_by_url(self, url='riak://myhost/mycoolbucket'):
         from celery import backends
         from celery.backends.riak import RiakBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, RiakBackend)
-        self.assertEqual(url_, url)
+        assert backend is RiakBackend
+        assert url_ == url
 
     def test_backend_params_by_url(self):
         self.app.conf.result_backend = 'riak://myhost:123/mycoolbucket'
-        self.assertEqual(self.backend.bucket_name, 'mycoolbucket')
-        self.assertEqual(self.backend.host, 'myhost')
-        self.assertEqual(self.backend.port, 123)
+        assert self.backend.bucket_name == 'mycoolbucket'
+        assert self.backend.host == 'myhost'
+        assert self.backend.port == 123
 
     def test_non_ASCII_bucket_raises(self):
         self.app.conf.riak_backend_settings = {
@@ -105,5 +106,5 @@ class test_RiakBackend(AppCase):
             'host': 'there.host.com',
             'port': '1234',
         }
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             RiakBackend(app=self.app)

+ 20 - 22
celery/tests/backends/test_rpc.py → t/unit/backends/test_rpc.py

@@ -1,12 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import Mock, patch
+
 from celery.backends.rpc import RPCBackend
 from celery._state import _task_stack
 
-from celery.tests.case import AppCase, Mock, patch
-
 
-class test_RPCBackend(AppCase):
+class test_RPCBackend:
 
     def setup(self):
         self.b = RPCBackend(app=self.app)
@@ -14,8 +16,8 @@ class test_RPCBackend(AppCase):
     def test_oid(self):
         oid = self.b.oid
         oid2 = self.b.oid
-        self.assertEqual(oid, oid2)
-        self.assertEqual(oid, self.app.oid)
+        assert oid == oid2
+        assert oid == self.app.oid
 
     def test_interface(self):
         self.b.on_reply_declare('task_id')
@@ -24,38 +26,34 @@ class test_RPCBackend(AppCase):
         req = Mock(name='request')
         req.reply_to = 'reply_to'
         req.correlation_id = 'corid'
-        self.assertTupleEqual(
-            self.b.destination_for('task_id', req),
-            ('reply_to', 'corid'),
-        )
+        assert self.b.destination_for('task_id', req) == ('reply_to', 'corid')
         task = Mock()
         _task_stack.push(task)
         try:
             task.request.reply_to = 'reply_to'
             task.request.correlation_id = 'corid'
-            self.assertTupleEqual(
-                self.b.destination_for('task_id', None),
-                ('reply_to', 'corid'),
+            assert self.b.destination_for('task_id', None) == (
+                'reply_to', 'corid',
             )
         finally:
             _task_stack.pop()
 
-        with self.assertRaises(RuntimeError):
+        with pytest.raises(RuntimeError):
             self.b.destination_for('task_id', None)
 
     def test_rkey(self):
-        self.assertEqual(self.b.rkey('id1'), 'id1')
+        assert self.b.rkey('id1') == 'id1'
 
     def test_binding(self):
         queue = self.b.binding
-        self.assertEqual(queue.name, self.b.oid)
-        self.assertEqual(queue.exchange, self.b.exchange)
-        self.assertEqual(queue.routing_key, self.b.oid)
-        self.assertFalse(queue.durable)
-        self.assertTrue(queue.auto_delete)
+        assert queue.name == self.b.oid
+        assert queue.exchange == self.b.exchange
+        assert queue.routing_key == self.b.oid
+        assert not queue.durable
+        assert queue.auto_delete
 
     def test_create_binding(self):
-        self.assertEqual(self.b._create_binding('id'), self.b.binding)
+        assert self.b._create_binding('id') == self.b.binding
 
     def test_on_task_call(self):
         with patch('celery.backends.rpc.maybe_declare') as md:
@@ -68,5 +66,5 @@ class test_RPCBackend(AppCase):
 
     def test_create_exchange(self):
         ex = self.b._create_exchange('name')
-        self.assertIsInstance(ex, self.b.Exchange)
-        self.assertEqual(ex.name, '')
+        assert isinstance(ex, self.b.Exchange)
+        assert ex.name == ''

+ 0 - 0
celery/tests/concurrency/__init__.py → t/unit/bin/__init__.py


+ 0 - 0
celery/tests/bin/celery.py → t/unit/bin/celery.py


+ 0 - 0
celery/tests/bin/proj/__init__.py → t/unit/bin/proj/__init__.py


+ 0 - 0
celery/tests/bin/proj/app.py → t/unit/bin/proj/app.py


+ 29 - 32
celery/tests/bin/test_amqp.py → t/unit/bin/test_amqp.py

@@ -1,5 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import Mock, patch
+
 from celery.bin.amqp import (
     AMQPAdmin,
     AMQShell,
@@ -9,10 +13,8 @@ from celery.bin.amqp import (
 )
 from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, patch
-
 
-class test_AMQShell(AppCase):
+class test_AMQShell:
 
     def setup(self):
         self.fh = WhateverIO()
@@ -24,54 +26,54 @@ class test_AMQShell(AppCase):
 
     def test_queue_declare(self):
         self.shell.onecmd('queue.declare foo')
-        self.assertIn('ok', self.fh.getvalue())
+        assert 'ok' in self.fh.getvalue()
 
     def test_missing_command(self):
         self.shell.onecmd('foo foo')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
     def RV(self):
         raise Exception(self.fh.getvalue())
 
     def test_spec_format_response(self):
         spec = self.shell.amqp['exchange.declare']
-        self.assertEqual(spec.format_response(None), 'ok.')
-        self.assertEqual(spec.format_response('NO'), 'NO')
+        assert spec.format_response(None) == 'ok.'
+        assert spec.format_response('NO') == 'NO'
 
     def test_missing_namespace(self):
         self.shell.onecmd('ns.cmd arg')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
     def test_help(self):
         self.shell.onecmd('help')
-        self.assertIn('Example:', self.fh.getvalue())
+        assert 'Example:' in self.fh.getvalue()
 
     def test_help_command(self):
         self.shell.onecmd('help queue.declare')
-        self.assertIn('passive:no', self.fh.getvalue())
+        assert 'passive:no' in self.fh.getvalue()
 
     def test_help_unknown_command(self):
         self.shell.onecmd('help foo.baz')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
     def test_onecmd_error(self):
         self.shell.dispatch = Mock()
         self.shell.dispatch.side_effect = MemoryError()
         self.shell.say = Mock()
-        self.assertFalse(self.shell.needs_reconnect)
+        assert not self.shell.needs_reconnect
         self.shell.onecmd('hello')
         self.shell.say.assert_called()
-        self.assertTrue(self.shell.needs_reconnect)
+        assert self.shell.needs_reconnect
 
     def test_exit(self):
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             self.shell.onecmd('exit')
-        self.assertIn("don't leave!", self.fh.getvalue())
+        assert "don't leave!" in self.fh.getvalue()
 
     def test_note_silent(self):
         self.shell.silent = True
         self.shell.note('foo bar')
-        self.assertNotIn('foo bar', self.fh.getvalue())
+        assert 'foo bar' not in self.fh.getvalue()
 
     def test_reconnect(self):
         self.shell.onecmd('queue.declare foo')
@@ -79,14 +81,9 @@ class test_AMQShell(AppCase):
         self.shell.onecmd('queue.delete foo')
 
     def test_completenames(self):
-        self.assertEqual(
-            self.shell.completenames('queue.dec'),
-            ['queue.declare'],
-        )
-        self.assertEqual(
-            sorted(self.shell.completenames('declare')),
-            sorted(['queue.declare', 'exchange.declare']),
-        )
+        assert self.shell.completenames('queue.dec') == ['queue.declare']
+        assert (sorted(self.shell.completenames('declare')) ==
+                sorted(['queue.declare', 'exchange.declare']))
 
     def test_empty_line(self):
         self.shell.emptyline = Mock()
@@ -98,10 +95,10 @@ class test_AMQShell(AppCase):
 
     def test_respond(self):
         self.shell.respond({'foo': 'bar'})
-        self.assertIn('foo', self.fh.getvalue())
+        assert 'foo' in self.fh.getvalue()
 
     def test_prompt(self):
-        self.assertTrue(self.shell.prompt)
+        assert self.shell.prompt
 
     def test_no_returns(self):
         self.shell.onecmd('queue.declare foo')
@@ -114,20 +111,20 @@ class test_AMQShell(AppCase):
         m.body = 'the quick brown fox'
         m.properties = {'a': 1}
         m.delivery_info = {'exchange': 'bar'}
-        self.assertTrue(dump_message(m))
+        assert dump_message(m)
 
     def test_dump_message_no_message(self):
-        self.assertIn('No messages in queue', dump_message(None))
+        assert 'No messages in queue' in dump_message(None)
 
     def test_note(self):
         self.adm.silent = True
         self.adm.note('FOO')
-        self.assertNotIn('FOO', self.fh.getvalue())
+        assert 'FOO' not in self.fh.getvalue()
 
     def test_run(self):
         a = self.create_adm('queue.declare', 'foo')
         a.run()
-        self.assertIn('ok', self.fh.getvalue())
+        assert 'ok' in self.fh.getvalue()
 
     def test_run_loop(self):
         a = self.create_adm()
@@ -139,7 +136,7 @@ class test_AMQShell(AppCase):
 
         shell.cmdloop.side_effect = KeyboardInterrupt()
         a.run()
-        self.assertIn('bibi', self.fh.getvalue())
+        assert 'bibi' in self.fh.getvalue()
 
     @patch('celery.bin.amqp.amqp')
     def test_main(self, Command):
@@ -151,4 +148,4 @@ class test_AMQShell(AppCase):
     def test_command(self, cls):
         x = amqp(app=self.app)
         x.run()
-        self.assertIs(cls.call_args[1]['app'], self.app)
+        assert cls.call_args[1]['app'] is self.app

+ 130 - 143
celery/tests/bin/test_base.py → t/unit/bin/test_base.py

@@ -1,6 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 
 import os
+import pytest
+
+from case import Mock, mock, patch
 
 from celery.bin.base import (
     Command,
@@ -11,10 +14,6 @@ from celery.bin.base import (
 from celery.five import bytes_if_py2
 from celery.utils.objects import Bunch
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, mock, patch,
-)
-
 
 class MyApp(object):
     user_options = {'preload': None}
@@ -33,7 +32,7 @@ class MockCommand(Command):
         return args, kwargs
 
 
-class test_Extensions(AppCase):
+class test_Extensions:
 
     def test_load(self):
         with patch('pkg_resources.iter_entry_points') as iterep:
@@ -58,28 +57,28 @@ class test_Extensions(AppCase):
 
             with patch('celery.utils.imports.symbol_by_name') as symbyname:
                 symbyname.side_effect = KeyError('foo')
-                with self.assertRaises(KeyError):
+                with pytest.raises(KeyError):
                     e.load()
 
 
-class test_HelpFormatter(AppCase):
+class test_HelpFormatter:
 
     def test_format_epilog(self):
         f = HelpFormatter()
-        self.assertTrue(f.format_epilog('hello'))
-        self.assertFalse(f.format_epilog(''))
+        assert f.format_epilog('hello')
+        assert not f.format_epilog('')
 
     def test_format_description(self):
         f = HelpFormatter()
-        self.assertTrue(f.format_description('hello'))
+        assert f.format_description('hello')
 
 
-class test_Command(AppCase):
+class test_Command:
 
     def test_get_options(self):
         cmd = Command()
         cmd.option_list = (1, 2, 3)
-        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
+        assert cmd.get_options() == (1, 2, 3)
 
     def test_custom_description(self):
 
@@ -87,12 +86,12 @@ class test_Command(AppCase):
             description = 'foo'
 
         c = C()
-        self.assertEqual(c.description, 'foo')
+        assert c.description == 'foo'
 
     def test_register_callbacks(self):
         c = Command(on_error=8, on_usage_error=9)
-        self.assertEqual(c.on_error, 8)
-        self.assertEqual(c.on_usage_error, 9)
+        assert c.on_error == 8
+        assert c.on_usage_error == 9
 
     def test_run_raises_UsageError(self):
         cb = Mock()
@@ -101,7 +100,7 @@ class test_Command(AppCase):
         c.run = Mock()
         exc = c.run.side_effect = c.UsageError('foo', status=3)
 
-        self.assertEqual(c(), exc.status)
+        assert c() == exc.status
         cb.assert_called_with(exc)
         c.verify_args.assert_called_with(())
 
@@ -119,238 +118,226 @@ class test_Command(AppCase):
             pass
         c.run = run
 
-        with self.assertRaises(c.UsageError):
+        with pytest.raises(c.UsageError):
             c.verify_args((1,))
         c.verify_args((1, 2, 3))
 
     def test_run_interface(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             Command().run()
 
     @patch('sys.stdout')
     def test_early_version(self, stdout):
         cmd = Command()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             cmd.early_version(['--version'])
 
-    def test_execute_from_commandline(self):
-        cmd = MockCommand(app=self.app)
+    def test_execute_from_commandline(self, app):
+        cmd = MockCommand(app=app)
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
-        self.assertTupleEqual(args1, cmd.mock_args)
-        self.assertDictContainsSubset({'foo': 'bar'}, kwargs1)
-        self.assertTrue(kwargs1.get('prog_name'))
+        assert args1 == cmd.mock_args
+        assert kwargs1['foo'] == 'bar'
+        assert kwargs1.get('prog_name')
         args2, kwargs2 = cmd.execute_from_commandline(['foo'])   # pass list
-        self.assertTupleEqual(args2, cmd.mock_args)
-        self.assertDictContainsSubset({'foo': 'bar', 'prog_name': 'foo'},
-                                      kwargs2)
-
-    @mock.stdouts
-    def test_with_bogus_args(self, _, stderr):
-        cmd = MockCommand(app=self.app)
-        cmd.supports_args = False
-        with self.assertRaises(SystemExit):
-            cmd.execute_from_commandline(argv=['--bogus'])
-        self.assertTrue(stderr.getvalue())
-        self.assertIn('Unrecognized', stderr.getvalue())
-
-    def test_with_custom_config_module(self):
+        assert args2 == cmd.mock_args
+        assert kwargs2['foo'] == 'bar'
+        assert kwargs2['prog_name'] == 'foo'
+
+    def test_with_bogus_args(self, app):
+        with mock.stdouts() as (_, stderr):
+            cmd = MockCommand(app=app)
+            cmd.supports_args = False
+            with pytest.raises(SystemExit):
+                cmd.execute_from_commandline(argv=['--bogus'])
+            assert stderr.getvalue()
+            assert 'Unrecognized' in stderr.getvalue()
+
+    def test_with_custom_config_module(self, app):
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--config=foo.bar.baz'])
-            self.assertEqual(os.environ.get('CELERY_CONFIG_MODULE'),
-                             'foo.bar.baz')
+            assert os.environ.get('CELERY_CONFIG_MODULE') == 'foo.bar.baz'
         finally:
             if prev:
                 os.environ['CELERY_CONFIG_MODULE'] = prev
             else:
                 os.environ.pop('CELERY_CONFIG_MODULE', None)
 
-    def test_with_custom_broker(self):
+    def test_with_custom_broker(self, app):
         prev = os.environ.pop('CELERY_BROKER_URL', None)
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--broker=xyzza://'])
-            self.assertEqual(
-                os.environ.get('CELERY_BROKER_URL'), 'xyzza://',
-            )
+            assert os.environ.get('CELERY_BROKER_URL') == 'xyzza://'
         finally:
             if prev:
                 os.environ['CELERY_BROKER_URL'] = prev
             else:
                 os.environ.pop('CELERY_BROKER_URL', None)
 
-    def test_with_custom_app(self):
-        cmd = MockCommand(app=self.app)
-        app = '.'.join([__name__, 'APP'])
-        cmd.setup_app_from_commandline(['--app=%s' % (app,),
+    def test_with_custom_app(self, app):
+        cmd = MockCommand(app=app)
+        appstr = '.'.join([__name__, 'APP'])
+        cmd.setup_app_from_commandline(['--app=%s' % (appstr,),
                                         '--loglevel=INFO'])
-        self.assertIs(cmd.app, APP)
-        cmd.setup_app_from_commandline(['-A', app,
+        assert cmd.app is APP
+        cmd.setup_app_from_commandline(['-A', appstr,
                                         '--loglevel=INFO'])
-        self.assertIs(cmd.app, APP)
+        assert cmd.app is APP
 
-    def test_setup_app_sets_quiet(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_sets_quiet(self, app):
+        cmd = MockCommand(app=app)
         cmd.setup_app_from_commandline(['-q'])
-        self.assertTrue(cmd.quiet)
-        cmd2 = MockCommand(app=self.app)
+        assert cmd.quiet
+        cmd2 = MockCommand(app=app)
         cmd2.setup_app_from_commandline(['--quiet'])
-        self.assertTrue(cmd2.quiet)
+        assert cmd2.quiet
 
-    def test_setup_app_sets_chdir(self):
+    def test_setup_app_sets_chdir(self, app):
         with patch('os.chdir') as chdir:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--workdir=/opt'])
             chdir.assert_called_with('/opt')
 
-    def test_setup_app_sets_loader(self):
+    def test_setup_app_sets_loader(self, app):
         prev = os.environ.get('CELERY_LOADER')
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--loader=X.Y:Z'])
-            self.assertEqual(os.environ['CELERY_LOADER'], 'X.Y:Z')
+            assert os.environ['CELERY_LOADER'] == 'X.Y:Z'
         finally:
             if prev is not None:
                 os.environ['CELERY_LOADER'] = prev
 
-    def test_setup_app_no_respect(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_no_respect(self, app):
+        cmd = MockCommand(app=app)
         cmd.respects_app_option = False
         with patch('celery.bin.base.Celery') as cp:
             cmd.setup_app_from_commandline(['--app=x.y:z'])
             cp.assert_called()
 
-    def test_setup_app_custom_app(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_custom_app(self, app):
+        cmd = MockCommand(app=app)
         app = cmd.app = Mock()
         app.user_options = {'preload': None}
         cmd.setup_app_from_commandline([])
-        self.assertEqual(cmd.app, app)
-
-    def test_find_app_suspects(self):
-        cmd = MockCommand(app=self.app)
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj:hello'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.hello'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app:app'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app.app'))
-        with self.assertRaises(AttributeError):
-            cmd.find_app('celery.tests.bin')
-
-        with self.assertRaises(AttributeError):
+        assert cmd.app == app
+
+    def test_find_app_suspects(self, app):
+        cmd = MockCommand(app=app)
+        assert cmd.find_app('t.unit.bin.proj.app')
+        assert cmd.find_app('t.unit.bin.proj')
+        assert cmd.find_app('t.unit.bin.proj:hello')
+        assert cmd.find_app('t.unit.bin.proj.hello')
+        assert cmd.find_app('t.unit.bin.proj.app:app')
+        assert cmd.find_app('t.unit.bin.proj.app.app')
+        with pytest.raises(AttributeError):
+            cmd.find_app('t.unit.bin')
+
+        with pytest.raises(AttributeError):
             cmd.find_app(__name__)
 
-    def test_ask(self):
+    def test_ask(self, app, patching):
         try:
-            input = self.patch('celery.bin.base.input')
+            input = patching('celery.bin.base.input')
         except AttributeError:
-            input = self.patch('builtins.input')
-        cmd = MockCommand(app=self.app)
+            input = patching('builtins.input')
+        cmd = MockCommand(app=app)
         input.return_value = 'yes'
-        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'yes')
+        assert cmd.ask('q', ('yes', 'no'), 'no') == 'yes'
         input.return_value = 'nop'
-        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'no')
+        assert cmd.ask('q', ('yes', 'no'), 'no') == 'no'
 
-    def test_host_format(self):
-        cmd = MockCommand(app=self.app)
+    def test_host_format(self, app):
+        cmd = MockCommand(app=app)
         with patch('celery.utils.nodenames.gethostname') as hn:
             hn.return_value = 'blacktron.example.com'
-            self.assertEqual(cmd.host_format(''), '')
-            self.assertEqual(
-                cmd.host_format('celery@%h'),
-                'celery@blacktron.example.com',
-            )
-            self.assertEqual(
-                cmd.host_format('celery@%d'),
-                'celery@example.com',
-            )
-            self.assertEqual(
-                cmd.host_format('celery@%n'),
-                'celery@blacktron',
-            )
-
-    def test_say_chat_quiet(self):
-        cmd = MockCommand(app=self.app)
+            assert cmd.host_format('') == ''
+            assert (cmd.host_format('celery@%h') ==
+                    'celery@blacktron.example.com')
+            assert cmd.host_format('celery@%d') == 'celery@example.com'
+            assert cmd.host_format('celery@%n') == 'celery@blacktron'
+
+    def test_say_chat_quiet(self, app):
+        cmd = MockCommand(app=app)
         cmd.quiet = True
-        self.assertIsNone(cmd.say_chat('<-', 'foo', 'foo'))
+        assert cmd.say_chat('<-', 'foo', 'foo') is None
 
-    def test_say_chat_show_body(self):
-        cmd = MockCommand(app=self.app)
+    def test_say_chat_show_body(self, app):
+        cmd = MockCommand(app=app)
         cmd.out = Mock()
         cmd.show_body = True
         cmd.say_chat('->', 'foo', 'body')
         cmd.out.assert_called_with('body')
 
-    def test_say_chat_no_body(self):
-        cmd = MockCommand(app=self.app)
+    def test_say_chat_no_body(self, app):
+        cmd = MockCommand(app=app)
         cmd.out = Mock()
         cmd.show_body = False
         cmd.say_chat('->', 'foo', 'body')
 
-    @depends_on_current_app
-    def test_with_cmdline_config(self):
-        cmd = MockCommand(app=self.app)
+    @pytest.mark.usefixtures('depends_on_current_app')
+    def test_with_cmdline_config(self, app):
+        cmd = MockCommand(app=app)
         cmd.enable_config_from_cmdline = True
         cmd.namespace = 'worker'
         rest = cmd.setup_app_from_commandline(argv=[
             '--loglevel=INFO', '--',
             'broker.url=amqp://broker.example.com',
             '.prefetch_multiplier=100'])
-        self.assertEqual(cmd.app.conf.broker_url,
-                         'amqp://broker.example.com')
-        self.assertEqual(cmd.app.conf.worker_prefetch_multiplier, 100)
-        self.assertListEqual(rest, ['--loglevel=INFO'])
+        assert cmd.app.conf.broker_url == 'amqp://broker.example.com'
+        assert cmd.app.conf.worker_prefetch_multiplier == 100
+        assert rest == ['--loglevel=INFO']
 
         cmd.app = None
         cmd.get_app = Mock(name='get_app')
-        cmd.get_app.return_value = self.app
-        self.app.user_options['preload'] = [
+        cmd.get_app.return_value = app
+        app.user_options['preload'] = [
             Option('--foo', action='store_true'),
         ]
         cmd.setup_app_from_commandline(argv=[
             '--foo', '--loglevel=INFO', '--',
             'broker.url=amqp://broker.example.com',
             '.prefetch_multiplier=100'])
-        self.assertIs(cmd.app, cmd.get_app())
+        assert cmd.app is cmd.get_app()
 
-    def test_preparse_options__required_short(self):
-        cmd = MockCommand(app=self.app)
-        with self.assertRaises(ValueError):
+    def test_preparse_options__required_short(self, app):
+        cmd = MockCommand(app=app)
+        with pytest.raises(ValueError):
             cmd.preparse_options(
                 ['a', '-f'], [Option('-f', action='store')])
 
-    def test_preparse_options__longopt_whitespace(self):
-        cmd = MockCommand(app=self.app)
+    def test_preparse_options__longopt_whitespace(self, app):
+        cmd = MockCommand(app=app)
         cmd.preparse_options(
             ['a', '--foo', 'val'], [Option('--foo', action='store')])
 
-    def test_preparse_options__shortopt_store_true(self):
-        cmd = MockCommand(app=self.app)
+    def test_preparse_options__shortopt_store_true(self, app):
+        cmd = MockCommand(app=app)
         cmd.preparse_options(
             ['a', '--foo'], [Option('--foo', action='store_true')])
 
-    def test_get_default_app(self):
-        self.patch('celery._state.get_current_app')
-        cmd = MockCommand(app=self.app)
+    def test_get_default_app(self, app, patching):
+        patching('celery._state.get_current_app')
+        cmd = MockCommand(app=app)
         from celery._state import get_current_app
-        self.assertIs(cmd._get_default_app(), get_current_app())
+        assert cmd._get_default_app() is get_current_app()
 
-    def test_set_colored(self):
-        cmd = MockCommand(app=self.app)
+    def test_set_colored(self, app):
+        cmd = MockCommand(app=app)
         cmd.colored = 'foo'
-        self.assertEqual(cmd.colored, 'foo')
+        assert cmd.colored == 'foo'
 
-    def test_set_no_color(self):
-        cmd = MockCommand(app=self.app)
+    def test_set_no_color(self, app):
+        cmd = MockCommand(app=app)
         cmd.no_color = False
         _ = cmd.colored  # noqa
         cmd.no_color = True
-        self.assertFalse(cmd.colored.enabled)
+        assert not cmd.colored.enabled
 
-    def test_find_app(self):
-        cmd = MockCommand(app=self.app)
+    def test_find_app(self, app):
+        cmd = MockCommand(app=app)
         with patch('celery.utils.imports.symbol_by_name') as sbn:
             from types import ModuleType
             x = ModuleType(bytes_if_py2('proj'))
@@ -365,13 +352,13 @@ class test_Command(AppCase):
                 return x
             sbn.side_effect = on_sbn
             x.__path__ = [True]
-            self.assertEqual(cmd.find_app('proj'), 'quick brown fox')
+            assert cmd.find_app('proj') == 'quick brown fox'
 
     def test_parse_preload_options_shortopt(self):
         cmd = Command()
         cmd.preload_options = (Option('-s', action='store', dest='silent'),)
         acc = cmd.parse_preload_options(['-s', 'yes'])
-        self.assertEqual(acc.get('silent'), 'yes')
+        assert acc.get('silent') == 'yes'
 
     def test_parse_preload_options_with_equals_and_append(self):
         cmd = Command()
@@ -379,7 +366,7 @@ class test_Command(AppCase):
         cmd.preload_options = (opt,)
         acc = cmd.parse_preload_options(['--zoom=1', '--zoom=2'])
 
-        self.assertEqual(acc, {'zoom': ['1', '2']})
+        assert acc, {'zoom': ['1' == '2']}
 
     def test_parse_preload_options_without_equals_and_append(self):
         cmd = Command()
@@ -387,4 +374,4 @@ class test_Command(AppCase):
         cmd.preload_options = (opt,)
         acc = cmd.parse_preload_options(['--zoom', '1', '--zoom', '2'])
 
-        self.assertEqual(acc, {'zoom': ['1', '2']})
+        assert acc, {'zoom': ['1' == '2']}

+ 20 - 20
celery/tests/bin/test_beat.py → t/unit/bin/test_beat.py

@@ -1,15 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
 import logging
+import pytest
 import sys
 
+from case import Mock, mock, patch
+
 from celery import beat
 from celery import platforms
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 
-from celery.tests.case import AppCase, Mock, mock, patch
-
 
 def MockBeat(*args, **kwargs):
     class _Beat(beatapp.Beat):
@@ -23,16 +24,16 @@ def MockBeat(*args, **kwargs):
     return b
 
 
-class test_Beat(AppCase):
+class test_Beat:
 
     def test_loglevel_string(self):
         b = beatapp.Beat(app=self.app, loglevel='DEBUG',
                          redirect_stdouts=False)
-        self.assertEqual(b.loglevel, logging.DEBUG)
+        assert b.loglevel == logging.DEBUG
 
         b2 = beatapp.Beat(app=self.app, loglevel=logging.DEBUG,
                           redirect_stdouts=False)
-        self.assertEqual(b2.loglevel, logging.DEBUG)
+        assert b2.loglevel == logging.DEBUG
 
     def test_colorize(self):
         self.app.log.setup = Mock()
@@ -40,7 +41,7 @@ class test_Beat(AppCase):
                          redirect_stdouts=False)
         b.setup_logging()
         self.app.log.setup.assert_called()
-        self.assertEqual(self.app.log.setup.call_args[1]['colorize'], False)
+        assert not self.app.log.setup.call_args[1]['colorize']
 
     def test_init_loader(self):
         b = beatapp.Beat(app=self.app, redirect_stdouts=False)
@@ -78,7 +79,7 @@ class test_Beat(AppCase):
         clock.start = Mock(name='beat.Service().start')
         clock.sync = Mock(name='beat.Service().sync')
         handlers = self.psig(b.install_sync_handler, clock)
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             handlers['SIGINT']('SIGINT', object())
         clock.sync.assert_called_with()
 
@@ -93,40 +94,39 @@ class test_Beat(AppCase):
         b.redirect_stdouts = False
         b.app.log.already_setup = False
         b.setup_logging()
-        with self.assertRaises(AttributeError):
+        with pytest.raises(AttributeError):
             sys.stdout.logger
 
     import sys
     orig_stdout = sys.__stdout__
 
     @patch('celery.apps.beat.logger')
-    @mock.restore_logging()
-    @mock.stdouts
-    def test_logs_errors(self, logger, stdout, stderr):
+    def test_logs_errors(self, logger):
         b = MockBeat(
             app=self.app, redirect_stdouts=False, socket_timeout=None,
         )
         b.install_sync_handler = Mock('beat.install_sync_handler')
         b.install_sync_handler.side_effect = RuntimeError('xxx')
-        with self.assertRaises(RuntimeError):
-            b.start_scheduler()
+        with mock.restore_logging():
+            with pytest.raises(RuntimeError):
+                b.start_scheduler()
         logger.critical.assert_called()
 
     @patch('celery.platforms.create_pidlock')
-    @mock.stdouts
-    def test_using_pidfile(self, create_pidlock, stdout, stderr):
+    def test_using_pidfile(self, create_pidlock):
         b = MockBeat(app=self.app, pidfile='pidfilelockfilepid',
                      socket_timeout=None, redirect_stdouts=False)
         b.install_sync_handler = Mock(name='beat.install_sync_handler')
-        b.start_scheduler()
+        with mock.stdouts():
+            b.start_scheduler()
         create_pidlock.assert_called()
 
 
-class test_div(AppCase):
+class test_div:
 
     def setup(self):
-        self.Beat = self.app.Beat = self.patch('celery.apps.beat.Beat')
-        self.detached = self.patch('celery.bin.beat.detached')
+        self.Beat = self.app.Beat = self.patching('celery.apps.beat.Beat')
+        self.detached = self.patching('celery.bin.beat.detached')
         self.Beat.__name__ = 'Beat'
 
     def test_main(self):
@@ -144,4 +144,4 @@ class test_div(AppCase):
         cmd = beat_bin.beat()
         cmd.app = self.app
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
-        self.assertEqual(options.schedule, 'foo')
+        assert options.schedule == 'foo'

+ 104 - 116
celery/tests/bin/test_celery.py → t/unit/bin/test_celery.py

@@ -1,9 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
 import sys
 
 from datetime import datetime
 
+from case import Mock, patch
 from kombu.utils.json import dumps
 
 from celery import __main__
@@ -31,10 +33,8 @@ from celery.bin.celery import (
 from celery.five import WhateverIO
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 
-from celery.tests.case import AppCase, Mock, patch
 
-
-class test__main__(AppCase):
+class test__main__:
 
     def test_main(self):
         with patch('celery.__main__.maybe_patch_concurrency') as mpc:
@@ -55,13 +55,13 @@ class test__main__(AppCase):
                     sys.argv = prev
 
 
-class test_Command(AppCase):
+class test_Command:
 
     def test_Error_repr(self):
         x = Error('something happened')
-        self.assertIsNotNone(x.status)
-        self.assertTrue(x.reason)
-        self.assertTrue(str(x))
+        assert x.status is not None
+        assert x.reason
+        assert str(x)
 
     def setup(self):
         self.out = WhateverIO()
@@ -83,58 +83,52 @@ class test_Command(AppCase):
             pass
 
         self.cmd.run = ok_run
-        self.assertEqual(self.cmd(), EX_OK)
+        assert self.cmd() == EX_OK
 
         def error_run():
             raise Error('error', EX_FAILURE)
         self.cmd.run = error_run
-        self.assertEqual(self.cmd(), EX_FAILURE)
+        assert self.cmd() == EX_FAILURE
 
     def test_run_from_argv(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.cmd.run_from_argv('prog', ['foo', 'bar'])
 
     def test_pretty_list(self):
-        self.assertEqual(self.cmd.pretty([])[1], '- empty -')
-        self.assertIn('bar', self.cmd.pretty(['foo', 'bar'])[1])
+        assert self.cmd.pretty([])[1] == '- empty -'
+        assert 'bar', self.cmd.pretty(['foo' in 'bar'][1])
 
-    def test_pretty_dict(self):
-        self.assertIn(
-            'OK',
-            str(self.cmd.pretty({'ok': 'the quick brown fox'})[0]),
-        )
-        self.assertIn(
-            'ERROR',
-            str(self.cmd.pretty({'error': 'the quick brown fox'})[0]),
-        )
+    def test_pretty_dict(self, text='the quick brown fox'):
+        assert 'OK' in str(self.cmd.pretty({'ok': text})[0])
+        assert 'ERROR' in str(self.cmd.pretty({'error': text})[0])
 
     def test_pretty(self):
-        self.assertIn('OK', str(self.cmd.pretty('the quick brown')))
-        self.assertIn('OK', str(self.cmd.pretty(object())))
-        self.assertIn('OK', str(self.cmd.pretty({'foo': 'bar'})))
+        assert 'OK' in str(self.cmd.pretty('the quick brown'))
+        assert 'OK' in str(self.cmd.pretty(object()))
+        assert 'OK' in str(self.cmd.pretty({'foo': 'bar'}))
 
 
-class test_list(AppCase):
+class test_list:
 
     def test_list_bindings_no_support(self):
         l = list_(app=self.app, stderr=WhateverIO())
         management = Mock()
         management.get_bindings.side_effect = NotImplementedError()
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.list_bindings(management)
 
     def test_run(self):
         l = list_(app=self.app, stderr=WhateverIO())
         l.run('bindings')
 
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.run(None)
 
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.run('foo')
 
 
-class test_call(AppCase):
+class test_call:
 
     def setup(self):
 
@@ -152,22 +146,22 @@ class test_call(AppCase):
         a.run(self.add.name,
               args=dumps([4, 4]),
               kwargs=dumps({'x': 2, 'y': 2}))
-        self.assertEqual(send_task.call_args[1]['args'], [4, 4])
-        self.assertEqual(send_task.call_args[1]['kwargs'], {'x': 2, 'y': 2})
+        assert send_task.call_args[1]['args'], [4 == 4]
+        assert send_task.call_args[1]['kwargs'] == {'x': 2, 'y': 2}
 
         a.run(self.add.name, expires=10, countdown=10)
-        self.assertEqual(send_task.call_args[1]['expires'], 10)
-        self.assertEqual(send_task.call_args[1]['countdown'], 10)
+        assert send_task.call_args[1]['expires'] == 10
+        assert send_task.call_args[1]['countdown'] == 10
 
         now = datetime.now()
         iso = now.isoformat()
         a.run(self.add.name, expires=iso)
-        self.assertEqual(send_task.call_args[1]['expires'], now)
-        with self.assertRaises(ValueError):
+        assert send_task.call_args[1]['expires'] == now
+        with pytest.raises(ValueError):
             a.run(self.add.name, expires='foobaribazibar')
 
 
-class test_purge(AppCase):
+class test_purge:
 
     def test_run(self):
         out = WhateverIO()
@@ -175,11 +169,11 @@ class test_purge(AppCase):
         a._purge = Mock(name='_purge')
         a._purge.return_value = 0
         a.run(force=True)
-        self.assertIn('No messages purged', out.getvalue())
+        assert 'No messages purged' in out.getvalue()
 
         a._purge.return_value = 100
         a.run(force=True)
-        self.assertIn('100 messages', out.getvalue())
+        assert '100 messages' in out.getvalue()
 
         a.out = Mock(name='out')
         a.ask = Mock(name='ask')
@@ -189,7 +183,7 @@ class test_purge(AppCase):
         a.run(force=False)
 
 
-class test_result(AppCase):
+class test_result:
 
     def setup(self):
 
@@ -204,18 +198,18 @@ class test_result(AppCase):
             r = result(app=self.app, stdout=out)
             get.return_value = 'Jerry'
             r.run('id')
-            self.assertIn('Jerry', out.getvalue())
+            assert 'Jerry' in out.getvalue()
 
             get.return_value = 'Elaine'
             r.run('id', task=self.add.name)
-            self.assertIn('Elaine', out.getvalue())
+            assert 'Elaine' in out.getvalue()
 
             with patch('celery.result.AsyncResult.traceback') as tb:
                 r.run('id', task=self.add.name, traceback=True)
-                self.assertIn(str(tb), out.getvalue())
+                assert str(tb) in out.getvalue()
 
 
-class test_status(AppCase):
+class test_status:
 
     @patch('celery.bin.celery.inspect')
     def test_run(self, inspect_):
@@ -223,22 +217,22 @@ class test_status(AppCase):
         ins = inspect_.return_value = Mock()
         ins.run.return_value = []
         s = status(self.app, stdout=out, stderr=err)
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             s.run()
 
         ins.run.return_value = ['a', 'b', 'c']
         s.run()
-        self.assertIn('3 nodes online', out.getvalue())
+        assert '3 nodes online' in out.getvalue()
         s.run(quiet=True)
 
 
-class test_migrate(AppCase):
+class test_migrate:
 
     @patch('celery.contrib.migrate.migrate_tasks')
     def test_run(self, migrate_tasks):
         out = WhateverIO()
         m = migrate(app=self.app, stdout=out, stderr=WhateverIO())
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             m.run()
         migrate_tasks.assert_not_called()
 
@@ -249,61 +243,61 @@ class test_migrate(AppCase):
         state.count = 10
         state.strtotal = 30
         m.on_migrate_task(state, {'task': 'tasks.add', 'id': 'ID'}, None)
-        self.assertIn('10/30', out.getvalue())
+        assert '10/30' in out.getvalue()
 
 
-class test_report(AppCase):
+class test_report:
 
     def test_run(self):
         out = WhateverIO()
         r = report(app=self.app, stdout=out)
-        self.assertEqual(r.run(), EX_OK)
-        self.assertTrue(out.getvalue())
+        assert r.run() == EX_OK
+        assert out.getvalue()
 
 
-class test_help(AppCase):
+class test_help:
 
     def test_run(self):
         out = WhateverIO()
         h = help(app=self.app, stdout=out)
         h.parser = Mock()
-        self.assertEqual(h.run(), EX_USAGE)
-        self.assertTrue(out.getvalue())
-        self.assertTrue(h.usage('help'))
+        assert h.run() == EX_USAGE
+        assert out.getvalue()
+        assert h.usage('help')
         h.parser.print_help.assert_called_with()
 
 
-class test_CeleryCommand(AppCase):
+class test_CeleryCommand:
 
     def test_execute_from_commandline(self):
         x = CeleryCommand(app=self.app)
         x.handle_argv = Mock()
         x.handle_argv.return_value = 1
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
 
         x.handle_argv.return_value = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
 
         x.handle_argv.side_effect = KeyboardInterrupt()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
 
         x.respects_app_option = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline(['celery', 'multi'])
-        self.assertFalse(x.respects_app_option)
+        assert not x.respects_app_option
         x.respects_app_option = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline(['manage.py', 'celery', 'multi'])
-        self.assertFalse(x.respects_app_option)
+        assert not x.respects_app_option
 
     def test_with_pool_option(self):
         x = CeleryCommand(app=self.app)
-        self.assertIsNone(x.with_pool_option(['celery', 'events']))
-        self.assertTrue(x.with_pool_option(['celery', 'worker']))
-        self.assertTrue(x.with_pool_option(['manage.py', 'celery', 'worker']))
+        assert x.with_pool_option(['celery', 'events']) is None
+        assert x.with_pool_option(['celery', 'worker'])
+        assert x.with_pool_option(['manage.py', 'celery', 'worker'])
 
     def test_load_extensions_no_commands(self):
         with patch('celery.bin.celery.Extensions') as Ext:
@@ -327,35 +321,33 @@ class test_CeleryCommand(AppCase):
                 mod.command_classes = prev
 
     def test_determine_exit_status(self):
-        self.assertEqual(determine_exit_status('true'), EX_OK)
-        self.assertEqual(determine_exit_status(''), EX_FAILURE)
+        assert determine_exit_status('true') == EX_OK
+        assert determine_exit_status('') == EX_FAILURE
 
     def test_relocate_args_from_start(self):
         x = CeleryCommand(app=self.app)
-        self.assertEqual(x._relocate_args_from_start(None), [])
-        self.assertEqual(
-            x._relocate_args_from_start(
-                ['-l', 'debug', 'worker', '-c', '3', '--foo'],
-            ),
-            ['worker', '-c', '3', '--foo', '-l', 'debug'],
-        )
-        self.assertEqual(
-            x._relocate_args_from_start(
-                ['--pool=gevent', '-l', 'debug', 'worker', '--foo', '-c', '3'],
-            ),
-            ['worker', '--foo', '-c', '3', '--pool=gevent', '-l', 'debug'],
-        )
-        self.assertEqual(
-            x._relocate_args_from_start(['foo', '--foo=1']),
-            ['foo', '--foo=1'],
-        )
+        assert x._relocate_args_from_start(None) == []
+        relargs1 = x._relocate_args_from_start([
+            '-l', 'debug', 'worker', '-c', '3', '--foo',
+        ])
+        assert relargs1 == ['worker', '-c', '3', '--foo', '-l', 'debug']
+        relargs2 = x._relocate_args_from_start([
+            '--pool=gevent', '-l', 'debug', 'worker', '--foo', '-c', '3',
+        ])
+        assert relargs2 == [
+            'worker', '--foo', '-c', '3',
+            '--pool=gevent', '-l', 'debug',
+        ]
+        assert x._relocate_args_from_start(['foo', '--foo=1']) == [
+            'foo', '--foo=1',
+        ]
 
     def test_register_command(self):
         prev, CeleryCommand.commands = dict(CeleryCommand.commands), {}
         try:
             fun = Mock(name='fun')
             CeleryCommand.register_command(fun, name='foo')
-            self.assertIs(CeleryCommand.commands['foo'], fun)
+            assert CeleryCommand.commands['foo'] is fun
         finally:
             CeleryCommand.commands = prev
 
@@ -411,46 +403,42 @@ class test_CeleryCommand(AppCase):
         main = Mock(name='__main__')
         main.__file__ = '/opt/foo.py'
         with patch.dict(sys.modules, __main__=main):
-            self.assertEqual(x.prepare_prog_name('__main__.py'), '/opt/foo.py')
-            self.assertEqual(x.prepare_prog_name('celery'), 'celery')
+            assert x.prepare_prog_name('__main__.py') == '/opt/foo.py'
+            assert x.prepare_prog_name('celery') == 'celery'
 
 
-class test_RemoteControl(AppCase):
+class test_RemoteControl:
 
     def test_call_interface(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             _RemoteControl(app=self.app).call()
 
 
-class test_inspect(AppCase):
+class test_inspect:
 
     def test_usage(self):
-        self.assertTrue(inspect(app=self.app).usage('foo'))
+        assert inspect(app=self.app).usage('foo')
 
     def test_command_info(self):
         i = inspect(app=self.app)
-        self.assertTrue(i.get_command_info(
+        assert i.get_command_info(
             'ping', help=True, color=i.colored.red, app=self.app,
-        ))
+        )
 
     def test_list_commands_color(self):
         i = inspect(app=self.app)
-        self.assertTrue(i.list_commands(
-            help=True, color=i.colored.red, app=self.app,
-        ))
-        self.assertTrue(i.list_commands(
-            help=False, color=None, app=self.app,
-        ))
+        assert i.list_commands(help=True, color=i.colored.red, app=self.app)
+        assert i.list_commands(help=False, color=None, app=self.app)
 
     def test_epilog(self):
-        self.assertTrue(inspect(app=self.app).epilog)
+        assert inspect(app=self.app).epilog
 
     def test_do_call_method_sql_transport_type(self):
         self.app.connection = Mock()
         conn = self.app.connection.return_value = Mock(name='Connection')
         conn.transport.driver_type = 'sql'
         i = inspect(app=self.app)
-        with self.assertRaises(i.Error):
+        with pytest.raises(i.Error):
             i.do_call_method(['ping'])
 
     def test_say_directions(self):
@@ -472,22 +460,22 @@ class test_inspect(AppCase):
     def test_run(self, real):
         out = WhateverIO()
         i = inspect(app=self.app, stdout=out)
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run()
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('help')
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('xyzzybaz')
 
         i.run('ping')
         real.assert_called()
         i.run('ping', destination='foo,bar')
-        self.assertEqual(real.call_args[1]['destination'], ['foo', 'bar'])
-        self.assertEqual(real.call_args[1]['timeout'], 0.2)
+        assert real.call_args[1]['destination'], ['foo' == 'bar']
+        assert real.call_args[1]['timeout'] == 0.2
         callback = real.call_args[1]['callback']
 
         callback({'foo': {'ok': 'pong'}})
-        self.assertIn('OK', out.getvalue())
+        assert 'OK' in out.getvalue()
 
         with patch('celery.bin.celery.dumps') as dumps:
             i.run('ping', json=True)
@@ -495,17 +483,17 @@ class test_inspect(AppCase):
 
         instance = real.return_value = Mock()
         instance._request.return_value = None
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('ping')
 
         out.seek(0)
         out.truncate()
         i.quiet = True
         i.say_chat('<-', 'hello')
-        self.assertFalse(out.getvalue())
+        assert not out.getvalue()
 
 
-class test_control(AppCase):
+class test_control:
 
     def control(self, patch_call, *args, **kwargs):
         kwargs.setdefault('app', Mock(name='app'))
@@ -521,10 +509,10 @@ class test_control(AppCase):
             'foo', arguments={'kw': 2}, reply=True)
 
 
-class test_multi(AppCase):
+class test_multi:
 
     def test_get_options(self):
-        self.assertIsNone(multi(app=self.app).get_options())
+        assert multi(app=self.app).get_options() is None
 
     def test_run_from_argv(self):
         with patch('celery.bin.multi.MultiTool') as MultiTool:
@@ -535,7 +523,7 @@ class test_multi(AppCase):
             )
 
 
-class test_main(AppCase):
+class test_main:
 
     @patch('celery.bin.celery.CeleryCommand')
     def test_main(self, Command):
@@ -551,11 +539,11 @@ class test_main(AppCase):
         cmd.execute_from_commandline.assert_called_with(None)
 
 
-class test_compat(AppCase):
+class test_compat:
 
     def test_compat_command_decorator(self):
         with patch('celery.bin.celery.CeleryCommand') as CC:
-            self.assertEqual(command(), CC.register_command)
+            assert command() == CC.register_command
             fun = Mock(name='fun')
             command(fun)
             CC.register_command.assert_called_with(fun)

+ 27 - 21
celery/tests/bin/test_celeryd_detach.py → t/unit/bin/test_celeryd_detach.py

@@ -1,5 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
+from case import Mock, mock, patch
+
 from celery.platforms import IS_WINDOWS
 from celery.bin.celeryd_detach import (
     detach,
@@ -7,11 +11,9 @@ from celery.bin.celeryd_detach import (
     main,
 )
 
-from celery.tests.case import AppCase, Mock, mock, patch
-
 
 if not IS_WINDOWS:
-    class test_detached(AppCase):
+    class test_detached:
 
         @patch('celery.bin.celeryd_detach.detached')
         @patch('os.execv')
@@ -44,9 +46,9 @@ if not IS_WINDOWS:
             logger.critical.assert_called()
             setup_logs.assert_called_with(
                 'ERROR', '/var/log', hostname='foo@example.com')
-            self.assertEqual(r, 1)
+            assert r == 1
 
-            self.patch('celery.current_app')
+            self.patching('celery.current_app')
             from celery import current_app
             r = detach(
                 '/bin/boo', ['a', 'b', 'c'],
@@ -57,7 +59,7 @@ if not IS_WINDOWS:
             )
 
 
-class test_PartialOptionParser(AppCase):
+class test_PartialOptionParser:
 
     def test_parser(self):
         x = detached_celeryd(self.app)
@@ -66,23 +68,23 @@ class test_PartialOptionParser(AppCase):
             '--logfile=foo', '--fake', '--enable',
             'a', 'b', '-c1', '-d', '2',
         ])
-        self.assertEqual(options.logfile, 'foo')
-        self.assertEqual(values, ['a', 'b'])
-        self.assertEqual(p.leftovers, ['--enable', '-c1', '-d', '2'])
+        assert options.logfile == 'foo'
+        assert values, ['a' == 'b']
+        assert p.leftovers, ['--enable', '-c1', '-d' == '2']
         options, values = p.parse_args([
             '--fake', '--enable',
             '--pidfile=/var/pid/foo.pid',
             'a', 'b', '-c1', '-d', '2',
         ])
-        self.assertEqual(options.pidfile, '/var/pid/foo.pid')
+        assert options.pidfile == '/var/pid/foo.pid'
 
         with mock.stdouts():
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--logfile'])
             p.get_option('--logfile').nargs = 2
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--logfile=a'])
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--fake=abc'])
 
         assert p.get_option('--logfile').nargs == 2
@@ -90,18 +92,22 @@ class test_PartialOptionParser(AppCase):
         p.get_option('--logfile').nargs = 1
 
 
-class test_Command(AppCase):
-    argv = ['--foobar=10,2', '-c', '1',
-            '--logfile=/var/log', '-lDEBUG',
-            '--', '.disable_rate_limits=1']
+class test_Command:
+    argv = [
+        '--foobar=10,2', '-c', '1',
+        '--logfile=/var/log', '-lDEBUG',
+        '--', '.disable_rate_limits=1',
+    ]
 
     def test_parse_options(self):
         x = detached_celeryd(app=self.app)
         o, v, l = x.parse_options('cd', self.argv)
-        self.assertEqual(o.logfile, '/var/log')
-        self.assertEqual(l, ['--foobar=10,2', '-c', '1',
-                             '-lDEBUG', '--logfile=/var/log',
-                             '--pidfile=celeryd.pid'])
+        assert o.logfile == '/var/log'
+        assert l == [
+            '--foobar=10,2', '-c', '1',
+            '-lDEBUG', '--logfile=/var/log',
+            '--pidfile=celeryd.pid',
+        ]
         x.parse_options('cd', [])  # no args
 
     @patch('sys.exit')

+ 7 - 7
celery/tests/bin/test_celeryevdump.py → t/unit/bin/test_celeryevdump.py

@@ -2,6 +2,8 @@ from __future__ import absolute_import, unicode_literals
 
 from time import time
 
+from case import Mock, patch
+
 from celery.events.dumper import (
     humanize_type,
     Dumper,
@@ -9,23 +11,21 @@ from celery.events.dumper import (
 )
 from celery.five import WhateverIO
 
-from celery.tests.case import AppCase, Mock, patch
-
 
-class test_Dumper(AppCase):
+class test_Dumper:
 
     def setup(self):
         self.out = WhateverIO()
         self.dumper = Dumper(out=self.out)
 
     def test_humanize_type(self):
-        self.assertEqual(humanize_type('worker-offline'), 'shutdown')
-        self.assertEqual(humanize_type('task-started'), 'task started')
+        assert humanize_type('worker-offline') == 'shutdown'
+        assert humanize_type('task-started') == 'task started'
 
     def test_format_task_event(self):
         self.dumper.format_task_event(
             'worker@example.com', time(), 'task-started', 'tasks.add', {})
-        self.assertTrue(self.out.getvalue())
+        assert self.out.getvalue()
 
     def test_on_event(self):
         event = {
@@ -37,7 +37,7 @@ class test_Dumper(AppCase):
             'kwargs': '{}',
         }
         self.dumper.on_event(dict(event, type='task-received'))
-        self.assertTrue(self.out.getvalue())
+        assert self.out.getvalue()
         self.dumper.on_event(dict(event, type='task-revoked'))
         self.dumper.on_event(dict(event, type='worker-online'))
 

+ 35 - 14
celery/tests/bin/test_events.py → t/unit/bin/test_events.py

@@ -1,8 +1,29 @@
 from __future__ import absolute_import, unicode_literals
 
+import importlib
+
+from functools import wraps
+
+from case import patch, skip
+
 from celery.bin import events
 
-from celery.tests.case import AppCase, patch, _old_patch, skip
+
+def _old_patch(module, name, mocked):
+    module = importlib.import_module(module)
+
+    def _patch(fun):
+
+        @wraps(fun)
+        def __patched(*args, **kwargs):
+            prev = getattr(module, name)
+            setattr(module, name, mocked)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                setattr(module, name, prev)
+        return __patched
+    return _patch
 
 
 class MockCommand(object):
@@ -17,7 +38,7 @@ def proctitle(prog, info=None):
 proctitle.last = ()
 
 
-class test_events(AppCase):
+class test_events:
 
     def setup(self):
         self.ev = events.events(app=self.app)
@@ -26,8 +47,8 @@ class test_events(AppCase):
                 lambda **kw: 'me dumper, you?')
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     def test_run_dump(self):
-        self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
-        self.assertIn('celery events:dump', proctitle.last[0])
+        assert self.ev.run(dump=True), 'me dumper == you?'
+        assert 'celery events:dump' in proctitle.last[0]
 
     @skip.unless_module('curses', import_errors=(ImportError, OSError))
     def test_run_top(self):
@@ -35,8 +56,8 @@ class test_events(AppCase):
                     lambda **kw: 'me top, you?')
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)
         def _inner():
-            self.assertEqual(self.ev.run(), 'me top, you?')
-            self.assertIn('celery events:top', proctitle.last[0])
+            assert self.ev.run(), 'me top == you?'
+            assert 'celery events:top' in proctitle.last[0]
         return _inner()
 
     @_old_patch('celery.events.snapshot', 'evcam',
@@ -44,12 +65,12 @@ class test_events(AppCase):
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     def test_run_cam(self):
         a, kw = self.ev.run(camera='foo.bar.baz', logfile='logfile')
-        self.assertEqual(a[0], 'foo.bar.baz')
-        self.assertEqual(kw['freq'], 1.0)
-        self.assertIsNone(kw['maxrate'])
-        self.assertEqual(kw['loglevel'], 'INFO')
-        self.assertEqual(kw['logfile'], 'logfile')
-        self.assertIn('celery events:cam', proctitle.last[0])
+        assert a[0] == 'foo.bar.baz'
+        assert kw['freq'] == 1.0
+        assert kw['maxrate'] is None
+        assert kw['loglevel'] == 'INFO'
+        assert kw['logfile'] == 'logfile'
+        assert 'celery events:cam' in proctitle.last[0]
 
     @patch('celery.events.snapshot.evcam')
     @patch('celery.bin.events.detached')
@@ -60,10 +81,10 @@ class test_events(AppCase):
         evcam.assert_called()
 
     def test_get_options(self):
-        self.assertFalse(self.ev.get_options())
+        assert not self.ev.get_options()
 
     @_old_patch('celery.bin.events', 'events', MockCommand)
     def test_main(self):
         MockCommand.executed = []
         events.main()
-        self.assertTrue(MockCommand.executed)
+        assert MockCommand.executed

+ 63 - 84
celery/tests/bin/test_multi.py → t/unit/bin/test_multi.py

@@ -1,19 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
 import signal
 import sys
 
-from celery.bin.multi import (
-    main,
-    MultiTool,
-    __doc__ as doc,
-)
-from celery.five import WhateverIO
+from case import Mock, patch
 
-from celery.tests.case import AppCase, Mock, patch
+from celery.bin.multi import main, MultiTool, __doc__ as doc
+from celery.five import WhateverIO
 
 
-class test_MultiTool(AppCase):
+class test_MultiTool:
 
     def setup(self):
         self.fh = WhateverIO()
@@ -47,7 +44,7 @@ class test_MultiTool(AppCase):
     def assert_sig_argument(self, args, expected):
         p = self.t.OptionParser(args)
         p.parse()
-        self.assertEqual(self.t._find_sig_argument(p), expected)
+        assert self.t._find_sig_argument(p) == expected
 
     def test_execute_from_commandline(self):
         self.t.call_command = Mock(name='call_command')
@@ -55,51 +52,43 @@ class test_MultiTool(AppCase):
             'multi start --verbose 10 --foo'.split(),
             cmd='X',
         )
-        self.assertEqual(self.t.cmd, 'X')
-        self.assertEqual(self.t.prog_name, 'multi')
+        assert self.t.cmd == 'X'
+        assert self.t.prog_name == 'multi'
         self.t.call_command.assert_called_with('start', ['10', '--foo'])
 
     def test_execute_from_commandline__arguments(self):
-        self.assertTrue(self.t.execute_from_commandline('multi'.split()))
-        self.assertTrue(self.t.execute_from_commandline('multi -bar'.split()))
+        assert self.t.execute_from_commandline('multi'.split())
+        assert self.t.execute_from_commandline('multi -bar'.split())
 
     def test_call_command(self):
         cmd = self.t.commands['foo'] = Mock(name='foo')
         self.t.retcode = 303
-        self.assertIs(
-            self.t.call_command('foo', ['1', '2', '--foo=3']),
-            cmd.return_value,
-        )
+        assert (self.t.call_command('foo', ['1', '2', '--foo=3']) is
+                cmd.return_value)
         cmd.assert_called_with('1', '2', '--foo=3')
 
     def test_call_command__error(self):
-        self.assertEqual(
-            self.t.call_command('asdqwewqe', ['1', '2']),
-            1,
-        )
+        assert self.t.call_command('asdqwewqe', ['1', '2']) == 1
         self.t.carp.assert_called()
 
     def test_handle_reserved_options(self):
-        self.assertListEqual(
-            self.t._handle_reserved_options(
-                ['a', '-q', 'b', '--no-color', 'c']),
-            ['a', 'b', 'c'],
-        )
+        assert self.t._handle_reserved_options(
+            ['a', '-q', 'b', '--no-color', 'c']) == ['a', 'b', 'c']
 
     def test_start(self):
         self.cluster.start.return_value = [0, 0, 1, 0]
-        self.assertTrue(self.t.start('10', '-A', 'proj'))
+        assert self.t.start('10', '-A', 'proj')
         self.t.splash.assert_called_with()
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.cluster.start.assert_called_with()
 
     def test_start__exitcodes(self):
         self.cluster.start.return_value = [0, 0, 0]
-        self.assertFalse(self.t.start('foo', 'bar', 'baz'))
+        assert not self.t.start('foo', 'bar', 'baz')
         self.cluster.start.assert_called_with()
 
         self.cluster.start.return_value = [0, 1, 0]
-        self.assertTrue(self.t.start('foo', 'bar', 'baz'))
+        assert self.t.start('foo', 'bar', 'baz')
 
     def test_stop(self):
         self.t.stop('10', '-A', 'proj', retry=3)
@@ -130,17 +119,15 @@ class test_MultiTool(AppCase):
     def test_get(self):
         node = self.cluster.find.return_value = Mock(name='node')
         node.argv = ['A', 'B', 'C']
-        self.assertIs(
-            self.t.get('wanted', '10', '-A', 'proj'),
-            self.t.ok.return_value,
-        )
+        assert (self.t.get('wanted', '10', '-A', 'proj') is
+                self.t.ok.return_value)
         self.cluster.find.assert_called_with('wanted')
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.t.ok.assert_called_with(' '.join(node.argv))
 
     def test_get__KeyError(self):
         self.cluster.find.side_effect = KeyError()
-        self.assertTrue(self.t.get('wanted', '10', '-A', 'proj'))
+        assert self.t.get('wanted', '10', '-A', 'proj')
 
     def test_show(self):
         nodes = self.t.cluster_from_argv.return_value = [
@@ -150,10 +137,7 @@ class test_MultiTool(AppCase):
         nodes[0].argv_with_executable = ['python', 'foo', 'bar']
         nodes[1].argv_with_executable = ['python', 'xuzzy', 'baz']
 
-        self.assertIs(
-            self.t.show('10', '-A', 'proj'),
-            self.t.ok.return_value,
-        )
+        assert self.t.show('10', '-A', 'proj') is self.t.ok.return_value
         self.t.ok.assert_called_with(
             '\n'.join(' '.join(node.argv_with_executable) for node in nodes))
 
@@ -169,7 +153,7 @@ class test_MultiTool(AppCase):
         node1.expander.return_value = 'A'
         node2.expander.return_value = 'B'
         nodes = self.t.cluster_from_argv.return_value = [node1, node2]
-        self.assertIs(self.t.expand('%p', '10'), self.t.ok.return_value)
+        assert self.t.expand('%p', '10') is self.t.ok.return_value
         self.t.cluster_from_argv.assert_called_with(('10',))
         for node in nodes:
             node.expander.assert_called_with('%p')
@@ -196,26 +180,23 @@ class test_MultiTool(AppCase):
     def test_Cluster(self):
         m = MultiTool()
         c = m.cluster_from_argv(['A', 'B', 'C'])
-        self.assertIs(c.env, m.env)
-        self.assertEqual(c.cmd, 'celery worker')
-        self.assertEqual(c.on_stopping_preamble, m.on_stopping_preamble)
-        self.assertEqual(c.on_send_signal, m.on_send_signal)
-        self.assertEqual(c.on_still_waiting_for, m.on_still_waiting_for)
-        self.assertEqual(
-            c.on_still_waiting_progress,
-            m.on_still_waiting_progress,
-        )
-        self.assertEqual(c.on_still_waiting_end, m.on_still_waiting_end)
-        self.assertEqual(c.on_node_start, m.on_node_start)
-        self.assertEqual(c.on_node_restart, m.on_node_restart)
-        self.assertEqual(c.on_node_shutdown_ok, m.on_node_shutdown_ok)
-        self.assertEqual(c.on_node_status, m.on_node_status)
-        self.assertEqual(c.on_node_signal_dead, m.on_node_signal_dead)
-        self.assertEqual(c.on_node_signal, m.on_node_signal)
-        self.assertEqual(c.on_node_down, m.on_node_down)
-        self.assertEqual(c.on_child_spawn, m.on_child_spawn)
-        self.assertEqual(c.on_child_signalled, m.on_child_signalled)
-        self.assertEqual(c.on_child_failure, m.on_child_failure)
+        assert c.env is m.env
+        assert c.cmd == 'celery worker'
+        assert c.on_stopping_preamble == m.on_stopping_preamble
+        assert c.on_send_signal == m.on_send_signal
+        assert c.on_still_waiting_for == m.on_still_waiting_for
+        assert c.on_still_waiting_progress == m.on_still_waiting_progress
+        assert c.on_still_waiting_end == m.on_still_waiting_end
+        assert c.on_node_start == m.on_node_start
+        assert c.on_node_restart == m.on_node_restart
+        assert c.on_node_shutdown_ok == m.on_node_shutdown_ok
+        assert c.on_node_status == m.on_node_status
+        assert c.on_node_signal_dead == m.on_node_signal_dead
+        assert c.on_node_signal == m.on_node_signal
+        assert c.on_node_down == m.on_node_down
+        assert c.on_child_spawn == m.on_child_spawn
+        assert c.on_child_signalled == m.on_child_signalled
+        assert c.on_child_failure == m.on_child_failure
 
     def test_on_stopping_preamble(self):
         self.t.on_stopping_preamble([])
@@ -271,12 +252,12 @@ class test_MultiTool(AppCase):
         self.t.on_child_failure(Mock(), Mock())
 
     def test_constant_strings(self):
-        self.assertTrue(self.t.OK)
-        self.assertTrue(self.t.DOWN)
-        self.assertTrue(self.t.FAILED)
+        assert self.t.OK
+        assert self.t.DOWN
+        assert self.t.FAILED
 
 
-class test_MultiTool_functional(AppCase):
+class test_MultiTool_functional:
 
     def setup(self):
         self.fh = WhateverIO()
@@ -285,12 +266,12 @@ class test_MultiTool_functional(AppCase):
 
     def test_note(self):
         self.t.note('hello world')
-        self.assertEqual(self.fh.getvalue(), 'hello world\n')
+        assert self.fh.getvalue() == 'hello world\n'
 
     def test_note_quiet(self):
         self.t.quiet = True
         self.t.note('hello world')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
     def test_carp(self):
         self.t.say = Mock()
@@ -300,61 +281,59 @@ class test_MultiTool_functional(AppCase):
     def test_info(self):
         self.t.verbose = True
         self.t.info('hello info')
-        self.assertEqual(self.fh.getvalue(), 'hello info\n')
+        assert self.fh.getvalue() == 'hello info\n'
 
     def test_info_not_verbose(self):
         self.t.verbose = False
         self.t.info('hello info')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
     def test_error(self):
         self.t.carp = Mock()
         self.t.usage = Mock()
-        self.assertEqual(self.t.error('foo'), 1)
+        assert self.t.error('foo') == 1
         self.t.carp.assert_called_with('foo')
         self.t.usage.assert_called_with()
 
         self.t.carp = Mock()
-        self.assertEqual(self.t.error(), 1)
+        assert self.t.error() == 1
         self.t.carp.assert_not_called()
 
     def test_nosplash(self):
         self.t.nosplash = True
         self.t.splash()
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
     def test_splash(self):
         self.t.nosplash = False
         self.t.splash()
-        self.assertIn('celery multi', self.fh.getvalue())
+        assert 'celery multi' in self.fh.getvalue()
 
     def test_usage(self):
         self.t.usage()
-        self.assertTrue(self.fh.getvalue())
+        assert self.fh.getvalue()
 
     def test_help(self):
         self.t.help([])
-        self.assertIn(doc, self.fh.getvalue())
+        assert doc in self.fh.getvalue()
 
     def test_expand(self):
         self.t.expand('foo%n', 'ask', 'klask', 'dask')
-        self.assertEqual(
-            self.fh.getvalue(), 'fooask\nfooklask\nfoodask\n',
-        )
+        assert self.fh.getvalue() == 'fooask\nfooklask\nfoodask\n'
 
     @patch('celery.apps.multi.gethostname')
     def test_get(self, gethostname):
         gethostname.return_value = 'e.com'
         self.t.get('xuzzy@e.com', 'foo', 'bar', 'baz')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
         self.t.get('foo@e.com', 'foo', 'bar', 'baz')
-        self.assertTrue(self.fh.getvalue())
+        assert self.fh.getvalue()
 
     @patch('celery.apps.multi.gethostname')
     def test_names(self, gethostname):
         gethostname.return_value = 'e.com'
         self.t.names('foo', 'bar', 'baz')
-        self.assertIn('foo@e.com\nbar@e.com\nbaz@e.com', self.fh.getvalue())
+        assert 'foo@e.com\nbar@e.com\nbaz@e.com' in self.fh.getvalue()
 
     def test_execute_from_commandline(self):
         start = self.t.commands['start'] = Mock()
@@ -379,14 +358,14 @@ class test_MultiTool_functional(AppCase):
             ['multi', 'start', 'foo',
              '--nosplash', '--quiet', '-q', '--verbose', '--no-color'],
         )
-        self.assertTrue(self.t.nosplash)
-        self.assertTrue(self.t.quiet)
-        self.assertTrue(self.t.verbose)
-        self.assertTrue(self.t.no_color)
+        assert self.t.nosplash
+        assert self.t.quiet
+        assert self.t.verbose
+        assert self.t.no_color
 
     @patch('celery.bin.multi.MultiTool')
     def test_main(self, MultiTool):
         m = MultiTool.return_value = Mock()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             main()
         m.execute_from_commandline.assert_called_with(sys.argv)

+ 235 - 239
celery/tests/bin/test_worker.py → t/unit/bin/test_worker.py

@@ -2,9 +2,11 @@ from __future__ import absolute_import, unicode_literals
 
 import logging
 import os
+import pytest
 import sys
 
 from billiard.process import current_process
+from case import Mock, mock, patch, skip
 from kombu import Exchange, Queue
 
 from celery import platforms
@@ -18,13 +20,10 @@ from celery.exceptions import (
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.worker import state
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
 
-
-class WorkerAppCase(AppCase):
-
-    def teardown(self):
-        trace.reset_worker_optimizations()
+@pytest.fixture(autouse=True)
+def reset_worker_optimizations(request):
+    request.addfinalizer(trace.reset_worker_optimizations)
 
 
 class Worker(cd.Worker):
@@ -34,34 +33,34 @@ class Worker(cd.Worker):
         self.on_start()
 
 
-class test_Worker(WorkerAppCase):
+class test_Worker:
     Worker = Worker
 
-    @mock.stdouts
-    def test_queues_string(self, stdout, stderr):
-        w = self.app.Worker()
-        w.setup_queues('foo,bar,baz')
-        self.assertIn('foo', self.app.amqp.queues)
-
-    @mock.stdouts
-    def test_cpu_count(self, stdout, stderr):
-        with patch('celery.worker.cpu_count') as cpu_count:
-            cpu_count.side_effect = NotImplementedError()
-            w = self.app.Worker(concurrency=None)
-            self.assertEqual(w.concurrency, 2)
-        w = self.app.Worker(concurrency=5)
-        self.assertEqual(w.concurrency, 5)
-
-    @mock.stdouts
-    def test_windows_B_option(self, stdout, stderr):
-        self.app.IS_WINDOWS = True
-        with self.assertRaises(SystemExit):
-            worker(app=self.app).run(beat=True)
+    def test_queues_string(self):
+        with mock.stdouts():
+            w = self.app.Worker()
+            w.setup_queues('foo,bar,baz')
+            assert 'foo' in self.app.amqp.queues
+
+    def test_cpu_count(self):
+        with mock.stdouts():
+            with patch('celery.worker.cpu_count') as cpu_count:
+                cpu_count.side_effect = NotImplementedError()
+                w = self.app.Worker(concurrency=None)
+                assert w.concurrency == 2
+            w = self.app.Worker(concurrency=5)
+            assert w.concurrency == 5
+
+    def test_windows_B_option(self):
+        with mock.stdouts():
+            self.app.IS_WINDOWS = True
+            with pytest.raises(SystemExit):
+                worker(app=self.app).run(beat=True)
 
     def test_setup_concurrency_very_early(self):
         x = worker()
         x.run = Mock()
-        with self.assertRaises(ImportError):
+        with pytest.raises(ImportError):
             x.execute_from_commandline(['worker', '-P', 'xyzybox'])
 
     def test_run_from_argv_basic(self):
@@ -80,15 +79,15 @@ class test_Worker(WorkerAppCase):
         with patch('celery.bin.worker.detached_celeryd') as detached:
             x.maybe_detach([])
             detached.assert_not_called()
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 x.maybe_detach(['--detach'])
             detached.assert_called()
 
-    @mock.stdouts
-    def test_invalid_loglevel_gives_error(self, stdout, stderr):
-        x = worker(app=self.app)
-        with self.assertRaises(SystemExit):
-            x.run(loglevel='GRIM_REAPER')
+    def test_invalid_loglevel_gives_error(self):
+        with mock.stdouts():
+            x = worker(app=self.app)
+            with pytest.raises(SystemExit):
+                x.run(loglevel='GRIM_REAPER')
 
     def test_no_loglevel(self):
         self.app.Worker = Mock()
@@ -96,25 +95,24 @@ class test_Worker(WorkerAppCase):
 
     def test_tasklist(self):
         worker = self.app.Worker()
-        self.assertTrue(worker.app.tasks)
-        self.assertTrue(worker.app.finalized)
-        self.assertTrue(worker.tasklist(include_builtins=True))
+        assert worker.app.tasks
+        assert worker.app.finalized
+        assert worker.tasklist(include_builtins=True)
         worker.tasklist(include_builtins=False)
 
     def test_extra_info(self):
         worker = self.app.Worker()
         worker.loglevel = logging.WARNING
-        self.assertFalse(worker.extra_info())
+        assert not worker.extra_info()
         worker.loglevel = logging.INFO
-        self.assertTrue(worker.extra_info())
+        assert worker.extra_info()
 
-    @mock.stdouts
-    def test_loglevel_string(self, stdout, stderr):
-        worker = self.Worker(app=self.app, loglevel='INFO')
-        self.assertEqual(worker.loglevel, logging.INFO)
+    def test_loglevel_string(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app, loglevel='INFO')
+            assert worker.loglevel == logging.INFO
 
-    @mock.stdouts
-    def test_run_worker(self, stdout, stderr):
+    def test_run_worker(self, patching):
         handlers = {}
 
         class Signals(platforms.Signals):
@@ -122,158 +120,156 @@ class test_Worker(WorkerAppCase):
             def __setitem__(self, sig, handler):
                 handlers[sig] = handler
 
-        p = platforms.signals
-        platforms.signals = Signals()
-        try:
+        patching.setattr('celery.platforms.signals', Signals())
+        with mock.stdouts():
             w = self.Worker(app=self.app)
             w._isatty = False
             w.on_start()
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
-                self.assertIn(sig, handlers)
+                assert sig in handlers
 
             handlers.clear()
             w = self.Worker(app=self.app)
             w._isatty = True
             w.on_start()
             for sig in 'SIGINT', 'SIGTERM':
-                self.assertIn(sig, handlers)
-            self.assertNotIn('SIGHUP', handlers)
-        finally:
-            platforms.signals = p
+                assert sig in handlers
+            assert 'SIGHUP' not in handlers
 
-    @mock.stdouts
-    def test_startup_info(self, stdout, stderr):
-        worker = self.Worker(app=self.app)
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
-        worker.loglevel = logging.DEBUG
-        self.assertTrue(worker.startup_info())
-        worker.loglevel = logging.INFO
-        self.assertTrue(worker.startup_info())
-
-        prev_loader = self.app.loader
-        worker = self.Worker(app=self.app, queues='foo,bar,baz,xuzzy,do,re,mi')
-        with patch('celery.apps.worker.qualname') as qualname:
-            qualname.return_value = 'acme.backed_beans.Loader'
-            self.assertTrue(worker.startup_info())
-
-        with patch('celery.apps.worker.qualname') as qualname:
-            qualname.return_value = 'celery.loaders.Loader'
-            self.assertTrue(worker.startup_info())
-
-        from celery.loaders.app import AppLoader
-        self.app.loader = AppLoader(app=self.app)
-        self.assertTrue(worker.startup_info())
+    def test_startup_info(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert worker.startup_info()
+            worker.loglevel = logging.DEBUG
+            assert worker.startup_info()
+            worker.loglevel = logging.INFO
+            assert worker.startup_info()
+
+            prev_loader = self.app.loader
+            worker = self.Worker(
+                app=self.app,
+                queues='foo,bar,baz,xuzzy,do,re,mi',
+            )
+            with patch('celery.apps.worker.qualname') as qualname:
+                qualname.return_value = 'acme.backed_beans.Loader'
+                assert worker.startup_info()
+
+            with patch('celery.apps.worker.qualname') as qualname:
+                qualname.return_value = 'celery.loaders.Loader'
+                assert worker.startup_info()
+
+            from celery.loaders.app import AppLoader
+            self.app.loader = AppLoader(app=self.app)
+            assert worker.startup_info()
+
+            self.app.loader = prev_loader
+            worker.task_events = True
+            assert worker.startup_info()
+
+            # test when there are too few output lines
+            # to draft the ascii art onto
+            prev, cd.ARTLINES = cd.ARTLINES, ['the quick brown fox']
+            try:
+                assert worker.startup_info()
+            finally:
+                cd.ARTLINES = prev
 
-        self.app.loader = prev_loader
-        worker.task_events = True
-        self.assertTrue(worker.startup_info())
+    def test_run(self):
+        with mock.stdouts():
+            self.Worker(app=self.app).on_start()
+            self.Worker(app=self.app, purge=True).on_start()
+            worker = self.Worker(app=self.app)
+            worker.on_start()
 
-        # test when there are too few output lines
-        # to draft the ascii art onto
-        prev, cd.ARTLINES = cd.ARTLINES, ['the quick brown fox']
-        try:
-            self.assertTrue(worker.startup_info())
-        finally:
-            cd.ARTLINES = prev
-
-    @mock.stdouts
-    def test_run(self, stdout, stderr):
-        self.Worker(app=self.app).on_start()
-        self.Worker(app=self.app, purge=True).on_start()
-        worker = self.Worker(app=self.app)
-        worker.on_start()
-
-    @mock.stdouts
-    def test_purge_messages(self, stdout, stderr):
-        self.Worker(app=self.app).purge_messages()
-
-    @mock.stdouts
-    def test_init_queues(self, stdout, stderr):
-        app = self.app
-        c = app.conf
-        app.amqp.queues = app.amqp.Queues({
-            'celery': {'exchange': 'celery',
-                       'routing_key': 'celery'},
-            'video': {'exchange': 'video',
-                      'routing_key': 'video'},
-        })
-        worker = self.Worker(app=self.app)
-        worker.setup_queues(['video'])
-        self.assertIn('video', app.amqp.queues)
-        self.assertIn('video', app.amqp.queues.consume_from)
-        self.assertIn('celery', app.amqp.queues)
-        self.assertNotIn('celery', app.amqp.queues.consume_from)
-
-        c.task_create_missing_queues = False
-        del(app.amqp.queues)
-        with self.assertRaises(ImproperlyConfigured):
-            self.Worker(app=self.app).setup_queues(['image'])
-        del(app.amqp.queues)
-        c.task_create_missing_queues = True
-        worker = self.Worker(app=self.app)
-        worker.setup_queues(['image'])
-        self.assertIn('image', app.amqp.queues.consume_from)
-        self.assertEqual(
-            Queue('image', Exchange('image'), routing_key='image'),
-            app.amqp.queues['image'],
-        )
+    def test_purge_messages(self):
+        with mock.stdouts():
+            self.Worker(app=self.app).purge_messages()
+
+    def test_init_queues(self):
+        with mock.stdouts():
+            app = self.app
+            c = app.conf
+            app.amqp.queues = app.amqp.Queues({
+                'celery': {
+                    'exchange': 'celery',
+                    'routing_key': 'celery',
+                },
+                'video': {
+                    'exchange': 'video',
+                    'routing_key': 'video',
+                },
+            })
+            worker = self.Worker(app=self.app)
+            worker.setup_queues(['video'])
+            assert 'video' in app.amqp.queues
+            assert 'video' in app.amqp.queues.consume_from
+            assert 'celery' in app.amqp.queues
+            assert 'celery' not in app.amqp.queues.consume_from
+
+            c.task_create_missing_queues = False
+            del(app.amqp.queues)
+            with pytest.raises(ImproperlyConfigured):
+                self.Worker(app=self.app).setup_queues(['image'])
+            del(app.amqp.queues)
+            c.task_create_missing_queues = True
+            worker = self.Worker(app=self.app)
+            worker.setup_queues(['image'])
+            assert 'image' in app.amqp.queues.consume_from
+            assert app.amqp.queues['image'] == Queue(
+                'image', Exchange('image'),
+                routing_key='image',
+            )
 
     def test_include_argument(self):
         worker1 = self.Worker(app=self.app, include='os')
-        self.assertListEqual(worker1.include, ['os'])
+        assert worker1.include == ['os']
         worker2 = self.Worker(app=self.app,
                               include='os,sys')
-        self.assertListEqual(worker2.include, ['os', 'sys'])
+        assert worker2.include == ['os', 'sys']
         self.Worker(app=self.app, include=['os', 'sys'])
 
-    @mock.stdouts
-    def test_unknown_loglevel(self, stdout, stderr):
-        with self.assertRaises(SystemExit):
-            worker(app=self.app).run(loglevel='ALIEN')
-        worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
-        self.assertEqual(worker1.loglevel, 0xFFFF)
+    def test_unknown_loglevel(self):
+        with mock.stdouts():
+            with pytest.raises(SystemExit):
+                worker(app=self.app).run(loglevel='ALIEN')
+            worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
+            assert worker1.loglevel == 0xFFFF
 
     @patch('os._exit')
     @skip.if_win32()
-    @mock.stdouts
-    def test_warns_if_running_as_privileged_user(self, _exit, stdout, stderr):
-        with patch('os.getuid') as getuid:
+    def test_warns_if_running_as_privileged_user(self, _exit, patching):
+        getuid = patching('os.getuid')
+
+        with mock.stdouts() as (_, stderr):
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
             worker = self.Worker(app=self.app)
             worker.on_start()
             _exit.assert_called_with(1)
-            from celery import platforms
-            platforms.C_FORCE_ROOT = True
-            try:
-                with self.assertWarnsRegex(
-                        RuntimeWarning,
-                        r'absolutely not recommended'):
-                    worker = self.Worker(app=self.app)
-                    worker.on_start()
-            finally:
-                platforms.C_FORCE_ROOT = False
+            patching.setattr('celery.platforms.C_FORCE_ROOT', True)
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert 'a very bad idea' in stderr.getvalue()
+            patching.setattr('celery.platforms.C_FORCE_ROOT', False)
             self.app.conf.accept_content = ['json']
-            with self.assertWarnsRegex(
-                    RuntimeWarning,
-                    r'absolutely not recommended'):
-                worker = self.Worker(app=self.app)
-                worker.on_start()
-
-    @mock.stdouts
-    def test_redirect_stdouts(self, stdout, stderr):
-        self.Worker(app=self.app, redirect_stdouts=False)
-        with self.assertRaises(AttributeError):
-            sys.stdout.logger
-
-    @mock.stdouts
-    def test_on_start_custom_logging(self, stdout, stderr):
-        self.app.log.redirect_stdouts = Mock()
-        worker = self.Worker(app=self.app, redirect_stoutds=True)
-        worker._custom_logging = True
-        worker.on_start()
-        self.app.log.redirect_stdouts.assert_not_called()
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert 'superuser' in stderr.getvalue()
+
+    def test_redirect_stdouts(self):
+        with mock.stdouts():
+            self.Worker(app=self.app, redirect_stdouts=False)
+            with pytest.raises(AttributeError):
+                sys.stdout.logger
+
+    def test_on_start_custom_logging(self):
+        with mock.stdouts():
+            self.app.log.redirect_stdouts = Mock()
+            worker = self.Worker(app=self.app, redirect_stoutds=True)
+            worker._custom_logging = True
+            worker.on_start()
+            self.app.log.redirect_stdouts.assert_not_called()
 
     def test_setup_logging_no_color(self):
         worker = self.Worker(
@@ -282,15 +278,15 @@ class test_Worker(WorkerAppCase):
         prev, self.app.log.setup = self.app.log.setup, Mock()
         try:
             worker.setup_logging()
-            self.assertFalse(self.app.log.setup.call_args[1]['colorize'])
+            assert not self.app.log.setup.call_args[1]['colorize']
         finally:
             self.app.log.setup = prev
 
-    @mock.stdouts
-    def test_startup_info_pool_is_str(self, stdout, stderr):
-        worker = self.Worker(app=self.app, redirect_stdouts=False)
-        worker.pool_cls = 'foo'
-        worker.startup_info()
+    def test_startup_info_pool_is_str(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app, redirect_stdouts=False)
+            worker.pool_cls = 'foo'
+            worker.startup_info()
 
     def test_redirect_stdouts_already_handled(self):
         logging_setup = [False]
@@ -303,14 +299,13 @@ class test_Worker(WorkerAppCase):
             worker = self.Worker(app=self.app, redirect_stdouts=False)
             worker.app.log.already_setup = False
             worker.setup_logging()
-            self.assertTrue(logging_setup[0])
-            with self.assertRaises(AttributeError):
+            assert logging_setup[0]
+            with pytest.raises(AttributeError):
                 sys.stdout.logger
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
 
-    @mock.stdouts
-    def test_platform_tweaks_macOS(self, stdout, stderr):
+    def test_platform_tweaks_macOS(self):
 
         class macOSWorker(Worker):
             proxy_workaround_installed = False
@@ -318,27 +313,27 @@ class test_Worker(WorkerAppCase):
             def macOS_proxy_detection_workaround(self):
                 self.proxy_workaround_installed = True
 
-        worker = macOSWorker(app=self.app, redirect_stdouts=False)
+        with mock.stdouts():
+            worker = macOSWorker(app=self.app, redirect_stdouts=False)
 
-        def install_HUP_nosupport(controller):
-            controller.hup_not_supported_installed = True
+            def install_HUP_nosupport(controller):
+                controller.hup_not_supported_installed = True
 
-        class Controller(object):
-            pass
+            class Controller(object):
+                pass
 
-        prev = cd.install_HUP_not_supported_handler
-        cd.install_HUP_not_supported_handler = install_HUP_nosupport
-        try:
-            worker.app.IS_macOS = True
-            controller = Controller()
-            worker.install_platform_tweaks(controller)
-            self.assertTrue(controller.hup_not_supported_installed)
-            self.assertTrue(worker.proxy_workaround_installed)
-        finally:
-            cd.install_HUP_not_supported_handler = prev
+            prev = cd.install_HUP_not_supported_handler
+            cd.install_HUP_not_supported_handler = install_HUP_nosupport
+            try:
+                worker.app.IS_macOS = True
+                controller = Controller()
+                worker.install_platform_tweaks(controller)
+                assert controller.hup_not_supported_installed
+                assert worker.proxy_workaround_installed
+            finally:
+                cd.install_HUP_not_supported_handler = prev
 
-    @mock.stdouts
-    def test_general_platform_tweaks(self, stdout, stderr):
+    def test_general_platform_tweaks(self):
 
         restart_worker_handler_installed = [False]
 
@@ -348,33 +343,34 @@ class test_Worker(WorkerAppCase):
         class Controller(object):
             pass
 
-        prev = cd.install_worker_restart_handler
-        cd.install_worker_restart_handler = install_worker_restart_handler
-        try:
-            worker = self.Worker(app=self.app)
-            worker.app.IS_macOS = False
-            worker.install_platform_tweaks(Controller())
-            self.assertTrue(restart_worker_handler_installed[0])
-        finally:
-            cd.install_worker_restart_handler = prev
+        with mock.stdouts():
+            prev = cd.install_worker_restart_handler
+            cd.install_worker_restart_handler = install_worker_restart_handler
+            try:
+                worker = self.Worker(app=self.app)
+                worker.app.IS_macOS = False
+                worker.install_platform_tweaks(Controller())
+                assert restart_worker_handler_installed[0]
+            finally:
+                cd.install_worker_restart_handler = prev
 
-    @mock.stdouts
-    def test_on_consumer_ready(self, stdout, stderr):
+    def test_on_consumer_ready(self):
         worker_ready_sent = [False]
 
         @signals.worker_ready.connect
         def on_worker_ready(**kwargs):
             worker_ready_sent[0] = True
 
-        self.Worker(app=self.app).on_consumer_ready(object())
-        self.assertTrue(worker_ready_sent[0])
+        with mock.stdouts():
+            self.Worker(app=self.app).on_consumer_ready(object())
+            assert worker_ready_sent[0]
 
 
 @mock.stdouts
-class test_funs(WorkerAppCase):
+class test_funs:
 
     def test_active_thread_count(self):
-        self.assertTrue(cd.active_thread_count())
+        assert cd.active_thread_count()
 
     @skip.unless_module('setproctitle')
     def test_set_process_status(self):
@@ -382,16 +378,16 @@ class test_funs(WorkerAppCase):
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
             st = worker.set_process_status('Running')
-            self.assertIn('celeryd', st)
-            self.assertIn('xyzza', st)
-            self.assertIn('Running', st)
+            assert 'celeryd' in st
+            assert 'xyzza' in st
+            assert 'Running' in st
             prev2, sys.argv = sys.argv, ['Arg0', 'Arg1']
             try:
                 st = worker.set_process_status('Running')
-                self.assertIn('celeryd', st)
-                self.assertIn('xyzza', st)
-                self.assertIn('Running', st)
-                self.assertIn('Arg1', st)
+                assert 'celeryd' in st
+                assert 'xyzza' in st
+                assert 'Running' in st
+                assert 'Arg1' in st
             finally:
                 sys.argv = prev2
         finally:
@@ -402,8 +398,8 @@ class test_funs(WorkerAppCase):
         cmd.app = self.app
         opts, args = cmd.parse_options('worker', ['--concurrency=512',
                                        '--heartbeat-interval=10'])
-        self.assertEqual(opts.concurrency, 512)
-        self.assertEqual(opts.heartbeat_interval, 10)
+        assert opts.concurrency == 512
+        assert opts.heartbeat_interval == 10
 
     def test_main(self):
         p, cd.Worker = cd.Worker, Worker
@@ -416,7 +412,7 @@ class test_funs(WorkerAppCase):
 
 
 @mock.stdouts
-class test_signal_handlers(WorkerAppCase):
+class test_signal_handlers:
 
     class _Worker(object):
         stopped = False
@@ -459,16 +455,16 @@ class test_signal_handlers(WorkerAppCase):
             p, platforms.signals = platforms.signals, Signals()
             try:
                 handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_stop)
-                self.assertEqual(state.should_stop, EX_FAILURE)
+                assert state.should_stop
+                assert state.should_stop == EX_FAILURE
             finally:
                 platforms.signals = p
                 state.should_stop = None
 
             try:
                 next_handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_terminate)
-                self.assertEqual(state.should_terminate, EX_FAILURE)
+                assert state.should_terminate
+                assert state.should_terminate == EX_FAILURE
             finally:
                 state.should_terminate = None
 
@@ -476,12 +472,12 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             p, platforms.signals = platforms.signals, Signals()
             try:
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
             finally:
                 platforms.signals = p
 
-            with self.assertRaises(WorkerTerminate):
+            with pytest.raises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
 
     @skip.unless_module('multiprocessing')
@@ -494,7 +490,7 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_int_handler, worker)
                 handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_stop)
+                assert state.should_stop
             finally:
                 process.name = name
                 state.should_stop = None
@@ -504,7 +500,7 @@ class test_signal_handlers(WorkerAppCase):
             try:
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_int_handler, worker)
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
             finally:
                 process.name = name
@@ -527,7 +523,7 @@ class test_signal_handlers(WorkerAppCase):
                     cd.install_worker_term_hard_handler, worker)
                 try:
                     handlers['SIGQUIT']('SIGQUIT', object())
-                    self.assertTrue(state.should_terminate)
+                    assert state.should_terminate
                 finally:
                     state.should_terminate = None
             with patch('celery.apps.worker.active_thread_count') as c:
@@ -536,7 +532,7 @@ class test_signal_handlers(WorkerAppCase):
                 handlers = self.psig(
                     cd.install_worker_term_hard_handler, worker)
                 try:
-                    with self.assertRaises(WorkerTerminate):
+                    with pytest.raises(WorkerTerminate):
                         handlers['SIGQUIT']('SIGQUIT', object())
                 finally:
                     state.should_terminate = None
@@ -550,7 +546,7 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertEqual(state.should_stop, EX_OK)
+                assert state.should_stop == EX_OK
             finally:
                 state.should_stop = None
 
@@ -560,7 +556,7 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
             finally:
                 state.should_stop = None
@@ -570,7 +566,7 @@ class test_signal_handlers(WorkerAppCase):
     @skip.if_jython()
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
-        self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
+        assert handlers['SIGUSR1']('SIGUSR1', object()) is None
         stderr.write.assert_called()
 
     @skip.unless_module('multiprocessing')
@@ -583,12 +579,12 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertEqual(state.should_stop, EX_OK)
+                assert state.should_stop == EX_OK
             with patch('celery.apps.worker.active_thread_count') as c:
                 c.return_value = 1
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
         finally:
             process.name = name
@@ -609,11 +605,11 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers['SIGHUP']('SIGHUP', object())
-            self.assertEqual(state.should_stop, EX_OK)
+            assert state.should_stop == EX_OK
             register.assert_called()
             callback = register.call_args[0][0]
             callback()
-            self.assertTrue(argv)
+            assert argv
         finally:
             os.execv = execv
             state.should_stop = None
@@ -625,7 +621,7 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
             try:
                 handlers['SIGQUIT']('SIGQUIT', object())
-                self.assertTrue(state.should_terminate)
+                assert state.should_terminate
             finally:
                 state.should_terminate = None
 
@@ -634,5 +630,5 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
-            with self.assertRaises(WorkerTerminate):
+            with pytest.raises(WorkerTerminate):
                 handlers['SIGQUIT']('SIGQUIT', object())

+ 0 - 0
celery/tests/contrib/__init__.py → t/unit/compat_modules/__init__.py


+ 16 - 14
celery/tests/compat_modules/test_compat.py → t/unit/compat_modules/test_compat.py

@@ -1,20 +1,22 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from datetime import timedelta
 
+from celery.five import bytes_if_py2
 from celery.schedules import schedule
 from celery.task import (
     periodic_task,
     PeriodicTask
 )
 
-from celery.tests.case import AppCase, depends_on_current_app  # noqa
-
 
-@depends_on_current_app
-class test_periodic_tasks(AppCase):
+class test_periodic_tasks:
 
     def setup(self):
+        self.app.set_current()  # @depends_on_current_app
+
         @periodic_task(app=self.app, shared=False,
                        run_every=schedule(timedelta(hours=1), app=self.app))
         def my_periodic():
@@ -25,32 +27,32 @@ class test_periodic_tasks(AppCase):
         return self.app.now()
 
     def test_must_have_run_every(self):
-        with self.assertRaises(NotImplementedError):
-            type('Foo', (PeriodicTask,), {'__module__': __name__})
+        with pytest.raises(NotImplementedError):
+            type(bytes_if_py2('Foo'), (PeriodicTask,), {
+                '__module__': __name__,
+            })
 
     def test_remaining_estimate(self):
         s = self.my_periodic.run_every
-        self.assertIsInstance(
+        assert isinstance(
             s.remaining_estimate(s.maybe_make_aware(self.now())),
             timedelta)
 
     def test_is_due_not_due(self):
         due, remaining = self.my_periodic.run_every.is_due(self.now())
-        self.assertFalse(due)
+        assert not due
         # This assertion may fail if executed in the
         # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
+        assert remaining > 59
 
     def test_is_due(self):
         p = self.my_periodic
         due, remaining = p.run_every.is_due(
             self.now() - p.run_every.run_every,
         )
-        self.assertTrue(due)
-        self.assertEqual(
-            remaining, p.run_every.run_every.total_seconds(),
-        )
+        assert due
+        assert remaining == p.run_every.run_every.total_seconds()
 
     def test_schedule_repr(self):
         p = self.my_periodic
-        self.assertTrue(repr(p.run_every))
+        assert repr(p.run_every)

+ 13 - 14
celery/tests/compat_modules/test_compat_utils.py → t/unit/compat_modules/test_compat_utils.py

@@ -1,39 +1,38 @@
 from __future__ import absolute_import, unicode_literals
 
 import celery
+import pytest
 
 from celery.app.task import Task as ModernTask
 from celery.task.base import Task as CompatTask
 
-from celery.tests.case import AppCase, depends_on_current_app
 
-
-@depends_on_current_app
-class test_MagicModule(AppCase):
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_MagicModule:
 
     def test_class_property_set_without_type(self):
-        self.assertTrue(ModernTask.__dict__['app'].__get__(CompatTask()))
+        assert ModernTask.__dict__['app'].__get__(CompatTask())
 
     def test_class_property_set_on_class(self):
-        self.assertIs(ModernTask.__dict__['app'].__set__(None, None),
-                      ModernTask.__dict__['app'])
+        assert (ModernTask.__dict__['app'].__set__(None, None) is
+                ModernTask.__dict__['app'])
 
-    def test_class_property_set(self):
+    def test_class_property_set(self, app):
 
         class X(CompatTask):
             pass
-        ModernTask.__dict__['app'].__set__(X(), self.app)
-        self.assertIs(X.app, self.app)
+        ModernTask.__dict__['app'].__set__(X(), app)
+        assert X.app is app
 
     def test_dir(self):
-        self.assertTrue(dir(celery.messaging))
+        assert dir(celery.messaging)
 
     def test_direct(self):
-        self.assertTrue(celery.task)
+        assert celery.task
 
     def test_app_attrs(self):
-        self.assertEqual(celery.task.control.broadcast,
-                         celery.current_app.control.broadcast)
+        assert (celery.task.control.broadcast ==
+                celery.current_app.control.broadcast)
 
     def test_decorators_task(self):
         @celery.decorators.task

+ 37 - 0
t/unit/compat_modules/test_decorators.py

@@ -0,0 +1,37 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+import warnings
+
+from celery.task import base
+
+
+def add(x, y):
+    return x + y
+
+
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_decorators:
+
+    def test_task_alias(self):
+        from celery import task
+        assert task.__file__
+        assert task(add)
+
+    def setup(self):
+        with warnings.catch_warnings(record=True):
+            from celery import decorators
+            self.decorators = decorators
+
+    def assert_compat_decorator(self, decorator, type, **opts):
+        task = decorator(**opts)(add)
+        assert task(8, 8) == 16
+        assert isinstance(task, type)
+
+    def test_task(self):
+        self.assert_compat_decorator(self.decorators.task, base.BaseTask)
+
+    def test_periodic_task(self):
+        self.assert_compat_decorator(
+            self.decorators.periodic_task, base.BaseTask, run_every=1,
+        )

+ 4 - 3
celery/tests/compat_modules/test_messaging.py → t/unit/compat_modules/test_messaging.py

@@ -1,11 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from celery import messaging
-from celery.tests.case import AppCase, depends_on_current_app
 
 
-@depends_on_current_app
-class test_compat_messaging_module(AppCase):
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_compat_messaging_module:
 
     def test_get_consume_set(self):
         conn = messaging.establish_connection()

+ 0 - 0
celery/tests/events/__init__.py → t/unit/concurrency/__init__.py


+ 32 - 33
celery/tests/concurrency/test_concurrency.py → t/unit/concurrency/test_concurrency.py

@@ -1,15 +1,17 @@
 from __future__ import absolute_import, unicode_literals
 
 import os
+import pytest
 
 from itertools import count
 
+from case import Mock, patch
+
 from celery.concurrency.base import apply_target, BasePool
 from celery.exceptions import WorkerShutdown, WorkerTerminate
-from celery.tests.case import AppCase, Mock, patch
 
 
-class test_BasePool(AppCase):
+class test_BasePool:
 
     def test_apply_target(self):
 
@@ -29,14 +31,12 @@ class test_BasePool(AppCase):
                      callback=gen_callback('callback'),
                      accept_callback=gen_callback('accept_callback'))
 
-        self.assertDictContainsSubset(
-            {'target': (1, (8, 16)), 'callback': (2, (42,))},
-            scratch,
-        )
+        assert scratch['target'] == (1, (8, 16))
+        assert scratch['callback'] == (2, (42,))
         pa1 = scratch['accept_callback']
-        self.assertEqual(0, pa1[0])
-        self.assertEqual(pa1[1][0], os.getpid())
-        self.assertTrue(pa1[1][1])
+        assert pa1[0] == 0
+        assert pa1[1][0] == os.getpid()
+        assert pa1[1][1]
 
         # No accept callback
         scratch.clear()
@@ -44,32 +44,33 @@ class test_BasePool(AppCase):
                      args=(8, 16),
                      callback=gen_callback('callback'),
                      accept_callback=None)
-        self.assertDictEqual(scratch,
-                             {'target': (3, (8, 16)),
-                              'callback': (4, (42,))})
+        assert scratch == {
+            'target': (3, (8, 16)),
+            'callback': (4, (42,)),
+        }
 
     def test_apply_target__propagate(self):
         target = Mock(name='target')
         target.side_effect = KeyError()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target, propagate=(KeyError,))
 
     def test_apply_target__raises(self):
         target = Mock(name='target')
         target.side_effect = KeyError()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target)
 
     def test_apply_target__raises_WorkerShutdown(self):
         target = Mock(name='target')
         target.side_effect = WorkerShutdown()
-        with self.assertRaises(WorkerShutdown):
+        with pytest.raises(WorkerShutdown):
             apply_target(target)
 
     def test_apply_target__raises_WorkerTerminate(self):
         target = Mock(name='target')
         target.side_effect = WorkerTerminate()
-        with self.assertRaises(WorkerTerminate):
+        with pytest.raises(WorkerTerminate):
             apply_target(target)
 
     def test_apply_target__raises_BaseException(self):
@@ -85,7 +86,7 @@ class test_BasePool(AppCase):
         callback = Mock(name='callback')
         reraise.side_effect = KeyError()
         target.side_effect = BaseException()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target, callback=callback)
         callback.assert_not_called()
 
@@ -95,7 +96,7 @@ class test_BasePool(AppCase):
         x.apply_async(object)
 
     def test_num_processes(self):
-        self.assertEqual(BasePool(7).num_processes, 7)
+        assert BasePool(7).num_processes == 7
 
     def test_interface_on_start(self):
         BasePool(10).on_start()
@@ -107,22 +108,22 @@ class test_BasePool(AppCase):
         BasePool(10).on_apply()
 
     def test_interface_info(self):
-        self.assertDictEqual(BasePool(10).info, {
+        assert BasePool(10).info == {
             'max-concurrency': 10,
-        })
+        }
 
     def test_interface_flush(self):
-        self.assertIsNone(BasePool(10).flush())
+        assert BasePool(10).flush() is None
 
     def test_active(self):
         p = BasePool(10)
-        self.assertFalse(p.active)
+        assert not p.active
         p._state = p.RUN
-        self.assertTrue(p.active)
+        assert p.active
 
     def test_restart(self):
         p = BasePool(10)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             p.restart()
 
     def test_interface_on_terminate(self):
@@ -130,29 +131,27 @@ class test_BasePool(AppCase):
         p.on_terminate()
 
     def test_interface_terminate_job(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             BasePool(10).terminate_job(101)
 
     def test_interface_did_start_ok(self):
-        self.assertTrue(BasePool(10).did_start_ok())
+        assert BasePool(10).did_start_ok()
 
     def test_interface_register_with_event_loop(self):
-        self.assertIsNone(
-            BasePool(10).register_with_event_loop(Mock()),
-        )
+        assert BasePool(10).register_with_event_loop(Mock()) is None
 
     def test_interface_on_soft_timeout(self):
-        self.assertIsNone(BasePool(10).on_soft_timeout(Mock()))
+        assert BasePool(10).on_soft_timeout(Mock()) is None
 
     def test_interface_on_hard_timeout(self):
-        self.assertIsNone(BasePool(10).on_hard_timeout(Mock()))
+        assert BasePool(10).on_hard_timeout(Mock()) is None
 
     def test_interface_close(self):
         p = BasePool(10)
         p.on_close = Mock()
         p.close()
-        self.assertEqual(p._state, p.CLOSE)
+        assert p._state == p.CLOSE
         p.on_close.assert_called_with()
 
     def test_interface_no_close(self):
-        self.assertIsNone(BasePool(10).on_close())
+        assert BasePool(10).on_close() is None

+ 36 - 38
celery/tests/concurrency/test_eventlet.py → t/unit/concurrency/test_eventlet.py

@@ -1,25 +1,34 @@
 from __future__ import absolute_import, unicode_literals
 
-import os
+import pytest
 import sys
 
+from case import Mock, patch, skip
+
 from celery.concurrency.eventlet import (
     apply_target,
     Timer,
     TaskPool,
 )
 
-from celery.tests.case import AppCase, Mock, patch, skip
+eventlet_modules = (
+    'eventlet',
+    'eventlet.debug',
+    'eventlet.greenthread',
+    'eventlet.greenpool',
+    'greenlet',
+)
 
 
 @skip.if_pypy()
-class EventletCase(AppCase):
+class EventletCase:
 
     def setup(self):
-        self.mock_modules(*eventlet_modules)
+        self.patching.modules(*eventlet_modules)
 
     def teardown(self):
-        for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
+        for mod in [mod for mod in sys.modules
+                    if mod.startswith('eventlet')]:
             try:
                 del(sys.modules[mod])
             except KeyError:
@@ -36,45 +45,34 @@ class test_aaa_eventlet_patch(EventletCase):
 
     @patch('eventlet.debug.hub_blocking_detection', create=True)
     @patch('eventlet.monkey_patch', create=True)
-    def test_aaa_blockdetecet(self, monkey_patch, hub_blocking_detection):
-        os.environ['EVENTLET_NOBLOCK'] = '10.3'
-        try:
-            from celery import maybe_patch_concurrency
-            maybe_patch_concurrency(['x', '-P', 'eventlet'])
-            monkey_patch.assert_called_with()
-            hub_blocking_detection.assert_called_with(10.3, 10.3)
-        finally:
-            os.environ.pop('EVENTLET_NOBLOCK', None)
-
-
-eventlet_modules = (
-    'eventlet',
-    'eventlet.debug',
-    'eventlet.greenthread',
-    'eventlet.greenpool',
-    'greenlet',
-)
+    def test_aaa_blockdetecet(
+            self, monkey_patch, hub_blocking_detection, patching):
+        patching.setenv('EVENTLET_NOBLOCK', '10.3')
+        from celery import maybe_patch_concurrency
+        maybe_patch_concurrency(['x', '-P', 'eventlet'])
+        monkey_patch.assert_called_with()
+        hub_blocking_detection.assert_called_with(10.3, 10.3)
 
 
 class test_Timer(EventletCase):
 
-    def setup(self):
-        EventletCase.setup(self)
-        self.spawn_after = self.patch('eventlet.greenthread.spawn_after')
-        self.GreenletExit = self.patch('greenlet.GreenletExit')
+    @pytest.fixture(autouse=True)
+    def setup_patches(self, patching):
+        self.spawn_after = patching('eventlet.greenthread.spawn_after')
+        self.GreenletExit = patching('greenlet.GreenletExit')
 
     def test_sched(self):
         x = Timer()
         x.GreenletExit = KeyError
         entry = Mock()
         g = x._enter(1, 0, entry)
-        self.assertTrue(x.queue)
+        assert x.queue
 
         x._entry_exit(g, entry)
         g.wait.side_effect = KeyError()
         x._entry_exit(g, entry)
         entry.cancel.assert_called_with()
-        self.assertFalse(x._queue)
+        assert not x._queue
 
         x._queue.add(g)
         x.clear()
@@ -94,10 +92,10 @@ class test_Timer(EventletCase):
 
 class test_TaskPool(EventletCase):
 
-    def setup(self):
-        EventletCase.setup(self)
-        self.GreenPool = self.patch('eventlet.greenpool.GreenPool')
-        self.greenthread = self.patch('eventlet.greenthread')
+    @pytest.fixture(autouse=True)
+    def setup_patches(self, patching):
+        self.GreenPool = patching('eventlet.greenpool.GreenPool')
+        self.greenthread = patching('eventlet.greenthread')
 
     def test_pool(self):
         x = TaskPool()
@@ -106,7 +104,7 @@ class test_TaskPool(EventletCase):
         x.on_apply(Mock())
         x._pool = None
         x.on_stop()
-        self.assertTrue(x.getpid())
+        assert x.getpid()
 
     @patch('celery.concurrency.eventlet.base')
     def test_apply_target(self, base):
@@ -117,21 +115,21 @@ class test_TaskPool(EventletCase):
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
         x.grow(2)
-        self.assertEqual(x.limit, 12)
+        assert x.limit == 12
         x._pool.resize.assert_called_with(12)
 
     def test_shrink(self):
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
         x.shrink(2)
-        self.assertEqual(x.limit, 8)
+        assert x.limit == 8
         x._pool.resize.assert_called_with(8)
 
     def test_get_info(self):
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
-        self.assertDictEqual(x._get_info(), {
+        assert x._get_info() == {
             'max-concurrency': 10,
             'free-threads': x._pool.free(),
             'running-threads': x._pool.running(),
-        })
+        }

+ 128 - 0
t/unit/concurrency/test_gevent.py

@@ -0,0 +1,128 @@
+from __future__ import absolute_import, unicode_literals
+
+from case import Mock, skip
+
+from celery.concurrency.gevent import (
+    Timer,
+    TaskPool,
+    apply_timeout,
+)
+
+gevent_modules = (
+    'gevent',
+    'gevent.monkey',
+    'gevent.greenlet',
+    'gevent.pool',
+    'greenlet',
+)
+
+
+@skip.if_pypy()
+class test_gevent_patch:
+
+    def test_is_patched(self):
+        self.patching.modules(*gevent_modules)
+        patch_all = self.patching('gevent.monkey.patch_all')
+        import gevent
+        gevent.version_info = (1, 0, 0)
+        from celery import maybe_patch_concurrency
+        maybe_patch_concurrency(['x', '-P', 'gevent'])
+        patch_all.assert_called()
+
+
+@skip.if_pypy()
+class test_Timer:
+
+    def setup(self):
+        self.patching.modules(*gevent_modules)
+        self.greenlet = self.patching('gevent.greenlet')
+        self.GreenletExit = self.patching('gevent.greenlet.GreenletExit')
+
+    def test_sched(self):
+        self.greenlet.Greenlet = object
+        x = Timer()
+        self.greenlet.Greenlet = Mock()
+        x._Greenlet.spawn_later = Mock()
+        x._GreenletExit = KeyError
+        entry = Mock()
+        g = x._enter(1, 0, entry)
+        assert x.queue
+
+        x._entry_exit(g)
+        g.kill.assert_called_with()
+        assert not x._queue
+
+        x._queue.add(g)
+        x.clear()
+        x._queue.add(g)
+        g.kill.side_effect = KeyError()
+        x.clear()
+
+        g = x._Greenlet()
+        g.cancel()
+
+
+@skip.if_pypy()
+class test_TaskPool:
+
+    def setup(self):
+        self.patching.modules(*gevent_modules)
+        self.spawn_raw = self.patching('gevent.spawn_raw')
+        self.Pool = self.patching('gevent.pool.Pool')
+
+    def test_pool(self):
+        x = TaskPool()
+        x.on_start()
+        x.on_stop()
+        x.on_apply(Mock())
+        x._pool = None
+        x.on_stop()
+
+        x._pool = Mock()
+        x._pool._semaphore.counter = 1
+        x._pool.size = 1
+        x.grow()
+        assert x._pool.size == 2
+        assert x._pool._semaphore.counter == 2
+        x.shrink()
+        assert x._pool.size, 1
+        assert x._pool._semaphore.counter == 1
+
+        x._pool = [4, 5, 6]
+        assert x.num_processes == 3
+
+
+@skip.if_pypy()
+class test_apply_timeout:
+
+    def test_apply_timeout(self):
+        self.patching.modules(*gevent_modules)
+
+        class Timeout(Exception):
+            value = None
+
+            def __init__(self, value):
+                self.__class__.value = value
+
+            def __enter__(self):
+                return self
+
+            def __exit__(self, *exc_info):
+                pass
+        timeout_callback = Mock(name='timeout_callback')
+        apply_target = Mock(name='apply_target')
+        apply_timeout(
+            Mock(), timeout=10, callback=Mock(name='callback'),
+            timeout_callback=timeout_callback,
+            apply_target=apply_target, Timeout=Timeout,
+        )
+        assert Timeout.value == 10
+        apply_target.assert_called()
+
+        apply_target.side_effect = Timeout(10)
+        apply_timeout(
+            Mock(), timeout=10, callback=Mock(),
+            timeout_callback=timeout_callback,
+            apply_target=apply_target, Timeout=Timeout,
+        )
+        timeout_callback.assert_called_with(False, 10)

+ 15 - 20
celery/tests/concurrency/test_pool.py → t/unit/concurrency/test_pool.py

@@ -3,9 +3,9 @@ from __future__ import absolute_import, unicode_literals
 import time
 import itertools
 
-from billiard.einfo import ExceptionInfo
+from case import skip
 
-from celery.tests.case import AppCase, skip
+from billiard.einfo import ExceptionInfo
 
 
 def do_something(i):
@@ -24,7 +24,7 @@ def raise_something(i):
 
 
 @skip.unless_module('multiprocessing')
-class test_TaskPool(AppCase):
+class test_TaskPool:
 
     def setup(self):
         from celery.concurrency.prefork import TaskPool
@@ -32,8 +32,8 @@ class test_TaskPool(AppCase):
 
     def test_attrs(self):
         p = self.TaskPool(2)
-        self.assertEqual(p.limit, 2)
-        self.assertIsNone(p._pool)
+        assert p.limit == 2
+        assert p._pool is None
 
     def x_apply(self):
         p = self.TaskPool(2)
@@ -52,28 +52,23 @@ class test_TaskPool(AppCase):
         res2 = p.apply_async(raise_something, args=[10], errback=myerrback)
         res3 = p.apply_async(do_something, args=[20], callback=mycallback)
 
-        self.assertEqual(res.get(), 100)
+        assert res.get() == 100
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 100},
-                                      scratchpad.get(0))
+        assert scratchpad.get(0)['ret_value'] == 100
 
-        self.assertIsInstance(res2.get(), ExceptionInfo)
-        self.assertTrue(scratchpad.get(1))
+        assert isinstance(res2.get(), ExceptionInfo)
+        assert scratchpad.get(1)
         time.sleep(1)
-        self.assertIsInstance(scratchpad[1]['ret_value'],
-                              ExceptionInfo)
-        self.assertEqual(scratchpad[1]['ret_value'].exception.args,
-                         ('FOO EXCEPTION',))
+        assert isinstance(scratchpad[1]['ret_value'], ExceptionInfo)
+        assert scratchpad[1]['ret_value'].exception.args == ('FOO EXCEPTION',)
 
-        self.assertEqual(res3.get(), 400)
+        assert res3.get() == 400
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 400},
-                                      scratchpad.get(2))
+        assert scratchpad.get(2)['ret_value'] == 400
 
         res3 = p.apply_async(do_something, args=[30], callback=mycallback)
 
-        self.assertEqual(res3.get(), 900)
+        assert res3.get() == 900
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 900},
-                                      scratchpad.get(3))
+        assert scratchpad.get(3)['ret_value'] == 900
         p.stop()

+ 42 - 52
celery/tests/concurrency/test_prefork.py → t/unit/concurrency/test_prefork.py

@@ -2,18 +2,19 @@ from __future__ import absolute_import, unicode_literals
 
 import errno
 import os
+import pytest
 import socket
 
 from itertools import cycle
 
+from case import Mock, mock, patch, skip
+
 from celery.app.defaults import DEFAULTS
 from celery.five import range
 from celery.utils.collections import AttributeDict
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import asynpool
@@ -53,7 +54,7 @@ class MockResult(object):
         return self.value
 
 
-class test_process_initializer(AppCase):
+class test_process_initializer:
 
     @patch('celery.platforms.signals')
     @patch('celery.platforms.set_mp_process_title')
@@ -78,9 +79,9 @@ class test_process_initializer(AppCase):
                 process_initializer(app, 'awesome.worker.com')
                 _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
                 _signals.reset.assert_any_call(*WORKER_SIGRESET)
-                self.assertTrue(app.loader.init_worker.call_count)
+                assert app.loader.init_worker.call_count
                 on_worker_process_init.assert_called()
-                self.assertIs(_tls.current_app, app)
+                assert _tls.current_app is app
                 set_mp_process_title.assert_called_with(
                     'celeryd', hostname='awesome.worker.com',
                 )
@@ -101,7 +102,7 @@ class test_process_initializer(AppCase):
                     os.environ.pop('CELERY_LOG_FILE', None)
 
 
-class test_process_destructor(AppCase):
+class test_process_destructor:
 
     @patch('celery.concurrency.prefork.signals')
     def test_process_destructor(self, signals):
@@ -181,13 +182,9 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
 
 
-@skip.unless_module('multiprocessing')
-class PoolCase(AppCase):
-    pass
-
-
 @skip.if_win32()
-class test_AsynPool(PoolCase):
+@skip.unless_module('multiprocessing')
+class test_AsynPool:
 
     def test_gen_not_started(self):
 
@@ -195,11 +192,11 @@ class test_AsynPool(PoolCase):
             yield 1
             yield 2
         g = gen()
-        self.assertTrue(asynpool.gen_not_started(g))
+        assert asynpool.gen_not_started(g)
         next(g)
-        self.assertFalse(asynpool.gen_not_started(g))
+        assert not asynpool.gen_not_started(g)
         list(g)
-        self.assertFalse(asynpool.gen_not_started(g))
+        assert not asynpool.gen_not_started(g)
 
     @patch('select.select', create=True)
     def test_select(self, __select):
@@ -208,15 +205,11 @@ class test_AsynPool(PoolCase):
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll.return_value = {3}, set(), 0
-            self.assertEqual(
-                asynpool._select({3}, poll=poll),
-                ({3}, set(), 0),
-            )
+            assert asynpool._select({3}, poll=poll) == ({3}, set(), 0)
 
             poll.return_value = {3}, set(), 0
-            self.assertEqual(
-                asynpool._select({3}, None, {3}, poll=poll),
-                ({3}, set(), 0),
+            assert asynpool._select({3}, None, {3}, poll=poll) == (
+                {3}, set(), 0,
             )
 
             eintr = socket.error()
@@ -224,11 +217,8 @@ class test_AsynPool(PoolCase):
             poll.side_effect = eintr
 
             readers = {3}
-            self.assertEqual(
-                asynpool._select(readers, poll=poll),
-                (set(), set(), 1),
-            )
-            self.assertIn(3, readers)
+            assert asynpool._select(readers, poll=poll) == (set(), set(), 1)
+            assert 3 in readers
 
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
@@ -236,16 +226,15 @@ class test_AsynPool(PoolCase):
             with patch('select.select') as selcheck:
                 selcheck.side_effect = ebadf
                 readers = {3}
-                self.assertEqual(
-                    asynpool._select(readers, poll=poll),
-                    (set(), set(), 1),
+                assert asynpool._select(readers, poll=poll) == (
+                    set(), set(), 1,
                 )
-                self.assertNotIn(3, readers)
+                assert 3 not in readers
 
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll.side_effect = MemoryError()
-            with self.assertRaises(MemoryError):
+            with pytest.raises(MemoryError):
                 asynpool._select({1}, poll=poll)
 
         with patch('select.poll', create=True) as poller:
@@ -256,7 +245,7 @@ class test_AsynPool(PoolCase):
                     selcheck.side_effect = MemoryError()
                     raise ebadf
                 poll.side_effect = se
-                with self.assertRaises(MemoryError):
+                with pytest.raises(MemoryError):
                     asynpool._select({3}, poll=poll)
 
         with patch('select.poll', create=True) as poller:
@@ -268,7 +257,7 @@ class test_AsynPool(PoolCase):
                     selcheck.side_effect.errno = 1321
                     raise ebadf
                 poll.side_effect = se2
-                with self.assertRaises(socket.error):
+                with pytest.raises(socket.error):
                     asynpool._select({3}, poll=poll)
 
         with patch('select.poll', create=True) as poller:
@@ -276,14 +265,14 @@ class test_AsynPool(PoolCase):
 
             poll.side_effect = socket.error()
             poll.side_effect.errno = 34134
-            with self.assertRaises(socket.error):
+            with pytest.raises(socket.error):
                 asynpool._select({3}, poll=poll)
 
     def test_promise(self):
         fun = Mock()
         x = asynpool.promise(fun, (1,), {'foo': 1})
         x()
-        self.assertTrue(x.ready)
+        assert x.ready
         fun.assert_called_with(1, foo=1)
 
     def test_Worker(self):
@@ -293,7 +282,8 @@ class test_AsynPool(PoolCase):
 
 
 @skip.if_win32()
-class test_ResultHandler(PoolCase):
+@skip.unless_module('multiprocessing')
+class test_ResultHandler:
 
     def test_process_result(self):
         x = asynpool.ResultHandler(
@@ -303,7 +293,7 @@ class test_ResultHandler(PoolCase):
             on_process_alive=Mock(),
             on_job_ready=Mock(),
         )
-        self.assertTrue(x)
+        assert x
         hub = Mock(name='hub')
         recv = x._recv_message = Mock(name='recv_message')
         recv.return_value = iter([])
@@ -319,25 +309,25 @@ class test_ResultHandler(PoolCase):
         )
 
 
-class test_TaskPool(PoolCase):
+class test_TaskPool:
 
     def test_start(self):
         pool = TaskPool(10)
         pool.start()
-        self.assertTrue(pool._pool.started)
-        self.assertEqual(pool._pool._state, asynpool.RUN)
+        assert pool._pool.started
+        assert pool._pool._state == asynpool.RUN
 
         _pool = pool._pool
         pool.stop()
-        self.assertTrue(_pool.closed)
-        self.assertTrue(_pool.joined)
+        assert _pool.closed
+        assert _pool.joined
         pool.stop()
 
         pool.start()
         _pool = pool._pool
         pool.terminate()
         pool.terminate()
-        self.assertTrue(_pool.terminated)
+        assert _pool.terminated
 
     def test_restart(self):
         pool = TaskPool(10)
@@ -349,7 +339,7 @@ class test_TaskPool(PoolCase):
     def test_did_start_ok(self):
         pool = TaskPool(10)
         pool._pool = Mock(name='pool')
-        self.assertIs(pool.did_start_ok(), pool._pool.did_start_ok())
+        assert pool.did_start_ok() is pool._pool.did_start_ok()
 
     def test_register_with_event_loop(self):
         pool = TaskPool(10)
@@ -380,11 +370,11 @@ class test_TaskPool(PoolCase):
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool.start()
-        self.assertEqual(pool._pool._processes, 10)
+        assert pool._pool._processes == 10
         pool.grow()
-        self.assertEqual(pool._pool._processes, 11)
+        assert pool._pool._processes == 11
         pool.shrink(2)
-        self.assertEqual(pool._pool._processes, 9)
+        assert pool._pool._processes == 9
 
     def test_info(self):
         pool = TaskPool(10)
@@ -400,11 +390,11 @@ class test_TaskPool(PoolCase):
                 return {}
         pool._pool = _Pool()
         info = pool.info
-        self.assertEqual(info['max-concurrency'], pool.limit)
-        self.assertEqual(info['max-tasks-per-child'], 'N/A')
-        self.assertEqual(info['timeouts'], (5, 10))
+        assert info['max-concurrency'] == pool.limit
+        assert info['max-tasks-per-child'] == 'N/A'
+        assert info['timeouts'] == (5, 10)
 
     def test_num_processes(self):
         pool = TaskPool(7)
         pool.start()
-        self.assertEqual(pool.num_processes, 7)
+        assert pool.num_processes == 7

+ 2 - 3
celery/tests/concurrency/test_solo.py → t/unit/concurrency/test_solo.py

@@ -4,10 +4,9 @@ import operator
 
 from celery.concurrency import solo
 from celery.utils.functional import noop
-from celery.tests.case import AppCase
 
 
-class test_solo_TaskPool(AppCase):
+class test_solo_TaskPool:
 
     def test_on_start(self):
         x = solo.TaskPool()
@@ -21,4 +20,4 @@ class test_solo_TaskPool(AppCase):
     def test_info(self):
         x = solo.TaskPool()
         x.on_start()
-        self.assertTrue(x.info)
+        assert x.info

+ 0 - 0
celery/tests/fixups/__init__.py → t/unit/contrib/__init__.py


+ 6 - 9
celery/tests/contrib/test_abortable.py → t/unit/contrib/test_abortable.py

@@ -1,13 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
-from celery.tests.case import AppCase
 
 
-class test_AbortableTask(AppCase):
+class test_AbortableTask:
 
     def setup(self):
-
         @self.app.task(base=AbortableTask, shared=False)
         def abortable():
             return True
@@ -16,16 +14,15 @@ class test_AbortableTask(AppCase):
     def test_async_result_is_abortable(self):
         result = self.abortable.apply_async()
         tid = result.id
-        self.assertIsInstance(
-            self.abortable.AsyncResult(tid), AbortableAsyncResult,
-        )
+        assert isinstance(
+            self.abortable.AsyncResult(tid), AbortableAsyncResult)
 
     def test_is_not_aborted(self):
         self.abortable.push_request()
         try:
             result = self.abortable.apply_async()
             tid = result.id
-            self.assertFalse(self.abortable.is_aborted(task_id=tid))
+            assert not self.abortable.is_aborted(task_id=tid)
         finally:
             self.abortable.pop_request()
 
@@ -34,7 +31,7 @@ class test_AbortableTask(AppCase):
         self.abortable.push_request()
         try:
             self.abortable.request.id = 'foo'
-            self.assertFalse(self.abortable.is_aborted())
+            assert not self.abortable.is_aborted()
         finally:
             self.abortable.pop_request()
 
@@ -44,6 +41,6 @@ class test_AbortableTask(AppCase):
             result = self.abortable.apply_async()
             result.abort()
             tid = result.id
-            self.assertTrue(self.abortable.is_aborted(task_id=tid))
+            assert self.abortable.is_aborted(task_id=tid)
         finally:
             self.abortable.pop_request()

+ 80 - 76
celery/tests/contrib/test_migrate.py → t/unit/contrib/test_migrate.py

@@ -1,8 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 
+import pytest
+
 from contextlib import contextmanager
 
 from amqp import ChannelError
+from case import Mock, mock, patch
 
 from kombu import Connection, Producer, Queue, Exchange
 
@@ -26,7 +29,6 @@ from celery.contrib.migrate import (
     move,
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Mock, mock, patch
 
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
@@ -52,22 +54,22 @@ def Message(body, exchange='exchange', routing_key='rkey',
     )
 
 
-class test_State(AppCase):
+class test_State:
 
     def test_strtotal(self):
         x = State()
-        self.assertEqual(x.strtotal, '?')
+        assert x.strtotal == '?'
         x.total_apx = 100
-        self.assertEqual(x.strtotal, '100')
+        assert x.strtotal == '100'
 
     def test_repr(self):
         x = State()
-        self.assertTrue(repr(x))
+        assert repr(x)
         x.filtered = 'foo'
-        self.assertTrue(repr(x))
+        assert repr(x)
 
 
-class test_move(AppCase):
+class test_move:
 
     @contextmanager
     def move_context(self, **kwargs):
@@ -113,7 +115,7 @@ class test_move(AppCase):
         with self.move_context(limit=1) as (callback, pred, republish):
             pred.return_value = 'foo'
             body, message = self.msgpair()
-            with self.assertRaises(StopFiltering):
+            with pytest.raises(StopFiltering):
                 callback(body, message)
             republish.assert_called()
 
@@ -127,7 +129,7 @@ class test_move(AppCase):
             cb.assert_called()
 
 
-class test_start_filter(AppCase):
+class test_start_filter:
 
     def test_start(self):
         with patch('celery.contrib.migrate.eventloop') as evloop:
@@ -174,11 +176,11 @@ class test_start_filter(AppCase):
                     callback(body, Message(body))
                 except StopFiltering:
                     stop_filtering_raised = True
-            self.assertTrue(state.count)
-            self.assertTrue(stop_filtering_raised)
+            assert state.count
+            assert stop_filtering_raised
 
 
-class test_filter_callback(AppCase):
+class test_filter_callback:
 
     def test_filter(self):
         callback = Mock()
@@ -193,57 +195,59 @@ class test_filter_callback(AppCase):
         callback.assert_called_with(t1, message)
 
 
-class test_utils(AppCase):
+def test_task_id_in():
+    assert task_id_in(['A'], {'id': 'A'}, Mock())
+    assert not task_id_in(['A'], {'id': 'B'}, Mock())
+
+
+def test_task_id_eq():
+    assert task_id_eq('A', {'id': 'A'}, Mock())
+    assert not task_id_eq('A', {'id': 'B'}, Mock())
 
-    def test_task_id_in(self):
-        self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
-        self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
 
-    def test_task_id_eq(self):
-        self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
-        self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
+def test_expand_dest():
+    assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')
+    assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')
 
-    def test_expand_dest(self):
-        self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
-        self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
 
-    def test_maybe_queue(self):
-        app = Mock()
-        app.amqp.queues = {'foo': 313}
-        self.assertEqual(_maybe_queue(app, 'foo'), 313)
-        self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
+def test_maybe_queue():
+    app = Mock()
+    app.amqp.queues = {'foo': 313}
+    assert _maybe_queue(app, 'foo') == 313
+    assert _maybe_queue(app, Queue('foo')) == Queue('foo')
 
-    @mock.stdouts
-    def test_filter_status(self, stdout, stderr):
+
+def test_filter_status():
+    with mock.stdouts() as (stdout, stderr):
         filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
-        self.assertTrue(stdout.getvalue())
-
-    def test_move_by_taskmap(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_by_taskmap({'add': Queue('foo')})
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertTrue(cb({'task': 'add'}, Mock()))
-
-    def test_move_by_idmap(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_by_idmap({'123f': Queue('foo')})
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertTrue(cb({'id': '123f'}, Mock()))
-
-    def test_move_task_by_id(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_task_by_id('123f', Queue('foo'))
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertEqual(
-                cb({'id': '123f'}, Mock()),
-                Queue('foo'),
-            )
-
-
-class test_migrate_task(AppCase):
+        assert stdout.getvalue()
+
+
+def test_move_by_taskmap():
+    with patch('celery.contrib.migrate.move') as move:
+        move_by_taskmap({'add': Queue('foo')})
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'task': 'add'}, Mock())
+
+
+def test_move_by_idmap():
+    with patch('celery.contrib.migrate.move') as move:
+        move_by_idmap({'123f': Queue('foo')})
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'id': '123f'}, Mock())
+
+
+def test_move_task_by_id():
+    with patch('celery.contrib.migrate.move') as move:
+        move_task_by_id('123f', Queue('foo'))
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'id': '123f'}, Mock()) == Queue('foo')
+
+
+class test_migrate_task:
 
     def test_removes_compression_header(self):
         x = Message('foo', compression='zlib')
@@ -251,18 +255,18 @@ class test_migrate_task(AppCase):
         migrate_task(producer, x.body, x)
         producer.publish.assert_called()
         args, kwargs = producer.publish.call_args
-        self.assertIsInstance(args[0], bytes_t)
-        self.assertNotIn('compression', kwargs['headers'])
-        self.assertEqual(kwargs['compression'], 'zlib')
-        self.assertEqual(kwargs['content_type'], 'application/json')
-        self.assertEqual(kwargs['content_encoding'], 'utf-8')
-        self.assertEqual(kwargs['exchange'], 'exchange')
-        self.assertEqual(kwargs['routing_key'], 'rkey')
+        assert isinstance(args[0], bytes_t)
+        assert 'compression' not in kwargs['headers']
+        assert kwargs['compression'] == 'zlib'
+        assert kwargs['content_type'] == 'application/json'
+        assert kwargs['content_encoding'] == 'utf-8'
+        assert kwargs['exchange'] == 'exchange'
+        assert kwargs['routing_key'] == 'rkey'
 
 
-class test_migrate_tasks(AppCase):
+class test_migrate_tasks:
 
-    def test_migrate(self, name='testcelery'):
+    def test_migrate(self, app, name='testcelery'):
         x = Connection('memory://foo')
         y = Connection('memory://foo')
         # use separate state
@@ -275,25 +279,25 @@ class test_migrate_tasks(AppCase):
         Producer(x).publish('foo', exchange=name, routing_key=name)
         Producer(x).publish('bar', exchange=name, routing_key=name)
         Producer(x).publish('baz', exchange=name, routing_key=name)
-        self.assertTrue(x.default_channel.queues)
-        self.assertFalse(y.default_channel.queues)
+        assert x.default_channel.queues
+        assert not y.default_channel.queues
 
-        migrate_tasks(x, y, accept=['text/plain'], app=self.app)
+        migrate_tasks(x, y, accept=['text/plain'], app=app)
 
         yq = q(y.default_channel)
-        self.assertEqual(yq.get().body, ensure_bytes('foo'))
-        self.assertEqual(yq.get().body, ensure_bytes('bar'))
-        self.assertEqual(yq.get().body, ensure_bytes('baz'))
+        assert yq.get().body == ensure_bytes('foo')
+        assert yq.get().body == ensure_bytes('bar')
+        assert yq.get().body == ensure_bytes('baz')
 
         Producer(x).publish('foo', exchange=name, routing_key=name)
         callback = Mock()
         migrate_tasks(x, y,
-                      callback=callback, accept=['text/plain'], app=self.app)
+                      callback=callback, accept=['text/plain'], app=app)
         callback.assert_called()
         migrate = Mock()
         Producer(x).publish('baz', exchange=name, routing_key=name)
         migrate_tasks(x, y, callback=callback,
-                      migrate=migrate, accept=['text/plain'], app=self.app)
+                      migrate=migrate, accept=['text/plain'], app=app)
         migrate.assert_called()
 
         with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
@@ -303,12 +307,12 @@ class test_migrate_tasks(AppCase):
                     raise ChannelError('some channel error')
                 return 0, 3, 0
             qd.side_effect = effect
-            migrate_tasks(x, y, app=self.app)
+            migrate_tasks(x, y, app=app)
 
         x = Connection('memory://')
         x.default_channel.queues = {}
         y.default_channel.queues = {}
         callback = Mock()
         migrate_tasks(x, y,
-                      callback=callback, accept=['text/plain'], app=self.app)
+                      callback=callback, accept=['text/plain'], app=app)
         callback.assert_not_called()

+ 12 - 10
celery/tests/contrib/test_rdb.py → t/unit/contrib/test_rdb.py

@@ -2,6 +2,9 @@ from __future__ import absolute_import, unicode_literals
 
 import errno
 import socket
+import pytest
+
+from case import Mock, patch, skip
 
 from celery.contrib.rdb import (
     Rdb,
@@ -9,26 +12,25 @@ from celery.contrib.rdb import (
     set_trace,
 )
 from celery.five import WhateverIO
-from celery.tests.case import AppCase, Mock, patch, skip
 
 
 class SockErr(socket.error):
     errno = None
 
 
-class test_Rdb(AppCase):
+class test_Rdb:
 
     @patch('celery.contrib.rdb.Rdb')
     def test_debugger(self, Rdb):
         x = debugger()
-        self.assertTrue(x)
-        self.assertIs(x, debugger())
+        assert x
+        assert x is debugger()
 
     @patch('celery.contrib.rdb.debugger')
     @patch('celery.contrib.rdb._frame')
     def test_set_trace(self, _frame, debugger):
-        self.assertTrue(set_trace(Mock()))
-        self.assertTrue(set_trace())
+        assert set_trace(Mock())
+        assert set_trace()
         debugger.return_value.set_trace.assert_called()
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
@@ -40,7 +42,7 @@ class test_Rdb(AppCase):
         out = WhateverIO()
         with Rdb(out=out) as rdb:
             get_avail_port.assert_called()
-            self.assertIn('helu', out.getvalue())
+            assert 'helu' in out.getvalue()
 
             # set_quit
             with patch('sys.settrace') as settrace:
@@ -54,7 +56,7 @@ class test_Rdb(AppCase):
                     rdb.set_trace(Mock())
                     pset.side_effect = SockErr
                     pset.side_effect.errno = errno.ENOENT
-                    with self.assertRaises(SockErr):
+                    with pytest.raises(SockErr):
                         rdb.set_trace()
 
             # _close_session
@@ -90,11 +92,11 @@ class test_Rdb(AppCase):
 
         err = sock.return_value.bind.side_effect = SockErr()
         err.errno = errno.ENOENT
-        with self.assertRaises(SockErr):
+        with pytest.raises(SockErr):
             with Rdb(out=out):
                 pass
         err.errno = errno.EADDRINUSE
-        with self.assertRaises(Exception):
+        with pytest.raises(Exception):
             with Rdb(out=out):
                 pass
         called = [0]

+ 0 - 0
celery/tests/tasks/__init__.py → t/unit/events/__init__.py


Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini