Browse Source

Use py.test for everything :-)

Ask Solem 9 years ago
parent
commit
29df527147
100 changed files with 3420 additions and 5078 deletions
  1. 4 0
      .gitignore
  2. 12 10
      CONTRIBUTING.rst
  3. 5 4
      Makefile
  4. 2 2
      celery/app/trace.py
  5. 0 96
      celery/tests/__init__.py
  6. 0 301
      celery/tests/app/test_amqp.py
  7. 0 19
      celery/tests/app/test_celery.py
  8. 0 65
      celery/tests/app/test_defaults.py
  9. 0 78
      celery/tests/app/test_registry.py
  10. 0 41
      celery/tests/backends/test_backends.py
  11. 0 349
      celery/tests/case.py
  12. 0 38
      celery/tests/compat_modules/test_decorators.py
  13. 0 129
      celery/tests/concurrency/test_gevent.py
  14. 0 54
      celery/tests/tasks/test_states.py
  15. 0 98
      celery/tests/utils/test_debug.py
  16. 0 20
      celery/tests/utils/test_encoding.py
  17. 0 210
      celery/tests/utils/test_functional.py
  18. 0 58
      celery/tests/utils/test_imports.py
  19. 0 372
      celery/tests/utils/test_local.py
  20. 0 77
      celery/tests/utils/test_serialization.py
  21. 0 27
      celery/tests/utils/test_sysinfo.py
  22. 0 87
      celery/tests/utils/test_term.py
  23. 0 94
      celery/tests/utils/test_text.py
  24. 0 253
      celery/tests/utils/test_timeutils.py
  25. 0 44
      celery/tests/utils/test_utils.py
  26. 7 11
      docs/contributing.rst
  27. 2 2
      docs/internals/guide.rst
  28. 1 1
      docs/whatsnew-4.0.rst
  29. 5 2
      funtests/suite/test_leak.py
  30. 1 0
      requirements/test-ci-base.txt
  31. 2 1
      requirements/test.txt
  32. 3 2
      setup.cfg
  33. 39 35
      setup.py
  34. 0 0
      t/__init__.py
  35. 418 0
      t/conftest.py
  36. 0 0
      t/unit/__init__.py
  37. 0 0
      t/unit/app/__init__.py
  38. 269 0
      t/unit/app/test_amqp.py
  39. 12 14
      t/unit/app/test_annotations.py
  40. 222 246
      t/unit/app/test_app.py
  41. 79 80
      t/unit/app/test_beat.py
  42. 19 17
      t/unit/app/test_builtins.py
  43. 18 0
      t/unit/app/test_celery.py
  44. 50 47
      t/unit/app/test_control.py
  45. 65 0
      t/unit/app/test_defaults.py
  46. 8 10
      t/unit/app/test_exceptions.py
  47. 32 36
      t/unit/app/test_loaders.py
  48. 45 60
      t/unit/app/test_log.py
  49. 72 0
      t/unit/app/test_registry.py
  50. 30 37
      t/unit/app/test_routes.py
  51. 276 328
      t/unit/app/test_schedules.py
  52. 16 20
      t/unit/app/test_utils.py
  53. 0 0
      t/unit/apps/__init__.py
  54. 96 124
      t/unit/apps/test_multi.py
  55. 0 0
      t/unit/backends/__init__.py
  56. 40 47
      t/unit/backends/test_amqp.py
  57. 41 0
      t/unit/backends/test_backends.py
  58. 106 111
      t/unit/backends/test_base.py
  59. 35 37
      t/unit/backends/test_cache.py
  60. 18 15
      t/unit/backends/test_cassandra.py
  61. 6 5
      t/unit/backends/test_consul.py
  62. 24 22
      t/unit/backends/test_couchbase.py
  63. 20 20
      t/unit/backends/test_couchdb.py
  64. 47 49
      t/unit/backends/test_database.py
  65. 16 14
      t/unit/backends/test_elasticsearch.py
  66. 13 16
      t/unit/backends/test_filesystem.py
  67. 80 92
      t/unit/backends/test_mongodb.py
  68. 50 61
      t/unit/backends/test_redis.py
  69. 21 20
      t/unit/backends/test_riak.py
  70. 20 22
      t/unit/backends/test_rpc.py
  71. 0 0
      t/unit/bin/__init__.py
  72. 0 0
      t/unit/bin/celery.py
  73. 0 0
      t/unit/bin/proj/__init__.py
  74. 0 0
      t/unit/bin/proj/app.py
  75. 29 32
      t/unit/bin/test_amqp.py
  76. 130 143
      t/unit/bin/test_base.py
  77. 20 20
      t/unit/bin/test_beat.py
  78. 104 116
      t/unit/bin/test_celery.py
  79. 27 21
      t/unit/bin/test_celeryd_detach.py
  80. 7 7
      t/unit/bin/test_celeryevdump.py
  81. 35 14
      t/unit/bin/test_events.py
  82. 63 84
      t/unit/bin/test_multi.py
  83. 235 239
      t/unit/bin/test_worker.py
  84. 0 0
      t/unit/compat_modules/__init__.py
  85. 16 14
      t/unit/compat_modules/test_compat.py
  86. 13 14
      t/unit/compat_modules/test_compat_utils.py
  87. 37 0
      t/unit/compat_modules/test_decorators.py
  88. 4 3
      t/unit/compat_modules/test_messaging.py
  89. 0 0
      t/unit/concurrency/__init__.py
  90. 32 33
      t/unit/concurrency/test_concurrency.py
  91. 36 38
      t/unit/concurrency/test_eventlet.py
  92. 128 0
      t/unit/concurrency/test_gevent.py
  93. 15 20
      t/unit/concurrency/test_pool.py
  94. 42 52
      t/unit/concurrency/test_prefork.py
  95. 2 3
      t/unit/concurrency/test_solo.py
  96. 0 0
      t/unit/contrib/__init__.py
  97. 6 9
      t/unit/contrib/test_abortable.py
  98. 80 76
      t/unit/contrib/test_migrate.py
  99. 12 10
      t/unit/contrib/test_rdb.py
  100. 0 0
      t/unit/events/__init__.py

+ 4 - 0
.gitignore

@@ -25,3 +25,7 @@ celery/tests/cover/
 .ve*
 .ve*
 cover/
 cover/
 .vagrant/
 .vagrant/
+.cache/
+htmlcov/
+coverage.xml
+test.db

+ 12 - 10
CONTRIBUTING.rst

@@ -464,12 +464,12 @@ dependencies, so install these next:
     $ pip install -U -r requirements/default.txt
     $ pip install -U -r requirements/default.txt
 
 
 After installing the dependencies required, you can now execute
 After installing the dependencies required, you can now execute
-the test suite by calling ``nosetests <nose>``:
+the test suite by calling ``py.test <pytest``:
 ::
 ::
 
 
-    $ nosetests
+    $ py.test
 
 
-Some useful options to ``nosetests`` are:
+Some useful options to ``py.test`` are:
 
 
 * ``-x``
 * ``-x``
 
 
@@ -479,10 +479,6 @@ Some useful options to ``nosetests`` are:
 
 
     Don't capture output
     Don't capture output
 
 
-* ``-nologcapture``
-
-    Don't capture log output.
-
 * ``-v``
 * ``-v``
 
 
     Run with verbose output.
     Run with verbose output.
@@ -491,7 +487,7 @@ If you want to run the tests for a single test file only
 you can do so like this:
 you can do so like this:
 ::
 ::
 
 
-    $ nosetests celery.tests.test_worker.test_worker_job
+    $ py.test t/unit/worker/test_worker_job.py
 
 
 .. _contributing-pull-requests:
 .. _contributing-pull-requests:
 
 
@@ -525,7 +521,7 @@ Installing the ``coverage`` module:
 Code coverage in HTML:
 Code coverage in HTML:
 ::
 ::
 
 
-    $ nosetests --with-coverage --cover-html
+    $ py.test --cov=celery --cov-report=html
 
 
 The coverage output will then be located at
 The coverage output will then be located at
 ``celery/tests/cover/index.html``.
 ``celery/tests/cover/index.html``.
@@ -533,7 +529,7 @@ The coverage output will then be located at
 Code coverage in XML (Cobertura-style):
 Code coverage in XML (Cobertura-style):
 ::
 ::
 
 
-    $ nosetests --with-coverage --cover-xml --cover-xml-file=coverage.xml
+    $ py.test --cov=celery --cov-report=xml
 
 
 The coverage XML output will then be located at ``coverage.xml``
 The coverage XML output will then be located at ``coverage.xml``
 
 
@@ -857,6 +853,12 @@ Ask Solem
 :github: https://github.com/ask
 :github: https://github.com/ask
 :twitter: http://twitter.com/#!/asksol
 :twitter: http://twitter.com/#!/asksol
 
 
+Asif Saif Uddin
+~~~~~~~~~~~~~~~
+
+:github: https://github.com/auvipy
+:twitter: https://twitter.com/#!/auvipy
+
 Dmitry Malinovsky
 Dmitry Malinovsky
 ~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~
 
 

+ 5 - 4
Makefile

@@ -9,6 +9,8 @@ FLAKE8=flake8
 FLAKEPLUS=flakeplus
 FLAKEPLUS=flakeplus
 SPHINX2RST=sphinx2rst
 SPHINX2RST=sphinx2rst
 
 
+TESTDIR=t
+
 SPHINX_DIR=docs/
 SPHINX_DIR=docs/
 SPHINX_BUILDDIR="${SPHINX_DIR}/_build"
 SPHINX_BUILDDIR="${SPHINX_DIR}/_build"
 README=README.rst
 README=README.rst
@@ -84,13 +86,13 @@ configcheck:
 
 
 flakecheck:
 flakecheck:
 	# the only way to enable all-1 errors is to ignore one of them.
 	# the only way to enable all-1 errors is to ignore one of them.
-	$(FLAKE8) --ignore=X999 "$(PROJ)"
+	$(FLAKE8) --ignore=X999 "$(PROJ)" "$(TESTDIR)"
 
 
 flakediag:
 flakediag:
 	-$(MAKE) flakecheck
 	-$(MAKE) flakecheck
 
 
 flakepluscheck:
 flakepluscheck:
-	$(FLAKEPLUS) --$(FLAKEPLUSTARGET) "$(PROJ)"
+	$(FLAKEPLUS) --$(FLAKEPLUSTARGET) "$(PROJ)" "$(TESTDIR)"
 
 
 flakeplusdiag:
 flakeplusdiag:
 	-$(MAKE) flakepluscheck
 	-$(MAKE) flakepluscheck
@@ -138,7 +140,7 @@ test:
 	$(PYTHON) setup.py test
 	$(PYTHON) setup.py test
 
 
 cov:
 cov:
-	$(NOSETESTS) -xv --with-coverage --cover-html --cover-branch
+	py.test -x --cov="$(PROJ)" --cov-report=html
 
 
 build:
 build:
 	$(PYTHON) setup.py sdist bdist_wheel
 	$(PYTHON) setup.py sdist bdist_wheel
@@ -158,4 +160,3 @@ graph: clean-graph $(WORKER_GRAPH)
 
 
 authorcheck:
 authorcheck:
 	git shortlog -se | cut -f2 | extra/release/attribution.py
 	git shortlog -se | cut -f2 | extra/release/attribution.py
-

+ 2 - 2
celery/app/trace.py

@@ -418,13 +418,13 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
                     except EncodeError as exc:
                     except EncodeError as exc:
                         I, R, state, retval = on_error(task_request, exc, uuid)
                         I, R, state, retval = on_error(task_request, exc, uuid)
                     else:
                     else:
+                        Rstr = saferepr(R, resultrepr_maxsize)
+                        T = monotonic() - time_start
                         if task_on_success:
                         if task_on_success:
                             task_on_success(retval, uuid, args, kwargs)
                             task_on_success(retval, uuid, args, kwargs)
                         if success_receivers:
                         if success_receivers:
                             send_success(sender=task, result=retval)
                             send_success(sender=task, result=retval)
                         if _does_info:
                         if _does_info:
-                            T = monotonic() - time_start
-                            Rstr = saferepr(R, resultrepr_maxsize)
                             info(LOG_SUCCESS, {
                             info(LOG_SUCCESS, {
                                 'id': uuid, 'name': name,
                                 'id': uuid, 'name': name,
                                 'return_value': Rstr, 'runtime': T,
                                 'return_value': Rstr, 'runtime': T,

+ 0 - 96
celery/tests/__init__.py

@@ -1,96 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import logging
-import os
-import sys
-import warnings
-
-from importlib import import_module
-
-PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
-
-try:
-    WindowsError = WindowsError  # noqa
-except NameError:
-
-    class WindowsError(Exception):
-        pass
-
-
-def setup():
-    using_coverage = (
-        os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
-    )
-    os.environ.update(
-        # warn if config module not found
-        C_WNOCONF='yes',
-        KOMBU_DISABLE_LIMIT_PROTECTION='yes',
-    )
-
-    if using_coverage and not PYPY3:
-        from warnings import catch_warnings
-        with catch_warnings(record=True):
-            import_all_modules()
-        warnings.resetwarnings()
-    from celery.tests.case import Trap
-    from celery._state import set_default_app
-    set_default_app(Trap())
-
-
-def teardown():
-    # Don't want SUBDEBUG log messages at finalization.
-    try:
-        from multiprocessing.util import get_logger
-    except ImportError:
-        pass
-    else:
-        get_logger().setLevel(logging.WARNING)
-
-    # Make sure test database is removed.
-    import os
-    if os.path.exists('test.db'):
-        try:
-            os.remove('test.db')
-        except WindowsError:
-            pass
-
-    # Make sure there are no remaining threads at shutdown.
-    import threading
-    remaining_threads = [thread for thread in threading.enumerate()
-                         if thread.getName() != 'MainThread']
-    if remaining_threads:
-        sys.stderr.write(
-            '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
-                remaining_threads))
-
-
-def find_distribution_modules(name=__name__, file=__file__):
-    current_dist_depth = len(name.split('.')) - 1
-    current_dist = os.path.join(os.path.dirname(file),
-                                *([os.pardir] * current_dist_depth))
-    abs = os.path.abspath(current_dist)
-    dist_name = os.path.basename(abs)
-
-    for dirpath, dirnames, filenames in os.walk(abs):
-        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
-        if '__init__.py' in filenames:
-            yield package
-            for filename in filenames:
-                if filename.endswith('.py') and filename != '__init__.py':
-                    yield '.'.join([package, filename])[:-3]
-
-
-def import_all_modules(name=__name__, file=__file__,
-                       skip=('celery.decorators',
-                             'celery.task')):
-    for module in find_distribution_modules(name, file):
-        if not module.startswith(skip):
-            try:
-                import_module(module)
-            except ImportError:
-                pass
-            except OSError as exc:
-                warnings.warn(UserWarning(
-                    'Ignored error importing module {0}: {1!r}'.format(
-                        module, exc,
-                    )))

+ 0 - 301
celery/tests/app/test_amqp.py

@@ -1,301 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from datetime import datetime, timedelta
-
-from kombu import Exchange, Queue
-
-from celery import uuid
-from celery.app.amqp import Queues, utf8dict
-from celery.five import keys
-from celery.utils.time import to_utc
-
-from celery.tests.case import AppCase, Mock
-
-
-class test_TaskConsumer(AppCase):
-
-    def test_accept_content(self):
-        with self.app.pool.acquire(block=True) as conn:
-            self.app.conf.accept_content = ['application/json']
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn).accept,
-                {'application/json'},
-            )
-            self.assertEqual(
-                self.app.amqp.TaskConsumer(conn, accept=['json']).accept,
-                {'application/json'},
-            )
-
-
-class test_ProducerPool(AppCase):
-
-    def test_setup_nolimit(self):
-        self.app.conf.broker_pool_limit = None
-        try:
-            delattr(self.app, '_pool')
-        except AttributeError:
-            pass
-        self.app.amqp._producer_pool = None
-        pool = self.app.amqp.producer_pool
-        self.assertEqual(pool.limit, self.app.pool.limit)
-        self.assertFalse(pool._resource.queue)
-
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-        r1.release()
-        r2.release()
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-
-    def test_setup(self):
-        self.app.conf.broker_pool_limit = 2
-        try:
-            delattr(self.app, '_pool')
-        except AttributeError:
-            pass
-        self.app.amqp._producer_pool = None
-        pool = self.app.amqp.producer_pool
-        self.assertEqual(pool.limit, self.app.pool.limit)
-        self.assertTrue(pool._resource.queue)
-
-        p1 = r1 = pool.acquire()
-        p2 = r2 = pool.acquire()
-        r1.release()
-        r2.release()
-        r1 = pool.acquire()
-        r2 = pool.acquire()
-        self.assertIs(p2, r1)
-        self.assertIs(p1, r2)
-        r1.release()
-        r2.release()
-
-
-class test_Queues(AppCase):
-
-    def test_queues_format(self):
-        self.app.amqp.queues._consume_from = {}
-        self.assertEqual(self.app.amqp.queues.format(), '')
-
-    def test_with_defaults(self):
-        self.assertEqual(Queues(None), {})
-
-    def test_add(self):
-        q = Queues()
-        q.add('foo', exchange='ex', routing_key='rk')
-        self.assertIn('foo', q)
-        self.assertIsInstance(q['foo'], Queue)
-        self.assertEqual(q['foo'].routing_key, 'rk')
-
-    def test_with_ha_policy(self):
-        qn = Queues(ha_policy=None, create_missing=False)
-        qn.add('xyz')
-        self.assertIsNone(qn['xyz'].queue_arguments)
-
-        qn.add('xyx', queue_arguments={'x-foo': 'bar'})
-        self.assertEqual(qn['xyx'].queue_arguments, {'x-foo': 'bar'})
-
-        q = Queues(ha_policy='all', create_missing=False)
-        q.add(Queue('foo'))
-        self.assertEqual(q['foo'].queue_arguments, {'x-ha-policy': 'all'})
-
-        qq = Queue('xyx2', queue_arguments={'x-foo': 'bari'})
-        q.add(qq)
-        self.assertEqual(q['xyx2'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-foo': 'bari',
-        })
-
-        q2 = Queues(ha_policy=['A', 'B', 'C'], create_missing=False)
-        q2.add(Queue('foo'))
-        self.assertEqual(q2['foo'].queue_arguments, {
-            'x-ha-policy': 'nodes',
-            'x-ha-policy-params': ['A', 'B', 'C'],
-        })
-
-    def test_select_add(self):
-        q = Queues()
-        q.select(['foo', 'bar'])
-        q.select_add('baz')
-        self.assertItemsEqual(keys(q._consume_from), ['foo', 'bar', 'baz'])
-
-    def test_deselect(self):
-        q = Queues()
-        q.select(['foo', 'bar'])
-        q.deselect('bar')
-        self.assertItemsEqual(keys(q._consume_from), ['foo'])
-
-    def test_with_ha_policy_compat(self):
-        q = Queues(ha_policy='all')
-        q.add('bar')
-        self.assertEqual(q['bar'].queue_arguments, {'x-ha-policy': 'all'})
-
-    def test_add_default_exchange(self):
-        ex = Exchange('fff', 'fanout')
-        q = Queues(default_exchange=ex)
-        q.add(Queue('foo'))
-        self.assertEqual(q['foo'].exchange.name, '')
-
-    def test_alias(self):
-        q = Queues()
-        q.add(Queue('foo', alias='barfoo'))
-        self.assertIs(q['barfoo'], q['foo'])
-
-    def test_with_max_priority(self):
-        qs1 = Queues(max_priority=10)
-        qs1.add('foo')
-        self.assertEqual(qs1['foo'].queue_arguments, {'x-max-priority': 10})
-
-        q1 = Queue('xyx', queue_arguments={'x-max-priority': 3})
-        qs1.add(q1)
-        self.assertEqual(qs1['xyx'].queue_arguments, {
-            'x-max-priority': 3,
-        })
-
-        q1 = Queue('moo', queue_arguments=None)
-        qs1.add(q1)
-        self.assertEqual(qs1['moo'].queue_arguments, {
-            'x-max-priority': 10,
-        })
-
-        qs2 = Queues(ha_policy='all', max_priority=5)
-        qs2.add('bar')
-        self.assertEqual(qs2['bar'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-max-priority': 5
-        })
-
-        q2 = Queue('xyx2', queue_arguments={'x-max-priority': 2})
-        qs2.add(q2)
-        self.assertEqual(qs2['xyx2'].queue_arguments, {
-            'x-ha-policy': 'all',
-            'x-max-priority': 2,
-        })
-
-        qs3 = Queues(max_priority=None)
-        qs3.add('foo2')
-        self.assertEqual(qs3['foo2'].queue_arguments, None)
-
-        q3 = Queue('xyx3', queue_arguments={'x-max-priority': 7})
-        qs3.add(q3)
-        self.assertEqual(qs3['xyx3'].queue_arguments, {
-            'x-max-priority': 7,
-        })
-
-
-class test_AMQP(AppCase):
-
-    def setup(self):
-        self.simple_message = self.app.amqp.as_task_v2(
-            uuid(), 'foo', create_sent_event=True,
-        )
-
-    def test_Queues__with_ha_policy(self):
-        x = self.app.amqp.Queues({}, ha_policy='all')
-        self.assertEqual(x.ha_policy, 'all')
-
-    def test_Queues__with_max_priority(self):
-        x = self.app.amqp.Queues({}, max_priority=23)
-        self.assertEqual(x.max_priority, 23)
-
-    def test_send_task_message__no_kwargs(self):
-        self.app.amqp.send_task_message(Mock(), 'foo', self.simple_message)
-
-    def test_send_task_message__properties(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, foo=1, retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['foo'], 1)
-
-    def test_send_task_message__headers(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, headers={'x1x': 'y2x'},
-            retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['headers']['x1x'], 'y2x')
-
-    def test_send_task_message__queue_string(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, queue='foo', retry=False,
-        )
-        kwargs = prod.publish.call_args[1]
-        self.assertEqual(kwargs['routing_key'], 'foo')
-        self.assertEqual(kwargs['exchange'], '')
-
-    def test_send_event_exchange_string(self):
-        evd = Mock(name='evd')
-        self.app.amqp.send_task_message(
-            Mock(), 'foo', self.simple_message, retry=False,
-            exchange='xyz', routing_key='xyb',
-            event_dispatcher=evd,
-        )
-        evd.publish.assert_called()
-        event = evd.publish.call_args[0][1]
-        self.assertEqual(event['routing_key'], 'xyb')
-        self.assertEqual(event['exchange'], 'xyz')
-
-    def test_send_task_message__with_delivery_mode(self):
-        prod = Mock(name='producer')
-        self.app.amqp.send_task_message(
-            prod, 'foo', self.simple_message, delivery_mode=33, retry=False,
-        )
-        self.assertEqual(prod.publish.call_args[1]['delivery_mode'], 33)
-
-    def test_routes(self):
-        r1 = self.app.amqp.routes
-        r2 = self.app.amqp.routes
-        self.assertIs(r1, r2)
-
-
-class test_as_task_v2(AppCase):
-
-    def test_raises_if_args_is_not_tuple(self):
-        with self.assertRaises(TypeError):
-            self.app.amqp.as_task_v2(uuid(), 'foo', args='123')
-
-    def test_raises_if_kwargs_is_not_mapping(self):
-        with self.assertRaises(TypeError):
-            self.app.amqp.as_task_v2(uuid(), 'foo', kwargs=(1, 2, 3))
-
-    def test_countdown_to_eta(self):
-        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo', countdown=10, now=now,
-        )
-        self.assertEqual(
-            m.headers['eta'],
-            (now + timedelta(seconds=10)).isoformat(),
-        )
-
-    def test_expires_to_datetime(self):
-        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo', expires=30, now=now,
-        )
-        self.assertEqual(
-            m.headers['expires'],
-            (now + timedelta(seconds=30)).isoformat(),
-        )
-
-    def test_callbacks_errbacks_chord(self):
-
-        @self.app.task
-        def t(i):
-            pass
-
-        m = self.app.amqp.as_task_v2(
-            uuid(), 'foo',
-            callbacks=[t.s(1), t.s(2)],
-            errbacks=[t.s(3), t.s(4)],
-            chord=t.s(5),
-        )
-        _, _, embed = m.body
-        self.assertListEqual(
-            embed['callbacks'], [utf8dict(t.s(1)), utf8dict(t.s(2))],
-        )
-        self.assertListEqual(
-            embed['errbacks'], [utf8dict(t.s(3)), utf8dict(t.s(4))],
-        )
-        self.assertEqual(embed['chord'], utf8dict(t.s(5)))

+ 0 - 19
celery/tests/app/test_celery.py

@@ -1,19 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.tests.case import AppCase
-
-import celery
-
-
-class test_celery_package(AppCase):
-
-    def test_version(self):
-        self.assertTrue(celery.VERSION)
-        self.assertGreaterEqual(len(celery.VERSION), 3)
-        celery.VERSION = (0, 3, 0)
-        self.assertGreaterEqual(celery.__version__.count('.'), 2)
-
-    def test_meta(self):
-        for m in ('__author__', '__contact__', '__homepage__',
-                  '__docformat__'):
-            self.assertTrue(getattr(celery, m, None))

+ 0 - 65
celery/tests/app/test_defaults.py

@@ -1,65 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from importlib import import_module
-
-from celery.app.defaults import (
-    _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY,
-    DEFAULTS, NAMESPACES, SETTING_KEYS
-)
-from celery.five import values
-
-from celery.tests.case import AppCase, mock
-
-
-class test_defaults(AppCase):
-
-    def setup(self):
-        self._prev = sys.modules.pop('celery.app.defaults', None)
-
-    def teardown(self):
-        if self._prev:
-            sys.modules['celery.app.defaults'] = self._prev
-
-    def test_option_repr(self):
-        self.assertTrue(repr(NAMESPACES['broker']['url']))
-
-    def test_any(self):
-        val = object()
-        self.assertIs(self.defaults.Option.typemap['any'](val), val)
-
-    @mock.sys_platform('darwin')
-    @mock.pypy_version((1, 4, 0))
-    def test_default_pool_pypy_14(self):
-        self.assertEqual(self.defaults.DEFAULT_POOL, 'solo')
-
-    @mock.sys_platform('darwin')
-    @mock.pypy_version((1, 5, 0))
-    def test_default_pool_pypy_15(self):
-        self.assertEqual(self.defaults.DEFAULT_POOL, 'prefork')
-
-    def test_compat_indices(self):
-        self.assertFalse(any(key.isupper() for key in DEFAULTS))
-        self.assertFalse(any(key.islower() for key in _OLD_DEFAULTS))
-        self.assertFalse(any(key.isupper() for key in _TO_OLD_KEY))
-        self.assertFalse(any(key.islower() for key in _TO_NEW_KEY))
-        self.assertFalse(any(key.isupper() for key in SETTING_KEYS))
-        self.assertFalse(any(key.islower() for key in _OLD_SETTING_KEYS))
-        self.assertFalse(any(value.isupper() for value in values(_TO_NEW_KEY)))
-        self.assertFalse(any(value.islower() for value in values(_TO_OLD_KEY)))
-
-        for key in _TO_NEW_KEY:
-            self.assertIn(key, _OLD_SETTING_KEYS)
-        for key in _TO_OLD_KEY:
-            self.assertIn(key, SETTING_KEYS)
-
-    def test_find(self):
-        find = self.defaults.find
-
-        self.assertEqual(find('default_queue')[2].default, 'celery')
-        self.assertEqual(find('task_default_exchange')[2], 'celery')
-
-    @property
-    def defaults(self):
-        return import_module('celery.app.defaults')

+ 0 - 78
celery/tests/app/test_registry.py

@@ -1,78 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.app.registry import _unpickle_task, _unpickle_task_v2
-from celery.tests.case import AppCase, depends_on_current_app
-
-
-def returns():
-    return 1
-
-
-class test_unpickle_task(AppCase):
-
-    @depends_on_current_app
-    def test_unpickle_v1(self):
-        self.app.tasks['txfoo'] = 'bar'
-        self.assertEqual(_unpickle_task('txfoo'), 'bar')
-
-    @depends_on_current_app
-    def test_unpickle_v2(self):
-        self.app.tasks['txfoo1'] = 'bar1'
-        self.assertEqual(_unpickle_task_v2('txfoo1'), 'bar1')
-        self.assertEqual(_unpickle_task_v2('txfoo1', module='celery'), 'bar1')
-
-
-class test_TaskRegistry(AppCase):
-
-    def setup(self):
-        self.mytask = self.app.task(name='A', shared=False)(returns)
-        self.myperiodic = self.app.task(
-            name='B', shared=False, type='periodic',
-        )(returns)
-
-    def test_NotRegistered_str(self):
-        self.assertTrue(repr(self.app.tasks.NotRegistered('tasks.add')))
-
-    def assertRegisterUnregisterCls(self, r, task):
-        r.unregister(task)
-        with self.assertRaises(r.NotRegistered):
-            r.unregister(task)
-        r.register(task)
-        self.assertIn(task.name, r)
-
-    def assertRegisterUnregisterFunc(self, r, task, task_name):
-        with self.assertRaises(r.NotRegistered):
-            r.unregister(task_name)
-        r.register(task, task_name)
-        self.assertIn(task_name, r)
-
-    def test_task_registry(self):
-        r = self.app._tasks
-        self.assertIsInstance(r, dict, 'TaskRegistry is mapping')
-
-        self.assertRegisterUnregisterCls(r, self.mytask)
-        self.assertRegisterUnregisterCls(r, self.myperiodic)
-
-        r.register(self.myperiodic)
-        r.unregister(self.myperiodic.name)
-        self.assertNotIn(self.myperiodic, r)
-        r.register(self.myperiodic)
-
-        tasks = dict(r)
-        self.assertIs(tasks.get(self.mytask.name), self.mytask)
-        self.assertIs(tasks.get(self.myperiodic.name), self.myperiodic)
-
-        self.assertIs(r[self.mytask.name], self.mytask)
-        self.assertIs(r[self.myperiodic.name], self.myperiodic)
-
-        r.unregister(self.mytask)
-        self.assertNotIn(self.mytask.name, r)
-        r.unregister(self.myperiodic)
-        self.assertNotIn(self.myperiodic.name, r)
-
-        self.assertTrue(self.mytask.run())
-        self.assertTrue(self.myperiodic.run())
-
-    def test_compat(self):
-        self.assertTrue(self.app.tasks.regular())
-        self.assertTrue(self.app.tasks.periodic())

+ 0 - 41
celery/tests/backends/test_backends.py

@@ -1,41 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery import backends
-from celery.backends.amqp import AMQPBackend
-from celery.backends.cache import CacheBackend
-from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, depends_on_current_app, patch
-
-
-class test_backends(AppCase):
-
-    def test_get_backend_aliases(self):
-        expects = [('amqp://', AMQPBackend),
-                   ('cache+memory://', CacheBackend)]
-
-        for url, expect_cls in expects:
-            backend, url = backends.get_backend_by_url(url, self.app.loader)
-            self.assertIsInstance(
-                backend(app=self.app, url=url),
-                expect_cls,
-            )
-
-    def test_unknown_backend(self):
-        with self.assertRaises(ImportError):
-            backends.get_backend_cls('fasodaopjeqijwqe', self.app.loader)
-
-    @depends_on_current_app
-    def test_default_backend(self):
-        self.assertEqual(backends.default_backend, self.app.backend)
-
-    def test_backend_by_url(self, url='redis://localhost/1'):
-        from celery.backends.redis import RedisBackend
-        backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, RedisBackend)
-        self.assertEqual(url_, url)
-
-    def test_sym_raises_ValuError(self):
-        with patch('celery.backends.symbol_by_name') as sbn:
-            sbn.side_effect = ValueError()
-            with self.assertRaises(ImproperlyConfigured):
-                backends.get_backend_cls('xxx.xxx:foo', self.app.loader)

+ 0 - 349
celery/tests/case.py

@@ -1,349 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import importlib
-import inspect
-import logging
-import numbers
-import os
-import sys
-import threading
-
-from copy import deepcopy
-from datetime import datetime, timedelta
-from functools import partial
-
-from kombu import Queue
-from kombu.utils.imports import symbol_by_name
-from vine.utils import wraps
-
-from celery import Celery
-from celery.app import current_app
-from celery.backends.cache import CacheBackend, DummyClient
-from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning
-from celery.utils.imports import qualname
-
-from case import (
-    ANY, ContextMock, MagicMock, Mock, call, mock, skip, patch, sentinel,
-)
-from case import Case as _Case
-from case.utils import decorator
-
-__all__ = [
-    'ANY', 'ContextMock', 'MagicMock', 'Mock',
-    'call', 'mock', 'skip', 'patch', 'sentinel',
-
-    'AppCase', 'TaskMessage', 'TaskMessage1',
-    'depends_on_current_app', 'assert_signal_called', 'task_message_from_sig',
-]
-
-CASE_REDEFINES_SETUP = """\
-{name} (subclass of AppCase) redefines private "setUp", should be: "setup"\
-"""
-CASE_REDEFINES_TEARDOWN = """\
-{name} (subclass of AppCase) redefines private "tearDown", \
-should be: "teardown"\
-"""
-CASE_LOG_REDIRECT_EFFECT = """\
-Test {0} didn't disable LoggingProxy for {1}\
-"""
-CASE_LOG_LEVEL_EFFECT = """\
-Test {0} Modified the level of the root logger\
-"""
-CASE_LOG_HANDLER_EFFECT = """\
-Test {0} Modified handlers for the root logger\
-"""
-
-CELERY_TEST_CONFIG = {
-    #: Don't want log output when running suite.
-    'worker_hijack_root_logger': False,
-    'worker_log_color': False,
-    'task_default_queue': 'testcelery',
-    'task_default_exchange': 'testcelery',
-    'task_default_routing_key': 'testcelery',
-    'task_queues': (
-        Queue('testcelery', routing_key='testcelery'),
-    ),
-    'accept_content': ('json', 'pickle'),
-    'enable_utc': True,
-    'timezone': 'UTC',
-
-    # Mongo results tests (only executed if installed and running)
-    'mongodb_backend_settings': {
-        'host': os.environ.get('MONGO_HOST') or 'localhost',
-        'port': os.environ.get('MONGO_PORT') or 27017,
-        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
-        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
-                                'taskmeta_collection'),
-        'user': os.environ.get('MONGO_USER'),
-        'password': os.environ.get('MONGO_PASSWORD'),
-    }
-}
-
-
-class Case(_Case):
-    DeprecationWarning = CDeprecationWarning
-    PendingDeprecationWarning = CPendingDeprecationWarning
-
-
-class Trap(object):
-
-    def __getattr__(self, name):
-        raise RuntimeError('Test depends on current_app')
-
-
-class UnitLogging(symbol_by_name(Celery.log_cls)):
-
-    def __init__(self, *args, **kwargs):
-        super(UnitLogging, self).__init__(*args, **kwargs)
-        self.already_setup = True
-
-
-def UnitApp(name=None, set_as_current=False, log=UnitLogging,
-            broker='memory://', backend='cache+memory://', **kwargs):
-    app = Celery(name or 'celery.tests',
-                 set_as_current=set_as_current,
-                 log=log, broker=broker, backend=backend,
-                 **kwargs)
-    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
-    return app
-
-
-def alive_threads():
-    return [thread for thread in threading.enumerate() if thread.is_alive()]
-
-
-def depends_on_current_app(fun):
-    if inspect.isclass(fun):
-        fun.contained = False
-    else:
-        @wraps(fun)
-        def __inner(self, *args, **kwargs):
-            self.app.set_current()
-            return fun(self, *args, **kwargs)
-        return __inner
-
-
-class AppCase(Case):
-    contained = True
-    _threads_at_startup = [None]
-
-    def __init__(self, *args, **kwargs):
-        super(AppCase, self).__init__(*args, **kwargs)
-        setUp = self.__class__.__dict__.get('setUp')
-        tearDown = self.__class__.__dict__.get('tearDown')
-        if setUp and not hasattr(setUp, '__wrapped__'):
-            raise RuntimeError(
-                CASE_REDEFINES_SETUP.format(name=qualname(self)),
-            )
-        if tearDown and not hasattr(tearDown, '__wrapped__'):
-            raise RuntimeError(
-                CASE_REDEFINES_TEARDOWN.format(name=qualname(self)),
-            )
-
-    def Celery(self, *args, **kwargs):
-        return UnitApp(*args, **kwargs)
-
-    def threads_at_startup(self):
-        if self._threads_at_startup[0] is None:
-            self._threads_at_startup[0] = alive_threads()
-        return self._threads_at_startup[0]
-
-    def setUp(self):
-        self._threads_at_setup = self.threads_at_startup()
-        from celery import _state
-        from celery import result
-        self._prev_res_join_block = result.task_join_will_block
-        self._prev_state_join_block = _state.task_join_will_block
-        result.task_join_will_block = \
-            _state.task_join_will_block = lambda: False
-        self._current_app = current_app()
-        self._default_app = _state.default_app
-        trap = Trap()
-        self._prev_tls = _state._tls
-        _state.set_default_app(trap)
-
-        class NonTLS(object):
-            current_app = trap
-        _state._tls = NonTLS()
-
-        self.app = self.Celery(set_as_current=False)
-        if not self.contained:
-            self.app.set_current()
-        root = logging.getLogger()
-        self.__rootlevel = root.level
-        self.__roothandlers = root.handlers
-        _state._set_task_join_will_block(False)
-        try:
-            self.setup()
-        except:
-            self._teardown_app()
-            raise
-
-    def _teardown_app(self):
-        from celery import _state
-        from celery import result
-        from celery.utils.log import LoggingProxy
-        assert sys.stdout
-        assert sys.stderr
-        assert sys.__stdout__
-        assert sys.__stderr__
-        this = self._get_test_name()
-        result.task_join_will_block = self._prev_res_join_block
-        _state.task_join_will_block = self._prev_state_join_block
-        if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
-                isinstance(sys.__stdout__, (LoggingProxy, Mock)):
-            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
-        if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
-                isinstance(sys.__stderr__, (LoggingProxy, Mock)):
-            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
-        backend = self.app.__dict__.get('backend')
-        if backend is not None:
-            if isinstance(backend, CacheBackend):
-                if isinstance(backend.client, DummyClient):
-                    backend.client.cache.clear()
-                backend._cache.clear()
-        from celery import _state
-        _state._set_task_join_will_block(False)
-
-        _state.set_default_app(self._default_app)
-        _state._tls = self._prev_tls
-        _state._tls.current_app = self._current_app
-        if self.app is not self._current_app:
-            self.app.close()
-        self.app = None
-        self.assertEqual(self._threads_at_setup, alive_threads())
-
-        # Make sure no test left the shutdown flags enabled.
-        from celery.worker import state as worker_state
-        # check for EX_OK
-        self.assertIsNot(worker_state.should_stop, False)
-        self.assertIsNot(worker_state.should_terminate, False)
-        # check for other true values
-        self.assertFalse(worker_state.should_stop)
-        self.assertFalse(worker_state.should_terminate)
-
-    def _get_test_name(self):
-        return '.'.join([self.__class__.__name__, self._testMethodName])
-
-    def tearDown(self):
-        try:
-            self.teardown()
-        finally:
-            self._teardown_app()
-        self.assert_no_logging_side_effect()
-
-    def assert_no_logging_side_effect(self):
-        this = self._get_test_name()
-        root = logging.getLogger()
-        if root.level != self.__rootlevel:
-            raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
-        if root.handlers != self.__roothandlers:
-            raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
-
-    def assert_signal_called(self, signal, **expected):
-        return assert_signal_called(signal, **expected)
-
-    def setup(self):
-        pass
-
-    def teardown(self):
-        pass
-
-
-@decorator
-def assert_signal_called(signal, **expected):
-    handler = Mock()
-    call_handler = partial(handler)
-    signal.connect(call_handler)
-    try:
-        yield handler
-    finally:
-        signal.disconnect(call_handler)
-    handler.assert_called_with(signal=signal, **expected)
-
-
-def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
-                errbacks=None, chain=None, shadow=None, utc=None, **options):
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {
-        'id': id,
-        'task': name,
-        'shadow': shadow,
-    }
-    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
-    message.headers.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        (args, kwargs, embed), serializer='json',
-    )
-    message.payload = (args, kwargs, embed)
-    return message
-
-
-def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
-                 errbacks=None, chain=None, **options):
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {}
-    message.payload = {
-        'task': name,
-        'id': id,
-        'args': args,
-        'kwargs': kwargs,
-        'callbacks': callbacks,
-        'errbacks': errbacks,
-    }
-    message.payload.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        message.payload,
-    )
-    return message
-
-
-def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
-    sig.freeze()
-    callbacks = sig.options.pop('link', None)
-    errbacks = sig.options.pop('link_error', None)
-    countdown = sig.options.pop('countdown', None)
-    if countdown:
-        eta = app.now() + timedelta(seconds=countdown)
-    else:
-        eta = sig.options.pop('eta', None)
-    if eta and isinstance(eta, datetime):
-        eta = eta.isoformat()
-    expires = sig.options.pop('expires', None)
-    if expires and isinstance(expires, numbers.Real):
-        expires = app.now() + timedelta(seconds=expires)
-    if expires and isinstance(expires, datetime):
-        expires = expires.isoformat()
-    return TaskMessage(
-        sig.task, id=sig.id, args=sig.args,
-        kwargs=sig.kwargs,
-        callbacks=[dict(s) for s in callbacks] if callbacks else None,
-        errbacks=[dict(s) for s in errbacks] if errbacks else None,
-        eta=eta,
-        expires=expires,
-        utc=utc,
-        **sig.options
-    )
-
-
-def _old_patch(module, name, mocked):
-    module = importlib.import_module(module)
-
-    def _patch(fun):
-
-        @wraps(fun)
-        def __patched(*args, **kwargs):
-            prev = getattr(module, name)
-            setattr(module, name, mocked)
-            try:
-                return fun(*args, **kwargs)
-            finally:
-                setattr(module, name, prev)
-        return __patched
-    return _patch

+ 0 - 38
celery/tests/compat_modules/test_decorators.py

@@ -1,38 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import warnings
-
-from celery.task import base
-
-from celery.tests.case import AppCase, depends_on_current_app
-
-
-def add(x, y):
-    return x + y
-
-
-@depends_on_current_app
-class test_decorators(AppCase):
-
-    def test_task_alias(self):
-        from celery import task
-        self.assertTrue(task.__file__)
-        self.assertTrue(task(add))
-
-    def setup(self):
-        with warnings.catch_warnings(record=True):
-            from celery import decorators
-            self.decorators = decorators
-
-    def assertCompatDecorator(self, decorator, type, **opts):
-        task = decorator(**opts)(add)
-        self.assertEqual(task(8, 8), 16)
-        self.assertIsInstance(task, type)
-
-    def test_task(self):
-        self.assertCompatDecorator(self.decorators.task, base.BaseTask)
-
-    def test_periodic_task(self):
-        self.assertCompatDecorator(self.decorators.periodic_task,
-                                   base.BaseTask,
-                                   run_every=1)

+ 0 - 129
celery/tests/concurrency/test_gevent.py

@@ -1,129 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.concurrency.gevent import (
-    Timer,
-    TaskPool,
-    apply_timeout,
-)
-
-from celery.tests.case import AppCase, Mock, patch, skip
-
-gevent_modules = (
-    'gevent',
-    'gevent.monkey',
-    'gevent.greenlet',
-    'gevent.pool',
-    'greenlet',
-)
-
-
-@skip.if_pypy()
-class GeventCase(AppCase):
-
-    def setup(self):
-        self.mock_modules(*gevent_modules)
-
-
-class test_gevent_patch(GeventCase):
-
-    def test_is_patched(self):
-        with patch('gevent.monkey.patch_all', create=True) as patch_all:
-            import gevent
-            gevent.version_info = (1, 0, 0)
-            from celery import maybe_patch_concurrency
-            maybe_patch_concurrency(['x', '-P', 'gevent'])
-            patch_all.assert_called()
-
-
-class test_Timer(GeventCase):
-
-    def setup(self):
-        GeventCase.setup(self)
-        self.greenlet = self.patch('gevent.greenlet')
-        self.GreenletExit = self.patch('gevent.greenlet.GreenletExit')
-
-    def test_sched(self):
-        self.greenlet.Greenlet = object
-        x = Timer()
-        self.greenlet.Greenlet = Mock()
-        x._Greenlet.spawn_later = Mock()
-        x._GreenletExit = KeyError
-        entry = Mock()
-        g = x._enter(1, 0, entry)
-        self.assertTrue(x.queue)
-
-        x._entry_exit(g)
-        g.kill.assert_called_with()
-        self.assertFalse(x._queue)
-
-        x._queue.add(g)
-        x.clear()
-        x._queue.add(g)
-        g.kill.side_effect = KeyError()
-        x.clear()
-
-        g = x._Greenlet()
-        g.cancel()
-
-
-class test_TaskPool(GeventCase):
-
-    def setup(self):
-        GeventCase.setup(self)
-        self.spawn_raw = self.patch('gevent.spawn_raw')
-        self.Pool = self.patch('gevent.pool.Pool')
-
-    def test_pool(self):
-        x = TaskPool()
-        x.on_start()
-        x.on_stop()
-        x.on_apply(Mock())
-        x._pool = None
-        x.on_stop()
-
-        x._pool = Mock()
-        x._pool._semaphore.counter = 1
-        x._pool.size = 1
-        x.grow()
-        self.assertEqual(x._pool.size, 2)
-        self.assertEqual(x._pool._semaphore.counter, 2)
-        x.shrink()
-        self.assertEqual(x._pool.size, 1)
-        self.assertEqual(x._pool._semaphore.counter, 1)
-
-        x._pool = [4, 5, 6]
-        self.assertEqual(x.num_processes, 3)
-
-
-class test_apply_timeout(AppCase):
-
-    def test_apply_timeout(self):
-
-            class Timeout(Exception):
-                value = None
-
-                def __init__(self, value):
-                    self.__class__.value = value
-
-                def __enter__(self):
-                    return self
-
-                def __exit__(self, *exc_info):
-                    pass
-            timeout_callback = Mock(name='timeout_callback')
-            apply_target = Mock(name='apply_target')
-            apply_timeout(
-                Mock(), timeout=10, callback=Mock(name='callback'),
-                timeout_callback=timeout_callback,
-                apply_target=apply_target, Timeout=Timeout,
-            )
-            self.assertEqual(Timeout.value, 10)
-            apply_target.assert_called()
-
-            apply_target.side_effect = Timeout(10)
-            apply_timeout(
-                Mock(), timeout=10, callback=Mock(),
-                timeout_callback=timeout_callback,
-                apply_target=apply_target, Timeout=Timeout,
-            )
-            timeout_callback.assert_called_with(False, 10)

+ 0 - 54
celery/tests/tasks/test_states.py

@@ -1,54 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery import states
-
-from celery.tests.case import Case
-
-
-class test_state_precedence(Case):
-
-    def test_gt(self):
-        self.assertGreater(
-            states.state(states.SUCCESS), states.state(states.PENDING),
-        )
-        self.assertGreater(
-            states.state(states.FAILURE), states.state(states.RECEIVED),
-        )
-        self.assertGreater(
-            states.state(states.REVOKED), states.state(states.STARTED),
-        )
-        self.assertGreater(
-            states.state(states.SUCCESS), states.state('CRASHED'),
-        )
-        self.assertGreater(
-            states.state(states.FAILURE), states.state('CRASHED'),
-        )
-        self.assertLessEqual(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-
-    def test_lt(self):
-        self.assertLess(
-            states.state(states.PENDING), states.state(states.SUCCESS),
-        )
-        self.assertLess(
-            states.state(states.RECEIVED), states.state(states.FAILURE),
-        )
-        self.assertLess(
-            states.state(states.STARTED), states.state(states.REVOKED),
-        )
-        self.assertLess(
-            states.state('CRASHED'), states.state(states.SUCCESS),
-        )
-        self.assertLess(
-            states.state('CRASHED'), states.state(states.FAILURE),
-        )
-        self.assertLess(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-        self.assertLessEqual(
-            states.state(states.REVOKED), states.state('CRASHED'),
-        )
-        self.assertGreaterEqual(
-            states.state('CRASHED'), states.state(states.REVOKED),
-        )

+ 0 - 98
celery/tests/utils/test_debug.py

@@ -1,98 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import debug
-
-from celery.tests.case import Case, Mock, patch
-
-
-class test_on_blocking(Case):
-
-    @patch('inspect.getframeinfo')
-    def test_on_blocking(self, getframeinfo):
-        frame = Mock(name='frame')
-        with self.assertRaises(RuntimeError):
-            debug._on_blocking(1, frame)
-            getframeinfo.assert_called_with(frame)
-
-
-class test_blockdetection(Case):
-
-    @patch('celery.utils.debug.signals')
-    def test_context(self, signals):
-        with debug.blockdetection(10):
-            signals.arm_alarm.assert_called_with(10)
-            signals.__setitem__.assert_called_with('ALRM', debug._on_blocking)
-        signals.__setitem__.assert_called_with('ALRM', signals['ALRM'])
-        signals.reset_alarm.assert_called_with()
-
-
-class test_sample_mem(Case):
-
-    @patch('celery.utils.debug.mem_rss')
-    def test_sample_mem(self, mem_rss):
-        prev, debug._mem_sample = debug._mem_sample, []
-        try:
-            debug.sample_mem()
-            self.assertIs(debug._mem_sample[0], mem_rss())
-        finally:
-            debug._mem_sample = prev
-
-
-class test_sample(Case):
-
-    def test_sample(self):
-        x = list(range(100))
-        self.assertEqual(
-            list(debug.sample(x, 10)),
-            [0, 10, 20, 30, 40, 50, 60, 70, 80, 90],
-        )
-        x = list(range(91))
-        self.assertEqual(
-            list(debug.sample(x, 10)),
-            [0, 9, 18, 27, 36, 45, 54, 63, 72, 81],
-        )
-
-
-class test_hfloat(Case):
-
-    def test_hfloat(self):
-        self.assertEqual(str(debug.hfloat(10, 5)), '10')
-        self.assertEqual(str(debug.hfloat(10.45645234234, 5)), '10.456')
-
-
-class test_humanbytes(Case):
-
-    def test_humanbytes(self):
-        self.assertEqual(debug.humanbytes(2 ** 20), '1MB')
-        self.assertEqual(debug.humanbytes(4 * 2 ** 20), '4MB')
-        self.assertEqual(debug.humanbytes(2 ** 16), '64KB')
-        self.assertEqual(debug.humanbytes(2 ** 16), '64KB')
-        self.assertEqual(debug.humanbytes(2 ** 8), '256b')
-
-
-class test_mem_rss(Case):
-
-    @patch('celery.utils.debug.ps')
-    @patch('celery.utils.debug.humanbytes')
-    def test_mem_rss(self, humanbytes, ps):
-        ret = debug.mem_rss()
-        ps.assert_called_with()
-        ps().memory_info.assert_called_with()
-        humanbytes.assert_called_with(ps().memory_info().rss)
-        self.assertIs(ret, humanbytes())
-        ps.return_value = None
-        self.assertIsNone(debug.mem_rss())
-
-
-class test_ps(Case):
-
-    @patch('celery.utils.debug.Process')
-    @patch('os.getpid')
-    def test_ps(self, getpid, Process):
-        prev, debug._process = debug._process, None
-        try:
-            debug.ps()
-            Process.assert_called_with(getpid())
-            self.assertIs(debug._process, Process())
-        finally:
-            debug._process = prev

+ 0 - 20
celery/tests/utils/test_encoding.py

@@ -1,20 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import encoding
-from celery.tests.case import Case
-
-
-class test_encoding(Case):
-
-    def test_safe_str(self):
-        self.assertTrue(encoding.safe_str(object()))
-        self.assertTrue(encoding.safe_str('foo'))
-
-    def test_safe_repr(self):
-        self.assertTrue(encoding.safe_repr(object()))
-
-        class foo(object):
-            def __repr__(self):
-                raise ValueError('foo')
-
-        self.assertTrue(encoding.safe_repr(foo()))

+ 0 - 210
celery/tests/utils/test_functional.py

@@ -1,210 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu.utils.functional import lazy
-
-from celery.five import range, nextfun
-from celery.utils.functional import (
-    DummyContext,
-    fun_takes_argument,
-    head_from_fun,
-    firstmethod,
-    first,
-    maybe_list,
-    mlazy,
-    padlist,
-    regen,
-)
-
-from celery.tests.case import Case
-
-
-class test_DummyContext(Case):
-
-    def test_context(self):
-        with DummyContext():
-            pass
-        with self.assertRaises(KeyError):
-            with DummyContext():
-                raise KeyError()
-
-
-class test_utils(Case):
-
-    def test_padlist(self):
-        self.assertListEqual(
-            padlist(['George', 'Costanza', 'NYC'], 3),
-            ['George', 'Costanza', 'NYC'],
-        )
-        self.assertListEqual(
-            padlist(['George', 'Costanza'], 3),
-            ['George', 'Costanza', None],
-        )
-        self.assertListEqual(
-            padlist(['George', 'Costanza', 'NYC'], 4, default='Earth'),
-            ['George', 'Costanza', 'NYC', 'Earth'],
-        )
-
-    def test_firstmethod_AttributeError(self):
-        self.assertIsNone(firstmethod('foo')([object()]))
-
-    def test_firstmethod_handles_lazy(self):
-
-        class A(object):
-
-            def __init__(self, value=None):
-                self.value = value
-
-            def m(self):
-                return self.value
-
-        self.assertEqual('four', firstmethod('m')([
-            A(), A(), A(), A('four'), A('five')]))
-        self.assertEqual('four', firstmethod('m')([
-            A(), A(), A(), lazy(lambda: A('four')), A('five')]))
-
-    def test_first(self):
-        iterations = [0]
-
-        def predicate(value):
-            iterations[0] += 1
-            if value == 5:
-                return True
-            return False
-
-        self.assertEqual(5, first(predicate, range(10)))
-        self.assertEqual(iterations[0], 6)
-
-        iterations[0] = 0
-        self.assertIsNone(first(predicate, range(10, 20)))
-        self.assertEqual(iterations[0], 10)
-
-    def test_maybe_list(self):
-        self.assertEqual(maybe_list(1), [1])
-        self.assertEqual(maybe_list([1]), [1])
-        self.assertIsNone(maybe_list(None))
-
-
-class test_mlazy(Case):
-
-    def test_is_memoized(self):
-
-        it = iter(range(20, 30))
-        p = mlazy(nextfun(it))
-        self.assertEqual(p(), 20)
-        self.assertTrue(p.evaluated)
-        self.assertEqual(p(), 20)
-        self.assertEqual(repr(p), '20')
-
-
-class test_regen(Case):
-
-    def test_regen_list(self):
-        l = [1, 2]
-        r = regen(iter(l))
-        self.assertIs(regen(l), l)
-        self.assertEqual(r, l)
-        self.assertEqual(r, l)
-        self.assertEqual(r.__length_hint__(), 0)
-
-        fun, args = r.__reduce__()
-        self.assertEqual(fun(*args), l)
-
-    def test_regen_gen(self):
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[7], 7)
-        self.assertEqual(g[6], 6)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g[4], 4)
-        self.assertEqual(g[3], 3)
-        self.assertEqual(g[2], 2)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g.data, list(range(10)))
-        self.assertEqual(g[8], 8)
-        self.assertEqual(g[0], 0)
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[0], 0)
-        self.assertEqual(g[1], 1)
-        self.assertEqual(g.data, list(range(10)))
-        g = regen(iter([1]))
-        self.assertEqual(g[0], 1)
-        with self.assertRaises(IndexError):
-            g[1]
-        self.assertEqual(g.data, [1])
-
-        g = regen(iter(list(range(10))))
-        self.assertEqual(g[-1], 9)
-        self.assertEqual(g[-2], 8)
-        self.assertEqual(g[-3], 7)
-        self.assertEqual(g[-4], 6)
-        self.assertEqual(g[-5], 5)
-        self.assertEqual(g[5], 5)
-        self.assertEqual(g.data, list(range(10)))
-
-        self.assertListEqual(list(iter(g)), list(range(10)))
-
-
-class test_head_from_fun(Case):
-
-    def test_from_cls(self):
-        class X(object):
-            def __call__(x, y, kwarg=1):
-                pass
-
-        g = head_from_fun(X())
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-    def test_from_fun(self):
-        def f(x, y, kwarg=1):
-            pass
-        g = head_from_fun(f)
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-    def test_from_fun_with_hints(self):
-        local = {}
-        fun = ('def f_hints(x: int, y: int, kwarg: int=1):'
-               '    pass')
-        try:
-            exec(fun, {}, local)
-        except SyntaxError:
-            # py2
-            return
-        f_hints = local['f_hints']
-
-        g = head_from_fun(f_hints)
-        with self.assertRaises(TypeError):
-            g(1)
-        g(1, 2)
-        g(1, 2, kwarg=3)
-
-
-class test_fun_takes_argument(Case):
-
-    def test_starkwargs(self):
-        self.assertTrue(fun_takes_argument('foo', lambda **kw: 1))
-
-    def test_named(self):
-        self.assertTrue(fun_takes_argument('foo', lambda a, foo, bar: 1))
-
-        def fun(a, b, c, d):
-            return 1
-
-        self.assertTrue(fun_takes_argument('foo', fun, position=4))
-
-    def test_starargs(self):
-        self.assertTrue(fun_takes_argument('foo', lambda a, *args: 1))
-
-    def test_does_not(self):
-        self.assertFalse(fun_takes_argument('foo', lambda a, bar, baz: 1))
-        self.assertFalse(fun_takes_argument('foo', lambda: 1))
-
-        def fun(a, b, foo):
-            return 1
-
-        self.assertFalse(fun_takes_argument('foo', fun, position=4))

+ 0 - 58
celery/tests/utils/test_imports.py

@@ -1,58 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.five import bytes_if_py2
-
-from celery.utils.imports import (
-    NotAPackage,
-    qualname,
-    gen_task_name,
-    reload_from_cwd,
-    module_file,
-    find_module,
-)
-
-from celery.tests.case import Case, Mock, patch
-
-
-class test_import_utils(Case):
-
-    def test_find_module(self):
-        self.assertTrue(find_module('celery'))
-        imp = Mock()
-        imp.return_value = None
-        with self.assertRaises(NotAPackage):
-            find_module('foo.bar.baz', imp=imp)
-        self.assertTrue(find_module('celery.worker.request'))
-
-    def test_qualname(self):
-        Class = type(bytes_if_py2('Fox'), (object,), {
-            '__module__': 'quick.brown',
-        })
-        self.assertEqual(qualname(Class), 'quick.brown.Fox')
-        self.assertEqual(qualname(Class()), 'quick.brown.Fox')
-
-    @patch('celery.utils.imports.reload')
-    def test_reload_from_cwd(self, reload):
-        reload_from_cwd('foo')
-        reload.assert_called()
-
-    def test_reload_from_cwd_custom_reloader(self):
-        reload = Mock()
-        reload_from_cwd('foo', reload)
-        reload.assert_called()
-
-    def test_module_file(self):
-        m1 = Mock()
-        m1.__file__ = '/opt/foo/xyz.pyc'
-        self.assertEqual(module_file(m1), '/opt/foo/xyz.py')
-        m2 = Mock()
-        m2.__file__ = '/opt/foo/xyz.py'
-        self.assertEqual(module_file(m1), '/opt/foo/xyz.py')
-
-
-class test_gen_task_name(Case):
-
-    def test_no_module(self):
-        app = Mock()
-        app.name == '__main__'
-        self.assertTrue(gen_task_name(app, 'foo', 'axsadaewe'))

+ 0 - 372
celery/tests/utils/test_local.py

@@ -1,372 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from celery.five import python_2_unicode_compatible, string, long_t
-from celery.local import (
-    Proxy,
-    PromiseProxy,
-    maybe_evaluate,
-    try_import,
-)
-from celery.tests.case import Case, Mock, skip
-
-PY3 = sys.version_info[0] == 3
-
-
-class test_try_import(Case):
-
-    def test_imports(self):
-        self.assertTrue(try_import(__name__))
-
-    def test_when_default(self):
-        default = object()
-        self.assertIs(try_import('foobar.awqewqe.asdwqewq', default), default)
-
-
-class test_Proxy(Case):
-
-    def test_std_class_attributes(self):
-        self.assertEqual(Proxy.__name__, 'Proxy')
-        self.assertEqual(Proxy.__module__, 'celery.local')
-        self.assertIsInstance(Proxy.__doc__, str)
-
-    def test_doc(self):
-        def real():
-            pass
-        x = Proxy(real, __doc__='foo')
-        self.assertEqual(x.__doc__, 'foo')
-
-    def test_name(self):
-
-        def real():
-            """real function"""
-            return 'REAL'
-
-        x = Proxy(lambda: real, name='xyz')
-        self.assertEqual(x.__name__, 'xyz')
-
-        y = Proxy(lambda: real)
-        self.assertEqual(y.__name__, 'real')
-
-        self.assertEqual(x.__doc__, 'real function')
-
-        self.assertEqual(x.__class__, type(real))
-        self.assertEqual(x.__dict__, real.__dict__)
-        self.assertEqual(repr(x), repr(real))
-        self.assertTrue(x.__module__)
-
-    def test_get_current_local(self):
-        x = Proxy(lambda: 10)
-        object.__setattr__(x, '_Proxy_local', Mock())
-        self.assertTrue(x._get_current_object())
-
-    def test_bool(self):
-
-        class X(object):
-
-            def __bool__(self):
-                return False
-            __nonzero__ = __bool__
-
-        x = Proxy(lambda: X())
-        self.assertFalse(x)
-
-    def test_slots(self):
-
-        class X(object):
-            __slots__ = ()
-
-        x = Proxy(X)
-        with self.assertRaises(AttributeError):
-            x.__dict__
-
-    @skip.if_python3()
-    def test_unicode(self):
-
-        @python_2_unicode_compatible
-        class X(object):
-
-            def __unicode__(self):
-                return 'UNICODE'
-            __str__ = __unicode__
-
-            def __repr__(self):
-                return 'REPR'
-
-        x = Proxy(lambda: X())
-        self.assertEqual(string(x), 'UNICODE')
-        del(X.__unicode__)
-        del(X.__str__)
-        self.assertEqual(string(x), 'REPR')
-
-    def test_dir(self):
-
-        class X(object):
-
-            def __dir__(self):
-                return ['a', 'b', 'c']
-
-        x = Proxy(lambda: X())
-        self.assertListEqual(dir(x), ['a', 'b', 'c'])
-
-        class Y(object):
-
-            def __dir__(self):
-                raise RuntimeError()
-        y = Proxy(lambda: Y())
-        self.assertListEqual(dir(y), [])
-
-    def test_getsetdel_attr(self):
-
-        class X(object):
-            a = 1
-            b = 2
-            c = 3
-
-            def __dir__(self):
-                return ['a', 'b', 'c']
-
-        v = X()
-
-        x = Proxy(lambda: v)
-        self.assertListEqual(x.__members__, ['a', 'b', 'c'])
-        self.assertEqual(x.a, 1)
-        self.assertEqual(x.b, 2)
-        self.assertEqual(x.c, 3)
-
-        setattr(x, 'a', 10)
-        self.assertEqual(x.a, 10)
-
-        del(x.a)
-        self.assertEqual(x.a, 1)
-
-    def test_dictproxy(self):
-        v = {}
-        x = Proxy(lambda: v)
-        x['foo'] = 42
-        self.assertEqual(x['foo'], 42)
-        self.assertEqual(len(x), 1)
-        self.assertIn('foo', x)
-        del(x['foo'])
-        with self.assertRaises(KeyError):
-            x['foo']
-        self.assertTrue(iter(x))
-
-    def test_listproxy(self):
-        v = []
-        x = Proxy(lambda: v)
-        x.append(1)
-        x.extend([2, 3, 4])
-        self.assertEqual(x[0], 1)
-        self.assertEqual(x[:-1], [1, 2, 3])
-        del(x[-1])
-        self.assertEqual(x[:-1], [1, 2])
-        x[0] = 10
-        self.assertEqual(x[0], 10)
-        self.assertIn(10, x)
-        self.assertEqual(len(x), 3)
-        self.assertTrue(iter(x))
-        x[0:2] = [1, 2]
-        del(x[0:2])
-        self.assertTrue(str(x))
-        if sys.version_info[0] < 3:
-            self.assertEqual(x.__cmp__(object()), -1)
-
-    def test_complex_cast(self):
-
-        class O(object):
-
-            def __complex__(self):
-                return complex(10.333)
-
-        o = Proxy(O)
-        self.assertEqual(o.__complex__(), complex(10.333))
-
-    def test_index(self):
-
-        class O(object):
-
-            def __index__(self):
-                return 1
-
-        o = Proxy(O)
-        self.assertEqual(o.__index__(), 1)
-
-    def test_coerce(self):
-
-        class O(object):
-
-            def __coerce__(self, other):
-                return self, other
-
-        o = Proxy(O)
-        self.assertTrue(o.__coerce__(3))
-
-    def test_int(self):
-        self.assertEqual(Proxy(lambda: 10) + 1, Proxy(lambda: 11))
-        self.assertEqual(Proxy(lambda: 10) - 1, Proxy(lambda: 9))
-        self.assertEqual(Proxy(lambda: 10) * 2, Proxy(lambda: 20))
-        self.assertEqual(Proxy(lambda: 10) ** 2, Proxy(lambda: 100))
-        self.assertEqual(Proxy(lambda: 20) / 2, Proxy(lambda: 10))
-        self.assertEqual(Proxy(lambda: 20) // 2, Proxy(lambda: 10))
-        self.assertEqual(Proxy(lambda: 11) % 2, Proxy(lambda: 1))
-        self.assertEqual(Proxy(lambda: 10) << 2, Proxy(lambda: 40))
-        self.assertEqual(Proxy(lambda: 10) >> 2, Proxy(lambda: 2))
-        self.assertEqual(Proxy(lambda: 10) ^ 7, Proxy(lambda: 13))
-        self.assertEqual(Proxy(lambda: 10) | 40, Proxy(lambda: 42))
-        self.assertEqual(~Proxy(lambda: 10), Proxy(lambda: -11))
-        self.assertEqual(-Proxy(lambda: 10), Proxy(lambda: -10))
-        self.assertEqual(+Proxy(lambda: -10), Proxy(lambda: -10))
-        self.assertTrue(Proxy(lambda: 10) < Proxy(lambda: 20))
-        self.assertTrue(Proxy(lambda: 20) > Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) >= Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) <= Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 10) == Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 20) != Proxy(lambda: 10))
-        self.assertTrue(Proxy(lambda: 100).__divmod__(30))
-        self.assertTrue(Proxy(lambda: 100).__truediv__(30))
-        self.assertTrue(abs(Proxy(lambda: -100)))
-
-        x = Proxy(lambda: 10)
-        x -= 1
-        self.assertEqual(x, 9)
-        x = Proxy(lambda: 9)
-        x += 1
-        self.assertEqual(x, 10)
-        x = Proxy(lambda: 10)
-        x *= 2
-        self.assertEqual(x, 20)
-        x = Proxy(lambda: 20)
-        x /= 2
-        self.assertEqual(x, 10)
-        x = Proxy(lambda: 10)
-        x %= 2
-        self.assertEqual(x, 0)
-        x = Proxy(lambda: 10)
-        x <<= 3
-        self.assertEqual(x, 80)
-        x = Proxy(lambda: 80)
-        x >>= 4
-        self.assertEqual(x, 5)
-        x = Proxy(lambda: 5)
-        x ^= 1
-        self.assertEqual(x, 4)
-        x = Proxy(lambda: 4)
-        x **= 4
-        self.assertEqual(x, 256)
-        x = Proxy(lambda: 256)
-        x //= 2
-        self.assertEqual(x, 128)
-        x = Proxy(lambda: 128)
-        x |= 2
-        self.assertEqual(x, 130)
-        x = Proxy(lambda: 130)
-        x &= 10
-        self.assertEqual(x, 2)
-
-        x = Proxy(lambda: 10)
-        self.assertEqual(type(x.__float__()), float)
-        self.assertEqual(type(x.__int__()), int)
-        if not PY3:
-            self.assertEqual(type(x.__long__()), long_t)
-        self.assertTrue(hex(x))
-        self.assertTrue(oct(x))
-
-    def test_hash(self):
-
-        class X(object):
-
-            def __hash__(self):
-                return 1234
-
-        self.assertEqual(hash(Proxy(lambda: X())), 1234)
-
-    def test_call(self):
-
-        class X(object):
-
-            def __call__(self):
-                return 1234
-
-        self.assertEqual(Proxy(lambda: X())(), 1234)
-
-    def test_context(self):
-
-        class X(object):
-            entered = exited = False
-
-            def __enter__(self):
-                self.entered = True
-                return 1234
-
-            def __exit__(self, *exc_info):
-                self.exited = True
-
-        v = X()
-        x = Proxy(lambda: v)
-        with x as val:
-            self.assertEqual(val, 1234)
-        self.assertTrue(x.entered)
-        self.assertTrue(x.exited)
-
-    def test_reduce(self):
-
-        class X(object):
-
-            def __reduce__(self):
-                return 123
-
-        x = Proxy(lambda: X())
-        self.assertEqual(x.__reduce__(), 123)
-
-
-class test_PromiseProxy(Case):
-
-    def test_only_evaluated_once(self):
-
-        class X(object):
-            attr = 123
-            evals = 0
-
-            def __init__(self):
-                self.__class__.evals += 1
-
-        p = PromiseProxy(X)
-        self.assertEqual(p.attr, 123)
-        self.assertEqual(p.attr, 123)
-        self.assertEqual(X.evals, 1)
-
-    def test_callbacks(self):
-        source = Mock(name='source')
-        p = PromiseProxy(source)
-        cbA = Mock(name='cbA')
-        cbB = Mock(name='cbB')
-        cbC = Mock(name='cbC')
-        p.__then__(cbA, p)
-        p.__then__(cbB, p)
-        self.assertFalse(p.__evaluated__())
-        self.assertTrue(object.__getattribute__(p, '__pending__'))
-
-        self.assertTrue(repr(p))
-        self.assertTrue(p.__evaluated__())
-        with self.assertRaises(AttributeError):
-            object.__getattribute__(p, '__pending__')
-        cbA.assert_called_with(p)
-        cbB.assert_called_with(p)
-
-        self.assertTrue(p.__evaluated__())
-        p.__then__(cbC, p)
-        cbC.assert_called_with(p)
-
-        with self.assertRaises(AttributeError):
-            object.__getattribute__(p, '__pending__')
-
-    def test_maybe_evaluate(self):
-        x = PromiseProxy(lambda: 30)
-        self.assertFalse(x.__evaluated__())
-        self.assertEqual(maybe_evaluate(x), 30)
-        self.assertEqual(maybe_evaluate(x), 30)
-
-        self.assertEqual(maybe_evaluate(30), 30)
-        self.assertTrue(x.__evaluated__())

+ 0 - 77
celery/tests/utils/test_serialization.py

@@ -1,77 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pytz
-import sys
-
-from datetime import datetime, date, time, timedelta
-
-from kombu import Queue
-
-from celery.utils.serialization import (
-    UnpickleableExceptionWrapper,
-    get_pickleable_etype,
-    jsonify,
-)
-
-from celery.tests.case import Case, Mock, mock
-
-
-class test_AAPickle(Case):
-
-    def test_no_cpickle(self):
-        prev = sys.modules.pop('celery.utils.serialization', None)
-        try:
-            with mock.mask_modules('cPickle'):
-                from celery.utils.serialization import pickle
-                import pickle as orig_pickle
-                self.assertIs(pickle.dumps, orig_pickle.dumps)
-        finally:
-            sys.modules['celery.utils.serialization'] = prev
-
-
-class test_UnpickleExceptionWrapper(Case):
-
-    def test_init(self):
-        x = UnpickleableExceptionWrapper('foo', 'Bar', [10, lambda x: x])
-        self.assertTrue(x.exc_args)
-        self.assertEqual(len(x.exc_args), 2)
-
-
-class test_get_pickleable_etype(Case):
-
-    def test_get_pickleable_etype(self):
-
-        class Unpickleable(Exception):
-            def __reduce__(self):
-                raise ValueError('foo')
-
-        self.assertIs(get_pickleable_etype(Unpickleable), Exception)
-
-
-class test_jsonify(Case):
-
-    def test_simple(self):
-        self.assertTrue(jsonify(Queue('foo')))
-        self.assertTrue(jsonify(['foo', 'bar', 'baz']))
-        self.assertTrue(jsonify({'foo': 'bar'}))
-        self.assertTrue(jsonify(datetime.utcnow()))
-        self.assertTrue(jsonify(datetime.utcnow().replace(tzinfo=pytz.utc)))
-        self.assertTrue(jsonify(datetime.utcnow().replace(microsecond=0)))
-        self.assertTrue(jsonify(date(2012, 1, 1)))
-        self.assertTrue(jsonify(time(hour=1, minute=30)))
-        self.assertTrue(jsonify(time(hour=1, minute=30, microsecond=3)))
-        self.assertTrue(jsonify(timedelta(seconds=30)))
-        self.assertTrue(jsonify(10))
-        self.assertTrue(jsonify(10.3))
-        self.assertTrue(jsonify('hello'))
-
-        unknown_type_filter = Mock()
-        obj = object()
-        self.assertIs(
-            jsonify(obj, unknown_type_filter=unknown_type_filter),
-            unknown_type_filter.return_value,
-        )
-        unknown_type_filter.assert_called_with(obj)
-
-        with self.assertRaises(ValueError):
-            jsonify(obj)

+ 0 - 27
celery/tests/utils/test_sysinfo.py

@@ -1,27 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils.sysinfo import load_average, df
-
-from celery.tests.case import Case, patch, skip
-
-
-@skip.unless_symbol('os.getloadavg')
-class test_load_average(Case):
-
-    def test_avg(self):
-        with patch('os.getloadavg') as getloadavg:
-            getloadavg.return_value = 0.54736328125, 0.6357421875, 0.69921875
-            l = load_average()
-            self.assertTrue(l)
-            self.assertEqual(l, (0.55, 0.64, 0.7))
-
-
-@skip.unless_symbol('posix.statvfs_result')
-class test_df(Case):
-
-    def test_df(self):
-        x = df('/')
-        self.assertTrue(x.total_blocks)
-        self.assertTrue(x.available)
-        self.assertTrue(x.capacity)
-        self.assertTrue(x.stat)

+ 0 - 87
celery/tests/utils/test_term.py

@@ -1,87 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from celery.utils import term
-from celery.utils.term import colored, fg
-from celery.five import text_t
-
-from celery.tests.case import Case, skip
-
-
-@skip.if_win32()
-class test_colored(Case):
-
-    def setUp(self):
-        self._prev_encoding = sys.getdefaultencoding
-
-        def getdefaultencoding():
-            return 'utf-8'
-
-        sys.getdefaultencoding = getdefaultencoding
-
-    def tearDown(self):
-        sys.getdefaultencoding = self._prev_encoding
-
-    def test_colors(self):
-        colors = (
-            ('black', term.BLACK),
-            ('red', term.RED),
-            ('green', term.GREEN),
-            ('yellow', term.YELLOW),
-            ('blue', term.BLUE),
-            ('magenta', term.MAGENTA),
-            ('cyan', term.CYAN),
-            ('white', term.WHITE),
-        )
-
-        for name, key in colors:
-            self.assertIn(fg(30 + key), str(colored().names[name]('foo')))
-
-        self.assertTrue(str(colored().bold('f')))
-        self.assertTrue(str(colored().underline('f')))
-        self.assertTrue(str(colored().blink('f')))
-        self.assertTrue(str(colored().reverse('f')))
-        self.assertTrue(str(colored().bright('f')))
-        self.assertTrue(str(colored().ired('f')))
-        self.assertTrue(str(colored().igreen('f')))
-        self.assertTrue(str(colored().iyellow('f')))
-        self.assertTrue(str(colored().iblue('f')))
-        self.assertTrue(str(colored().imagenta('f')))
-        self.assertTrue(str(colored().icyan('f')))
-        self.assertTrue(str(colored().iwhite('f')))
-        self.assertTrue(str(colored().reset('f')))
-
-        self.assertTrue(text_t(colored().green('∂bar')))
-
-        self.assertTrue(
-            colored().red('éefoo') + colored().green('∂bar'))
-
-        self.assertEqual(
-            colored().red('foo').no_color(), 'foo')
-
-        self.assertTrue(
-            repr(colored().blue('åfoo')))
-
-        self.assertIn("''", repr(colored()))
-
-        c = colored()
-        s = c.red('foo', c.blue('bar'), c.green('baz'))
-        self.assertTrue(s.no_color())
-
-        c._fold_no_color(s, 'øfoo')
-        c._fold_no_color('fooå', s)
-
-        c = colored().red('åfoo')
-        self.assertEqual(
-            c._add(c, 'baræ'),
-            '\x1b[1;31m\xe5foo\x1b[0mbar\xe6',
-        )
-
-        c2 = colored().blue('ƒƒz')
-        c3 = c._add(c, c2)
-        self.assertEqual(
-            c3,
-            '\x1b[1;31m\xe5foo\x1b[0m\x1b[1;34m\u0192\u0192z\x1b[0m',
-        )

+ 0 - 94
celery/tests/utils/test_text.py

@@ -1,94 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils.text import (
-    abbr,
-    abbrtask,
-    ensure_newlines,
-    indent,
-    pretty,
-    truncate,
-    truncate_bytes,
-)
-
-from celery.tests.case import AppCase, Case
-
-RANDTEXT = """\
-The quick brown
-fox jumps
-over the
-lazy dog\
-"""
-
-RANDTEXT_RES = """\
-    The quick brown
-    fox jumps
-    over the
-    lazy dog\
-"""
-
-QUEUES = {
-    'queue1': {
-        'exchange': 'exchange1',
-        'exchange_type': 'type1',
-        'routing_key': 'bind1',
-    },
-    'queue2': {
-        'exchange': 'exchange2',
-        'exchange_type': 'type2',
-        'routing_key': 'bind2',
-    },
-}
-
-
-QUEUE_FORMAT1 = '.> queue1           exchange=exchange1(type1) key=bind1'
-QUEUE_FORMAT2 = '.> queue2           exchange=exchange2(type2) key=bind2'
-
-
-class test_Info(AppCase):
-
-    def test_textindent(self):
-        self.assertEqual(indent(RANDTEXT, 4), RANDTEXT_RES)
-
-    def test_format_queues(self):
-        self.app.amqp.queues = self.app.amqp.Queues(QUEUES)
-        self.assertEqual(sorted(self.app.amqp.queues.format().split('\n')),
-                         sorted([QUEUE_FORMAT1, QUEUE_FORMAT2]))
-
-    def test_ensure_newlines(self):
-        self.assertEqual(
-            len(ensure_newlines('foo\nbar\nbaz\n').splitlines()), 3,
-        )
-        self.assertEqual(
-            len(ensure_newlines('foo\nbar').splitlines()), 2,
-        )
-
-
-class test_utils(Case):
-
-    def test_truncate_text(self):
-        self.assertEqual(truncate('ABCDEFGHI', 3), 'ABC...')
-        self.assertEqual(truncate('ABCDEFGHI', 10), 'ABCDEFGHI')
-
-    def test_truncate_bytes(self):
-        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 3), b'ABC...')
-        self.assertEqual(truncate_bytes(b'ABCDEFGHI', 10), b'ABCDEFGHI')
-
-    def test_abbr(self):
-        self.assertEqual(abbr(None, 3), '???')
-        self.assertEqual(abbr('ABCDEFGHI', 6), 'ABC...')
-        self.assertEqual(abbr('ABCDEFGHI', 20), 'ABCDEFGHI')
-        self.assertEqual(abbr('ABCDEFGHI', 6, None), 'ABCDEF')
-
-    def test_abbrtask(self):
-        self.assertEqual(abbrtask(None, 3), '???')
-        self.assertEqual(
-            abbrtask('feeds.tasks.refresh', 10),
-            '[.]refresh',
-        )
-        self.assertEqual(
-            abbrtask('feeds.tasks.refresh', 30),
-            'feeds.tasks.refresh',
-        )
-
-    def test_pretty(self):
-        self.assertTrue(pretty(('a', 'b', 'c')))

+ 0 - 253
celery/tests/utils/test_timeutils.py

@@ -1,253 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pytz
-
-from datetime import datetime, timedelta, tzinfo
-from pytz import AmbiguousTimeError
-
-from celery.utils.time import (
-    delta_resolution,
-    humanize_seconds,
-    maybe_iso8601,
-    maybe_timedelta,
-    timezone,
-    rate,
-    remaining,
-    make_aware,
-    maybe_make_aware,
-    localize,
-    LocalTimezone,
-    ffwd,
-    utcoffset,
-)
-from celery.utils.iso8601 import parse_iso8601
-from celery.tests.case import Case, Mock, patch
-
-
-class test_LocalTimezone(Case):
-
-    def test_daylight(self):
-        with patch('celery.utils.time._time') as time:
-            time.timezone = 3600
-            time.daylight = False
-            x = LocalTimezone()
-            self.assertEqual(x.STDOFFSET, timedelta(seconds=-3600))
-            self.assertEqual(x.DSTOFFSET, x.STDOFFSET)
-            time.daylight = True
-            time.altzone = 3600
-            y = LocalTimezone()
-            self.assertEqual(y.STDOFFSET, timedelta(seconds=-3600))
-            self.assertEqual(y.DSTOFFSET, timedelta(seconds=-3600))
-
-            self.assertTrue(repr(y))
-
-            y._isdst = Mock()
-            y._isdst.return_value = True
-            self.assertTrue(y.utcoffset(datetime.now()))
-            self.assertFalse(y.dst(datetime.now()))
-            y._isdst.return_value = False
-            self.assertTrue(y.utcoffset(datetime.now()))
-            self.assertFalse(y.dst(datetime.now()))
-
-            self.assertTrue(y.tzname(datetime.now()))
-
-
-class test_iso8601(Case):
-
-    def test_parse_with_timezone(self):
-        d = datetime.utcnow().replace(tzinfo=pytz.utc)
-        self.assertEqual(parse_iso8601(d.isoformat()), d)
-        # 2013-06-07T20:12:51.775877+00:00
-        iso = d.isoformat()
-        iso1 = iso.replace('+00:00', '-01:00')
-        d1 = parse_iso8601(iso1)
-        self.assertEqual(d1.tzinfo._minutes, -60)
-        iso2 = iso.replace('+00:00', '+01:00')
-        d2 = parse_iso8601(iso2)
-        self.assertEqual(d2.tzinfo._minutes, +60)
-        iso3 = iso.replace('+00:00', 'Z')
-        d3 = parse_iso8601(iso3)
-        self.assertEqual(d3.tzinfo, pytz.UTC)
-
-
-class test_time_utils(Case):
-
-    def test_delta_resolution(self):
-        D = delta_resolution
-        dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
-        deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
-                    (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
-                    (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
-                    (timedelta(seconds=2), dt))
-        for delta, shoulda in deltamap:
-            self.assertEqual(D(dt, delta), shoulda)
-
-    def test_humanize_seconds(self):
-        t = ((4 * 60 * 60 * 24, '4.00 days'),
-             (1 * 60 * 60 * 24, '1.00 day'),
-             (4 * 60 * 60, '4.00 hours'),
-             (1 * 60 * 60, '1.00 hour'),
-             (4 * 60, '4.00 minutes'),
-             (1 * 60, '1.00 minute'),
-             (4, '4.00 seconds'),
-             (1, '1.00 second'),
-             (4.3567631221, '4.36 seconds'),
-             (0, 'now'))
-
-        for seconds, human in t:
-            self.assertEqual(humanize_seconds(seconds), human)
-
-        self.assertEqual(humanize_seconds(4, prefix='about '),
-                         'about 4.00 seconds')
-
-    def test_maybe_iso8601_datetime(self):
-        now = datetime.now()
-        self.assertIs(maybe_iso8601(now), now)
-
-    def test_maybe_timedelta(self):
-        D = maybe_timedelta
-
-        for i in (30, 30.6):
-            self.assertEqual(D(i), timedelta(seconds=i))
-
-        self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
-
-    def test_remaining_relative(self):
-        remaining(datetime.utcnow(), timedelta(hours=1), relative=True)
-
-
-class test_timezone(Case):
-
-    def test_get_timezone_with_pytz(self):
-        self.assertTrue(timezone.get_timezone('UTC'))
-
-    def test_tz_or_local(self):
-        self.assertEqual(timezone.tz_or_local(), timezone.local)
-        self.assertTrue(timezone.tz_or_local(timezone.utc))
-
-    def test_to_local(self):
-        self.assertTrue(
-            timezone.to_local(make_aware(datetime.utcnow(), timezone.utc)),
-        )
-        self.assertTrue(
-            timezone.to_local(datetime.utcnow())
-        )
-
-    def test_to_local_fallback(self):
-        self.assertTrue(
-            timezone.to_local_fallback(
-                make_aware(datetime.utcnow(), timezone.utc)),
-        )
-        self.assertTrue(
-            timezone.to_local_fallback(datetime.utcnow())
-        )
-
-
-class test_make_aware(Case):
-
-    def test_tz_without_localize(self):
-        tz = tzinfo()
-        self.assertFalse(hasattr(tz, 'localize'))
-        wtz = make_aware(datetime.utcnow(), tz)
-        self.assertEqual(wtz.tzinfo, tz)
-
-    def test_when_has_localize(self):
-
-        class tzz(tzinfo):
-            raises = False
-
-            def localize(self, dt, is_dst=None):
-                self.localized = True
-                if self.raises and is_dst is None:
-                    self.raised = True
-                    raise AmbiguousTimeError()
-                return 1  # needed by min() in Python 3 (None not hashable)
-
-        tz = tzz()
-        make_aware(datetime.utcnow(), tz)
-        self.assertTrue(tz.localized)
-
-        tz2 = tzz()
-        tz2.raises = True
-        make_aware(datetime.utcnow(), tz2)
-        self.assertTrue(tz2.localized)
-        self.assertTrue(tz2.raised)
-
-    def test_maybe_make_aware(self):
-        aware = datetime.utcnow().replace(tzinfo=timezone.utc)
-        self.assertTrue(maybe_make_aware(aware), timezone.utc)
-        naive = datetime.utcnow()
-        self.assertTrue(maybe_make_aware(naive))
-
-
-class test_localize(Case):
-
-    def test_tz_without_normalize(self):
-        tz = tzinfo()
-        self.assertFalse(hasattr(tz, 'normalize'))
-        self.assertTrue(localize(make_aware(datetime.utcnow(), tz), tz))
-
-    def test_when_has_normalize(self):
-
-        class tzz(tzinfo):
-            raises = None
-
-            def normalize(self, dt, **kwargs):
-                self.normalized = True
-                if self.raises and kwargs and kwargs.get('is_dst') is None:
-                    self.raised = True
-                    raise self.raises
-                return 1  # needed by min() in Python 3 (None not hashable)
-
-        tz = tzz()
-        localize(make_aware(datetime.utcnow(), tz), tz)
-        self.assertTrue(tz.normalized)
-
-        tz2 = tzz()
-        tz2.raises = AmbiguousTimeError()
-        localize(make_aware(datetime.utcnow(), tz2), tz2)
-        self.assertTrue(tz2.normalized)
-        self.assertTrue(tz2.raised)
-
-        tz3 = tzz()
-        tz3.raises = TypeError()
-        localize(make_aware(datetime.utcnow(), tz3), tz3)
-        self.assertTrue(tz3.normalized)
-        self.assertTrue(tz3.raised)
-
-
-class test_rate_limit_string(Case):
-
-    def test_conversion(self):
-        self.assertEqual(rate(999), 999)
-        self.assertEqual(rate(7.5), 7.5)
-        self.assertEqual(rate('2.5/s'), 2.5)
-        self.assertEqual(rate('1456/s'), 1456)
-        self.assertEqual(rate('100/m'),
-                         100 / 60.0)
-        self.assertEqual(rate('10/h'),
-                         10 / 60.0 / 60.0)
-
-        for zero in (0, None, '0', '0/m', '0/h', '0/s', '0.0/s'):
-            self.assertEqual(rate(zero), 0)
-
-
-class test_ffwd(Case):
-
-    def test_repr(self):
-        x = ffwd(year=2012)
-        self.assertTrue(repr(x))
-
-    def test_radd_with_unknown_gives_NotImplemented(self):
-        x = ffwd(year=2012)
-        self.assertEqual(x.__radd__(object()), NotImplemented)
-
-
-class test_utcoffset(Case):
-
-    def test_utcoffset(self):
-        with patch('celery.utils.time._time') as _time:
-            _time.daylight = True
-            self.assertIsNotNone(utcoffset(time=_time))
-            _time.daylight = False
-            self.assertIsNotNone(utcoffset(time=_time))

+ 0 - 44
celery/tests/utils/test_utils.py

@@ -1,44 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from celery.utils import chunks, cached_property
-
-from celery.tests.case import Case
-
-
-class test_chunks(Case):
-
-    def test_chunks(self):
-
-        # n == 2
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 2)
-        self.assertListEqual(
-            list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10]],
-        )
-
-        # n == 3
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 3)
-        self.assertListEqual(
-            list(x),
-            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]],
-        )
-
-        # n == 2 (exact)
-        x = chunks(iter([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 2)
-        self.assertListEqual(
-            list(x),
-            [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
-        )
-
-
-class test_utils(Case):
-
-    def test_cached_property(self):
-
-        def fun(obj):
-            return fun.value
-
-        x = cached_property(fun)
-        self.assertIs(x.__get__(None), x)
-        self.assertIs(x.__set__(None, None), x)
-        self.assertIs(x.__delete__(None), x)

+ 7 - 11
docs/contributing.rst

@@ -471,13 +471,13 @@ dependencies, so install these next:
     $ pip install -U -r requirements/default.txt
     $ pip install -U -r requirements/default.txt
 
 
 After installing the dependencies required, you can now execute
 After installing the dependencies required, you can now execute
-the test suite by calling :pypi:`nosetests <nose>`:
+the test suite by calling :pypi:`py.test <pytest`:
 
 
 .. code-block:: console
 .. code-block:: console
 
 
-    $ nosetests
+    $ py.test
 
 
-Some useful options to :command:`nosetests` are:
+Some useful options to :command:`py.test` are:
 
 
 * ``-x``
 * ``-x``
 
 
@@ -487,10 +487,6 @@ Some useful options to :command:`nosetests` are:
 
 
     Don't capture output
     Don't capture output
 
 
-* ``-nologcapture``
-
-    Don't capture log output.
-
 * ``-v``
 * ``-v``
 
 
     Run with verbose output.
     Run with verbose output.
@@ -500,7 +496,7 @@ you can do so like this:
 
 
 .. code-block:: console
 .. code-block:: console
 
 
-    $ nosetests celery.tests.test_worker.test_worker_job
+    $ py.test t/unit/worker/test_worker_job.py
 
 
 .. _contributing-pull-requests:
 .. _contributing-pull-requests:
 
 
@@ -536,16 +532,16 @@ Code coverage in HTML:
 
 
 .. code-block:: console
 .. code-block:: console
 
 
-    $ nosetests --with-coverage --cover-html
+    $ py.test --cov=celery --cov-report=html
 
 
 The coverage output will then be located at
 The coverage output will then be located at
-:file:`celery/tests/cover/index.html`.
+:file:`cover/index.html`.
 
 
 Code coverage in XML (Cobertura-style):
 Code coverage in XML (Cobertura-style):
 
 
 .. code-block:: console
 .. code-block:: console
 
 
-    $ nosetests --with-coverage --cover-xml --cover-xml-file=coverage.xml
+    $ py.test --cov=celery --cov-report=xml
 
 
 The coverage XML output will then be located at :file:`coverage.xml`
 The coverage XML output will then be located at :file:`coverage.xml`
 
 

+ 2 - 2
docs/internals/guide.rst

@@ -292,9 +292,9 @@ Module Overview
 
 
     single-mode interface to creating tasks, and controlling workers.
     single-mode interface to creating tasks, and controlling workers.
 
 
-- celery.tests
+- t.unit (int distribution)
 
 
-    The unittest suite.
+    The unit test suite.
 
 
 - celery.utils
 - celery.utils
 
 

+ 1 - 1
docs/whatsnew-4.0.rst

@@ -1004,7 +1004,7 @@ Instead of using router classes you can now simply define a function:
             return {'queue': 'hipri'}
             return {'queue': 'hipri'}
 
 
 If you don't need the arguments you can use start arguments, just make
 If you don't need the arguments you can use start arguments, just make
-sure you always also accept star arguments so that we've the ability
+sure you always also accept star arguments so that we have the ability
 to add more features in the future:
 to add more features in the future:
 
 
 .. code-block:: python
 .. code-block:: python

+ 5 - 2
funtests/suite/test_leak.py

@@ -5,10 +5,13 @@ import os
 import sys
 import sys
 import shlex
 import shlex
 import subprocess
 import subprocess
+import unittest
+
+from case import Case
+from case.skip import SkipTest
 
 
 from celery import current_app
 from celery import current_app
 from celery.five import range
 from celery.five import range
-from celery.tests.case import SkipTest, unittest
 
 
 import suite  # noqa
 import suite  # noqa
 
 
@@ -25,7 +28,7 @@ class Sizes(list):
         return sum(self) / len(self)
         return sum(self) / len(self)
 
 
 
 
-class LeakFunCase(unittest.TestCase):
+class LeakFunCase(Case):
 
 
     def setUp(self):
     def setUp(self):
         self.app = current_app
         self.app = current_app

+ 1 - 0
requirements/test-ci-base.txt

@@ -1,4 +1,5 @@
 coverage>=3.0
 coverage>=3.0
+pytest-cov
 codecov
 codecov
 -r extras/redis.txt
 -r extras/redis.txt
 -r extras/sqlalchemy.txt
 -r extras/sqlalchemy.txt

+ 2 - 1
requirements/test.txt

@@ -1 +1,2 @@
-case>=1.2.2
+case>=1.3.0
+pytest

+ 3 - 2
setup.cfg

@@ -1,5 +1,6 @@
-[nosetests]
-where = celery/tests
+[tool:pytest]
+testpaths = t/
+python_classes = test_*
 
 
 [build_sphinx]
 [build_sphinx]
 source-dir = docs/
 source-dir = docs/

+ 39 - 35
setup.py

@@ -1,12 +1,14 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 
 
-from setuptools import setup, find_packages
-
+import codecs
 import os
 import os
 import re
 import re
 import sys
 import sys
-import codecs
+
+import setuptools
+import setuptools.command.test
+
 
 
 try:
 try:
     import platform
     import platform
@@ -15,6 +17,8 @@ except (AttributeError, ImportError):
     def _pyimp():
     def _pyimp():
         return 'Python'
         return 'Python'
 
 
+NAME = 'celery'
+
 E_UNSUPPORTED_PYTHON = """
 E_UNSUPPORTED_PYTHON = """
 ----------------------------------------
 ----------------------------------------
  Celery 4.0 requires %s %s or later
  Celery 4.0 requires %s %s or later
@@ -80,10 +84,6 @@ except:
 finally:
 finally:
     sys.path[:] = orig_path
     sys.path[:] = orig_path
 
 
-NAME = 'celery'
-entrypoints = {}
-extra = {}
-
 # -*- Classifiers -*-
 # -*- Classifiers -*-
 
 
 classes = """
 classes = """
@@ -154,6 +154,10 @@ def _reqs(*f):
 def reqs(*f):
 def reqs(*f):
     return [req for subreq in _reqs(*f) for req in subreq]
     return [req for subreq in _reqs(*f) for req in subreq]
 
 
+
+def extras(*p):
+    return reqs('extras', *p)
+
 install_requires = reqs('default.txt')
 install_requires = reqs('default.txt')
 if JYTHON:
 if JYTHON:
     install_requires.extend(reqs('jython.txt'))
     install_requires.extend(reqs('jython.txt'))
@@ -165,46 +169,46 @@ if os.path.exists('README.rst'):
 else:
 else:
     long_description = 'See http://pypi.python.org/pypi/celery'
     long_description = 'See http://pypi.python.org/pypi/celery'
 
 
-# -*- Entry Points -*- #
-
-console_scripts = entrypoints['console_scripts'] = [
-    'celery = celery.__main__:main',
-]
+# -*- %%% -*-
 
 
-# -*- Extras -*-
 
 
+class pytest(setuptools.command.test.test):
+    user_options = [('pytest-args=', 'a', 'Arguments to pass to py.test')]
 
 
-def extras(*p):
-    return reqs('extras', *p)
+    def initialize_options(self):
+        setuptools.command.test.test.initialize_options(self)
+        self.pytest_args = []
 
 
-# Celery specific
-features = set([
-    'auth', 'cassandra', 'elasticsearch', 'memcache', 'pymemcache',
-    'couchbase', 'eventlet', 'gevent', 'msgpack', 'yaml',
-    'redis', 'sqs', 'couchdb', 'riak', 'zookeeper', 'solar',
-    'sqlalchemy', 'librabbitmq', 'pyro', 'slmq', 'tblib', 'consul'
-])
-extras_require = dict((x, extras(x + '.txt')) for x in features)
-extra['extras_require'] = extras_require
+    def run_tests(self):
+        import pytest
+        sys.exit(pytest.main(self.pytest_args))
 
 
-# -*- %%% -*-
-
-setup(
+setuptools.setup(
     name=NAME,
     name=NAME,
+    packages=setuptools.find_packages(exclude=['t', 't.*']),
     version=meta['version'],
     version=meta['version'],
     description=meta['doc'],
     description=meta['doc'],
+    long_description=long_description,
     author=meta['author'],
     author=meta['author'],
     author_email=meta['contact'],
     author_email=meta['contact'],
-    url=meta['homepage'],
     platforms=['any'],
     platforms=['any'],
     license='BSD',
     license='BSD',
-    packages=find_packages(),
-    include_package_data=True,
-    zip_safe=False,
+    url=meta['homepage'],
     install_requires=install_requires,
     install_requires=install_requires,
     tests_require=reqs('test.txt'),
     tests_require=reqs('test.txt'),
-    test_suite='nose.collector',
+    extras_require=dict((x, extras(x + '.txt')) for x in set([
+        'auth', 'cassandra', 'elasticsearch', 'memcache', 'pymemcache',
+        'couchbase', 'eventlet', 'gevent', 'msgpack', 'yaml',
+        'redis', 'sqs', 'couchdb', 'riak', 'zookeeper', 'solar',
+        'sqlalchemy', 'librabbitmq', 'pyro', 'slmq', 'tblib', 'consul'
+    ])),
     classifiers=classifiers,
     classifiers=classifiers,
-    entry_points=entrypoints,
-    long_description=long_description,
-    **extra)
+    entry_points={
+        'console_scripts': [
+            'celery = celery.__main__:main',
+        ]
+    },
+    cmdclass={'test': pytest},
+    include_package_data=True,
+    zip_safe=False,
+)

+ 0 - 0
celery/tests/app/__init__.py → t/__init__.py


+ 418 - 0
t/conftest.py

@@ -0,0 +1,418 @@
+from __future__ import absolute_import, unicode_literals
+
+import logging
+import numbers
+import os
+import pytest
+import sys
+import threading
+import warnings
+import weakref
+
+from copy import deepcopy
+from datetime import datetime, timedelta
+from functools import partial
+from importlib import import_module
+
+from case import Mock
+from case.utils import decorator
+from kombu import Queue
+from kombu.utils.imports import symbol_by_name
+
+from celery import Celery
+from celery.app import current_app
+from celery.backends.cache import CacheBackend, DummyClient
+
+try:
+    WindowsError = WindowsError  # noqa
+except NameError:
+
+    class WindowsError(Exception):
+        pass
+
+PYPY3 = getattr(sys, 'pypy_version_info', None) and sys.version_info[0] > 3
+
+CASE_LOG_REDIRECT_EFFECT = 'Test {0} didn\'t disable LoggingProxy for {1}'
+CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
+CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
+
+CELERY_TEST_CONFIG = {
+    #: Don't want log output when running suite.
+    'worker_hijack_root_logger': False,
+    'worker_log_color': False,
+    'task_default_queue': 'testcelery',
+    'task_default_exchange': 'testcelery',
+    'task_default_routing_key': 'testcelery',
+    'task_queues': (
+        Queue('testcelery', routing_key='testcelery'),
+    ),
+    'accept_content': ('json', 'pickle'),
+    'enable_utc': True,
+    'timezone': 'UTC',
+
+    # Mongo results tests (only executed if installed and running)
+    'mongodb_backend_settings': {
+        'host': os.environ.get('MONGO_HOST') or 'localhost',
+        'port': os.environ.get('MONGO_PORT') or 27017,
+        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
+        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
+                                'taskmeta_collection'),
+        'user': os.environ.get('MONGO_USER'),
+        'password': os.environ.get('MONGO_PASSWORD'),
+    }
+}
+
+
+class Trap(object):
+
+    def __getattr__(self, name):
+        raise RuntimeError('Test depends on current_app')
+
+
+class UnitLogging(symbol_by_name(Celery.log_cls)):
+
+    def __init__(self, *args, **kwargs):
+        super(UnitLogging, self).__init__(*args, **kwargs)
+        self.already_setup = True
+
+
+def TestApp(name=None, set_as_current=False, log=UnitLogging,
+            broker='memory://', backend='cache+memory://', **kwargs):
+    app = Celery(name or 'celery.tests',
+                 set_as_current=set_as_current,
+                 log=log, broker=broker, backend=backend,
+                 **kwargs)
+    app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
+    return app
+
+
+def alive_threads():
+    return [thread for thread in threading.enumerate() if thread.is_alive()]
+
+
+@pytest.fixture(autouse=True)
+def task_join_will_not_block(request):
+    from celery import _state
+    from celery import result
+    prev_res_join_block = result.task_join_will_block
+    _state.orig_task_join_will_block = _state.task_join_will_block
+    prev_state_join_block = _state.task_join_will_block
+    result.task_join_will_block = \
+        _state.task_join_will_block = lambda: False
+    _state._set_task_join_will_block(False)
+
+    def fin():
+        result.task_join_will_block = prev_res_join_block
+        _state.task_join_will_block = prev_state_join_block
+        _state._set_task_join_will_block(False)
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(scope='session', autouse=True)
+def record_threads_at_startup(request):
+    try:
+        request.session._threads_at_startup
+    except AttributeError:
+        request.session._threads_at_startup = alive_threads()
+
+
+@pytest.fixture(autouse=True)
+def threads_not_lingering(request):
+    def fin():
+        assert request.session._threads_at_startup == alive_threads()
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def app(request):
+    from celery import _state
+    prev_current_app = current_app()
+    prev_default_app = _state.default_app
+    prev_finalizers = set(_state._on_app_finalizers)
+    prev_apps = weakref.WeakSet(_state._apps)
+    trap = Trap()
+    prev_tls = _state._tls
+    _state.set_default_app(trap)
+
+    class NonTLS(object):
+        current_app = trap
+    _state._tls = NonTLS()
+
+    app = TestApp(set_as_current=False)
+    is_not_contained = any([
+        not getattr(request.module, 'app_contained', True),
+        not getattr(request.cls, 'app_contained', True),
+        not getattr(request.function, 'app_contained', True)
+    ])
+    if is_not_contained:
+        app.set_current()
+
+    def fin():
+        _state.set_default_app(prev_default_app)
+        _state._tls = prev_tls
+        _state._tls.current_app = prev_current_app
+        if app is not prev_current_app:
+            app.close()
+        _state._on_app_finalizers = prev_finalizers
+        _state._apps = prev_apps
+    request.addfinalizer(fin)
+    return app
+
+
+@pytest.fixture()
+def depends_on_current_app(app):
+    app.set_current()
+
+
+@pytest.fixture(autouse=True)
+def test_cases_shortcuts(request, app, patching):
+    if request.instance:
+        @app.task
+        def add(x, y):
+            return x + y
+
+        # IMPORTANT: We set an .app attribute for every test case class.
+        request.instance.app = app
+        request.instance.Celery = TestApp
+        request.instance.assert_signal_called = assert_signal_called
+        request.instance.task_message_from_sig = task_message_from_sig
+        request.instance.TaskMessage = TaskMessage
+        request.instance.TaskMessage1 = TaskMessage1
+        request.instance.CELERY_TEST_CONFIG = dict(CELERY_TEST_CONFIG)
+        request.instance.add = add
+        request.instance.patching = patching
+
+        def fin():
+            request.instance.app = None
+        request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def zzzz_test_cases_calls_setup_teardown(request):
+    if request.instance:
+        # we set the .patching attribute for every test class.
+        setup = getattr(request.instance, 'setup', None)
+        # we also call .setup() and .teardown() after every test method.
+        teardown = getattr(request.instance, 'teardown', None)
+        setup and setup()
+        teardown and request.addfinalizer(teardown)
+
+
+@pytest.fixture(autouse=True)
+def sanity_no_shutdown_flags_set(request):
+    def fin():
+        # Make sure no test left the shutdown flags enabled.
+        from celery.worker import state as worker_state
+        # check for EX_OK
+        assert worker_state.should_stop is not False
+        assert worker_state.should_terminate is not False
+        # check for other true values
+        assert not worker_state.should_stop
+        assert not worker_state.should_terminate
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def reset_cache_backend_state(request, app):
+    def fin():
+        backend = app.__dict__.get('backend')
+        if backend is not None:
+            if isinstance(backend, CacheBackend):
+                if isinstance(backend.client, DummyClient):
+                    backend.client.cache.clear()
+                backend._cache.clear()
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def sanity_stdouts(request):
+    def fin():
+        from celery.utils.log import LoggingProxy
+        assert sys.stdout
+        assert sys.stderr
+        assert sys.__stdout__
+        assert sys.__stderr__
+        this = request.node.name
+        if isinstance(sys.stdout, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stdout__, (LoggingProxy, Mock)):
+            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stdout'))
+        if isinstance(sys.stderr, (LoggingProxy, Mock)) or \
+                isinstance(sys.__stderr__, (LoggingProxy, Mock)):
+            raise RuntimeError(CASE_LOG_REDIRECT_EFFECT.format(this, 'stderr'))
+    request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def sanity_logging_side_effects(request):
+    root = logging.getLogger()
+    rootlevel = root.level
+    roothandlers = root.handlers
+
+    def fin():
+        this = request.node.name
+        root_now = logging.getLogger()
+        if root_now.level != rootlevel:
+            raise RuntimeError(CASE_LOG_LEVEL_EFFECT.format(this))
+        if root_now.handlers != roothandlers:
+            raise RuntimeError(CASE_LOG_HANDLER_EFFECT.format(this))
+    request.addfinalizer(fin)
+
+
+def setup_session(scope='session'):
+    using_coverage = (
+        os.environ.get('COVER_ALL_MODULES') or '--with-coverage' in sys.argv
+    )
+    os.environ.update(
+        # warn if config module not found
+        C_WNOCONF='yes',
+        KOMBU_DISABLE_LIMIT_PROTECTION='yes',
+    )
+
+    if using_coverage and not PYPY3:
+        from warnings import catch_warnings
+        with catch_warnings(record=True):
+            import_all_modules()
+        warnings.resetwarnings()
+    from celery._state import set_default_app
+    set_default_app(Trap())
+
+
+def teardown():
+    # Don't want SUBDEBUG log messages at finalization.
+    try:
+        from multiprocessing.util import get_logger
+    except ImportError:
+        pass
+    else:
+        get_logger().setLevel(logging.WARNING)
+
+    # Make sure test database is removed.
+    import os
+    if os.path.exists('test.db'):
+        try:
+            os.remove('test.db')
+        except WindowsError:
+            pass
+
+    # Make sure there are no remaining threads at shutdown.
+    import threading
+    remaining_threads = [thread for thread in threading.enumerate()
+                         if thread.getName() != 'MainThread']
+    if remaining_threads:
+        sys.stderr.write(
+            '\n\n**WARNING**: Remaining threads at teardown: %r...\n' % (
+                remaining_threads))
+
+
+def find_distribution_modules(name=__name__, file=__file__):
+    current_dist_depth = len(name.split('.')) - 1
+    current_dist = os.path.join(os.path.dirname(file),
+                                *([os.pardir] * current_dist_depth))
+    abs = os.path.abspath(current_dist)
+    dist_name = os.path.basename(abs)
+
+    for dirpath, dirnames, filenames in os.walk(abs):
+        package = (dist_name + dirpath[len(abs):]).replace('/', '.')
+        if '__init__.py' in filenames:
+            yield package
+            for filename in filenames:
+                if filename.endswith('.py') and filename != '__init__.py':
+                    yield '.'.join([package, filename])[:-3]
+
+
+def import_all_modules(name=__name__, file=__file__,
+                       skip=('celery.decorators',
+                             'celery.task')):
+    for module in find_distribution_modules(name, file):
+        if not module.startswith(skip):
+            try:
+                import_module(module)
+            except ImportError:
+                pass
+            except OSError as exc:
+                warnings.warn(UserWarning(
+                    'Ignored error importing module {0}: {1!r}'.format(
+                        module, exc,
+                    )))
+
+
+@decorator
+def assert_signal_called(signal, **expected):
+    handler = Mock()
+    call_handler = partial(handler)
+    signal.connect(call_handler)
+    try:
+        yield handler
+    finally:
+        signal.disconnect(call_handler)
+    handler.assert_called_with(signal=signal, **expected)
+
+
+def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
+                errbacks=None, chain=None, shadow=None, utc=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {
+        'id': id,
+        'task': name,
+        'shadow': shadow,
+    }
+    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
+    message.headers.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        (args, kwargs, embed), serializer='json',
+    )
+    message.payload = (args, kwargs, embed)
+    return message
+
+
+def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
+                 errbacks=None, chain=None, **options):
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {}
+    message.payload = {
+        'task': name,
+        'id': id,
+        'args': args,
+        'kwargs': kwargs,
+        'callbacks': callbacks,
+        'errbacks': errbacks,
+    }
+    message.payload.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        message.payload,
+    )
+    return message
+
+
+def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
+    sig.freeze()
+    callbacks = sig.options.pop('link', None)
+    errbacks = sig.options.pop('link_error', None)
+    countdown = sig.options.pop('countdown', None)
+    if countdown:
+        eta = app.now() + timedelta(seconds=countdown)
+    else:
+        eta = sig.options.pop('eta', None)
+    if eta and isinstance(eta, datetime):
+        eta = eta.isoformat()
+    expires = sig.options.pop('expires', None)
+    if expires and isinstance(expires, numbers.Real):
+        expires = app.now() + timedelta(seconds=expires)
+    if expires and isinstance(expires, datetime):
+        expires = expires.isoformat()
+    return TaskMessage(
+        sig.task, id=sig.id, args=sig.args,
+        kwargs=sig.kwargs,
+        callbacks=[dict(s) for s in callbacks] if callbacks else None,
+        errbacks=[dict(s) for s in errbacks] if errbacks else None,
+        eta=eta,
+        expires=expires,
+        utc=utc,
+        **sig.options
+    )

+ 0 - 0
celery/tests/apps/__init__.py → t/unit/__init__.py


+ 0 - 0
celery/tests/backends/__init__.py → t/unit/app/__init__.py


+ 269 - 0
t/unit/app/test_amqp.py

@@ -0,0 +1,269 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from datetime import datetime, timedelta
+
+from case import Mock
+from kombu import Exchange, Queue
+
+from celery import uuid
+from celery.app.amqp import Queues, utf8dict
+from celery.five import keys
+from celery.utils.time import to_utc
+
+
+class test_TaskConsumer:
+
+    def test_accept_content(self, app):
+        with app.pool.acquire(block=True) as con:
+            app.conf.accept_content = ['application/json']
+            assert app.amqp.TaskConsumer(con).accept == {
+                'application/json',
+            }
+            assert app.amqp.TaskConsumer(con, accept=['json']).accept == {
+                'application/json',
+            }
+
+
+class test_ProducerPool:
+
+    def test_setup_nolimit(self, app):
+        app.conf.broker_pool_limit = None
+        try:
+            delattr(app, '_pool')
+        except AttributeError:
+            pass
+        app.amqp._producer_pool = None
+        pool = app.amqp.producer_pool
+        assert pool.limit == app.pool.limit
+        assert not pool._resource.queue
+
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+        r1.release()
+        r2.release()
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+
+    def test_setup(self, app):
+        app.conf.broker_pool_limit = 2
+        try:
+            delattr(app, '_pool')
+        except AttributeError:
+            pass
+        app.amqp._producer_pool = None
+        pool = app.amqp.producer_pool
+        assert pool.limit == app.pool.limit
+        assert pool._resource.queue
+
+        p1 = r1 = pool.acquire()
+        p2 = r2 = pool.acquire()
+        r1.release()
+        r2.release()
+        r1 = pool.acquire()
+        r2 = pool.acquire()
+        assert p2 is r1
+        assert p1 is r2
+        r1.release()
+        r2.release()
+
+
+class test_Queues:
+
+    def test_queues_format(self):
+        self.app.amqp.queues._consume_from = {}
+        assert self.app.amqp.queues.format() == ''
+
+    def test_with_defaults(self):
+        assert Queues(None) == {}
+
+    def test_add(self):
+        q = Queues()
+        q.add('foo', exchange='ex', routing_key='rk')
+        assert 'foo' in q
+        assert isinstance(q['foo'], Queue)
+        assert q['foo'].routing_key == 'rk'
+
+    @pytest.mark.parametrize('ha_policy,qname,q,qargs,expected', [
+        (None, 'xyz', 'xyz', None, None),
+        (None, 'xyz', 'xyz', {'x-foo': 'bar'}, {'x-foo': 'bar'}),
+        ('all', 'foo', Queue('foo'), None, {'x-ha-policy': 'all'}),
+        ('all', 'xyx2',
+         Queue('xyx2', queue_arguments={'x-foo': 'bari'}),
+         None,
+         {'x-ha-policy': 'all', 'x-foo': 'bari'}),
+        (['A', 'B', 'C'], 'foo', Queue('foo'), None, {
+            'x-ha-policy': 'nodes',
+            'x-ha-policy-params': ['A', 'B', 'C']}),
+    ])
+    def test_with_ha_policy(self, ha_policy, qname, q, qargs, expected):
+        queues = Queues(ha_policy=ha_policy, create_missing=False)
+        queues.add(q, queue_arguments=qargs)
+        assert queues[qname].queue_arguments == expected
+
+    def test_select_add(self):
+        q = Queues()
+        q.select(['foo', 'bar'])
+        q.select_add('baz')
+        assert sorted(keys(q._consume_from)) == ['bar', 'baz', 'foo']
+
+    def test_deselect(self):
+        q = Queues()
+        q.select(['foo', 'bar'])
+        q.deselect('bar')
+        assert sorted(keys(q._consume_from)) == ['foo']
+
+    def test_with_ha_policy_compat(self):
+        q = Queues(ha_policy='all')
+        q.add('bar')
+        assert q['bar'].queue_arguments == {'x-ha-policy': 'all'}
+
+    def test_add_default_exchange(self):
+        ex = Exchange('fff', 'fanout')
+        q = Queues(default_exchange=ex)
+        q.add(Queue('foo'))
+        assert q['foo'].exchange.name == ''
+
+    def test_alias(self):
+        q = Queues()
+        q.add(Queue('foo', alias='barfoo'))
+        assert q['barfoo'] is q['foo']
+
+    @pytest.mark.parametrize('queues_kwargs,qname,q,expected', [
+        (dict(max_priority=10),
+         'foo', 'foo', {'x-max-priority': 10}),
+        (dict(max_priority=10),
+         'xyz', Queue('xyz', queue_arguments={'x-max-priority': 3}),
+         {'x-max-priority': 3}),
+        (dict(max_priority=10),
+         'moo', Queue('moo', queue_arguments=None),
+         {'x-max-priority': 10}),
+        (dict(ha_policy='all', max_priority=5),
+         'bar', 'bar',
+         {'x-ha-policy': 'all', 'x-max-priority': 5}),
+        (dict(ha_policy='all', max_priority=5),
+         'xyx2', Queue('xyx2', queue_arguments={'x-max-priority': 2}),
+         {'x-ha-policy': 'all', 'x-max-priority': 2}),
+        (dict(max_priority=None),
+         'foo2', 'foo2',
+         None),
+        (dict(max_priority=None),
+         'xyx3', Queue('xyx3', queue_arguments={'x-max-priority': 7}),
+         {'x-max-priority': 7}),
+
+    ])
+    def test_with_max_priority(self, queues_kwargs, qname, q, expected):
+        queues = Queues(**queues_kwargs)
+        queues.add(q)
+        assert queues[qname].queue_arguments == expected
+
+
+class test_AMQP:
+
+    def setup(self):
+        self.simple_message = self.app.amqp.as_task_v2(
+            uuid(), 'foo', create_sent_event=True,
+        )
+
+    def test_Queues__with_ha_policy(self):
+        x = self.app.amqp.Queues({}, ha_policy='all')
+        assert x.ha_policy == 'all'
+
+    def test_Queues__with_max_priority(self):
+        x = self.app.amqp.Queues({}, max_priority=23)
+        assert x.max_priority == 23
+
+    def test_send_task_message__no_kwargs(self):
+        self.app.amqp.send_task_message(Mock(), 'foo', self.simple_message)
+
+    def test_send_task_message__properties(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, foo=1, retry=False,
+        )
+        assert prod.publish.call_args[1]['foo'] == 1
+
+    def test_send_task_message__headers(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, headers={'x1x': 'y2x'},
+            retry=False,
+        )
+        assert prod.publish.call_args[1]['headers']['x1x'] == 'y2x'
+
+    def test_send_task_message__queue_string(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, queue='foo', retry=False,
+        )
+        kwargs = prod.publish.call_args[1]
+        assert kwargs['routing_key'] == 'foo'
+        assert kwargs['exchange'] == ''
+
+    def test_send_event_exchange_string(self):
+        evd = Mock(name='evd')
+        self.app.amqp.send_task_message(
+            Mock(), 'foo', self.simple_message, retry=False,
+            exchange='xyz', routing_key='xyb',
+            event_dispatcher=evd,
+        )
+        evd.publish.assert_called()
+        event = evd.publish.call_args[0][1]
+        assert event['routing_key'] == 'xyb'
+        assert event['exchange'] == 'xyz'
+
+    def test_send_task_message__with_delivery_mode(self):
+        prod = Mock(name='producer')
+        self.app.amqp.send_task_message(
+            prod, 'foo', self.simple_message, delivery_mode=33, retry=False,
+        )
+        assert prod.publish.call_args[1]['delivery_mode'] == 33
+
+    def test_routes(self):
+        r1 = self.app.amqp.routes
+        r2 = self.app.amqp.routes
+        assert r1 is r2
+
+
+class test_as_task_v2:
+
+    def test_raises_if_args_is_not_tuple(self):
+        with pytest.raises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', args='123')
+
+    def test_raises_if_kwargs_is_not_mapping(self):
+        with pytest.raises(TypeError):
+            self.app.amqp.as_task_v2(uuid(), 'foo', kwargs=(1, 2, 3))
+
+    def test_countdown_to_eta(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', countdown=10, now=now,
+        )
+        assert m.headers['eta'] == (now + timedelta(seconds=10)).isoformat()
+
+    def test_expires_to_datetime(self):
+        now = to_utc(datetime.utcnow()).astimezone(self.app.timezone)
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo', expires=30, now=now,
+        )
+        assert m.headers['expires'] == (
+            now + timedelta(seconds=30)).isoformat()
+
+    def test_callbacks_errbacks_chord(self):
+
+        @self.app.task
+        def t(i):
+            pass
+
+        m = self.app.amqp.as_task_v2(
+            uuid(), 'foo',
+            callbacks=[t.s(1), t.s(2)],
+            errbacks=[t.s(3), t.s(4)],
+            chord=t.s(5),
+        )
+        _, _, embed = m.body
+        assert embed['callbacks'] == [utf8dict(t.s(1)), utf8dict(t.s(2))]
+        assert embed['errbacks'] == [utf8dict(t.s(3)), utf8dict(t.s(4))]
+        assert embed['chord'] == utf8dict(t.s(5))

+ 12 - 14
celery/tests/app/test_annotations.py → t/unit/app/test_annotations.py

@@ -3,14 +3,12 @@ from __future__ import absolute_import, unicode_literals
 from celery.app.annotations import MapAnnotation, prepare
 from celery.app.annotations import MapAnnotation, prepare
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 
 
-from celery.tests.case import AppCase
-
 
 
 class MyAnnotation(object):
 class MyAnnotation(object):
     foo = 65
     foo = 65
 
 
 
 
-class AnnotationCase(AppCase):
+class AnnotationCase:
 
 
     def setup(self):
     def setup(self):
         @self.app.task(shared=False)
         @self.app.task(shared=False)
@@ -28,29 +26,29 @@ class test_MapAnnotation(AnnotationCase):
 
 
     def test_annotate(self):
     def test_annotate(self):
         x = MapAnnotation({self.add.name: {'foo': 1}})
         x = MapAnnotation({self.add.name: {'foo': 1}})
-        self.assertDictEqual(x.annotate(self.add), {'foo': 1})
-        self.assertIsNone(x.annotate(self.mul))
+        assert x.annotate(self.add) == {'foo': 1}
+        assert x.annotate(self.mul) is None
 
 
     def test_annotate_any(self):
     def test_annotate_any(self):
         x = MapAnnotation({'*': {'foo': 2}})
         x = MapAnnotation({'*': {'foo': 2}})
-        self.assertDictEqual(x.annotate_any(), {'foo': 2})
+        assert x.annotate_any() == {'foo': 2}
 
 
         x = MapAnnotation()
         x = MapAnnotation()
-        self.assertIsNone(x.annotate_any())
+        assert x.annotate_any() is None
 
 
 
 
 class test_prepare(AnnotationCase):
 class test_prepare(AnnotationCase):
 
 
     def test_dict_to_MapAnnotation(self):
     def test_dict_to_MapAnnotation(self):
         x = prepare({self.add.name: {'foo': 3}})
         x = prepare({self.add.name: {'foo': 3}})
-        self.assertIsInstance(x[0], MapAnnotation)
+        assert isinstance(x[0], MapAnnotation)
 
 
     def test_returns_list(self):
     def test_returns_list(self):
-        self.assertListEqual(prepare(1), [1])
-        self.assertListEqual(prepare([1]), [1])
-        self.assertListEqual(prepare((1,)), [1])
-        self.assertEqual(prepare(None), ())
+        assert prepare(1) == [1]
+        assert prepare([1]) == [1]
+        assert prepare((1,)) == [1]
+        assert prepare(None) == ()
 
 
     def test_evalutes_qualnames(self):
     def test_evalutes_qualnames(self):
-        self.assertEqual(prepare(qualname(MyAnnotation))[0]().foo, 65)
-        self.assertEqual(prepare([qualname(MyAnnotation)])[0]().foo, 65)
+        assert prepare(qualname(MyAnnotation))[0]().foo == 65
+        assert prepare([qualname(MyAnnotation)])[0]().foo == 65

+ 222 - 246
celery/tests/app/test_app.py → t/unit/app/test_app.py

@@ -1,12 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import gc
 import gc
-import os
 import itertools
 import itertools
+import os
+import pytest
 
 
 from copy import deepcopy
 from copy import deepcopy
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
+from case import ContextMock, Mock, mock, patch
 from vine import promise
 from vine import promise
 
 
 from celery import Celery
 from celery import Celery
@@ -16,7 +18,7 @@ from celery import _state
 from celery.app import base as _appbase
 from celery.app import base as _appbase
 from celery.app import defaults
 from celery.app import defaults
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.five import keys
+from celery.five import items, keys
 from celery.loaders.base import unconfigured
 from celery.loaders.base import unconfigured
 from celery.platforms import pyimplementation
 from celery.platforms import pyimplementation
 from celery.utils.collections import DictAttribute
 from celery.utils.collections import DictAttribute
@@ -24,17 +26,6 @@ from celery.utils.serialization import pickle
 from celery.utils.time import timezone
 from celery.utils.time import timezone
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import (
-    CELERY_TEST_CONFIG,
-    AppCase,
-    Mock,
-    Case,
-    ContextMock,
-    depends_on_current_app,
-    mock,
-    patch,
-)
-
 THIS_IS_A_KEY = 'this is a value'
 THIS_IS_A_KEY = 'this is a value'
 
 
 
 
@@ -54,37 +45,32 @@ class ObjectConfig2(object):
     UNDERSTAND_ME = True
     UNDERSTAND_ME = True
 
 
 
 
-def _get_test_config():
-    return deepcopy(CELERY_TEST_CONFIG)
-test_config = _get_test_config()
-
-
-class test_module(AppCase):
+class test_module:
 
 
     def test_default_app(self):
     def test_default_app(self):
-        self.assertEqual(_app.default_app, _state.default_app)
+        assert _app.default_app == _state.default_app
 
 
-    def test_bugreport(self):
-        self.assertTrue(_app.bugreport(app=self.app))
+    def test_bugreport(self, app):
+        assert _app.bugreport(app=app)
 
 
 
 
-class test_task_join_will_block(Case):
+class test_task_join_will_block:
 
 
-    def test_task_join_will_block(self):
-        prev, _state._task_join_will_block = _state._task_join_will_block, 0
-        try:
-            self.assertEqual(_state._task_join_will_block, 0)
-            _state._set_task_join_will_block(True)
-            print(_state.task_join_will_block)
-            self.assertTrue(_state.task_join_will_block())
-        finally:
-            _state._task_join_will_block = prev
+    def test_task_join_will_block(self, patching):
+        patching('celery._state._task_join_will_block', 0)
+        assert _state._task_join_will_block == 0
+        _state._set_task_join_will_block(True)
+        assert _state._task_join_will_block is True
+        # fixture 'app' sets this, so need to use orig_ function
+        # set there by that fixture.
+        res = _state.orig_task_join_will_block()
+        assert res is True
 
 
 
 
-class test_App(AppCase):
+class test_App:
 
 
     def setup(self):
     def setup(self):
-        self.app.add_defaults(test_config)
+        self.app.add_defaults(deepcopy(self.CELERY_TEST_CONFIG))
 
 
     def test_task_autofinalize_disabled(self):
     def test_task_autofinalize_disabled(self):
         with self.Celery('xyzibari', autofinalize=False) as app:
         with self.Celery('xyzibari', autofinalize=False) as app:
@@ -92,7 +78,7 @@ class test_App(AppCase):
             def ttafd():
             def ttafd():
                 return 42
                 return 42
 
 
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 ttafd()
                 ttafd()
 
 
         with self.Celery('xyzibari', autofinalize=False) as app:
         with self.Celery('xyzibari', autofinalize=False) as app:
@@ -101,14 +87,14 @@ class test_App(AppCase):
                 return 42
                 return 42
 
 
             app.finalize()
             app.finalize()
-            self.assertEqual(ttafd2(), 42)
+            assert ttafd2() == 42
 
 
     def test_registry_autofinalize_disabled(self):
     def test_registry_autofinalize_disabled(self):
         with self.Celery('xyzibari', autofinalize=False) as app:
         with self.Celery('xyzibari', autofinalize=False) as app:
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 app.tasks['celery.chain']
                 app.tasks['celery.chain']
             app.finalize()
             app.finalize()
-            self.assertTrue(app.tasks['celery.chain'])
+            assert app.tasks['celery.chain']
 
 
     def test_task(self):
     def test_task(self):
         with self.Celery('foozibari') as app:
         with self.Celery('foozibari') as app:
@@ -118,20 +104,20 @@ class test_App(AppCase):
 
 
             fun.__module__ = '__main__'
             fun.__module__ = '__main__'
             task = app.task(fun)
             task = app.task(fun)
-            self.assertEqual(task.name, app.main + '.fun')
+            assert task.name == app.main + '.fun'
 
 
     def test_task_too_many_args(self):
     def test_task_too_many_args(self):
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             self.app.task(Mock(name='fun'), True)
             self.app.task(Mock(name='fun'), True)
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             self.app.task(Mock(name='fun'), True, 1, 2)
             self.app.task(Mock(name='fun'), True, 1, 2)
 
 
     def test_with_config_source(self):
     def test_with_config_source(self):
         with self.Celery(config_source=ObjectConfig) as app:
         with self.Celery(config_source=ObjectConfig) as app:
-            self.assertEqual(app.conf.FOO, 1)
-            self.assertEqual(app.conf.BAR, 2)
+            assert app.conf.FOO == 1
+            assert app.conf.BAR == 2
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_task_windows_execv(self):
     def test_task_windows_execv(self):
         prev, _appbase.USING_EXECV = _appbase.USING_EXECV, True
         prev, _appbase.USING_EXECV = _appbase.USING_EXECV, True
         try:
         try:
@@ -139,55 +125,55 @@ class test_App(AppCase):
             def foo():
             def foo():
                 pass
                 pass
 
 
-            self.assertTrue(foo._get_current_object())  # is proxy
+            assert foo._get_current_object()  # is proxy
 
 
         finally:
         finally:
             _appbase.USING_EXECV = prev
             _appbase.USING_EXECV = prev
         assert not _appbase.USING_EXECV
         assert not _appbase.USING_EXECV
 
 
     def test_task_takes_no_args(self):
     def test_task_takes_no_args(self):
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             @self.app.task(1)
             @self.app.task(1)
             def foo():
             def foo():
                 pass
                 pass
 
 
     def test_add_defaults(self):
     def test_add_defaults(self):
-        self.assertFalse(self.app.configured)
+        assert not self.app.configured
         _conf = {'FOO': 300}
         _conf = {'FOO': 300}
 
 
         def conf():
         def conf():
             return _conf
             return _conf
 
 
         self.app.add_defaults(conf)
         self.app.add_defaults(conf)
-        self.assertIn(conf, self.app._pending_defaults)
-        self.assertFalse(self.app.configured)
-        self.assertEqual(self.app.conf.FOO, 300)
-        self.assertTrue(self.app.configured)
-        self.assertFalse(self.app._pending_defaults)
+        assert conf in self.app._pending_defaults
+        assert not self.app.configured
+        assert self.app.conf.FOO == 300
+        assert self.app.configured
+        assert not self.app._pending_defaults
 
 
         # defaults not pickled
         # defaults not pickled
         appr = loads(dumps(self.app))
         appr = loads(dumps(self.app))
-        with self.assertRaises(AttributeError):
+        with pytest.raises(AttributeError):
             appr.conf.FOO
             appr.conf.FOO
 
 
         # add more defaults after configured
         # add more defaults after configured
         conf2 = {'FOO': 'BAR'}
         conf2 = {'FOO': 'BAR'}
         self.app.add_defaults(conf2)
         self.app.add_defaults(conf2)
-        self.assertEqual(self.app.conf.FOO, 'BAR')
+        assert self.app.conf.FOO == 'BAR'
 
 
-        self.assertIn(_conf, self.app.conf.defaults)
-        self.assertIn(conf2, self.app.conf.defaults)
+        assert _conf in self.app.conf.defaults
+        assert conf2 in self.app.conf.defaults
 
 
     def test_connection_or_acquire(self):
     def test_connection_or_acquire(self):
         with self.app.connection_or_acquire(block=True):
         with self.app.connection_or_acquire(block=True):
-            self.assertTrue(self.app.pool._dirty)
+            assert self.app.pool._dirty
 
 
         with self.app.connection_or_acquire(pool=False):
         with self.app.connection_or_acquire(pool=False):
-            self.assertFalse(self.app.pool._dirty)
+            assert not self.app.pool._dirty
 
 
     def test_using_v1_reduce(self):
     def test_using_v1_reduce(self):
         self.app._using_v1_reduce = True
         self.app._using_v1_reduce = True
-        self.assertTrue(loads(dumps(self.app)))
+        assert loads(dumps(self.app))
 
 
     def test_autodiscover_tasks_force(self):
     def test_autodiscover_tasks_force(self):
         self.app.loader.autodiscover_tasks = Mock()
         self.app.loader.autodiscover_tasks = Mock()
@@ -215,9 +201,9 @@ class test_App(AppCase):
             self.app.autodiscover_tasks(lazy_list)
             self.app.autodiscover_tasks(lazy_list)
             import_modules.connect.assert_called()
             import_modules.connect.assert_called()
             prom = import_modules.connect.call_args[0][0]
             prom = import_modules.connect.call_args[0][0]
-            self.assertIsInstance(prom, promise)
-            self.assertEqual(prom.fun, self.app._autodiscover_tasks)
-            self.assertEqual(prom.args[0](), [1, 2, 3])
+            assert isinstance(prom, promise)
+            assert prom.fun == self.app._autodiscover_tasks
+            assert prom.args[0](), [1, 2 == 3]
 
 
     def test_autodiscover_tasks__no_packages(self):
     def test_autodiscover_tasks__no_packages(self):
         fixup1 = Mock(name='fixup')
         fixup1 = Mock(name='fixup')
@@ -231,28 +217,28 @@ class test_App(AppCase):
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
             ['A', 'B', 'C', 'D', 'E', 'F'], related_name='tasks',
         )
         )
 
 
-    @mock.environ('CELERY_BROKER_URL', '')
-    def test_with_broker(self):
+    def test_with_broker(self, patching):
+        patching.setenv('CELERY_BROKER_URL', '')
         with self.Celery(broker='foo://baribaz') as app:
         with self.Celery(broker='foo://baribaz') as app:
-            self.assertEqual(app.conf.broker_url, 'foo://baribaz')
+            assert app.conf.broker_url == 'foo://baribaz'
 
 
     def test_pending_configuration__setattr(self):
     def test_pending_configuration__setattr(self):
         with self.Celery(broker='foo://bar') as app:
         with self.Celery(broker='foo://bar') as app:
             app.conf.task_default_delivery_mode = 44
             app.conf.task_default_delivery_mode = 44
             app.conf.worker_agent = 'foo:Bar'
             app.conf.worker_agent = 'foo:Bar'
-            self.assertFalse(app.configured)
-            self.assertEqual(app.conf.worker_agent, 'foo:Bar')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app._preconf['worker_agent'], 'foo:Bar')
+            assert not app.configured
+            assert app.conf.worker_agent == 'foo:Bar'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app._preconf['worker_agent'] == 'foo:Bar'
 
 
-            self.assertTrue(app.configured)
+            assert app.configured
             reapp = pickle.loads(pickle.dumps(app))
             reapp = pickle.loads(pickle.dumps(app))
-            self.assertEqual(reapp._preconf['worker_agent'], 'foo:Bar')
-            self.assertFalse(reapp.configured)
-            self.assertEqual(reapp.conf.worker_agent, 'foo:Bar')
-            self.assertTrue(reapp.configured)
-            self.assertEqual(reapp.conf.broker_url, 'foo://bar')
-            self.assertEqual(reapp._preconf['worker_agent'], 'foo:Bar')
+            assert reapp._preconf['worker_agent'] == 'foo:Bar'
+            assert not reapp.configured
+            assert reapp.conf.worker_agent == 'foo:Bar'
+            assert reapp.configured
+            assert reapp.conf.broker_url == 'foo://bar'
+            assert reapp._preconf['worker_agent'] == 'foo:Bar'
 
 
     def test_pending_configuration__update(self):
     def test_pending_configuration__update(self):
         with self.Celery(broker='foo://bar') as app:
         with self.Celery(broker='foo://bar') as app:
@@ -260,10 +246,10 @@ class test_App(AppCase):
                 task_default_delivery_mode=44,
                 task_default_delivery_mode=44,
                 worker_agent='foo:Bar',
                 worker_agent='foo:Bar',
             )
             )
-            self.assertFalse(app.configured)
-            self.assertEqual(app.conf.worker_agent, 'foo:Bar')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app._preconf['worker_agent'], 'foo:Bar')
+            assert not app.configured
+            assert app.conf.worker_agent == 'foo:Bar'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app._preconf['worker_agent'] == 'foo:Bar'
 
 
     def test_pending_configuration__compat_settings(self):
     def test_pending_configuration__compat_settings(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -272,11 +258,11 @@ class test_App(AppCase):
                 CELERY_DEFAULT_DELIVERY_MODE=63,
                 CELERY_DEFAULT_DELIVERY_MODE=63,
                 CELERYD_AGENT='foo:Barz',
                 CELERYD_AGENT='foo:Barz',
             )
             )
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.task_default_delivery_mode, 63)
-            self.assertEqual(app.conf.worker_agent, 'foo:Barz')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app.conf.result_backend, 'foo')
+            assert app.conf.task_always_eager == 4
+            assert app.conf.task_default_delivery_mode == 63
+            assert app.conf.worker_agent == 'foo:Barz'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app.conf.result_backend == 'foo'
 
 
     def test_pending_configuration__compat_settings_mixing(self):
     def test_pending_configuration__compat_settings_mixing(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -286,8 +272,8 @@ class test_App(AppCase):
                 CELERYD_AGENT='foo:Barz',
                 CELERYD_AGENT='foo:Barz',
                 worker_consumer='foo:Fooz',
                 worker_consumer='foo:Fooz',
             )
             )
-            with self.assertRaises(ImproperlyConfigured):
-                self.assertEqual(app.conf.task_always_eager, 4)
+            with pytest.raises(ImproperlyConfigured):
+                assert app.conf.task_always_eager == 4
 
 
     def test_pending_configuration__django_settings(self):
     def test_pending_configuration__django_settings(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -297,13 +283,13 @@ class test_App(AppCase):
                 CELERY_WORKER_AGENT='foo:Barz',
                 CELERY_WORKER_AGENT='foo:Barz',
                 CELERY_RESULT_SERIALIZER='pickle',
                 CELERY_RESULT_SERIALIZER='pickle',
             )), namespace='CELERY')
             )), namespace='CELERY')
-            self.assertEqual(app.conf.result_serializer, 'pickle')
-            self.assertEqual(app.conf.CELERY_RESULT_SERIALIZER, 'pickle')
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.task_default_delivery_mode, 63)
-            self.assertEqual(app.conf.worker_agent, 'foo:Barz')
-            self.assertEqual(app.conf.broker_url, 'foo://bar')
-            self.assertEqual(app.conf.result_backend, 'foo')
+            assert app.conf.result_serializer == 'pickle'
+            assert app.conf.CELERY_RESULT_SERIALIZER == 'pickle'
+            assert app.conf.task_always_eager == 4
+            assert app.conf.task_default_delivery_mode == 63
+            assert app.conf.worker_agent == 'foo:Barz'
+            assert app.conf.broker_url == 'foo://bar'
+            assert app.conf.result_backend == 'foo'
 
 
     def test_pending_configuration__compat_settings_mixing_new(self):
     def test_pending_configuration__compat_settings_mixing_new(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -314,8 +300,8 @@ class test_App(AppCase):
                 CELERYD_CONSUMER='foo:Fooz',
                 CELERYD_CONSUMER='foo:Fooz',
                 CELERYD_POOL='foo:Xuzzy',
                 CELERYD_POOL='foo:Xuzzy',
             )
             )
-            with self.assertRaises(ImproperlyConfigured):
-                self.assertEqual(app.conf.worker_consumer, 'foo:Fooz')
+            with pytest.raises(ImproperlyConfigured):
+                assert app.conf.worker_consumer == 'foo:Fooz'
 
 
     def test_pending_configuration__compat_settings_mixing_alt(self):
     def test_pending_configuration__compat_settings_mixing_alt(self):
         with self.Celery(broker='foo://bar', backend='foo') as app:
         with self.Celery(broker='foo://bar', backend='foo') as app:
@@ -328,52 +314,52 @@ class test_App(AppCase):
                 CELERYD_POOL='foo:Xuzzy',
                 CELERYD_POOL='foo:Xuzzy',
                 worker_pool='foo:Xuzzy'
                 worker_pool='foo:Xuzzy'
             )
             )
-            self.assertEqual(app.conf.task_always_eager, 4)
-            self.assertEqual(app.conf.worker_pool, 'foo:Xuzzy')
+            assert app.conf.task_always_eager == 4
+            assert app.conf.worker_pool == 'foo:Xuzzy'
 
 
     def test_pending_configuration__setdefault(self):
     def test_pending_configuration__setdefault(self):
         with self.Celery(broker='foo://bar') as app:
         with self.Celery(broker='foo://bar') as app:
             app.conf.setdefault('worker_agent', 'foo:Bar')
             app.conf.setdefault('worker_agent', 'foo:Bar')
-            self.assertFalse(app.configured)
+            assert not app.configured
 
 
     def test_pending_configuration__iter(self):
     def test_pending_configuration__iter(self):
         with self.Celery(broker='foo://bar') as app:
         with self.Celery(broker='foo://bar') as app:
             app.conf.worker_agent = 'foo:Bar'
             app.conf.worker_agent = 'foo:Bar'
-            self.assertFalse(app.configured)
-            self.assertTrue(list(keys(app.conf)))
-            self.assertFalse(app.configured)
-            self.assertIn('worker_agent', app.conf)
-            self.assertFalse(app.configured)
-            self.assertTrue(dict(app.conf))
-            self.assertTrue(app.configured)
+            assert not app.configured
+            assert list(keys(app.conf))
+            assert not app.configured
+            assert 'worker_agent' in app.conf
+            assert not app.configured
+            assert dict(app.conf)
+            assert app.configured
 
 
     def test_pending_configuration__raises_ImproperlyConfigured(self):
     def test_pending_configuration__raises_ImproperlyConfigured(self):
         with self.Celery(set_as_current=False) as app:
         with self.Celery(set_as_current=False) as app:
             app.conf.worker_agent = 'foo://bar'
             app.conf.worker_agent = 'foo://bar'
             app.conf.task_default_delivery_mode = 44
             app.conf.task_default_delivery_mode = 44
             app.conf.CELERY_ALWAYS_EAGER = 5
             app.conf.CELERY_ALWAYS_EAGER = 5
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 app.finalize()
                 app.finalize()
 
 
         with self.Celery() as app:
         with self.Celery() as app:
-            self.assertFalse(self.app.conf.task_always_eager)
+            assert not self.app.conf.task_always_eager
 
 
     def test_repr(self):
     def test_repr(self):
-        self.assertTrue(repr(self.app))
+        assert repr(self.app)
 
 
     def test_custom_task_registry(self):
     def test_custom_task_registry(self):
         with self.Celery(tasks=self.app.tasks) as app2:
         with self.Celery(tasks=self.app.tasks) as app2:
-            self.assertIs(app2.tasks, self.app.tasks)
+            assert app2.tasks is self.app.tasks
 
 
     def test_include_argument(self):
     def test_include_argument(self):
         with self.Celery(include=('foo', 'bar.foo')) as app:
         with self.Celery(include=('foo', 'bar.foo')) as app:
-            self.assertEqual(app.conf.include, ('foo', 'bar.foo'))
+            assert app.conf.include, ('foo' == 'bar.foo')
 
 
     def test_set_as_current(self):
     def test_set_as_current(self):
         current = _state._tls.current_app
         current = _state._tls.current_app
         try:
         try:
             app = self.Celery(set_as_current=True)
             app = self.Celery(set_as_current=True)
-            self.assertIs(_state._tls.current_app, app)
+            assert _state._tls.current_app is app
         finally:
         finally:
             _state._tls.current_app = current
             _state._tls.current_app = current
 
 
@@ -384,7 +370,7 @@ class test_App(AppCase):
 
 
         _state._task_stack.push(foo)
         _state._task_stack.push(foo)
         try:
         try:
-            self.assertEqual(self.app.current_task.name, foo.name)
+            assert self.app.current_task.name == foo.name
         finally:
         finally:
             _state._task_stack.pop()
             _state._task_stack.pop()
 
 
@@ -433,7 +419,7 @@ class test_App(AppCase):
                 def foo():
                 def foo():
                     pass
                     pass
 
 
-                self.assertEqual(foo.name, 'xuzzy.foo')
+                assert foo.name == 'xuzzy.foo'
         finally:
         finally:
             _imports.MP_MAIN_FILE = None
             _imports.MP_MAIN_FILE = None
 
 
@@ -458,7 +444,7 @@ class test_App(AppCase):
             adX.name: {'@__call__': deco}
             adX.name: {'@__call__': deco}
         }
         }
         adX.bind(self.app)
         adX.bind(self.app)
-        self.assertIs(adX.app, self.app)
+        assert adX.app is self.app
 
 
         i = adX()
         i = adX()
         i(2, 4, x=3)
         i(2, 4, x=3)
@@ -472,9 +458,9 @@ class test_App(AppCase):
         def aawsX(x, y):
         def aawsX(x, y):
             pass
             pass
 
 
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             aawsX.apply_async(())
             aawsX.apply_async(())
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             aawsX.apply_async((2,))
             aawsX.apply_async((2,))
 
 
         with patch('celery.app.amqp.AMQP.create_task_message') as create:
         with patch('celery.app.amqp.AMQP.create_task_message') as create:
@@ -482,7 +468,7 @@ class test_App(AppCase):
                 create.return_value = Mock(), Mock(), Mock(), Mock()
                 create.return_value = Mock(), Mock(), Mock(), Mock()
                 aawsX.apply_async((4, 5))
                 aawsX.apply_async((4, 5))
                 args = create.call_args[0][2]
                 args = create.call_args[0][2]
-                self.assertEqual(args, ('hello', 4, 5))
+                assert args, ('hello', 4 == 5)
                 send.assert_called()
                 send.assert_called()
 
 
     def test_apply_async_adds_children(self):
     def test_apply_async_adds_children(self):
@@ -501,7 +487,7 @@ class test_App(AppCase):
             a3cX1.push_request(called_directly=False)
             a3cX1.push_request(called_directly=False)
             try:
             try:
                 res = a3cX2.apply_async(add_to_parent=True)
                 res = a3cX2.apply_async(add_to_parent=True)
-                self.assertIn(res, a3cX1.request.children)
+                assert res in a3cX1.request.children
             finally:
             finally:
                 a3cX1.pop_request()
                 a3cX1.pop_request()
         finally:
         finally:
@@ -512,9 +498,10 @@ class test_App(AppCase):
                        THE_MII_MAR='jars')
                        THE_MII_MAR='jars')
         self.app.conf.update(changes)
         self.app.conf.update(changes)
         saved = pickle.dumps(self.app)
         saved = pickle.dumps(self.app)
-        self.assertLess(len(saved), 2048)
+        assert len(saved) < 2048
         restored = pickle.loads(saved)
         restored = pickle.loads(saved)
-        self.assertDictContainsSubset(changes, restored.conf)
+        for key, value in items(changes):
+            assert restored.conf[key] == value
 
 
     def test_worker_main(self):
     def test_worker_main(self):
         from celery.bin import worker as worker_bin
         from celery.bin import worker as worker_bin
@@ -527,33 +514,33 @@ class test_App(AppCase):
         prev, worker_bin.worker = worker_bin.worker, worker
         prev, worker_bin.worker = worker_bin.worker, worker
         try:
         try:
             ret = self.app.worker_main(argv=['--version'])
             ret = self.app.worker_main(argv=['--version'])
-            self.assertListEqual(ret, ['--version'])
+            assert ret == ['--version']
         finally:
         finally:
             worker_bin.worker = prev
             worker_bin.worker = prev
 
 
     def test_config_from_envvar(self):
     def test_config_from_envvar(self):
-        os.environ['CELERYTEST_CONFIG_OBJECT'] = 'celery.tests.app.test_app'
+        os.environ['CELERYTEST_CONFIG_OBJECT'] = 't.unit.app.test_app'
         self.app.config_from_envvar('CELERYTEST_CONFIG_OBJECT')
         self.app.config_from_envvar('CELERYTEST_CONFIG_OBJECT')
-        self.assertEqual(self.app.conf.THIS_IS_A_KEY, 'this is a value')
+        assert self.app.conf.THIS_IS_A_KEY == 'this is a value'
 
 
     def assert_config2(self):
     def assert_config2(self):
-        self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
-        self.assertTrue(self.app.conf.MOMENT_TO_STOP)
-        self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
-        self.assertFalse(self.app.conf.WANT_ME_TO)
-        self.assertTrue(self.app.conf.UNDERSTAND_ME)
+        assert self.app.conf.LEAVE_FOR_WORK
+        assert self.app.conf.MOMENT_TO_STOP
+        assert self.app.conf.CALL_ME_BACK == 123456789
+        assert not self.app.conf.WANT_ME_TO
+        assert self.app.conf.UNDERSTAND_ME
 
 
     def test_config_from_object__lazy(self):
     def test_config_from_object__lazy(self):
         conf = ObjectConfig2()
         conf = ObjectConfig2()
         self.app.config_from_object(conf)
         self.app.config_from_object(conf)
-        self.assertIs(self.app.loader._conf, unconfigured)
-        self.assertIs(self.app._config_source, conf)
+        assert self.app.loader._conf is unconfigured
+        assert self.app._config_source is conf
 
 
         self.assert_config2()
         self.assert_config2()
 
 
     def test_config_from_object__force(self):
     def test_config_from_object__force(self):
         self.app.config_from_object(ObjectConfig2(), force=True)
         self.app.config_from_object(ObjectConfig2(), force=True)
-        self.assertTrue(self.app.loader._conf)
+        assert self.app.loader._conf
 
 
         self.assert_config2()
         self.assert_config2()
 
 
@@ -565,10 +552,10 @@ class test_App(AppCase):
             CELERY_TASK_PUBLISH_RETRY = False
             CELERY_TASK_PUBLISH_RETRY = False
 
 
         self.app.config_from_object(Config)
         self.app.config_from_object(Config)
-        self.assertEqual(self.app.conf.task_always_eager, 44)
-        self.assertEqual(self.app.conf.CELERY_ALWAYS_EAGER, 44)
-        self.assertFalse(self.app.conf.task_publish_retry)
-        self.assertEqual(self.app.conf.task_default_routing_key, 'testcelery')
+        assert self.app.conf.task_always_eager == 44
+        assert self.app.conf.CELERY_ALWAYS_EAGER == 44
+        assert not self.app.conf.task_publish_retry
+        assert self.app.conf.task_default_routing_key == 'testcelery'
 
 
     def test_config_from_object__supports_old_names(self):
     def test_config_from_object__supports_old_names(self):
 
 
@@ -577,11 +564,11 @@ class test_App(AppCase):
             task_default_delivery_mode = 301
             task_default_delivery_mode = 301
 
 
         self.app.config_from_object(Config())
         self.app.config_from_object(Config())
-        self.assertEqual(self.app.conf.CELERY_ALWAYS_EAGER, 45)
-        self.assertEqual(self.app.conf.task_always_eager, 45)
-        self.assertEqual(self.app.conf.CELERY_DEFAULT_DELIVERY_MODE, 301)
-        self.assertEqual(self.app.conf.task_default_delivery_mode, 301)
-        self.assertEqual(self.app.conf.task_default_routing_key, 'testcelery')
+        assert self.app.conf.CELERY_ALWAYS_EAGER == 45
+        assert self.app.conf.task_always_eager == 45
+        assert self.app.conf.CELERY_DEFAULT_DELIVERY_MODE == 301
+        assert self.app.conf.task_default_delivery_mode == 301
+        assert self.app.conf.task_default_routing_key == 'testcelery'
 
 
     def test_config_from_object__namespace_uppercase(self):
     def test_config_from_object__namespace_uppercase(self):
 
 
@@ -590,7 +577,7 @@ class test_App(AppCase):
             CELERY_TASK_DEFAULT_DELIVERY_MODE = 301
             CELERY_TASK_DEFAULT_DELIVERY_MODE = 301
 
 
         self.app.config_from_object(Config(), namespace='CELERY')
         self.app.config_from_object(Config(), namespace='CELERY')
-        self.assertEqual(self.app.conf.task_always_eager, 44)
+        assert self.app.conf.task_always_eager == 44
 
 
     def test_config_from_object__namespace_lowercase(self):
     def test_config_from_object__namespace_lowercase(self):
 
 
@@ -599,7 +586,7 @@ class test_App(AppCase):
             celery_task_default_delivery_mode = 301
             celery_task_default_delivery_mode = 301
 
 
         self.app.config_from_object(Config(), namespace='celery')
         self.app.config_from_object(Config(), namespace='celery')
-        self.assertEqual(self.app.conf.task_always_eager, 44)
+        assert self.app.conf.task_always_eager == 44
 
 
     def test_config_from_object__mixing_new_and_old(self):
     def test_config_from_object__mixing_new_and_old(self):
 
 
@@ -610,11 +597,10 @@ class test_App(AppCase):
             beat_schedule = '/foo/schedule'
             beat_schedule = '/foo/schedule'
             CELERY_DEFAULT_DELIVERY_MODE = 301
             CELERY_DEFAULT_DELIVERY_MODE = 301
 
 
-        with self.assertRaises(ImproperlyConfigured) as exc:
+        with pytest.raises(ImproperlyConfigured) as exc:
             self.app.config_from_object(Config(), force=True)
             self.app.config_from_object(Config(), force=True)
-            self.assertTrue(
-                exc.args[0].startswith('CELERY_DEFAULT_DELIVERY_MODE'))
-            self.assertIn('task_default_delivery_mode', exc.args[0])
+            assert exc.args[0].startswith('CELERY_DEFAULT_DELIVERY_MODE')
+            assert 'task_default_delivery_mode' in exc.args[0]
 
 
     def test_config_from_object__mixing_old_and_new(self):
     def test_config_from_object__mixing_old_and_new(self):
 
 
@@ -625,11 +611,10 @@ class test_App(AppCase):
             CELERYBEAT_SCHEDULE = '/foo/schedule'
             CELERYBEAT_SCHEDULE = '/foo/schedule'
             task_default_delivery_mode = 301
             task_default_delivery_mode = 301
 
 
-        with self.assertRaises(ImproperlyConfigured) as exc:
+        with pytest.raises(ImproperlyConfigured) as exc:
             self.app.config_from_object(Config(), force=True)
             self.app.config_from_object(Config(), force=True)
-            self.assertTrue(
-                exc.args[0].startswith('task_default_delivery_mode'))
-            self.assertIn('CELERY_DEFAULT_DELIVERY_MODE', exc.args[0])
+            assert exc.args[0].startswith('task_default_delivery_mode')
+            assert 'CELERY_DEFAULT_DELIVERY_MODE' in exc.args[0]
 
 
     def test_config_from_cmdline(self):
     def test_config_from_cmdline(self):
         cmdline = ['task_always_eager=no',
         cmdline = ['task_always_eager=no',
@@ -639,139 +624,130 @@ class test_App(AppCase):
                    '.foobarint=(int)300',
                    '.foobarint=(int)300',
                    'sqlalchemy_engine_options=(dict){"foo": "bar"}']
                    'sqlalchemy_engine_options=(dict){"foo": "bar"}']
         self.app.config_from_cmdline(cmdline, namespace='worker')
         self.app.config_from_cmdline(cmdline, namespace='worker')
-        self.assertFalse(self.app.conf.task_always_eager)
-        self.assertEqual(self.app.conf.result_backend, '/dev/null')
-        self.assertEqual(self.app.conf.worker_prefetch_multiplier, 368)
-        self.assertEqual(self.app.conf.worker_foobarstring, '300')
-        self.assertEqual(self.app.conf.worker_foobarint, 300)
-        self.assertDictEqual(self.app.conf.sqlalchemy_engine_options,
-                             {'foo': 'bar'})
+        assert not self.app.conf.task_always_eager
+        assert self.app.conf.result_backend == '/dev/null'
+        assert self.app.conf.worker_prefetch_multiplier == 368
+        assert self.app.conf.worker_foobarstring == '300'
+        assert self.app.conf.worker_foobarint == 300
+        assert self.app.conf.sqlalchemy_engine_options == {'foo': 'bar'}
 
 
     def test_setting__broker_transport_options(self):
     def test_setting__broker_transport_options(self):
 
 
         _args = {'foo': 'bar', 'spam': 'baz'}
         _args = {'foo': 'bar', 'spam': 'baz'}
 
 
         self.app.config_from_object(Bunch())
         self.app.config_from_object(Bunch())
-        self.assertEqual(self.app.conf.broker_transport_options, {})
+        assert self.app.conf.broker_transport_options == {}
 
 
         self.app.config_from_object(Bunch(broker_transport_options=_args))
         self.app.config_from_object(Bunch(broker_transport_options=_args))
-        self.assertEqual(self.app.conf.broker_transport_options, _args)
+        assert self.app.conf.broker_transport_options == _args
 
 
     def test_Windows_log_color_disabled(self):
     def test_Windows_log_color_disabled(self):
         self.app.IS_WINDOWS = True
         self.app.IS_WINDOWS = True
-        self.assertFalse(self.app.log.supports_color(True))
+        assert not self.app.log.supports_color(True)
 
 
     def test_WorkController(self):
     def test_WorkController(self):
         x = self.app.WorkController
         x = self.app.WorkController
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
 
 
     def test_Worker(self):
     def test_Worker(self):
         x = self.app.Worker
         x = self.app.Worker
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_AsyncResult(self):
     def test_AsyncResult(self):
         x = self.app.AsyncResult('1')
         x = self.app.AsyncResult('1')
-        self.assertIs(x.app, self.app)
+        assert x.app is self.app
         r = loads(dumps(x))
         r = loads(dumps(x))
         # not set as current, so ends up as default app after reduce
         # not set as current, so ends up as default app after reduce
-        self.assertIs(r.app, current_app._get_current_object())
+        assert r.app is current_app._get_current_object()
 
 
     def test_get_active_apps(self):
     def test_get_active_apps(self):
-        self.assertTrue(list(_state._get_active_apps()))
+        assert list(_state._get_active_apps())
 
 
         app1 = self.Celery()
         app1 = self.Celery()
         appid = id(app1)
         appid = id(app1)
-        self.assertIn(app1, _state._get_active_apps())
+        assert app1 in _state._get_active_apps()
         app1.close()
         app1.close()
         del(app1)
         del(app1)
 
 
         gc.collect()
         gc.collect()
 
 
         # weakref removed from list when app goes out of scope.
         # weakref removed from list when app goes out of scope.
-        with self.assertRaises(StopIteration):
+        with pytest.raises(StopIteration):
             next(app for app in _state._get_active_apps() if id(app) == appid)
             next(app for app in _state._get_active_apps() if id(app) == appid)
 
 
     def test_config_from_envvar_more(self, key='CELERY_HARNESS_CFG1'):
     def test_config_from_envvar_more(self, key='CELERY_HARNESS_CFG1'):
-        self.assertFalse(
-            self.app.config_from_envvar(
-                'HDSAJIHWIQHEWQU', force=True, silent=True),
-        )
-        with self.assertRaises(ImproperlyConfigured):
+        assert not self.app.config_from_envvar(
+            'HDSAJIHWIQHEWQU', force=True, silent=True)
+        with pytest.raises(ImproperlyConfigured):
             self.app.config_from_envvar(
             self.app.config_from_envvar(
                 'HDSAJIHWIQHEWQU', force=True, silent=False,
                 'HDSAJIHWIQHEWQU', force=True, silent=False,
             )
             )
         os.environ[key] = __name__ + '.object_config'
         os.environ[key] = __name__ + '.object_config'
-        self.assertTrue(self.app.config_from_envvar(key, force=True))
-        self.assertEqual(self.app.conf['FOO'], 1)
-        self.assertEqual(self.app.conf['BAR'], 2)
+        assert self.app.config_from_envvar(key, force=True)
+        assert self.app.conf['FOO'] == 1
+        assert self.app.conf['BAR'] == 2
 
 
         os.environ[key] = 'unknown_asdwqe.asdwqewqe'
         os.environ[key] = 'unknown_asdwqe.asdwqewqe'
-        with self.assertRaises(ImportError):
+        with pytest.raises(ImportError):
             self.app.config_from_envvar(key, silent=False)
             self.app.config_from_envvar(key, silent=False)
-        self.assertFalse(
-            self.app.config_from_envvar(key, force=True, silent=True),
-        )
+        assert not self.app.config_from_envvar(key, force=True, silent=True)
 
 
         os.environ[key] = __name__ + '.dict_config'
         os.environ[key] = __name__ + '.dict_config'
-        self.assertTrue(self.app.config_from_envvar(key, force=True))
-        self.assertEqual(self.app.conf['FOO'], 10)
-        self.assertEqual(self.app.conf['BAR'], 20)
+        assert self.app.config_from_envvar(key, force=True)
+        assert self.app.conf['FOO'] == 10
+        assert self.app.conf['BAR'] == 20
 
 
     @patch('celery.bin.celery.CeleryCommand.execute_from_commandline')
     @patch('celery.bin.celery.CeleryCommand.execute_from_commandline')
     def test_start(self, execute):
     def test_start(self, execute):
         self.app.start()
         self.app.start()
         execute.assert_called()
         execute.assert_called()
 
 
-    def test_amqp_get_broker_info(self):
-        self.assertDictContainsSubset(
-            {'hostname': 'localhost',
-             'userid': 'guest',
-             'password': 'guest',
-             'virtual_host': '/'},
-            self.app.connection('pyamqp://').info(),
-        )
-        self.app.conf.broker_port = 1978
-        self.app.conf.broker_vhost = 'foo'
-        self.assertDictContainsSubset(
-            {'port': 1978, 'virtual_host': 'foo'},
-            self.app.connection('pyamqp://:1978/foo').info(),
-        )
-        conn = self.app.connection('pyamqp:////value')
-        self.assertDictContainsSubset({'virtual_host': '/value'},
-                                      conn.info())
+    @pytest.mark.parametrize('url,expected_fields', [
+        ('pyamqp://', {
+            'hostname': 'localhost',
+            'userid': 'guest',
+            'password': 'guest',
+            'virtual_host': '/',
+        }),
+        ('pyamqp://:1978/foo', {
+            'port': 1978,
+            'virtual_host': 'foo',
+        }),
+        ('pyamqp:////value', {
+            'virtual_host': '/value',
+        })
+    ])
+    def test_amqp_get_broker_info(self, url, expected_fields):
+        info = self.app.connection(url).info()
+        for key, expected_value in items(expected_fields):
+            assert info[key] == expected_value
 
 
     def test_amqp_failover_strategy_selection(self):
     def test_amqp_failover_strategy_selection(self):
         # Test passing in a string and make sure the string
         # Test passing in a string and make sure the string
         # gets there untouched
         # gets there untouched
         self.app.conf.broker_failover_strategy = 'foo-bar'
         self.app.conf.broker_failover_strategy = 'foo-bar'
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            'foo-bar',
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == 'foo-bar'
 
 
         # Try passing in None
         # Try passing in None
         self.app.conf.broker_failover_strategy = None
         self.app.conf.broker_failover_strategy = None
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            itertools.cycle,
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == itertools.cycle
 
 
         # Test passing in a method
         # Test passing in a method
         def my_failover_strategy(it):
         def my_failover_strategy(it):
             yield True
             yield True
 
 
         self.app.conf.broker_failover_strategy = my_failover_strategy
         self.app.conf.broker_failover_strategy = my_failover_strategy
-        self.assertEqual(
-            self.app.connection('amqp:////value').failover_strategy,
-            my_failover_strategy,
-        )
+        assert self.app.connection('amqp:////value') \
+                       .failover_strategy == my_failover_strategy
 
 
     def test_after_fork(self):
     def test_after_fork(self):
         self.app._pool = Mock()
         self.app._pool = Mock()
         self.app.on_after_fork = Mock(name='on_after_fork')
         self.app.on_after_fork = Mock(name='on_after_fork')
         self.app._after_fork()
         self.app._after_fork()
-        self.assertIsNone(self.app._pool)
+        assert self.app._pool is None
         self.app.on_after_fork.send.assert_called_with(sender=self.app)
         self.app.on_after_fork.send.assert_called_with(sender=self.app)
         self.app._after_fork()
         self.app._after_fork()
 
 
@@ -794,21 +770,21 @@ class test_App(AppCase):
         try:
         try:
             self.app._after_fork_registered = False
             self.app._after_fork_registered = False
             self.app._ensure_after_fork()
             self.app._ensure_after_fork()
-            self.assertTrue(self.app._after_fork_registered)
+            assert self.app._after_fork_registered
         finally:
         finally:
             _appbase.register_after_fork = prev
             _appbase.register_after_fork = prev
 
 
     def test_canvas(self):
     def test_canvas(self):
-        self.assertTrue(self.app.canvas.Signature)
+        assert self.app.canvas.Signature
 
 
     def test_signature(self):
     def test_signature(self):
         sig = self.app.signature('foo', (1, 2))
         sig = self.app.signature('foo', (1, 2))
-        self.assertIs(sig.app, self.app)
+        assert sig.app is self.app
 
 
     def test_timezone__none_set(self):
     def test_timezone__none_set(self):
         self.app.conf.timezone = None
         self.app.conf.timezone = None
         tz = self.app.timezone
         tz = self.app.timezone
-        self.assertEqual(tz, timezone.get_timezone('UTC'))
+        assert tz == timezone.get_timezone('UTC')
 
 
     def test_compat_on_configure(self):
     def test_compat_on_configure(self):
         _on_configure = Mock(name='on_configure')
         _on_configure = Mock(name='on_configure')
@@ -837,22 +813,22 @@ class test_App(AppCase):
             10, self.app.signature('add', (2, 2)),
             10, self.app.signature('add', (2, 2)),
             name='add1', expires=3,
             name='add1', expires=3,
         )
         )
-        self.assertTrue(self.app._pending_periodic_tasks)
+        assert self.app._pending_periodic_tasks
         assert not self.app.configured
         assert not self.app.configured
 
 
         sig2 = add.s(4, 4)
         sig2 = add.s(4, 4)
-        self.assertTrue(self.app.configured)
+        assert self.app.configured
         self.app.add_periodic_task(20, sig2, name='add2', expires=4)
         self.app.add_periodic_task(20, sig2, name='add2', expires=4)
-        self.assertIn('add1', self.app.conf.beat_schedule)
-        self.assertIn('add2', self.app.conf.beat_schedule)
+        assert 'add1' in self.app.conf.beat_schedule
+        assert 'add2' in self.app.conf.beat_schedule
 
 
     def test_pool_no_multiprocessing(self):
     def test_pool_no_multiprocessing(self):
         with mock.mask_modules('multiprocessing.util'):
         with mock.mask_modules('multiprocessing.util'):
             pool = self.app.pool
             pool = self.app.pool
-            self.assertIs(pool, self.app._pool)
+            assert pool is self.app._pool
 
 
     def test_bugreport(self):
     def test_bugreport(self):
-        self.assertTrue(self.app.bugreport())
+        assert self.app.bugreport()
 
 
     def test_send_task__connection_provided(self):
     def test_send_task__connection_provided(self):
         connection = Mock(name='connection')
         connection = Mock(name='connection')
@@ -896,8 +872,8 @@ class test_App(AppCase):
             exchange='moo_exchange', routing_key='moo_exchange',
             exchange='moo_exchange', routing_key='moo_exchange',
             event_dispatcher=dispatcher,
             event_dispatcher=dispatcher,
         )
         )
-        self.assertTrue(dispatcher.sent)
-        self.assertEqual(dispatcher.sent[0][0], 'task-sent')
+        assert dispatcher.sent
+        assert dispatcher.sent[0][0] == 'task-sent'
         self.app.amqp.send_task_message(
         self.app.amqp.send_task_message(
             prod, 'footask', message, event_dispatcher=dispatcher,
             prod, 'footask', message, event_dispatcher=dispatcher,
             exchange='bar_exchange', routing_key='bar_exchange',
             exchange='bar_exchange', routing_key='bar_exchange',
@@ -909,56 +885,56 @@ class test_App(AppCase):
         self.app.amqp.queues.select.assert_called_with({'foo', 'bar'})
         self.app.amqp.queues.select.assert_called_with({'foo', 'bar'})
 
 
 
 
-class test_defaults(AppCase):
+class test_defaults:
 
 
     def test_strtobool(self):
     def test_strtobool(self):
         for s in ('false', 'no', '0'):
         for s in ('false', 'no', '0'):
-            self.assertFalse(defaults.strtobool(s))
+            assert not defaults.strtobool(s)
         for s in ('true', 'yes', '1'):
         for s in ('true', 'yes', '1'):
-            self.assertTrue(defaults.strtobool(s))
-        with self.assertRaises(TypeError):
+            assert defaults.strtobool(s)
+        with pytest.raises(TypeError):
             defaults.strtobool('unsure')
             defaults.strtobool('unsure')
 
 
 
 
-class test_debugging_utils(AppCase):
+class test_debugging_utils:
 
 
     def test_enable_disable_trace(self):
     def test_enable_disable_trace(self):
         try:
         try:
             _app.enable_trace()
             _app.enable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default_trace)
+            assert _app.app_or_default == _app._app_or_default_trace
             _app.disable_trace()
             _app.disable_trace()
-            self.assertEqual(_app.app_or_default, _app._app_or_default)
+            assert _app.app_or_default == _app._app_or_default
         finally:
         finally:
             _app.disable_trace()
             _app.disable_trace()
 
 
 
 
-class test_pyimplementation(AppCase):
+class test_pyimplementation:
 
 
     def test_platform_python_implementation(self):
     def test_platform_python_implementation(self):
         with mock.platform_pyimp(lambda: 'Xython'):
         with mock.platform_pyimp(lambda: 'Xython'):
-            self.assertEqual(pyimplementation(), 'Xython')
+            assert pyimplementation() == 'Xython'
 
 
     def test_platform_jython(self):
     def test_platform_jython(self):
         with mock.platform_pyimp():
         with mock.platform_pyimp():
             with mock.sys_platform('java 1.6.51'):
             with mock.sys_platform('java 1.6.51'):
-                self.assertIn('Jython', pyimplementation())
+                assert 'Jython' in pyimplementation()
 
 
     def test_platform_pypy(self):
     def test_platform_pypy(self):
         with mock.platform_pyimp():
         with mock.platform_pyimp():
             with mock.sys_platform('darwin'):
             with mock.sys_platform('darwin'):
                 with mock.pypy_version((1, 4, 3)):
                 with mock.pypy_version((1, 4, 3)):
-                    self.assertIn('PyPy', pyimplementation())
+                    assert 'PyPy' in pyimplementation()
                 with mock.pypy_version((1, 4, 3, 'a4')):
                 with mock.pypy_version((1, 4, 3, 'a4')):
-                    self.assertIn('PyPy', pyimplementation())
+                    assert 'PyPy' in pyimplementation()
 
 
     def test_platform_fallback(self):
     def test_platform_fallback(self):
         with mock.platform_pyimp():
         with mock.platform_pyimp():
             with mock.sys_platform('darwin'):
             with mock.sys_platform('darwin'):
                 with mock.pypy_version():
                 with mock.pypy_version():
-                    self.assertEqual('CPython', pyimplementation())
+                    assert 'CPython' == pyimplementation()
 
 
 
 
-class test_shared_task(AppCase):
+class test_shared_task:
 
 
     def test_registers_to_all_apps(self):
     def test_registers_to_all_apps(self):
         with self.Celery('xproj', set_as_current=True) as xproj:
         with self.Celery('xproj', set_as_current=True) as xproj:
@@ -972,16 +948,16 @@ class test_shared_task(AppCase):
             def bar():
             def bar():
                 return 84
                 return 84
 
 
-            self.assertIs(foo.app, xproj)
-            self.assertIs(bar.app, xproj)
-            self.assertTrue(foo._get_current_object())
+            assert foo.app is xproj
+            assert bar.app is xproj
+            assert foo._get_current_object()
 
 
             with self.Celery('yproj', set_as_current=True) as yproj:
             with self.Celery('yproj', set_as_current=True) as yproj:
-                self.assertIs(foo.app, yproj)
-                self.assertIs(bar.app, yproj)
+                assert foo.app is yproj
+                assert bar.app is yproj
 
 
                 @shared_task()
                 @shared_task()
                 def baz():
                 def baz():
                     return 168
                     return 168
 
 
-                self.assertIs(baz.app, yproj)
+                assert baz.app is yproj

+ 79 - 80
celery/tests/app/test_beat.py → t/unit/app/test_beat.py

@@ -1,18 +1,19 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import errno
 import errno
+import pytest
 
 
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from pickle import dumps, loads
 from pickle import dumps, loads
 
 
+from case import Mock, call, patch, skip
+
 from celery import beat
 from celery import beat
 from celery import uuid
 from celery import uuid
 from celery.five import keys, string_t
 from celery.five import keys, string_t
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import AppCase, Mock, call, patch, skip
-
 
 
 class MockShelve(dict):
 class MockShelve(dict):
     closed = False
     closed = False
@@ -39,7 +40,7 @@ class MockService(object):
         self.stopped = True
         self.stopped = True
 
 
 
 
-class test_ScheduleEntry(AppCase):
+class test_ScheduleEntry:
     Entry = beat.ScheduleEntry
     Entry = beat.ScheduleEntry
 
 
     def create_entry(self, **kwargs):
     def create_entry(self, **kwargs):
@@ -54,38 +55,38 @@ class test_ScheduleEntry(AppCase):
 
 
     def test_next(self):
     def test_next(self):
         entry = self.create_entry(schedule=10)
         entry = self.create_entry(schedule=10)
-        self.assertTrue(entry.last_run_at)
-        self.assertIsInstance(entry.last_run_at, datetime)
-        self.assertEqual(entry.total_run_count, 0)
+        assert entry.last_run_at
+        assert isinstance(entry.last_run_at, datetime)
+        assert entry.total_run_count == 0
 
 
         next_run_at = entry.last_run_at + timedelta(seconds=10)
         next_run_at = entry.last_run_at + timedelta(seconds=10)
         next_entry = entry.next(next_run_at)
         next_entry = entry.next(next_run_at)
-        self.assertGreaterEqual(next_entry.last_run_at, next_run_at)
-        self.assertEqual(next_entry.total_run_count, 1)
+        assert next_entry.last_run_at >= next_run_at
+        assert next_entry.total_run_count == 1
 
 
     def test_is_due(self):
     def test_is_due(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
         entry = self.create_entry(schedule=timedelta(seconds=10))
-        self.assertIs(entry.app, self.app)
-        self.assertIs(entry.schedule.app, self.app)
+        assert entry.app is self.app
+        assert entry.schedule.app is self.app
         due1, next_time_to_run1 = entry.is_due()
         due1, next_time_to_run1 = entry.is_due()
-        self.assertFalse(due1)
-        self.assertGreater(next_time_to_run1, 9)
+        assert not due1
+        assert next_time_to_run1 > 9
 
 
         next_run_at = entry.last_run_at - timedelta(seconds=10)
         next_run_at = entry.last_run_at - timedelta(seconds=10)
         next_entry = entry.next(next_run_at)
         next_entry = entry.next(next_run_at)
         due2, next_time_to_run2 = next_entry.is_due()
         due2, next_time_to_run2 = next_entry.is_due()
-        self.assertTrue(due2)
-        self.assertGreater(next_time_to_run2, 9)
+        assert due2
+        assert next_time_to_run2 > 9
 
 
     def test_repr(self):
     def test_repr(self):
         entry = self.create_entry()
         entry = self.create_entry()
-        self.assertIn('<ScheduleEntry:', repr(entry))
+        assert '<ScheduleEntry:' in repr(entry)
 
 
     def test_reduce(self):
     def test_reduce(self):
         entry = self.create_entry(schedule=timedelta(seconds=10))
         entry = self.create_entry(schedule=timedelta(seconds=10))
         fun, args = entry.__reduce__()
         fun, args = entry.__reduce__()
         res = fun(*args)
         res = fun(*args)
-        self.assertEqual(res.schedule, entry.schedule)
+        assert res.schedule == entry.schedule
 
 
     def test_lt(self):
     def test_lt(self):
         e1 = self.create_entry(schedule=timedelta(seconds=10))
         e1 = self.create_entry(schedule=timedelta(seconds=10))
@@ -99,20 +100,20 @@ class test_ScheduleEntry(AppCase):
 
 
     def test_update(self):
     def test_update(self):
         entry = self.create_entry()
         entry = self.create_entry()
-        self.assertEqual(entry.schedule, timedelta(seconds=10))
-        self.assertTupleEqual(entry.args, (2, 2))
-        self.assertDictEqual(entry.kwargs, {})
-        self.assertDictEqual(entry.options, {'routing_key': 'cpu'})
+        assert entry.schedule == timedelta(seconds=10)
+        assert entry.args == (2, 2)
+        assert entry.kwargs == {}
+        assert entry.options == {'routing_key': 'cpu'}
 
 
         entry2 = self.create_entry(schedule=timedelta(minutes=20),
         entry2 = self.create_entry(schedule=timedelta(minutes=20),
                                    args=(16, 16),
                                    args=(16, 16),
                                    kwargs={'callback': 'foo.bar.baz'},
                                    kwargs={'callback': 'foo.bar.baz'},
                                    options={'routing_key': 'urgent'})
                                    options={'routing_key': 'urgent'})
         entry.update(entry2)
         entry.update(entry2)
-        self.assertEqual(entry.schedule, schedule(timedelta(minutes=20)))
-        self.assertTupleEqual(entry.args, (16, 16))
-        self.assertDictEqual(entry.kwargs, {'callback': 'foo.bar.baz'})
-        self.assertDictEqual(entry.options, {'routing_key': 'urgent'})
+        assert entry.schedule == schedule(timedelta(minutes=20))
+        assert entry.args == (16, 16)
+        assert entry.kwargs == {'callback': 'foo.bar.baz'}
+        assert entry.options == {'routing_key': 'urgent'}
 
 
 
 
 class mScheduler(beat.Scheduler):
 class mScheduler(beat.Scheduler):
@@ -157,12 +158,12 @@ always_due = mocked_schedule(True, 1)
 always_pending = mocked_schedule(False, 1)
 always_pending = mocked_schedule(False, 1)
 
 
 
 
-class test_Scheduler(AppCase):
+class test_Scheduler:
 
 
     def test_custom_schedule_dict(self):
     def test_custom_schedule_dict(self):
         custom = {'foo': 'bar'}
         custom = {'foo': 'bar'}
         scheduler = mScheduler(app=self.app, schedule=custom, lazy=True)
         scheduler = mScheduler(app=self.app, schedule=custom, lazy=True)
-        self.assertIs(scheduler.data, custom)
+        assert scheduler.data is custom
 
 
     def test_apply_async_uses_registered_task_instances(self):
     def test_apply_async_uses_registered_task_instances(self):
 
 
@@ -204,11 +205,11 @@ class test_Scheduler(AppCase):
         not_sync.apply_async = Mock()
         not_sync.apply_async = Mock()
 
 
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
-        self.assertEqual(s.sync_every_tasks, 2)
+        assert s.sync_every_tasks == 2
         s._do_sync = Mock()
         s._do_sync = Mock()
 
 
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
-        self.assertEqual(s._tasks_since_sync, 1)
+        assert s._tasks_since_sync == 1
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s._do_sync.assert_called_with()
         s._do_sync.assert_called_with()
 
 
@@ -223,10 +224,10 @@ class test_Scheduler(AppCase):
         not_sync.apply_async = Mock()
         not_sync.apply_async = Mock()
 
 
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
-        self.assertEqual(s.sync_every_tasks, 1)
+        assert s.sync_every_tasks == 1
 
 
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
         s.apply_async(s.Entry(task=not_sync.name, app=self.app))
-        self.assertEqual(s._tasks_since_sync, 0)
+        assert s._tasks_since_sync == 0
 
 
         self.app.conf.beat_sync_every = 0
         self.app.conf.beat_sync_every = 0
 
 
@@ -238,25 +239,23 @@ class test_Scheduler(AppCase):
 
 
     def test_info(self):
     def test_info(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
-        self.assertIsInstance(scheduler.info, string_t)
+        assert isinstance(scheduler.info, string_t)
 
 
     def test_maybe_entry(self):
     def test_maybe_entry(self):
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
         entry = s.Entry(name='add every', task='tasks.add', app=self.app)
         entry = s.Entry(name='add every', task='tasks.add', app=self.app)
-        self.assertIs(s._maybe_entry(entry.name, entry), entry)
-        self.assertTrue(s._maybe_entry('add every', {
-            'task': 'tasks.add',
-        }))
+        assert s._maybe_entry(entry.name, entry) is entry
+        assert s._maybe_entry('add every', {'task': 'tasks.add'})
 
 
     def test_set_schedule(self):
     def test_set_schedule(self):
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
         s.schedule = {'foo': 'bar'}
         s.schedule = {'foo': 'bar'}
-        self.assertEqual(s.data, {'foo': 'bar'})
+        assert s.data == {'foo': 'bar'}
 
 
     @patch('kombu.connection.Connection.ensure_connection')
     @patch('kombu.connection.Connection.ensure_connection')
     def test_ensure_connection_error_handler(self, ensure):
     def test_ensure_connection_error_handler(self, ensure):
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
-        self.assertTrue(s._ensure_connected())
+        assert s._ensure_connected()
         ensure.assert_called()
         ensure.assert_called()
         callback = ensure.call_args[0][0]
         callback = ensure.call_args[0][0]
 
 
@@ -267,19 +266,19 @@ class test_Scheduler(AppCase):
         self.app.conf.beat_schedule = {}
         self.app.conf.beat_schedule = {}
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
         s.install_default_entries({})
         s.install_default_entries({})
-        self.assertNotIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' not in s.data
         self.app.backend.supports_autoexpire = False
         self.app.backend.supports_autoexpire = False
 
 
         self.app.conf.result_expires = 30
         self.app.conf.result_expires = 30
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
         s.install_default_entries({})
         s.install_default_entries({})
-        self.assertIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' in s.data
 
 
         self.app.backend.supports_autoexpire = True
         self.app.backend.supports_autoexpire = True
         self.app.conf.result_expires = 31
         self.app.conf.result_expires = 31
         s = mScheduler(app=self.app)
         s = mScheduler(app=self.app)
         s.install_default_entries({})
         s.install_default_entries({})
-        self.assertNotIn('celery.backend_cleanup', s.data)
+        assert 'celery.backend_cleanup' not in s.data
 
 
     def test_due_tick(self):
     def test_due_tick(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
@@ -287,28 +286,28 @@ class test_Scheduler(AppCase):
                       schedule=always_due,
                       schedule=always_due,
                       args=(1, 2),
                       args=(1, 2),
                       kwargs={'foo': 'bar'})
                       kwargs={'foo': 'bar'})
-        self.assertEqual(scheduler.tick(), 0)
+        assert scheduler.tick() == 0
 
 
     @patch('celery.beat.error')
     @patch('celery.beat.error')
     def test_due_tick_SchedulingError(self, error):
     def test_due_tick_SchedulingError(self, error):
         scheduler = mSchedulerSchedulingError(app=self.app)
         scheduler = mSchedulerSchedulingError(app=self.app)
         scheduler.add(name='test_due_tick_SchedulingError',
         scheduler.add(name='test_due_tick_SchedulingError',
                       schedule=always_due)
                       schedule=always_due)
-        self.assertEqual(scheduler.tick(), 0)
+        assert scheduler.tick() == 0
         error.assert_called()
         error.assert_called()
 
 
     def test_pending_tick(self):
     def test_pending_tick(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_pending_tick',
         scheduler.add(name='test_pending_tick',
                       schedule=always_pending)
                       schedule=always_pending)
-        self.assertEqual(scheduler.tick(), 1 - 0.010)
+        assert scheduler.tick() == 1 - 0.010
 
 
     def test_honors_max_interval(self):
     def test_honors_max_interval(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
         maxi = scheduler.max_interval
         maxi = scheduler.max_interval
         scheduler.add(name='test_honors_max_interval',
         scheduler.add(name='test_honors_max_interval',
                       schedule=mocked_schedule(False, maxi * 4))
                       schedule=mocked_schedule(False, maxi * 4))
-        self.assertEqual(scheduler.tick(), maxi)
+        assert scheduler.tick() == maxi
 
 
     def test_ticks(self):
     def test_ticks(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
@@ -317,13 +316,13 @@ class test_Scheduler(AppCase):
                  {'schedule': mocked_schedule(False, j)})
                  {'schedule': mocked_schedule(False, j)})
                  for i, j in enumerate(nums))
                  for i, j in enumerate(nums))
         scheduler.update_from_dict(s)
         scheduler.update_from_dict(s)
-        self.assertEqual(scheduler.tick(), min(nums) - 0.010)
+        assert scheduler.tick() == min(nums) - 0.010
 
 
     def test_schedule_no_remain(self):
     def test_schedule_no_remain(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
         scheduler.add(name='test_schedule_no_remain',
         scheduler.add(name='test_schedule_no_remain',
                       schedule=mocked_schedule(False, None))
                       schedule=mocked_schedule(False, None))
-        self.assertEqual(scheduler.tick(), scheduler.max_interval)
+        assert scheduler.tick() == scheduler.max_interval
 
 
     def test_interface(self):
     def test_interface(self):
         scheduler = mScheduler(app=self.app)
         scheduler = mScheduler(app=self.app)
@@ -340,9 +339,9 @@ class test_Scheduler(AppCase):
                             'baz': {'schedule': mocked_schedule(True, 10)}})
                             'baz': {'schedule': mocked_schedule(True, 10)}})
         a.merge_inplace(b.schedule)
         a.merge_inplace(b.schedule)
 
 
-        self.assertNotIn('foo', a.schedule)
-        self.assertIn('baz', a.schedule)
-        self.assertEqual(a.schedule['bar'].schedule._next_run_at, 40)
+        assert 'foo' not in a.schedule
+        assert 'baz' in a.schedule
+        assert a.schedule['bar'].schedule._next_run_at == 40
 
 
 
 
 def create_persistent_scheduler(shelv=None):
 def create_persistent_scheduler(shelv=None):
@@ -367,7 +366,7 @@ def create_persistent_scheduler(shelv=None):
     return MockPersistentScheduler, shelv
     return MockPersistentScheduler, shelv
 
 
 
 
-class test_PersistentScheduler(AppCase):
+class test_PersistentScheduler:
 
 
     @patch('os.remove')
     @patch('os.remove')
     def test_remove_db(self, remove):
     def test_remove_db(self, remove):
@@ -382,7 +381,7 @@ class test_PersistentScheduler(AppCase):
         remove.side_effect = err
         remove.side_effect = err
         s._remove_db()
         s._remove_db()
         err.errno = errno.EPERM
         err.errno = errno.EPERM
-        with self.assertRaises(OSError):
+        with pytest.raises(OSError):
             s._remove_db()
             s._remove_db()
 
 
     def test_setup_schedule(self):
     def test_setup_schedule(self):
@@ -420,11 +419,11 @@ class test_PersistentScheduler(AppCase):
         )
         )
         s._store = {'entries': {}}
         s._store = {'entries': {}}
         s.schedule = {'foo': 'bar'}
         s.schedule = {'foo': 'bar'}
-        self.assertDictEqual(s.schedule, {'foo': 'bar'})
-        self.assertDictEqual(s._store['entries'], s.schedule)
+        assert s.schedule == {'foo': 'bar'}
+        assert s._store['entries'] == s.schedule
 
 
 
 
-class test_Service(AppCase):
+class test_Service:
 
 
     def get_service(self):
     def get_service(self):
         Scheduler, mock_shelve = create_persistent_scheduler()
         Scheduler, mock_shelve = create_persistent_scheduler()
@@ -432,26 +431,26 @@ class test_Service(AppCase):
 
 
     def test_pickleable(self):
     def test_pickleable(self):
         s = beat.Service(app=self.app, scheduler_cls=Mock)
         s = beat.Service(app=self.app, scheduler_cls=Mock)
-        self.assertTrue(loads(dumps(s)))
+        assert loads(dumps(s))
 
 
     def test_start(self):
     def test_start(self):
         s, sh = self.get_service()
         s, sh = self.get_service()
         schedule = s.scheduler.schedule
         schedule = s.scheduler.schedule
-        self.assertIsInstance(schedule, dict)
-        self.assertIsInstance(s.scheduler, beat.Scheduler)
+        assert isinstance(schedule, dict)
+        assert isinstance(s.scheduler, beat.Scheduler)
         scheduled = list(schedule.keys())
         scheduled = list(schedule.keys())
         for task_name in keys(sh['entries']):
         for task_name in keys(sh['entries']):
-            self.assertIn(task_name, scheduled)
+            assert task_name in scheduled
 
 
         s.sync()
         s.sync()
-        self.assertTrue(sh.closed)
-        self.assertTrue(sh.synced)
-        self.assertTrue(s._is_stopped.isSet())
+        assert sh.closed
+        assert sh.synced
+        assert s._is_stopped.isSet()
         s.sync()
         s.sync()
         s.stop(wait=False)
         s.stop(wait=False)
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
         s.stop(wait=True)
         s.stop(wait=True)
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
 
         p = s.scheduler._store
         p = s.scheduler._store
         s.scheduler._store = None
         s.scheduler._store = None
@@ -474,24 +473,24 @@ class test_Service(AppCase):
         s, sh = self.get_service()
         s, sh = self.get_service()
         s.scheduler.tick_raises_exit = True
         s.scheduler.tick_raises_exit = True
         s.start()
         s.start()
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
 
     def test_start_manages_one_tick_before_shutdown(self):
     def test_start_manages_one_tick_before_shutdown(self):
         s, sh = self.get_service()
         s, sh = self.get_service()
         s.scheduler.shutdown_service = s
         s.scheduler.shutdown_service = s
         s.start()
         s.start()
-        self.assertTrue(s._is_shutdown.isSet())
+        assert s._is_shutdown.isSet()
 
 
 
 
-class test_EmbeddedService(AppCase):
+class test_EmbeddedService:
 
 
     @skip.unless_module('_multiprocessing', name='multiprocessing')
     @skip.unless_module('_multiprocessing', name='multiprocessing')
     def test_start_stop_process(self):
     def test_start_stop_process(self):
         from billiard.process import Process
         from billiard.process import Process
 
 
         s = beat.EmbeddedService(self.app)
         s = beat.EmbeddedService(self.app)
-        self.assertIsInstance(s, Process)
-        self.assertIsInstance(s.service, beat.Service)
+        assert isinstance(s, Process)
+        assert isinstance(s.service, beat.Service)
         s.service = MockService()
         s.service = MockService()
 
 
         class _Popen(object):
         class _Popen(object):
@@ -502,43 +501,43 @@ class test_EmbeddedService(AppCase):
 
 
         with patch('celery.platforms.close_open_fds'):
         with patch('celery.platforms.close_open_fds'):
             s.run()
             s.run()
-        self.assertTrue(s.service.started)
+        assert s.service.started
 
 
         s._popen = _Popen()
         s._popen = _Popen()
         s.stop()
         s.stop()
-        self.assertTrue(s.service.stopped)
-        self.assertTrue(s._popen.terminated)
+        assert s.service.stopped
+        assert s._popen.terminated
 
 
     def test_start_stop_threaded(self):
     def test_start_stop_threaded(self):
         s = beat.EmbeddedService(self.app, thread=True)
         s = beat.EmbeddedService(self.app, thread=True)
         from threading import Thread
         from threading import Thread
-        self.assertIsInstance(s, Thread)
-        self.assertIsInstance(s.service, beat.Service)
+        assert isinstance(s, Thread)
+        assert isinstance(s.service, beat.Service)
         s.service = MockService()
         s.service = MockService()
 
 
         s.run()
         s.run()
-        self.assertTrue(s.service.started)
+        assert s.service.started
 
 
         s.stop()
         s.stop()
-        self.assertTrue(s.service.stopped)
+        assert s.service.stopped
 
 
 
 
-class test_schedule(AppCase):
+class test_schedule:
 
 
     def test_maybe_make_aware(self):
     def test_maybe_make_aware(self):
         x = schedule(10, app=self.app)
         x = schedule(10, app=self.app)
         x.utc_enabled = True
         x.utc_enabled = True
         d = x.maybe_make_aware(datetime.utcnow())
         d = x.maybe_make_aware(datetime.utcnow())
-        self.assertTrue(d.tzinfo)
+        assert d.tzinfo
         x.utc_enabled = False
         x.utc_enabled = False
         d2 = x.maybe_make_aware(datetime.utcnow())
         d2 = x.maybe_make_aware(datetime.utcnow())
-        self.assertTrue(d2.tzinfo)
+        assert d2.tzinfo
 
 
     def test_to_local(self):
     def test_to_local(self):
         x = schedule(10, app=self.app)
         x = schedule(10, app=self.app)
         x.utc_enabled = True
         x.utc_enabled = True
         d = x.to_local(datetime.utcnow())
         d = x.to_local(datetime.utcnow())
-        self.assertIsNone(d.tzinfo)
+        assert d.tzinfo is None
         x.utc_enabled = False
         x.utc_enabled = False
         d = x.to_local(datetime.utcnow())
         d = x.to_local(datetime.utcnow())
-        self.assertTrue(d.tzinfo)
+        assert d.tzinfo

+ 19 - 17
celery/tests/app/test_builtins.py → t/unit/app/test_builtins.py

@@ -1,14 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import ContextMock, Mock, patch
+
 from celery import group, chord
 from celery import group, chord
 from celery.app import builtins
 from celery.app import builtins
 from celery.five import range
 from celery.five import range
 from celery.utils.functional import pass1
 from celery.utils.functional import pass1
 
 
-from celery.tests.case import AppCase, ContextMock, Mock, patch
-
 
 
-class BuiltinsCase(AppCase):
+class BuiltinsCase:
 
 
     def setup(self):
     def setup(self):
         @self.app.task(shared=False)
         @self.app.task(shared=False)
@@ -38,10 +40,10 @@ class test_accumulate(BuiltinsCase):
         self.accumulate = self.app.tasks['celery.accumulate']
         self.accumulate = self.app.tasks['celery.accumulate']
 
 
     def test_with_index(self):
     def test_with_index(self):
-        self.assertEqual(self.accumulate(1, 2, 3, 4, index=0), 1)
+        assert self.accumulate(1, 2, 3, 4, index=0) == 1
 
 
     def test_no_index(self):
     def test_no_index(self):
-        self.assertEqual(self.accumulate(1, 2, 3, 4), (1, 2, 3, 4))
+        assert self.accumulate(1, 2, 3, 4), (1, 2, 3 == 4)
 
 
 
 
 class test_map(BuiltinsCase):
 class test_map(BuiltinsCase):
@@ -55,7 +57,7 @@ class test_map(BuiltinsCase):
         res = self.app.tasks['celery.map'](
         res = self.app.tasks['celery.map'](
             map_mul, [(2, 2), (4, 4), (8, 8)],
             map_mul, [(2, 2), (4, 4), (8, 8)],
         )
         )
-        self.assertEqual(res, [4, 16, 64])
+        assert res, [4, 16 == 64]
 
 
 
 
 class test_starmap(BuiltinsCase):
 class test_starmap(BuiltinsCase):
@@ -69,7 +71,7 @@ class test_starmap(BuiltinsCase):
         res = self.app.tasks['celery.starmap'](
         res = self.app.tasks['celery.starmap'](
             smap_mul, [(2, 2), (4, 4), (8, 8)],
             smap_mul, [(2, 2), (4, 4), (8, 8)],
         )
         )
-        self.assertEqual(res, [4, 16, 64])
+        assert res, [4, 16 == 64]
 
 
 
 
 class test_chunks(BuiltinsCase):
 class test_chunks(BuiltinsCase):
@@ -90,13 +92,13 @@ class test_chunks(BuiltinsCase):
 class test_group(BuiltinsCase):
 class test_group(BuiltinsCase):
 
 
     def setup(self):
     def setup(self):
-        self.maybe_signature = self.patch('celery.canvas.maybe_signature')
+        self.maybe_signature = self.patching('celery.canvas.maybe_signature')
         self.maybe_signature.side_effect = pass1
         self.maybe_signature.side_effect = pass1
         self.app.producer_or_acquire = Mock()
         self.app.producer_or_acquire = Mock()
         self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
         self.app.producer_or_acquire.attach_mock(ContextMock(), 'return_value')
         self.app.conf.task_always_eager = True
         self.app.conf.task_always_eager = True
         self.task = builtins.add_group_task(self.app)
         self.task = builtins.add_group_task(self.app)
-        super(test_group, self).setup()
+        BuiltinsCase.setup(self)
 
 
     def test_apply_async_eager(self):
     def test_apply_async_eager(self):
         self.task.apply = Mock(name='apply')
         self.task.apply = Mock(name='apply')
@@ -135,7 +137,7 @@ class test_chain(BuiltinsCase):
         self.task = builtins.add_chain_task(self.app)
         self.task = builtins.add_chain_task(self.app)
 
 
     def test_not_implemented(self):
     def test_not_implemented(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.task()
             self.task()
 
 
 
 
@@ -143,13 +145,13 @@ class test_chord(BuiltinsCase):
 
 
     def setup(self):
     def setup(self):
         self.task = builtins.add_chord_task(self.app)
         self.task = builtins.add_chord_task(self.app)
-        super(test_chord, self).setup()
+        BuiltinsCase.setup(self)
 
 
     def test_apply_async(self):
     def test_apply_async(self):
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async()
         r = x.apply_async()
-        self.assertTrue(r)
-        self.assertTrue(r.parent)
+        assert r
+        assert r.parent
 
 
     def test_run_header_not_group(self):
     def test_run_header_not_group(self):
         self.task([self.add.s(i, i) for i in range(10)], self.xsum.s())
         self.task([self.add.s(i, i) for i in range(10)], self.xsum.s())
@@ -161,22 +163,22 @@ class test_chord(BuiltinsCase):
         x.apply_async(group_id='some_group_id')
         x.apply_async(group_id='some_group_id')
         x.run.assert_called()
         x.run.assert_called()
         resbody = x.run.call_args[0][1]
         resbody = x.run.call_args[0][1]
-        self.assertEqual(resbody.options['group_id'], 'some_group_id')
+        assert resbody.options['group_id'] == 'some_group_id'
         x2 = chord([self.add.s(i, i) for i in range(10)], body=body)
         x2 = chord([self.add.s(i, i) for i in range(10)], body=body)
         x2.run = Mock(name='chord.run(x2)')
         x2.run = Mock(name='chord.run(x2)')
         x2.apply_async(chord='some_chord_id')
         x2.apply_async(chord='some_chord_id')
         x2.run.assert_called()
         x2.run.assert_called()
         resbody = x2.run.call_args[0][1]
         resbody = x2.run.call_args[0][1]
-        self.assertEqual(resbody.options['chord'], 'some_chord_id')
+        assert resbody.options['chord'] == 'some_chord_id'
 
 
     def test_apply_eager(self):
     def test_apply_eager(self):
         self.app.conf.task_always_eager = True
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async()
         r = x.apply_async()
-        self.assertEqual(r.get(), 90)
+        assert r.get() == 90
 
 
     def test_apply_eager_with_arguments(self):
     def test_apply_eager_with_arguments(self):
         self.app.conf.task_always_eager = True
         self.app.conf.task_always_eager = True
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
         x = chord([self.add.s(i) for i in range(10)], body=self.xsum.s())
         r = x.apply_async([1])
         r = x.apply_async([1])
-        self.assertEqual(r.get(), 55)
+        assert r.get() == 55

+ 18 - 0
t/unit/app/test_celery.py

@@ -0,0 +1,18 @@
+from __future__ import absolute_import, unicode_literals
+
+import celery
+import pytest
+
+
+def test_version():
+    assert celery.VERSION
+    assert len(celery.VERSION) >= 3
+    celery.VERSION = (0, 3, 0)
+    assert celery.__version__.count('.') >= 2
+
+
+@pytest.mark.parametrize('attr', [
+    '__author__', '__contact__', '__homepage__', '__docformat__',
+])
+def test_meta(attr):
+    assert getattr(celery, attr, None)

+ 50 - 47
celery/tests/app/test_control.py → t/unit/app/test_control.py

@@ -1,12 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from kombu.pidbox import Mailbox
 from kombu.pidbox import Mailbox
 from vine.utils import wraps
 from vine.utils import wraps
 
 
 from celery import uuid
 from celery import uuid
 from celery.app import control
 from celery.app import control
 from celery.exceptions import DuplicateNodenameWarning
 from celery.exceptions import DuplicateNodenameWarning
-from celery.tests.case import AppCase
 
 
 
 
 class MockMailbox(Mailbox):
 class MockMailbox(Mailbox):
@@ -38,7 +39,7 @@ def with_mock_broadcast(fun):
     return _resets
     return _resets
 
 
 
 
-class test_flatten_reply(AppCase):
+class test_flatten_reply:
 
 
     def test_flatten_reply(self):
     def test_flatten_reply(self):
         reply = [
         reply = [
@@ -46,18 +47,16 @@ class test_flatten_reply(AppCase):
             {'foo@example.com': {'hello': 20}},
             {'foo@example.com': {'hello': 20}},
             {'bar@example.com': {'hello': 30}}
             {'bar@example.com': {'hello': 30}}
         ]
         ]
-        with self.assertWarns(DuplicateNodenameWarning) as w:
+        with pytest.warns(DuplicateNodenameWarning) as w:
             nodes = control.flatten_reply(reply)
             nodes = control.flatten_reply(reply)
 
 
-        self.assertIn(
-            'Received multiple replies from node name: foo@example.com.',
-            str(w.warning)
-        )
-        self.assertIn('foo@example.com', nodes)
-        self.assertIn('bar@example.com', nodes)
+        assert 'Received multiple replies from node name: {0}.'.format(
+            next(iter(reply[0]))) in str(w[0].message.args[0])
+        assert 'foo@example.com' in nodes
+        assert 'bar@example.com' in nodes
 
 
 
 
-class test_inspect(AppCase):
+class test_inspect:
 
 
     def setup(self):
     def setup(self):
         self.c = Control(app=self.app)
         self.c = Control(app=self.app)
@@ -65,91 +64,95 @@ class test_inspect(AppCase):
         self.i = self.c.inspect()
         self.i = self.c.inspect()
 
 
     def test_prepare_reply(self):
     def test_prepare_reply(self):
-        self.assertDictEqual(self.i._prepare([{'w1': {'ok': 1}},
-                                              {'w2': {'ok': 1}}]),
-                             {'w1': {'ok': 1}, 'w2': {'ok': 1}})
+        reply = self.i._prepare([
+            {'w1': {'ok': 1}},
+            {'w2': {'ok': 1}},
+        ])
+        assert reply == {
+            'w1': {'ok': 1},
+            'w2': {'ok': 1},
+        }
 
 
         i = self.c.inspect(destination='w1')
         i = self.c.inspect(destination='w1')
-        self.assertEqual(i._prepare([{'w1': {'ok': 1}}]),
-                         {'ok': 1})
+        assert i._prepare([{'w1': {'ok': 1}}]) == {'ok': 1}
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_active(self):
     def test_active(self):
         self.i.active()
         self.i.active()
-        self.assertIn('active', MockMailbox.sent)
+        assert 'active' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_clock(self):
     def test_clock(self):
         self.i.clock()
         self.i.clock()
-        self.assertIn('clock', MockMailbox.sent)
+        assert 'clock' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_conf(self):
     def test_conf(self):
         self.i.conf()
         self.i.conf()
-        self.assertIn('conf', MockMailbox.sent)
+        assert 'conf' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_hello(self):
     def test_hello(self):
         self.i.hello('george@vandelay.com')
         self.i.hello('george@vandelay.com')
-        self.assertIn('hello', MockMailbox.sent)
+        assert 'hello' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_memsample(self):
     def test_memsample(self):
         self.i.memsample()
         self.i.memsample()
-        self.assertIn('memsample', MockMailbox.sent)
+        assert 'memsample' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_memdump(self):
     def test_memdump(self):
         self.i.memdump()
         self.i.memdump()
-        self.assertIn('memdump', MockMailbox.sent)
+        assert 'memdump' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_objgraph(self):
     def test_objgraph(self):
         self.i.objgraph()
         self.i.objgraph()
-        self.assertIn('objgraph', MockMailbox.sent)
+        assert 'objgraph' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_scheduled(self):
     def test_scheduled(self):
         self.i.scheduled()
         self.i.scheduled()
-        self.assertIn('scheduled', MockMailbox.sent)
+        assert 'scheduled' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_reserved(self):
     def test_reserved(self):
         self.i.reserved()
         self.i.reserved()
-        self.assertIn('reserved', MockMailbox.sent)
+        assert 'reserved' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_stats(self):
     def test_stats(self):
         self.i.stats()
         self.i.stats()
-        self.assertIn('stats', MockMailbox.sent)
+        assert 'stats' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoked(self):
     def test_revoked(self):
         self.i.revoked()
         self.i.revoked()
-        self.assertIn('revoked', MockMailbox.sent)
+        assert 'revoked' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_tasks(self):
     def test_tasks(self):
         self.i.registered()
         self.i.registered()
-        self.assertIn('registered', MockMailbox.sent)
+        assert 'registered' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_ping(self):
     def test_ping(self):
         self.i.ping()
         self.i.ping()
-        self.assertIn('ping', MockMailbox.sent)
+        assert 'ping' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_active_queues(self):
     def test_active_queues(self):
         self.i.active_queues()
         self.i.active_queues()
-        self.assertIn('active_queues', MockMailbox.sent)
+        assert 'active_queues' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_report(self):
     def test_report(self):
         self.i.report()
         self.i.report()
-        self.assertIn('report', MockMailbox.sent)
+        assert 'report' in MockMailbox.sent
 
 
 
 
-class test_Broadcast(AppCase):
+class test_Broadcast:
 
 
     def setup(self):
     def setup(self):
         self.control = Control(app=self.app)
         self.control = Control(app=self.app)
@@ -166,80 +169,80 @@ class test_Broadcast(AppCase):
     @with_mock_broadcast
     @with_mock_broadcast
     def test_broadcast(self):
     def test_broadcast(self):
         self.control.broadcast('foobarbaz', arguments=[])
         self.control.broadcast('foobarbaz', arguments=[])
-        self.assertIn('foobarbaz', MockMailbox.sent)
+        assert 'foobarbaz' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_broadcast_limit(self):
     def test_broadcast_limit(self):
         self.control.broadcast(
         self.control.broadcast(
             'foobarbaz1', arguments=[], limit=None, destination=[1, 2, 3],
             'foobarbaz1', arguments=[], limit=None, destination=[1, 2, 3],
         )
         )
-        self.assertIn('foobarbaz1', MockMailbox.sent)
+        assert 'foobarbaz1' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_broadcast_validate(self):
     def test_broadcast_validate(self):
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             self.control.broadcast('foobarbaz2',
             self.control.broadcast('foobarbaz2',
                                    destination='foo')
                                    destination='foo')
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_rate_limit(self):
     def test_rate_limit(self):
         self.control.rate_limit(self.mytask.name, '100/m')
         self.control.rate_limit(self.mytask.name, '100/m')
-        self.assertIn('rate_limit', MockMailbox.sent)
+        assert 'rate_limit' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_time_limit(self):
     def test_time_limit(self):
         self.control.time_limit(self.mytask.name, soft=10, hard=20)
         self.control.time_limit(self.mytask.name, soft=10, hard=20)
-        self.assertIn('time_limit', MockMailbox.sent)
+        assert 'time_limit' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_add_consumer(self):
     def test_add_consumer(self):
         self.control.add_consumer('foo')
         self.control.add_consumer('foo')
-        self.assertIn('add_consumer', MockMailbox.sent)
+        assert 'add_consumer' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_cancel_consumer(self):
     def test_cancel_consumer(self):
         self.control.cancel_consumer('foo')
         self.control.cancel_consumer('foo')
-        self.assertIn('cancel_consumer', MockMailbox.sent)
+        assert 'cancel_consumer' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_enable_events(self):
     def test_enable_events(self):
         self.control.enable_events()
         self.control.enable_events()
-        self.assertIn('enable_events', MockMailbox.sent)
+        assert 'enable_events' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_disable_events(self):
     def test_disable_events(self):
         self.control.disable_events()
         self.control.disable_events()
-        self.assertIn('disable_events', MockMailbox.sent)
+        assert 'disable_events' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke(self):
     def test_revoke(self):
         self.control.revoke('foozbaaz')
         self.control.revoke('foozbaaz')
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_ping(self):
     def test_ping(self):
         self.control.ping()
         self.control.ping()
-        self.assertIn('ping', MockMailbox.sent)
+        assert 'ping' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_election(self):
     def test_election(self):
         self.control.election('some_id', 'topic', 'action')
         self.control.election('some_id', 'topic', 'action')
-        self.assertIn('election', MockMailbox.sent)
+        assert 'election' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_pool_grow(self):
     def test_pool_grow(self):
         self.control.pool_grow(2)
         self.control.pool_grow(2)
-        self.assertIn('pool_grow', MockMailbox.sent)
+        assert 'pool_grow' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_pool_shrink(self):
     def test_pool_shrink(self):
         self.control.pool_shrink(2)
         self.control.pool_shrink(2)
-        self.assertIn('pool_shrink', MockMailbox.sent)
+        assert 'pool_shrink' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke_from_result(self):
     def test_revoke_from_result(self):
         self.app.AsyncResult('foozbazzbar').revoke()
         self.app.AsyncResult('foozbazzbar').revoke()
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent
 
 
     @with_mock_broadcast
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
     def test_revoke_from_resultset(self):
@@ -247,4 +250,4 @@ class test_Broadcast(AppCase):
                                  [self.app.AsyncResult(x)
                                  [self.app.AsyncResult(x)
                                   for x in [uuid() for i in range(10)]])
                                   for x in [uuid() for i in range(10)]])
         r.revoke()
         r.revoke()
-        self.assertIn('revoke', MockMailbox.sent)
+        assert 'revoke' in MockMailbox.sent

+ 65 - 0
t/unit/app/test_defaults.py

@@ -0,0 +1,65 @@
+from __future__ import absolute_import, unicode_literals
+
+import sys
+
+from importlib import import_module
+
+from case import mock
+
+from celery.app.defaults import (
+    _OLD_DEFAULTS, _OLD_SETTING_KEYS, _TO_NEW_KEY, _TO_OLD_KEY,
+    DEFAULTS, NAMESPACES, SETTING_KEYS
+)
+from celery.five import values
+
+
+class test_defaults:
+
+    def setup(self):
+        self._prev = sys.modules.pop('celery.app.defaults', None)
+
+    def teardown(self):
+        if self._prev:
+            sys.modules['celery.app.defaults'] = self._prev
+
+    def test_option_repr(self):
+        assert repr(NAMESPACES['broker']['url'])
+
+    def test_any(self):
+        val = object()
+        assert self.defaults.Option.typemap['any'](val) is val
+
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 4, 0))
+    def test_default_pool_pypy_14(self):
+        assert self.defaults.DEFAULT_POOL == 'solo'
+
+    @mock.sys_platform('darwin')
+    @mock.pypy_version((1, 5, 0))
+    def test_default_pool_pypy_15(self):
+        assert self.defaults.DEFAULT_POOL == 'prefork'
+
+    def test_compat_indices(self):
+        assert not any(key.isupper() for key in DEFAULTS)
+        assert not any(key.islower() for key in _OLD_DEFAULTS)
+        assert not any(key.isupper() for key in _TO_OLD_KEY)
+        assert not any(key.islower() for key in _TO_NEW_KEY)
+        assert not any(key.isupper() for key in SETTING_KEYS)
+        assert not any(key.islower() for key in _OLD_SETTING_KEYS)
+        assert not any(value.isupper() for value in values(_TO_NEW_KEY))
+        assert not any(value.islower() for value in values(_TO_OLD_KEY))
+
+        for key in _TO_NEW_KEY:
+            assert key in _OLD_SETTING_KEYS
+        for key in _TO_OLD_KEY:
+            assert key in SETTING_KEYS
+
+    def test_find(self):
+        find = self.defaults.find
+
+        assert find('default_queue')[2].default == 'celery'
+        assert find('task_default_exchange')[2] == 'celery'
+
+    @property
+    def defaults(self):
+        return import_module('celery.app.defaults')

+ 8 - 10
celery/tests/app/test_exceptions.py → t/unit/app/test_exceptions.py

@@ -6,30 +6,28 @@ from datetime import datetime
 
 
 from celery.exceptions import Reject, Retry
 from celery.exceptions import Reject, Retry
 
 
-from celery.tests.case import AppCase
 
 
-
-class test_Retry(AppCase):
+class test_Retry:
 
 
     def test_when_datetime(self):
     def test_when_datetime(self):
         x = Retry('foo', KeyError(), when=datetime.utcnow())
         x = Retry('foo', KeyError(), when=datetime.utcnow())
-        self.assertTrue(x.humanize())
+        assert x.humanize()
 
 
     def test_pickleable(self):
     def test_pickleable(self):
         x = Retry('foo', KeyError(), when=datetime.utcnow())
         x = Retry('foo', KeyError(), when=datetime.utcnow())
-        self.assertTrue(pickle.loads(pickle.dumps(x)))
+        assert pickle.loads(pickle.dumps(x))
 
 
 
 
-class test_Reject(AppCase):
+class test_Reject:
 
 
     def test_attrs(self):
     def test_attrs(self):
         x = Reject('foo', requeue=True)
         x = Reject('foo', requeue=True)
-        self.assertEqual(x.reason, 'foo')
-        self.assertTrue(x.requeue)
+        assert x.reason == 'foo'
+        assert x.requeue
 
 
     def test_repr(self):
     def test_repr(self):
-        self.assertTrue(repr(Reject('foo', True)))
+        assert repr(Reject('foo', True))
 
 
     def test_pickleable(self):
     def test_pickleable(self):
         x = Retry('foo', True)
         x = Retry('foo', True)
-        self.assertTrue(pickle.loads(pickle.dumps(x)))
+        assert pickle.loads(pickle.dumps(x))

+ 32 - 36
celery/tests/app/test_loaders.py → t/unit/app/test_loaders.py

@@ -1,9 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import os
 import os
+import pytest
 import sys
 import sys
 import warnings
 import warnings
 
 
+from case import Mock, mock, patch
+
 from celery import loaders
 from celery import loaders
 from celery.exceptions import NotConfigured
 from celery.exceptions import NotConfigured
 from celery.five import bytes_if_py2
 from celery.five import bytes_if_py2
@@ -12,8 +15,6 @@ from celery.loaders import default
 from celery.loaders.app import AppLoader
 from celery.loaders.app import AppLoader
 from celery.utils.imports import NotAPackage
 from celery.utils.imports import NotAPackage
 
 
-from celery.tests.case import AppCase, Case, Mock, mock, patch
-
 
 
 class DummyLoader(base.BaseLoader):
 class DummyLoader(base.BaseLoader):
 
 
@@ -21,14 +22,13 @@ class DummyLoader(base.BaseLoader):
         return {'foo': 'bar', 'imports': ('os', 'sys')}
         return {'foo': 'bar', 'imports': ('os', 'sys')}
 
 
 
 
-class test_loaders(AppCase):
+class test_loaders:
 
 
     def test_get_loader_cls(self):
     def test_get_loader_cls(self):
-        self.assertEqual(loaders.get_loader_cls('default'),
-                         default.Loader)
+        assert loaders.get_loader_cls('default') is default.Loader
 
 
 
 
-class test_LoaderBase(AppCase):
+class test_LoaderBase:
     message_options = {'subject': 'Subject',
     message_options = {'subject': 'Subject',
                        'body': 'Body',
                        'body': 'Body',
                        'sender': 'x@x.com',
                        'sender': 'x@x.com',
@@ -47,25 +47,23 @@ class test_LoaderBase(AppCase):
         self.loader.on_worker_init()
         self.loader.on_worker_init()
 
 
     def test_now(self):
     def test_now(self):
-        self.assertTrue(self.loader.now(utc=True))
-        self.assertTrue(self.loader.now(utc=False))
+        assert self.loader.now(utc=True)
+        assert self.loader.now(utc=False)
 
 
     def test_read_configuration_no_env(self):
     def test_read_configuration_no_env(self):
-        self.assertIsNone(
-            base.BaseLoader(app=self.app).read_configuration(
-                'FOO_X_S_WE_WQ_Q_WE'),
-        )
+        assert base.BaseLoader(app=self.app).read_configuration(
+            'FOO_X_S_WE_WQ_Q_WE') is None
 
 
     def test_autodiscovery(self):
     def test_autodiscovery(self):
         with patch('celery.loaders.base.autodiscover_tasks') as auto:
         with patch('celery.loaders.base.autodiscover_tasks') as auto:
             auto.return_value = [Mock()]
             auto.return_value = [Mock()]
             auto.return_value[0].__name__ = 'moo'
             auto.return_value[0].__name__ = 'moo'
             self.loader.autodiscover_tasks(['A', 'B'])
             self.loader.autodiscover_tasks(['A', 'B'])
-            self.assertIn('moo', self.loader.task_modules)
+            assert 'moo' in self.loader.task_modules
             self.loader.task_modules.discard('moo')
             self.loader.task_modules.discard('moo')
 
 
     def test_import_task_module(self):
     def test_import_task_module(self):
-        self.assertEqual(sys, self.loader.import_task_module('sys'))
+        assert sys == self.loader.import_task_module('sys')
 
 
     def test_init_worker_process(self):
     def test_init_worker_process(self):
         self.loader.on_worker_process_init()
         self.loader.on_worker_process_init()
@@ -79,18 +77,16 @@ class test_LoaderBase(AppCase):
         self.loader.import_from_cwd.assert_called_with('module_name')
         self.loader.import_from_cwd.assert_called_with('module_name')
 
 
     def test_conf_property(self):
     def test_conf_property(self):
-        self.assertEqual(self.loader.conf['foo'], 'bar')
-        self.assertEqual(self.loader._conf['foo'], 'bar')
-        self.assertEqual(self.loader.conf['foo'], 'bar')
+        assert self.loader.conf['foo'] == 'bar'
+        assert self.loader._conf['foo'] == 'bar'
+        assert self.loader.conf['foo'] == 'bar'
 
 
     def test_import_default_modules(self):
     def test_import_default_modules(self):
         def modnames(l):
         def modnames(l):
             return [m.__name__ for m in l]
             return [m.__name__ for m in l]
         self.app.conf.imports = ('os', 'sys')
         self.app.conf.imports = ('os', 'sys')
-        self.assertEqual(
-            sorted(modnames(self.loader.import_default_modules())),
-            sorted(modnames([os, sys])),
-        )
+        assert (sorted(modnames(self.loader.import_default_modules())) ==
+                sorted(modnames([os, sys])))
 
 
     def test_import_from_cwd_custom_imp(self):
     def test_import_from_cwd_custom_imp(self):
         imp = Mock(name='imp')
         imp = Mock(name='imp')
@@ -98,17 +94,17 @@ class test_LoaderBase(AppCase):
         imp.assert_called()
         imp.assert_called()
 
 
     def test_cmdline_config_ValueError(self):
     def test_cmdline_config_ValueError(self):
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             self.loader.cmdline_config_parser(['broker.port=foobar'])
             self.loader.cmdline_config_parser(['broker.port=foobar'])
 
 
 
 
-class test_DefaultLoader(AppCase):
+class test_DefaultLoader:
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
     def test_read_configuration_not_a_package(self, find_module):
     def test_read_configuration_not_a_package(self, find_module):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
-        with self.assertRaises(NotAPackage):
+        with pytest.raises(NotAPackage):
             l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
@@ -116,7 +112,7 @@ class test_DefaultLoader(AppCase):
     def test_read_configuration_py_in_name(self, find_module):
     def test_read_configuration_py_in_name(self, find_module):
         find_module.side_effect = NotAPackage()
         find_module.side_effect = NotAPackage()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
-        with self.assertRaises(NotAPackage):
+        with pytest.raises(NotAPackage):
             l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=False)
 
 
     @patch('celery.loaders.base.find_module')
     @patch('celery.loaders.base.find_module')
@@ -124,7 +120,7 @@ class test_DefaultLoader(AppCase):
         default.C_WNOCONF = True
         default.C_WNOCONF = True
         find_module.side_effect = ImportError()
         find_module.side_effect = ImportError()
         l = default.Loader(app=self.app)
         l = default.Loader(app=self.app)
-        with self.assertWarnsRegex(NotConfigured, r'make sure it exists'):
+        with pytest.warns(NotConfigured):
             l.read_configuration(fail_silently=True)
             l.read_configuration(fail_silently=True)
         default.C_WNOCONF = False
         default.C_WNOCONF = False
         l.read_configuration(fail_silently=True)
         l.read_configuration(fail_silently=True)
@@ -145,9 +141,9 @@ class test_DefaultLoader(AppCase):
             l = default.Loader(app=self.app)
             l = default.Loader(app=self.app)
             l.find_module = Mock(name='find_module')
             l.find_module = Mock(name='find_module')
             settings = l.read_configuration(fail_silently=False)
             settings = l.read_configuration(fail_silently=False)
-            self.assertTupleEqual(settings.imports, ('os', 'sys'))
+            assert settings.imports == ('os', 'sys')
             settings = l.read_configuration(fail_silently=False)
             settings = l.read_configuration(fail_silently=False)
-            self.assertTupleEqual(settings.imports, ('os', 'sys'))
+            assert settings.imports == ('os', 'sys')
             l.on_worker_init()
             l.on_worker_init()
         finally:
         finally:
             if prevconfig:
             if prevconfig:
@@ -160,7 +156,7 @@ class test_DefaultLoader(AppCase):
         )
         )
         try:
         try:
             l = default.Loader(app=self.app)
             l = default.Loader(app=self.app)
-            with self.assertRaises(ImportError):
+            with pytest.raises(ImportError):
                 l.read_configuration(fail_silently=False)
                 l.read_configuration(fail_silently=False)
             l.read_configuration(fail_silently=True)
             l.read_configuration(fail_silently=True)
         finally:
         finally:
@@ -179,11 +175,11 @@ class test_DefaultLoader(AppCase):
         celery = sys.modules.pop('celery', None)
         celery = sys.modules.pop('celery', None)
         sys.modules.pop('celery.five', None)
         sys.modules.pop('celery.five', None)
         try:
         try:
-            self.assertTrue(l.import_from_cwd('celery'))
+            assert l.import_from_cwd('celery')
             sys.modules.pop('celery', None)
             sys.modules.pop('celery', None)
             sys.modules.pop('celery.five', None)
             sys.modules.pop('celery.five', None)
             sys.path.insert(0, os.getcwd())
             sys.path.insert(0, os.getcwd())
-            self.assertTrue(l.import_from_cwd('celery'))
+            assert l.import_from_cwd('celery')
         finally:
         finally:
             sys.path = old_path
             sys.path = old_path
             sys.modules['celery'] = celery
             sys.modules['celery'] = celery
@@ -198,12 +194,12 @@ class test_DefaultLoader(AppCase):
 
 
         with warnings.catch_warnings(record=True):
         with warnings.catch_warnings(record=True):
             l = _Loader(app=self.app)
             l = _Loader(app=self.app)
-            self.assertFalse(l.configured)
+            assert not l.configured
             context_executed[0] = True
             context_executed[0] = True
-        self.assertTrue(context_executed[0])
+        assert context_executed[0]
 
 
 
 
-class test_AppLoader(AppCase):
+class test_AppLoader:
 
 
     def setup(self):
     def setup(self):
         self.loader = AppLoader(app=self.app)
         self.loader = AppLoader(app=self.app)
@@ -212,10 +208,10 @@ class test_AppLoader(AppCase):
         self.app.conf.imports = ('subprocess',)
         self.app.conf.imports = ('subprocess',)
         sys.modules.pop('subprocess', None)
         sys.modules.pop('subprocess', None)
         self.loader.init_worker()
         self.loader.init_worker()
-        self.assertIn('subprocess', sys.modules)
+        assert 'subprocess' in sys.modules
 
 
 
 
-class test_autodiscovery(Case):
+class test_autodiscovery:
 
 
     def test_autodiscover_tasks(self):
     def test_autodiscover_tasks(self):
         base._RACE_PROTECTION = True
         base._RACE_PROTECTION = True

+ 45 - 60
celery/tests/app/test_log.py → t/unit/app/test_log.py

@@ -1,12 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-import sys
 import logging
 import logging
+import pytest
+import sys
 
 
 from collections import defaultdict
 from collections import defaultdict
 from io import StringIO
 from io import StringIO
 from tempfile import mktemp
 from tempfile import mktemp
 
 
+from case import Mock, mock, patch, skip
+from case.utils import get_logger_handlers
+
 from celery import signals
 from celery import signals
 from celery import uuid
 from celery import uuid
 from celery.app.log import TaskFormatter
 from celery.app.log import TaskFormatter
@@ -22,11 +26,8 @@ from celery.utils.log import (
     logger_isa,
     logger_isa,
 )
 )
 
 
-from case.utils import get_logger_handlers
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 
 
-class test_TaskFormatter(AppCase):
+class test_TaskFormatter:
 
 
     def test_no_task(self):
     def test_no_task(self):
         class Record(object):
         class Record(object):
@@ -40,39 +41,39 @@ class test_TaskFormatter(AppCase):
         record = Record()
         record = Record()
         x = TaskFormatter()
         x = TaskFormatter()
         x.format(record)
         x.format(record)
-        self.assertEqual(record.task_name, '???')
-        self.assertEqual(record.task_id, '???')
+        assert record.task_name == '???'
+        assert record.task_id == '???'
 
 
 
 
-class test_logger_isa(AppCase):
+class test_logger_isa:
 
 
     def test_isa(self):
     def test_isa(self):
         x = get_task_logger('Z1george')
         x = get_task_logger('Z1george')
-        self.assertTrue(logger_isa(x, task_logger))
+        assert logger_isa(x, task_logger)
         prev_x, x.parent = x.parent, None
         prev_x, x.parent = x.parent, None
         try:
         try:
-            self.assertFalse(logger_isa(x, task_logger))
+            assert not logger_isa(x, task_logger)
         finally:
         finally:
             x.parent = prev_x
             x.parent = prev_x
 
 
         y = get_task_logger('Z1elaine')
         y = get_task_logger('Z1elaine')
         y.parent = x
         y.parent = x
-        self.assertTrue(logger_isa(y, task_logger))
-        self.assertTrue(logger_isa(y, x))
-        self.assertTrue(logger_isa(y, y))
+        assert logger_isa(y, task_logger)
+        assert logger_isa(y, x)
+        assert logger_isa(y, y)
 
 
         z = get_task_logger('Z1jerry')
         z = get_task_logger('Z1jerry')
         z.parent = y
         z.parent = y
-        self.assertTrue(logger_isa(z, task_logger))
-        self.assertTrue(logger_isa(z, y))
-        self.assertTrue(logger_isa(z, x))
-        self.assertTrue(logger_isa(z, z))
+        assert logger_isa(z, task_logger)
+        assert logger_isa(z, y)
+        assert logger_isa(z, x)
+        assert logger_isa(z, z)
 
 
     def test_recursive(self):
     def test_recursive(self):
         x = get_task_logger('X1foo')
         x = get_task_logger('X1foo')
         prev, x.parent = x.parent, x
         prev, x.parent = x.parent, x
         try:
         try:
-            with self.assertRaises(RuntimeError):
+            with pytest.raises(RuntimeError):
                 logger_isa(x, task_logger)
                 logger_isa(x, task_logger)
         finally:
         finally:
             x.parent = prev
             x.parent = prev
@@ -83,7 +84,7 @@ class test_logger_isa(AppCase):
         try:
         try:
             prev_z, z.parent = z.parent, y
             prev_z, z.parent = z.parent, y
             try:
             try:
-                with self.assertRaises(RuntimeError):
+                with pytest.raises(RuntimeError):
                     logger_isa(y, task_logger)
                     logger_isa(y, task_logger)
             finally:
             finally:
                 z.parent = prev_z
                 z.parent = prev_z
@@ -91,7 +92,7 @@ class test_logger_isa(AppCase):
             y.parent = prev_y
             y.parent = prev_y
 
 
 
 
-class test_ColorFormatter(AppCase):
+class test_ColorFormatter:
 
 
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
     @patch('logging.Formatter.formatException')
     @patch('logging.Formatter.formatException')
@@ -99,7 +100,7 @@ class test_ColorFormatter(AppCase):
         x = ColorFormatter()
         x = ColorFormatter()
         value = KeyError()
         value = KeyError()
         fe.return_value = value
         fe.return_value = value
-        self.assertIs(x.formatException(value), value)
+        assert x.formatException(value) is value
         fe.assert_called()
         fe.assert_called()
         safe_str.assert_not_called()
         safe_str.assert_not_called()
 
 
@@ -111,7 +112,7 @@ class test_ColorFormatter(AppCase):
         try:
         try:
             raise Exception()
             raise Exception()
         except Exception:
         except Exception:
-            self.assertTrue(x.formatException(sys.exc_info()))
+            assert x.formatException(sys.exc_info())
         if sys.version_info[0] == 2:
         if sys.version_info[0] == 2:
             safe_str.assert_called()
             safe_str.assert_called()
 
 
@@ -122,7 +123,7 @@ class test_ColorFormatter(AppCase):
         record = Mock()
         record = Mock()
         record.levelname = 'ERROR'
         record.levelname = 'ERROR'
         record.msg = object()
         record.msg = object()
-        self.assertTrue(x.format(record))
+        assert x.format(record)
 
 
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
     def test_format_raises(self, safe_str):
     def test_format_raises(self, safe_str):
@@ -153,8 +154,8 @@ class test_ColorFormatter(AppCase):
         safe_str.return_value = record
         safe_str.return_value = record
 
 
         msg = x.format(record)
         msg = x.format(record)
-        self.assertIn('<Unrepresentable', msg)
-        self.assertEqual(safe_str.call_count, 1)
+        assert '<Unrepresentable' in msg
+        assert safe_str.call_count == 1
 
 
     @skip.if_python3()
     @skip.if_python3()
     @patch('celery.utils.log.safe_str')
     @patch('celery.utils.log.safe_str')
@@ -165,10 +166,10 @@ class test_ColorFormatter(AppCase):
         record.msg = 'HELLO'
         record.msg = 'HELLO'
         record.exc_text = 'error text'
         record.exc_text = 'error text'
         x.format(record)
         x.format(record)
-        self.assertEqual(safe_str.call_count, 1)
+        assert safe_str.call_count == 1
 
 
 
 
-class test_default_logger(AppCase):
+class test_default_logger:
 
 
     def setup(self):
     def setup(self):
         self.setup_logger = self.app.log.setup_logger
         self.setup_logger = self.app.log.setup_logger
@@ -178,11 +179,11 @@ class test_default_logger(AppCase):
 
 
     def test_get_logger_sets_parent(self):
     def test_get_logger_sets_parent(self):
         logger = get_logger('celery.test_get_logger')
         logger = get_logger('celery.test_get_logger')
-        self.assertEqual(logger.parent.name, base_logger.name)
+        assert logger.parent.name == base_logger.name
 
 
     def test_get_logger_root(self):
     def test_get_logger_root(self):
         logger = get_logger(base_logger.name)
         logger = get_logger(base_logger.name)
-        self.assertIs(logger.parent, logging.root)
+        assert logger.parent is logging.root
 
 
     @mock.restore_logging()
     @mock.restore_logging()
     def test_setup_logging_subsystem_misc(self):
     def test_setup_logging_subsystem_misc(self):
@@ -194,7 +195,7 @@ class test_default_logger(AppCase):
         self.app.log.setup_logging_subsystem()
         self.app.log.setup_logging_subsystem()
 
 
     def test_get_default_logger(self):
     def test_get_default_logger(self):
-        self.assertTrue(self.app.log.get_default_logger())
+        assert self.app.log.get_default_logger()
 
 
     def test_configure_logger(self):
     def test_configure_logger(self):
         logger = self.app.log.get_default_logger()
         logger = self.app.log.get_default_logger()
@@ -212,20 +213,6 @@ class test_default_logger(AppCase):
         with mock.mask_modules('billiard.util'):
         with mock.mask_modules('billiard.util'):
             self.app.log.setup_logging_subsystem()
             self.app.log.setup_logging_subsystem()
 
 
-    def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
-
-        with mock.wrap_logger(logger, loglevel=loglevel) as sio:
-            logger.log(loglevel, logmsg)
-            return sio.getvalue().strip()
-
-    def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
-        val = self._assertLog(logger, logmsg, loglevel=loglevel)
-        return self.assertEqual(val, logmsg, reason)
-
-    def assertDidLogFalse(self, logger, logmsg, reason, loglevel=None):
-        val = self._assertLog(logger, logmsg, loglevel=loglevel)
-        return self.assertFalse(val, reason)
-
     @mock.restore_logging()
     @mock.restore_logging()
     def test_setup_logger(self):
     def test_setup_logger(self):
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
@@ -234,10 +221,9 @@ class test_default_logger(AppCase):
         self.app.log.already_setup = False
         self.app.log.already_setup = False
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
         logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
                                    root=False, colorize=None)
                                    root=False, colorize=None)
-        self.assertIs(
-            get_logger_handlers(logger)[0].stream, sys.__stderr__,
-            'setup_logger logs to stderr without logfile argument.',
-        )
+        # setup_logger logs to stderr without logfile argument.
+        assert (get_logger_handlers(logger)[0].stream is
+                sys.__stderr__)
 
 
     @mock.restore_logging()
     @mock.restore_logging()
     def test_setup_logger_no_handlers_stream(self):
     def test_setup_logger_no_handlers_stream(self):
@@ -249,7 +235,7 @@ class test_default_logger(AppCase):
             l = self.setup_logger(logfile=sys.stderr,
             l = self.setup_logger(logfile=sys.stderr,
                                   loglevel=logging.INFO, root=False)
                                   loglevel=logging.INFO, root=False)
             l.info('The quick brown fox...')
             l.info('The quick brown fox...')
-            self.assertIn('The quick brown fox...', stderr.getvalue())
+            assert 'The quick brown fox...' in stderr.getvalue()
 
 
     @patch('os.fstat')
     @patch('os.fstat')
     def test_setup_logger_no_handlers_file(self, *args):
     def test_setup_logger_no_handlers_file(self, *args):
@@ -272,10 +258,9 @@ class test_default_logger(AppCase):
                 l = self.setup_logger(
                 l = self.setup_logger(
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                     logfile=tempfile, loglevel=logging.INFO, root=False,
                 )
                 )
-                self.assertIsInstance(
-                    get_logger_handlers(l)[0], logging.FileHandler,
-                )
-                self.assertIn(tempfile, files)
+                assert isinstance(get_logger_handlers(l)[0],
+                                  logging.FileHandler)
+                assert tempfile in files
 
 
     @mock.restore_logging()
     @mock.restore_logging()
     def test_redirect_stdouts(self):
     def test_redirect_stdouts(self):
@@ -287,7 +272,7 @@ class test_default_logger(AppCase):
                     logger, loglevel=logging.ERROR,
                     logger, loglevel=logging.ERROR,
                 )
                 )
                 logger.error('foo')
                 logger.error('foo')
-                self.assertIn('foo', sio.getvalue())
+                assert 'foo' in sio.getvalue()
                 self.app.log.redirect_stdouts_to_logger(
                 self.app.log.redirect_stdouts_to_logger(
                     logger, stdout=False, stderr=False,
                     logger, stdout=False, stderr=False,
                 )
                 )
@@ -303,22 +288,22 @@ class test_default_logger(AppCase):
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p = LoggingProxy(logger, loglevel=logging.ERROR)
             p.close()
             p.close()
             p.write('foo')
             p.write('foo')
-            self.assertNotIn('foo', sio.getvalue())
+            assert 'foo' not in sio.getvalue()
             p.closed = False
             p.closed = False
             p.write('foo')
             p.write('foo')
-            self.assertIn('foo', sio.getvalue())
+            assert 'foo' in sio.getvalue()
             lines = ['baz', 'xuzzy']
             lines = ['baz', 'xuzzy']
             p.writelines(lines)
             p.writelines(lines)
             for line in lines:
             for line in lines:
-                self.assertIn(line, sio.getvalue())
+                assert line in sio.getvalue()
             p.flush()
             p.flush()
             p.close()
             p.close()
-            self.assertFalse(p.isatty())
+            assert not p.isatty()
 
 
             with mock.stdouts() as (stdout, stderr):
             with mock.stdouts() as (stdout, stderr):
                 with in_sighandler():
                 with in_sighandler():
                     p.write('foo')
                     p.write('foo')
-                    self.assertTrue(stderr.getvalue())
+                    assert stderr.getvalue()
 
 
     @mock.restore_logging()
     @mock.restore_logging()
     def test_logging_proxy_recurse_protection(self):
     def test_logging_proxy_recurse_protection(self):
@@ -327,7 +312,7 @@ class test_default_logger(AppCase):
         p = LoggingProxy(logger, loglevel=logging.ERROR)
         p = LoggingProxy(logger, loglevel=logging.ERROR)
         p._thread.recurse_protection = True
         p._thread.recurse_protection = True
         try:
         try:
-            self.assertIsNone(p.write('FOOFO'))
+            assert p.write('FOOFO') is None
         finally:
         finally:
             p._thread.recurse_protection = False
             p._thread.recurse_protection = False
 
 

+ 72 - 0
t/unit/app/test_registry.py

@@ -0,0 +1,72 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from celery.app.registry import _unpickle_task, _unpickle_task_v2
+
+
+def returns():
+    return 1
+
+
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_unpickle_task:
+
+    def test_unpickle_v1(self, app):
+        app.tasks['txfoo'] = 'bar'
+        assert _unpickle_task('txfoo') == 'bar'
+
+    def test_unpickle_v2(self, app):
+        app.tasks['txfoo1'] = 'bar1'
+        assert _unpickle_task_v2('txfoo1') == 'bar1'
+        assert _unpickle_task_v2('txfoo1', module='celery') == 'bar1'
+
+
+class test_TaskRegistry:
+
+    def setup(self):
+        self.mytask = self.app.task(name='A', shared=False)(returns)
+        self.myperiodic = self.app.task(
+            name='B', shared=False, type='periodic',
+        )(returns)
+
+    def test_NotRegistered_str(self):
+        assert repr(self.app.tasks.NotRegistered('tasks.add'))
+
+    def assert_register_unregister_cls(self, r, task):
+        r.unregister(task)
+        with pytest.raises(r.NotRegistered):
+            r.unregister(task)
+        r.register(task)
+        assert task.name in r
+
+    def test_task_registry(self):
+        r = self.app._tasks
+        assert isinstance(r, dict)
+
+        self.assert_register_unregister_cls(r, self.mytask)
+        self.assert_register_unregister_cls(r, self.myperiodic)
+
+        r.register(self.myperiodic)
+        r.unregister(self.myperiodic.name)
+        assert self.myperiodic not in r
+        r.register(self.myperiodic)
+
+        tasks = dict(r)
+        assert tasks.get(self.mytask.name) is self.mytask
+        assert tasks.get(self.myperiodic.name) is self.myperiodic
+
+        assert r[self.mytask.name] is self.mytask
+        assert r[self.myperiodic.name] is self.myperiodic
+
+        r.unregister(self.mytask)
+        assert self.mytask.name not in r
+        r.unregister(self.myperiodic)
+        assert self.myperiodic.name not in r
+
+        assert self.mytask.run()
+        assert self.myperiodic.run()
+
+    def test_compat(self):
+        assert self.app.tasks.regular()
+        assert self.app.tasks.periodic()

+ 30 - 37
celery/tests/app/test_routes.py → t/unit/app/test_routes.py

@@ -1,14 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import ANY, Mock
 from kombu import Exchange, Queue
 from kombu import Exchange, Queue
 from kombu.utils.functional import maybe_evaluate
 from kombu.utils.functional import maybe_evaluate
 
 
 from celery.app import routes
 from celery.app import routes
 from celery.exceptions import QueueNotFound
 from celery.exceptions import QueueNotFound
+from celery.five import items
 from celery.utils.imports import qualname
 from celery.utils.imports import qualname
 
 
-from celery.tests.case import ANY, AppCase, Mock
-
 
 
 def Router(app, *args, **kwargs):
 def Router(app, *args, **kwargs):
     return routes.Router(*args, app=app, **kwargs)
     return routes.Router(*args, app=app, **kwargs)
@@ -25,7 +27,7 @@ def set_queues(app, **queues):
     app.amqp.queues = app.amqp.Queues(queues)
     app.amqp.queues = app.amqp.Queues(queues)
 
 
 
 
-class RouteCase(AppCase):
+class RouteCase:
 
 
     def setup(self):
     def setup(self):
         self.a_queue = {
         self.a_queue = {
@@ -51,10 +53,7 @@ class RouteCase(AppCase):
 
 
     def assert_routes_to_queue(self, queue, router, name,
     def assert_routes_to_queue(self, queue, router, name,
                                args=[], kwargs={}, options={}):
                                args=[], kwargs={}, options={}):
-        self.assertEqual(
-            router.route(options, name, args, kwargs)['queue'].name,
-            queue,
-        )
+        assert router.route(options, name, args, kwargs)['queue'].name == queue
 
 
     def assert_routes_to_default_queue(self, router, name, *args, **kwargs):
     def assert_routes_to_default_queue(self, router, name, *args, **kwargs):
         self.assert_routes_to_queue(
         self.assert_routes_to_queue(
@@ -67,38 +66,32 @@ class test_MapRoute(RouteCase):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         expand = E(self.app, self.app.amqp.queues)
         expand = E(self.app, self.app.amqp.queues)
         route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
         route = routes.MapRoute({self.mytask.name: {'queue': 'foo'}})
-        self.assertEqual(
-            expand(route(self.mytask.name))['queue'].name,
-            'foo',
-        )
-        self.assertIsNone(route('celery.awesome'))
+        assert expand(route(self.mytask.name))['queue'].name == 'foo'
+        assert route('celery.awesome') is None
 
 
     def test_route_for_task(self):
     def test_route_for_task(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         expand = E(self.app, self.app.amqp.queues)
         expand = E(self.app, self.app.amqp.queues)
         route = routes.MapRoute({self.mytask.name: self.b_queue})
         route = routes.MapRoute({self.mytask.name: self.b_queue})
-        self.assertDictContainsSubset(
-            self.b_queue,
-            expand(route(self.mytask.name)),
-        )
-        self.assertIsNone(route('celery.awesome'))
+        eroute = expand(route(self.mytask.name))
+        for key, value in items(self.b_queue):
+            assert eroute[key] == value
+        assert route('celery.awesome') is None
 
 
     def test_route_for_task__glob(self):
     def test_route_for_task__glob(self):
         route = routes.MapRoute([
         route = routes.MapRoute([
             ('proj.tasks.*', 'routeA'),
             ('proj.tasks.*', 'routeA'),
             ('demoapp.tasks.bar.*', {'exchange': 'routeB'}),
             ('demoapp.tasks.bar.*', {'exchange': 'routeB'}),
         ])
         ])
-        self.assertDictEqual(route('proj.tasks.foo'), {'queue': 'routeA'})
-        self.assertDictEqual(route('demoapp.tasks.bar.moo'), {
-            'exchange': 'routeB',
-        })
-        self.assertIsNone(route('demoapp.foo.bar.moo'))
+        assert route('proj.tasks.foo') == {'queue': 'routeA'}
+        assert route('demoapp.tasks.bar.moo') == {'exchange': 'routeB'}
+        assert route('demoapp.foo.bar.moo') is None
 
 
     def test_expand_route_not_found(self):
     def test_expand_route_not_found(self):
         expand = E(self.app, self.app.amqp.Queues(
         expand = E(self.app, self.app.amqp.Queues(
                    self.app.conf.task_queues, False))
                    self.app.conf.task_queues, False))
         route = routes.MapRoute({'a': {'queue': 'x'}})
         route = routes.MapRoute({'a': {'queue': 'x'}})
-        with self.assertRaises(QueueNotFound):
+        with pytest.raises(QueueNotFound):
             expand(route('a'))
             expand(route('a'))
 
 
 
 
@@ -106,7 +99,7 @@ class test_lookup_route(RouteCase):
 
 
     def test_init_queues(self):
     def test_init_queues(self):
         router = Router(self.app, queues=None)
         router = Router(self.app, queues=None)
-        self.assertDictEqual(router.queues, {})
+        assert router.queues == {}
 
 
     def test_lookup_takes_first(self):
     def test_lookup_takes_first(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
@@ -131,22 +124,22 @@ class test_lookup_route(RouteCase):
             self.mytask.name,
             self.mytask.name,
             args=[1, 2], kwargs={},
             args=[1, 2], kwargs={},
         )
         )
-        self.assertEqual(route['queue'].name, 'testq')
-        self.assertEqual(route['queue'].exchange, Exchange('testq'))
-        self.assertEqual(route['queue'].routing_key, 'testq')
-        self.assertEqual(route['immediate'], False)
+        assert route['queue'].name == 'testq'
+        assert route['queue'].exchange == Exchange('testq')
+        assert route['queue'].routing_key == 'testq'
+        assert route['immediate'] is False
 
 
     def test_expand_destination_string(self):
     def test_expand_destination_string(self):
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         set_queues(self.app, foo=self.a_queue, bar=self.b_queue)
         x = Router(self.app, {}, self.app.amqp.queues)
         x = Router(self.app, {}, self.app.amqp.queues)
         dest = x.expand_destination('foo')
         dest = x.expand_destination('foo')
-        self.assertEqual(dest['queue'].name, 'foo')
+        assert dest['queue'].name == 'foo'
 
 
     def test_expand_destination__Queue(self):
     def test_expand_destination__Queue(self):
         queue = Queue('foo')
         queue = Queue('foo')
         x = Router(self.app, {}, self.app.amqp.queues)
         x = Router(self.app, {}, self.app.amqp.queues)
         dest = x.expand_destination({'queue': queue})
         dest = x.expand_destination({'queue': queue})
-        self.assertIs(dest['queue'], queue)
+        assert dest['queue'] is queue
 
 
     def test_lookup_paths_traversed(self):
     def test_lookup_paths_traversed(self):
         self.simple_queue_setup()
         self.simple_queue_setup()
@@ -179,7 +172,7 @@ class test_lookup_route(RouteCase):
             task=self.mytask,
             task=self.mytask,
         )
         )
         options = step.call_args[0][3]
         options = step.call_args[0][3]
-        self.assertEqual(options['priority'], 3)
+        assert options['priority'] == 3
 
 
     def test_compat_router_classes__called_with(self):
     def test_compat_router_classes__called_with(self):
         self.simple_queue_setup()
         self.simple_queue_setup()
@@ -205,7 +198,7 @@ class TestRouter(object):
             return 'bar'
             return 'bar'
 
 
 
 
-class test_prepare(AppCase):
+class test_prepare:
 
 
     def test_prepare(self):
     def test_prepare(self):
         o = object()
         o = object()
@@ -215,13 +208,13 @@ class test_prepare(AppCase):
             o,
             o,
         ]
         ]
         p = routes.prepare(R)
         p = routes.prepare(R)
-        self.assertIsInstance(p[0], routes.MapRoute)
-        self.assertIsInstance(maybe_evaluate(p[1]), TestRouter)
-        self.assertIs(p[2], o)
+        assert isinstance(p[0], routes.MapRoute)
+        assert isinstance(maybe_evaluate(p[1]), TestRouter)
+        assert p[2] is o
 
 
-        self.assertEqual(routes.prepare(o), [o])
+        assert routes.prepare(o) == [o]
 
 
     def test_prepare_item_is_dict(self):
     def test_prepare_item_is_dict(self):
         R = {'foo': 'bar'}
         R = {'foo': 'bar'}
         p = routes.prepare(R)
         p = routes.prepare(R)
-        self.assertIsInstance(p[0], routes.MapRoute)
+        assert isinstance(p[0], routes.MapRoute)

File diff suppressed because it is too large
+ 276 - 328
t/unit/app/test_schedules.py


+ 16 - 20
celery/tests/app/test_utils.py → t/unit/app/test_utils.py

@@ -2,45 +2,41 @@ from __future__ import absolute_import, unicode_literals
 
 
 from collections import Mapping, MutableMapping
 from collections import Mapping, MutableMapping
 
 
-from celery.app.utils import Settings, filter_hidden_settings, bugreport
+from case import Mock
 
 
-from celery.tests.case import AppCase, Mock
+from celery.app.utils import Settings, filter_hidden_settings, bugreport
 
 
 
 
-class test_Settings(AppCase):
+class test_Settings:
 
 
     def test_is_mapping(self):
     def test_is_mapping(self):
         """Settings should be a collections.Mapping"""
         """Settings should be a collections.Mapping"""
-        self.assertTrue(issubclass(Settings, Mapping))
+        assert issubclass(Settings, Mapping)
 
 
     def test_is_mutable_mapping(self):
     def test_is_mutable_mapping(self):
         """Settings should be a collections.MutableMapping"""
         """Settings should be a collections.MutableMapping"""
-        self.assertTrue(issubclass(Settings, MutableMapping))
+        assert issubclass(Settings, MutableMapping)
 
 
     def test_find(self):
     def test_find(self):
-        self.assertTrue(self.app.conf.find_option('always_eager'))
+        assert self.app.conf.find_option('always_eager')
 
 
     def test_get_by_parts(self):
     def test_get_by_parts(self):
         self.app.conf.task_do_this_and_that = 303
         self.app.conf.task_do_this_and_that = 303
-        self.assertEqual(
-            self.app.conf.get_by_parts('task', 'do', 'this', 'and', 'that'),
-            303,
-        )
+        assert self.app.conf.get_by_parts(
+            'task', 'do', 'this', 'and', 'that') == 303
 
 
     def test_find_value_for_key(self):
     def test_find_value_for_key(self):
-        self.assertEqual(
-            self.app.conf.find_value_for_key('always_eager'),
-            False,
-        )
+        assert self.app.conf.find_value_for_key(
+            'always_eager') is False
 
 
     def test_table(self):
     def test_table(self):
-        self.assertTrue(self.app.conf.table(with_defaults=True))
-        self.assertTrue(self.app.conf.table(with_defaults=False))
-        self.assertTrue(self.app.conf.table(censored=False))
-        self.assertTrue(self.app.conf.table(censored=True))
+        assert self.app.conf.table(with_defaults=True)
+        assert self.app.conf.table(with_defaults=False)
+        assert self.app.conf.table(censored=False)
+        assert self.app.conf.table(censored=True)
 
 
 
 
-class test_filter_hidden_settings(AppCase):
+class test_filter_hidden_settings:
 
 
     def test_handles_non_string_keys(self):
     def test_handles_non_string_keys(self):
         """filter_hidden_settings shouldn't raise an exception when handling
         """filter_hidden_settings shouldn't raise an exception when handling
@@ -56,7 +52,7 @@ class test_filter_hidden_settings(AppCase):
         filter_hidden_settings(conf)
         filter_hidden_settings(conf)
 
 
 
 
-class test_bugreport(AppCase):
+class test_bugreport:
 
 
     def test_no_conn_driver_info(self):
     def test_no_conn_driver_info(self):
         self.app.connection = Mock()
         self.app.connection = Mock()

+ 0 - 0
celery/tests/bin/__init__.py → t/unit/apps/__init__.py


+ 96 - 124
celery/tests/apps/test_multi.py → t/unit/apps/test_multi.py

@@ -1,59 +1,61 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import errno
 import errno
+import pytest
 import signal
 import signal
 import sys
 import sys
 
 
+from case import Mock, call, patch, skip
+
 from celery.apps.multi import (
 from celery.apps.multi import (
     Cluster, MultiParser, NamespacedOptionParser, Node, format_opt,
     Cluster, MultiParser, NamespacedOptionParser, Node, format_opt,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, call, patch, skip
-
 
 
-class test_functions(AppCase):
+class test_functions:
 
 
     def test_parse_ns_range(self):
     def test_parse_ns_range(self):
         m = MultiParser()
         m = MultiParser()
-        self.assertEqual(m._parse_ns_range('1-3', True), ['1', '2', '3'])
-        self.assertEqual(m._parse_ns_range('1-3', False), ['1-3'])
-        self.assertEqual(m._parse_ns_range(
-            '1-3,10,11,20', True),
-            ['1', '2', '3', '10', '11', '20'],
-        )
+        assert m._parse_ns_range('1-3', True), ['1', '2' == '3']
+        assert m._parse_ns_range('1-3', False) == ['1-3']
+        assert m._parse_ns_range('1-3,10,11,20', True) == [
+            '1', '2', '3', '10', '11', '20',
+        ]
 
 
     def test_format_opt(self):
     def test_format_opt(self):
-        self.assertEqual(format_opt('--foo', None), '--foo')
-        self.assertEqual(format_opt('-c', 1), '-c 1')
-        self.assertEqual(format_opt('--log', 'foo'), '--log=foo')
+        assert format_opt('--foo', None) == '--foo'
+        assert format_opt('-c', 1) == '-c 1'
+        assert format_opt('--log', 'foo') == '--log=foo'
 
 
 
 
-class test_NamespacedOptionParser(AppCase):
+class test_NamespacedOptionParser:
 
 
     def test_parse(self):
     def test_parse(self):
         x = NamespacedOptionParser(['-c:1,3', '4'])
         x = NamespacedOptionParser(['-c:1,3', '4'])
         x.parse()
         x.parse()
-        self.assertEqual(x.namespaces.get('1,3'), {'-c': '4'})
+        assert x.namespaces.get('1,3') == {'-c': '4'}
         x = NamespacedOptionParser(['-c:jerry,elaine', '5',
         x = NamespacedOptionParser(['-c:jerry,elaine', '5',
                                     '--loglevel:kramer=DEBUG',
                                     '--loglevel:kramer=DEBUG',
                                     '--flag',
                                     '--flag',
                                     '--logfile=foo', '-Q', 'bar', 'a', 'b',
                                     '--logfile=foo', '-Q', 'bar', 'a', 'b',
                                     '--', '.disable_rate_limits=1'])
                                     '--', '.disable_rate_limits=1'])
         x.parse()
         x.parse()
-        self.assertEqual(x.options, {'--logfile': 'foo',
-                                     '-Q': 'bar',
-                                     '--flag': None})
-        self.assertEqual(x.values, ['a', 'b'])
-        self.assertEqual(x.namespaces.get('jerry,elaine'), {'-c': '5'})
-        self.assertEqual(x.namespaces.get('kramer'), {'--loglevel': 'DEBUG'})
-        self.assertEqual(x.passthrough, '-- .disable_rate_limits=1')
+        assert x.options == {
+            '--logfile': 'foo',
+            '-Q': 'bar',
+            '--flag': None,
+        }
+        assert x.values, ['a' == 'b']
+        assert x.namespaces.get('jerry,elaine') == {'-c': '5'}
+        assert x.namespaces.get('kramer') == {'--loglevel': 'DEBUG'}
+        assert x.passthrough == '-- .disable_rate_limits=1'
 
 
 
 
 def multi_args(p, *args, **kwargs):
 def multi_args(p, *args, **kwargs):
     return MultiParser(*args, **kwargs).parse(p)
     return MultiParser(*args, **kwargs).parse(p)
 
 
 
 
-class test_multi_args(AppCase):
+class test_multi_args:
 
 
     @patch('celery.apps.multi.gethostname')
     @patch('celery.apps.multi.gethostname')
     def test_parse(self, gethostname):
     def test_parse(self, gethostname):
@@ -72,14 +74,14 @@ class test_multi_args(AppCase):
         nodes = list(it)
         nodes = list(it)
 
 
         def assert_line_in(name, args):
         def assert_line_in(name, args):
-            self.assertIn(name, {n.name for n in nodes})
+            assert name in {n.name for n in nodes}
             argv = None
             argv = None
             for node in nodes:
             for node in nodes:
                 if node.name == name:
                 if node.name == name:
                     argv = node.argv
                     argv = node.argv
-            self.assertTrue(argv)
+            assert argv
             for arg in args:
             for arg in args:
-                self.assertIn(arg, argv)
+                assert arg in argv
 
 
         assert_line_in(
         assert_line_in(
             '*P*jerry@*S*',
             '*P*jerry@*S*',
@@ -100,11 +102,11 @@ class test_multi_args(AppCase):
              '-- .disable_rate_limits=1', '*AP*'],
              '-- .disable_rate_limits=1', '*AP*'],
         )
         )
         expand = nodes[0].expander
         expand = nodes[0].expander
-        self.assertEqual(expand('%h'), '*P*jerry@*S*')
-        self.assertEqual(expand('%n'), '*P*jerry')
+        assert expand('%h') == '*P*jerry@*S*'
+        assert expand('%n') == '*P*jerry'
         nodes2 = list(multi_args(p, cmd='COMMAND', append='',
         nodes2 = list(multi_args(p, cmd='COMMAND', append='',
                       prefix='*P*', suffix='*S*'))
                       prefix='*P*', suffix='*S*'))
-        self.assertEqual(nodes2[0].argv[-1], '-- .disable_rate_limits=1')
+        assert nodes2[0].argv[-1] == '-- .disable_rate_limits=1'
 
 
         p2 = NamespacedOptionParser(['10', '-c:1', '5'])
         p2 = NamespacedOptionParser(['10', '-c:1', '5'])
         p2.parse()
         p2.parse()
@@ -118,58 +120,47 @@ class test_multi_args(AppCase):
                 '',
                 '',
             )
             )
 
 
-        self.assertEqual(len(nodes3), 10)
-        self.assertEqual(nodes3[0].name, 'celery1@example.com')
-        self.assertTupleEqual(
-            nodes3[0].argv,
-            ('COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1'),
-        )
+        assert len(nodes3) == 10
+        assert nodes3[0].name == 'celery1@example.com'
+        assert nodes3[0].argv == (
+            'COMMAND', '-c 5', '-n celery1@example.com') + _args('celery1')
         for i, worker in enumerate(nodes3[1:]):
         for i, worker in enumerate(nodes3[1:]):
-            self.assertEqual(worker.name, 'celery%s@example.com' % (i + 2))
-            self.assertTupleEqual(
-                worker.argv,
-                (('COMMAND', '-n celery%s@example.com' % (i + 2)) +
-                 _args('celery%s' % (i + 2))),
-            )
+            assert worker.name == 'celery%s@example.com' % (i + 2)
+            node_i = 'celery%s' % (i + 2,)
+            assert worker.argv == (
+                'COMMAND',
+                '-n %s@example.com' % (node_i,)) + _args(node_i)
 
 
         nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
         nodes4 = list(multi_args(p2, cmd='COMMAND', suffix='""'))
-        self.assertEqual(len(nodes4), 10)
-        self.assertEqual(nodes4[0].name, 'celery1@')
-        self.assertTupleEqual(
-            nodes4[0].argv,
-            ('COMMAND', '-c 5', '-n celery1@') + _args('celery1'),
-        )
+        assert len(nodes4) == 10
+        assert nodes4[0].name == 'celery1@'
+        assert nodes4[0].argv == (
+            'COMMAND', '-c 5', '-n celery1@') + _args('celery1')
 
 
         p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
         p3 = NamespacedOptionParser(['foo@', '-c:foo', '5'])
         p3.parse()
         p3.parse()
         nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
         nodes5 = list(multi_args(p3, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes5[0].name, 'foo@')
-        self.assertTupleEqual(
-            nodes5[0].argv,
-            ('COMMAND', '-c 5', '-n foo@') + _args('foo'),
-        )
+        assert nodes5[0].name == 'foo@'
+        assert nodes5[0].argv == (
+            'COMMAND', '-c 5', '-n foo@') + _args('foo')
 
 
         p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
         p4 = NamespacedOptionParser(['foo', '-Q:1', 'test'])
         p4.parse()
         p4.parse()
         nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
         nodes6 = list(multi_args(p4, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes6[0].name, 'foo@')
-        self.assertTupleEqual(
-            nodes6[0].argv,
-            ('COMMAND', '-Q test', '-n foo@') + _args('foo'),
-        )
+        assert nodes6[0].name == 'foo@'
+        assert nodes6[0].argv == (
+            'COMMAND', '-Q test', '-n foo@') + _args('foo')
 
 
         p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
         p5 = NamespacedOptionParser(['foo@bar', '-Q:1', 'test'])
         p5.parse()
         p5.parse()
         nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
         nodes7 = list(multi_args(p5, cmd='COMMAND', suffix='""'))
-        self.assertEqual(nodes7[0].name, 'foo@bar')
-        self.assertTupleEqual(
-            nodes7[0].argv,
-            ('COMMAND', '-Q test', '-n foo@bar') + _args('foo'),
-        )
+        assert nodes7[0].name == 'foo@bar'
+        assert nodes7[0].argv == (
+            'COMMAND', '-Q test', '-n foo@bar') + _args('foo')
 
 
         p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
         p6 = NamespacedOptionParser(['foo@bar', '-Q:0', 'test'])
         p6.parse()
         p6.parse()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             list(multi_args(p6))
             list(multi_args(p6))
 
 
     def test_optmerge(self):
     def test_optmerge(self):
@@ -177,10 +168,10 @@ class test_multi_args(AppCase):
         p.parse()
         p.parse()
         p.options = {'x': 'y'}
         p.options = {'x': 'y'}
         r = p.optmerge('foo')
         r = p.optmerge('foo')
-        self.assertEqual(r['x'], 'y')
+        assert r['x'] == 'y'
 
 
 
 
-class test_Node(AppCase):
+class test_Node:
 
 
     def setup(self):
     def setup(self):
         self.p = Mock(name='p')
         self.p = Mock(name='p')
@@ -198,7 +189,7 @@ class test_Node(AppCase):
             'foo@bar.com',
             'foo@bar.com',
             max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
             max_tasks_per_child=30, A='foo', Q='q1,q2', O='fair',
         )
         )
-        self.assertItemsEqual(n.argv, (
+        assert sorted(n.argv) == sorted([
             '-m celery worker --detach',
             '-m celery worker --detach',
             '-A foo',
             '-A foo',
             '--executable={0}'.format(n.executable),
             '--executable={0}'.format(n.executable),
@@ -209,31 +200,31 @@ class test_Node(AppCase):
             '--max-tasks-per-child=30',
             '--max-tasks-per-child=30',
             '--pidfile=foo.pid',
             '--pidfile=foo.pid',
             '',
             '',
-        ))
+        ])
 
 
     @patch('os.kill')
     @patch('os.kill')
     def test_send(self, kill):
     def test_send(self, kill):
-        self.assertTrue(self.node.send(9))
+        assert self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
         kill.assert_called_with(self.node.pid, 9)
 
 
     @patch('os.kill')
     @patch('os.kill')
     def test_send__ESRCH(self, kill):
     def test_send__ESRCH(self, kill):
         kill.side_effect = OSError()
         kill.side_effect = OSError()
         kill.side_effect.errno = errno.ESRCH
         kill.side_effect.errno = errno.ESRCH
-        self.assertFalse(self.node.send(9))
+        assert not self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
         kill.assert_called_with(self.node.pid, 9)
 
 
     @patch('os.kill')
     @patch('os.kill')
     def test_send__error(self, kill):
     def test_send__error(self, kill):
         kill.side_effect = OSError()
         kill.side_effect = OSError()
         kill.side_effect.errno = errno.ENOENT
         kill.side_effect.errno = errno.ENOENT
-        with self.assertRaises(OSError):
+        with pytest.raises(OSError):
             self.node.send(9)
             self.node.send(9)
         kill.assert_called_with(self.node.pid, 9)
         kill.assert_called_with(self.node.pid, 9)
 
 
     def test_alive(self):
     def test_alive(self):
         self.node.send = Mock(name='send')
         self.node.send = Mock(name='send')
-        self.assertIs(self.node.alive(), self.node.send.return_value)
+        assert self.node.alive() is self.node.send.return_value
         self.node.send.assert_called_with(0)
         self.node.send.assert_called_with(0)
 
 
     def test_start(self):
     def test_start(self):
@@ -270,40 +261,32 @@ class test_Node(AppCase):
         )
         )
 
 
     def test_handle_process_exit(self):
     def test_handle_process_exit(self):
-        self.assertEqual(
-            self.node.handle_process_exit(0),
-            0,
-        )
+        assert self.node.handle_process_exit(0) == 0
 
 
     def test_handle_process_exit__failure(self):
     def test_handle_process_exit__failure(self):
         on_failure = Mock(name='on_failure')
         on_failure = Mock(name='on_failure')
-        self.assertEqual(
-            self.node.handle_process_exit(9, on_failure=on_failure),
-            9,
-        )
+        assert self.node.handle_process_exit(9, on_failure=on_failure) == 9
         on_failure.assert_called_with(self.node, 9)
         on_failure.assert_called_with(self.node, 9)
 
 
     def test_handle_process_exit__signalled(self):
     def test_handle_process_exit__signalled(self):
         on_signalled = Mock(name='on_signalled')
         on_signalled = Mock(name='on_signalled')
-        self.assertEqual(
-            self.node.handle_process_exit(-9, on_signalled=on_signalled),
-            9,
-        )
+        assert self.node.handle_process_exit(
+            -9, on_signalled=on_signalled) == 9
         on_signalled.assert_called_with(self.node, 9)
         on_signalled.assert_called_with(self.node, 9)
 
 
     def test_logfile(self):
     def test_logfile(self):
-        self.assertEqual(self.node.logfile, self.expander.return_value)
+        assert self.node.logfile == self.expander.return_value
         self.expander.assert_called_with('%n%I.log')
         self.expander.assert_called_with('%n%I.log')
 
 
 
 
-class test_Cluster(AppCase):
+class test_Cluster:
 
 
     def setup(self):
     def setup(self):
-        self.Popen = self.patch('celery.apps.multi.Popen')
-        self.kill = self.patch('os.kill')
-        self.gethostname = self.patch('celery.apps.multi.gethostname')
+        self.Popen = self.patching('celery.apps.multi.Popen')
+        self.kill = self.patching('os.kill')
+        self.gethostname = self.patching('celery.apps.multi.gethostname')
         self.gethostname.return_value = 'example.com'
         self.gethostname.return_value = 'example.com'
-        self.Pidfile = self.patch('celery.apps.multi.Pidfile')
+        self.Pidfile = self.patching('celery.apps.multi.Pidfile')
         self.cluster = Cluster(
         self.cluster = Cluster(
             [Node('foo@example.com'),
             [Node('foo@example.com'),
              Node('bar@example.com'),
              Node('bar@example.com'),
@@ -326,10 +309,10 @@ class test_Cluster(AppCase):
         )
         )
 
 
     def test_len(self):
     def test_len(self):
-        self.assertEqual(len(self.cluster), 3)
+        assert len(self.cluster) == 3
 
 
     def test_getitem(self):
     def test_getitem(self):
-        self.assertEqual(self.cluster[0].name, 'foo@example.com')
+        assert self.cluster[0].name == 'foo@example.com'
 
 
     def test_start(self):
     def test_start(self):
         self.cluster.start_node = Mock(name='start_node')
         self.cluster.start_node = Mock(name='start_node')
@@ -341,10 +324,8 @@ class test_Cluster(AppCase):
     def test_start_node(self):
     def test_start_node(self):
         self.cluster._start_node = Mock(name='_start_node')
         self.cluster._start_node = Mock(name='_start_node')
         node = self.cluster[0]
         node = self.cluster[0]
-        self.assertIs(
-            self.cluster.start_node(node),
-            self.cluster._start_node.return_value,
-        )
+        assert (self.cluster.start_node(node) is
+                self.cluster._start_node.return_value)
         self.cluster.on_node_start.assert_called_with(node)
         self.cluster.on_node_start.assert_called_with(node)
         self.cluster._start_node.assert_called_with(node)
         self.cluster._start_node.assert_called_with(node)
         self.cluster.on_node_status.assert_called_with(
         self.cluster.on_node_status.assert_called_with(
@@ -354,10 +335,7 @@ class test_Cluster(AppCase):
     def test__start_node(self):
     def test__start_node(self):
         node = self.cluster[0]
         node = self.cluster[0]
         node.start = Mock(name='node.start')
         node.start = Mock(name='node.start')
-        self.assertIs(
-            self.cluster._start_node(node),
-            node.start.return_value,
-        )
+        assert self.cluster._start_node(node) is node.start.return_value
         node.start.assert_called_with(
         node.start.assert_called_with(
             self.cluster.env,
             self.cluster.env,
             on_spawn=self.cluster.on_child_spawn,
             on_spawn=self.cluster.on_child_spawn,
@@ -394,33 +372,27 @@ class test_Cluster(AppCase):
         ])
         ])
         nodes = p.getpids(on_down=callback)
         nodes = p.getpids(on_down=callback)
         node_0, node_1 = nodes
         node_0, node_1 = nodes
-        self.assertEqual(node_0.name, 'foo@e.com')
-        self.assertEqual(
-            sorted(node_0.argv),
-            sorted([
-                '',
-                '--executable={0}'.format(node_0.executable),
-                '--logfile=foo%I.log',
-                '--pidfile=foo.pid',
-                '-m celery worker --detach',
-                '-n foo@e.com',
-            ]),
-        )
-        self.assertEqual(node_0.pid, 10)
+        assert node_0.name == 'foo@e.com'
+        assert sorted(node_0.argv) == sorted([
+            '',
+            '--executable={0}'.format(node_0.executable),
+            '--logfile=foo%I.log',
+            '--pidfile=foo.pid',
+            '-m celery worker --detach',
+            '-n foo@e.com',
+        ])
+        assert node_0.pid == 10
 
 
-        self.assertEqual(node_1.name, 'bar@e.com')
-        self.assertEqual(
-            sorted(node_1.argv),
-            sorted([
-                '',
-                '--executable={0}'.format(node_1.executable),
-                '--logfile=bar%I.log',
-                '--pidfile=bar.pid',
-                '-m celery worker --detach',
-                '-n bar@e.com',
-            ]),
-        )
-        self.assertEqual(node_1.pid, 11)
+        assert node_1.name == 'bar@e.com'
+        assert sorted(node_1.argv) == sorted([
+            '',
+            '--executable={0}'.format(node_1.executable),
+            '--logfile=bar%I.log',
+            '--pidfile=bar.pid',
+            '-m celery worker --detach',
+            '-n bar@e.com',
+        ])
+        assert node_1.pid == 11
 
 
         # without callback, should work
         # without callback, should work
         nodes = p.getpids('celery worker')
         nodes = p.getpids('celery worker')

+ 0 - 0
celery/tests/compat_modules/__init__.py → t/unit/backends/__init__.py


+ 40 - 47
celery/tests/backends/test_amqp.py → t/unit/backends/test_amqp.py

@@ -1,11 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import pickle
 import pickle
+import pytest
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 from datetime import timedelta
 from datetime import timedelta
 from pickle import dumps, loads
 from pickle import dumps, loads
 
 
+from case import Mock, mock
 from billiard.einfo import ExceptionInfo
 from billiard.einfo import ExceptionInfo
 
 
 from celery import states
 from celery import states
@@ -14,8 +16,6 @@ from celery.backends.amqp import AMQPBackend
 from celery.five import Empty, Queue, range
 from celery.five import Empty, Queue, range
 from celery.result import AsyncResult
 from celery.result import AsyncResult
 
 
-from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
-
 
 
 class SomeClass(object):
 class SomeClass(object):
 
 
@@ -23,7 +23,7 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
-class test_AMQPBackend(AppCase):
+class test_AMQPBackend:
 
 
     def setup(self):
     def setup(self):
         self.app.conf.result_cache_max = 100
         self.app.conf.result_cache_max = 100
@@ -35,9 +35,8 @@ class test_AMQPBackend(AppCase):
     def test_destination_for(self):
     def test_destination_for(self):
         b = self.create_backend()
         b = self.create_backend()
         request = Mock()
         request = Mock()
-        self.assertTupleEqual(
-            b.destination_for('id', request),
-            (b.rkey('id'), request.correlation_id),
+        assert b.destination_for('id', request) == (
+            b.rkey('id'), request.correlation_id,
         )
         )
 
 
     def test_store_result__no_routing_key(self):
     def test_store_result__no_routing_key(self):
@@ -53,14 +52,14 @@ class test_AMQPBackend(AppCase):
         tid = uuid()
         tid = uuid()
 
 
         tb1.mark_as_done(tid, 42)
         tb1.mark_as_done(tid, 42)
-        self.assertEqual(tb2.get_state(tid), states.SUCCESS)
-        self.assertEqual(tb2.get_result(tid), 42)
-        self.assertTrue(tb2._cache.get(tid))
-        self.assertTrue(tb2.get_result(tid), 42)
+        assert tb2.get_state(tid) == states.SUCCESS
+        assert tb2.get_result(tid) == 42
+        assert tb2._cache.get(tid)
+        assert tb2.get_result(tid), 42
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_pickleable(self):
     def test_pickleable(self):
-        self.assertTrue(loads(dumps(self.create_backend())))
+        assert loads(dumps(self.create_backend()))
 
 
     def test_revive(self):
     def test_revive(self):
         tb = self.create_backend()
         tb = self.create_backend()
@@ -75,8 +74,8 @@ class test_AMQPBackend(AppCase):
         tb1.mark_as_done(tid2, result)
         tb1.mark_as_done(tid2, result)
         # is serialized properly.
         # is serialized properly.
         rindb = tb2.get_result(tid2)
         rindb = tb2.get_result(tid2)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
         tb1 = self.create_backend()
         tb1 = self.create_backend()
@@ -88,27 +87,27 @@ class test_AMQPBackend(AppCase):
         except KeyError as exception:
         except KeyError as exception:
             einfo = ExceptionInfo()
             einfo = ExceptionInfo()
             tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
             tb1.mark_as_failure(tid3, exception, traceback=einfo.traceback)
-            self.assertEqual(tb2.get_state(tid3), states.FAILURE)
-            self.assertIsInstance(tb2.get_result(tid3), KeyError)
-            self.assertEqual(tb2.get_traceback(tid3), einfo.traceback)
+            assert tb2.get_state(tid3) == states.FAILURE
+            assert isinstance(tb2.get_result(tid3), KeyError)
+            assert tb2.get_traceback(tid3) == einfo.traceback
 
 
     def test_repair_uuid(self):
     def test_repair_uuid(self):
         from celery.backends.amqp import repair_uuid
         from celery.backends.amqp import repair_uuid
         for i in range(10):
         for i in range(10):
             tid = uuid()
             tid = uuid()
-            self.assertEqual(repair_uuid(tid.replace('-', '')), tid)
+            assert repair_uuid(tid.replace('-', '')) == tid
 
 
     def test_expires_is_int(self):
     def test_expires_is_int(self):
         b = self.create_backend(expires=48)
         b = self.create_backend(expires=48)
-        self.assertEqual(b.queue_arguments.get('x-expires'), 48 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 48 * 1000.0
 
 
     def test_expires_is_float(self):
     def test_expires_is_float(self):
         b = self.create_backend(expires=48.3)
         b = self.create_backend(expires=48.3)
-        self.assertEqual(b.queue_arguments.get('x-expires'), 48.3 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 48.3 * 1000.0
 
 
     def test_expires_is_timedelta(self):
     def test_expires_is_timedelta(self):
         b = self.create_backend(expires=timedelta(minutes=1))
         b = self.create_backend(expires=timedelta(minutes=1))
-        self.assertEqual(b.queue_arguments.get('x-expires'), 60 * 1000.0)
+        assert b.queue_arguments.get('x-expires') == 60 * 1000.0
 
 
     @mock.sleepdeprived()
     @mock.sleepdeprived()
     def test_store_result_retries(self):
     def test_store_result_retries(self):
@@ -125,22 +124,19 @@ class test_AMQPBackend(AppCase):
         from celery.app.amqp import Producer
         from celery.app.amqp import Producer
         prod, Producer.publish = Producer.publish, publish
         prod, Producer.publish = Producer.publish, publish
         try:
         try:
-            with self.assertRaises(KeyError):
+            with pytest.raises(KeyError):
                 backend.retry_policy['max_retries'] = None
                 backend.retry_policy['max_retries'] = None
                 backend.store_result('foo', 'bar', 'STARTED')
                 backend.store_result('foo', 'bar', 'STARTED')
 
 
-            with self.assertRaises(KeyError):
+            with pytest.raises(KeyError):
                 backend.retry_policy['max_retries'] = 10
                 backend.retry_policy['max_retries'] = 10
                 backend.store_result('foo', 'bar', 'STARTED')
                 backend.store_result('foo', 'bar', 'STARTED')
         finally:
         finally:
             Producer.publish = prod
             Producer.publish = prod
 
 
-    def assertState(self, retval, state):
-        self.assertEqual(retval['status'], state)
-
     def test_poll_no_messages(self):
     def test_poll_no_messages(self):
         b = self.create_backend()
         b = self.create_backend()
-        self.assertState(b.get_task_meta(uuid()), states.PENDING)
+        assert b.get_task_meta(uuid())['status'] == states.PENDING
 
 
     @contextmanager
     @contextmanager
     def _result_context(self):
     def _result_context(self):
@@ -199,7 +195,7 @@ class test_AMQPBackend(AppCase):
         with self._result_context() as (results, backend, Message):
         with self._result_context() as (results, backend, Message):
             for i in range(1001):
             for i in range(1001):
                 results.put(Message(task_id='id', status=states.RECEIVED))
                 results.put(Message(task_id='id', status=states.RECEIVED))
-            with self.assertRaises(backend.BacklogLimitExceeded):
+            with pytest.raises(backend.BacklogLimitExceeded):
                 backend.get_task_meta('id')
                 backend.get_task_meta('id')
 
 
     def test_poll_result(self):
     def test_poll_result(self):
@@ -214,27 +210,24 @@ class test_AMQPBackend(AppCase):
             for state_message in state_messages:
             for state_message in state_messages:
                 results.put(state_message)
                 results.put(state_message)
             r1 = backend.get_task_meta(tid)
             r1 = backend.get_task_meta(tid)
-            self.assertDictContainsSubset(
-                {'status': states.FAILURE, 'seq': 3}, r1,
-                'FFWDs to the last state',
-            )
+            # FFWDs to the last state.
+            assert r1['status'] == states.FAILURE
+            assert r1['seq'] == 3
 
 
             # Caches last known state.
             # Caches last known state.
             tid = uuid()
             tid = uuid()
             results.put(Message(task_id=tid))
             results.put(Message(task_id=tid))
             backend.get_task_meta(tid)
             backend.get_task_meta(tid)
-            self.assertIn(tid, backend._cache, 'Caches last known state')
+            assert tid, backend._cache in 'Caches last known state'
 
 
-            self.assertTrue(state_messages[-1].requeued)
+            assert state_messages[-1].requeued
 
 
             # Returns cache if no new states.
             # Returns cache if no new states.
             results.queue.clear()
             results.queue.clear()
             assert not results.qsize()
             assert not results.qsize()
             backend._cache[tid] = 'hello'
             backend._cache[tid] = 'hello'
-            self.assertEqual(
-                backend.get_task_meta(tid), 'hello',
-                'Returns cache if no new states',
-            )
+            # returns cache if no new states.
+            assert backend.get_task_meta(tid) == 'hello'
 
 
     def test_drain_events_decodes_exceptions_in_meta(self):
     def test_drain_events_decodes_exceptions_in_meta(self):
         tid = uuid()
         tid = uuid()
@@ -242,39 +235,39 @@ class test_AMQPBackend(AppCase):
         b.store_result(tid, RuntimeError('aap'), states.FAILURE)
         b.store_result(tid, RuntimeError('aap'), states.FAILURE)
         result = AsyncResult(tid, backend=b)
         result = AsyncResult(tid, backend=b)
 
 
-        with self.assertRaises(Exception) as cm:
+        with pytest.raises(Exception) as excinfo:
             result.get()
             result.get()
 
 
-        self.assertEqual(cm.exception.__class__.__name__, 'RuntimeError')
-        self.assertEqual(str(cm.exception), 'aap')
+        assert excinfo.value.__class__.__name__ == 'RuntimeError'
+        assert str(excinfo.value) == 'aap'
 
 
     def test_no_expires(self):
     def test_no_expires(self):
         b = self.create_backend(expires=None)
         b = self.create_backend(expires=None)
         app = self.app
         app = self.app
         app.conf.result_expires = None
         app.conf.result_expires = None
         b = self.create_backend(expires=None)
         b = self.create_backend(expires=None)
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             b.queue_arguments['x-expires']
             b.queue_arguments['x-expires']
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         self.create_backend().process_cleanup()
         self.create_backend().process_cleanup()
 
 
     def test_reload_task_result(self):
     def test_reload_task_result(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().reload_task_result('x')
             self.create_backend().reload_task_result('x')
 
 
     def test_reload_group_result(self):
     def test_reload_group_result(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().reload_group_result('x')
             self.create_backend().reload_group_result('x')
 
 
     def test_save_group(self):
     def test_save_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().save_group('x', 'x')
             self.create_backend().save_group('x', 'x')
 
 
     def test_restore_group(self):
     def test_restore_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().restore_group('x')
             self.create_backend().restore_group('x')
 
 
     def test_delete_group(self):
     def test_delete_group(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.create_backend().delete_group('x')
             self.create_backend().delete_group('x')

+ 41 - 0
t/unit/backends/test_backends.py

@@ -0,0 +1,41 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from case import patch
+
+from celery import backends
+from celery.backends.amqp import AMQPBackend
+from celery.backends.cache import CacheBackend
+from celery.exceptions import ImproperlyConfigured
+
+
+class test_backends:
+
+    @pytest.mark.parametrize('url,expect_cls', [
+        ('amqp://', AMQPBackend),
+        ('cache+memory://', CacheBackend),
+    ])
+    def test_get_backend_aliases(self, url, expect_cls, app):
+        backend, url = backends.get_backend_by_url(url, app.loader)
+        assert isinstance(backend(app=app, url=url), expect_cls)
+
+    def test_unknown_backend(self, app):
+        with pytest.raises(ImportError):
+            backends.get_backend_cls('fasodaopjeqijwqe', app.loader)
+
+    @pytest.mark.usefixtures('depends_on_current_app')
+    def test_default_backend(self, app):
+        assert backends.default_backend == app.backend
+
+    def test_backend_by_url(self, app, url='redis://localhost/1'):
+        from celery.backends.redis import RedisBackend
+        backend, url_ = backends.get_backend_by_url(url, app.loader)
+        assert backend is RedisBackend
+        assert url_ == url
+
+    def test_sym_raises_ValuError(self, app):
+        with patch('celery.backends.symbol_by_name') as sbn:
+            sbn.side_effect = ValueError()
+            with pytest.raises(ImproperlyConfigured):
+                backends.get_backend_cls('xxx.xxx:foo', app.loader)

+ 106 - 111
celery/tests/backends/test_base.py → t/unit/backends/test_base.py

@@ -1,17 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
 import sys
 import sys
 import types
 import types
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
-from celery.exceptions import ChordError, TimeoutError
-from celery.five import items, bytes_if_py2, range
-from celery.utils import serialization
-from celery.utils.serialization import subclass_exception
-from celery.utils.serialization import find_pickleable_exception as fnpe
-from celery.utils.serialization import UnpickleableExceptionWrapper
-from celery.utils.serialization import get_pickleable_exception as gpe
+from case import ANY, Mock, call, patch, skip
 
 
 from celery import states
 from celery import states
 from celery import group, uuid
 from celery import group, uuid
@@ -21,10 +16,15 @@ from celery.backends.base import (
     DisabledBackend,
     DisabledBackend,
     _nulldict,
     _nulldict,
 )
 )
+from celery.exceptions import ChordError, TimeoutError
+from celery.five import items, bytes_if_py2, range
 from celery.result import result_from_tuple
 from celery.result import result_from_tuple
+from celery.utils import serialization
 from celery.utils.functional import pass1
 from celery.utils.functional import pass1
-
-from celery.tests.case import ANY, AppCase, Case, Mock, call, patch, skip
+from celery.utils.serialization import subclass_exception
+from celery.utils.serialization import find_pickleable_exception as fnpe
+from celery.utils.serialization import UnpickleableExceptionWrapper
+from celery.utils.serialization import get_pickleable_exception as gpe
 
 
 
 
 class wrapobject(object):
 class wrapobject(object):
@@ -47,7 +47,7 @@ Lookalike = subclass_exception(
 )
 )
 
 
 
 
-class test_nulldict(Case):
+class test_nulldict:
 
 
     def test_nulldict(self):
     def test_nulldict(self):
         x = _nulldict()
         x = _nulldict()
@@ -56,25 +56,24 @@ class test_nulldict(Case):
         x.setdefault('foo', 3)
         x.setdefault('foo', 3)
 
 
 
 
-class test_serialization(AppCase):
+class test_serialization:
 
 
     def test_create_exception_cls(self):
     def test_create_exception_cls(self):
-        self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
-        self.assertTrue(serialization.create_exception_cls('FooError', 'm',
-                                                           KeyError))
+        assert serialization.create_exception_cls('FooError', 'm')
+        assert serialization.create_exception_cls('FooError', 'm', KeyError)
 
 
 
 
-class test_BaseBackend_interface(AppCase):
+class test_BaseBackend_interface:
 
 
     def setup(self):
     def setup(self):
         self.b = BaseBackend(self.app)
         self.b = BaseBackend(self.app)
 
 
     def test__forget(self):
     def test__forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.b._forget('SOMExx-N0Nex1stant-IDxx-')
             self.b._forget('SOMExx-N0Nex1stant-IDxx-')
 
 
     def test_forget(self):
     def test_forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.b.forget('SOMExx-N0nex1stant-IDxx-')
             self.b.forget('SOMExx-N0nex1stant-IDxx-')
 
 
     def test_on_chord_part_return(self):
     def test_on_chord_part_return(self):
@@ -86,29 +85,29 @@ class test_BaseBackend_interface(AppCase):
             group(app=self.app), (), 'dakj221', None,
             group(app=self.app), (), 'dakj221', None,
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
             result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
         )
         )
-        self.assertTrue(self.app.tasks[unlock].apply_async.call_count)
+        assert self.app.tasks[unlock].apply_async.call_count
 
 
 
 
-class test_exception_pickle(AppCase):
+class test_exception_pickle:
 
 
     @skip.if_python3(reason='does not support old style classes')
     @skip.if_python3(reason='does not support old style classes')
     @skip.if_pypy()
     @skip.if_pypy()
     def test_oldstyle(self):
     def test_oldstyle(self):
-        self.assertTrue(fnpe(Oldstyle()))
+        assert fnpe(Oldstyle())
 
 
     def test_BaseException(self):
     def test_BaseException(self):
-        self.assertIsNone(fnpe(Exception()))
+        assert fnpe(Exception()) is None
 
 
     def test_get_pickleable_exception(self):
     def test_get_pickleable_exception(self):
         exc = Exception('foo')
         exc = Exception('foo')
-        self.assertEqual(gpe(exc), exc)
+        assert gpe(exc) == exc
 
 
     def test_unpickleable(self):
     def test_unpickleable(self):
-        self.assertIsInstance(fnpe(Unpickleable()), KeyError)
-        self.assertIsNone(fnpe(Impossible()))
+        assert isinstance(fnpe(Unpickleable()), KeyError)
+        assert fnpe(Impossible()) is None
 
 
 
 
-class test_prepare_exception(AppCase):
+class test_prepare_exception:
 
 
     def setup(self):
     def setup(self):
         self.b = BaseBackend(self.app)
         self.b = BaseBackend(self.app)
@@ -116,28 +115,28 @@ class test_prepare_exception(AppCase):
     def test_unpickleable(self):
     def test_unpickleable(self):
         self.b.serializer = 'pickle'
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
         x = self.b.prepare_exception(Unpickleable(1, 2, 'foo'))
-        self.assertIsInstance(x, KeyError)
+        assert isinstance(x, KeyError)
         y = self.b.exception_to_python(x)
         y = self.b.exception_to_python(x)
-        self.assertIsInstance(y, KeyError)
+        assert isinstance(y, KeyError)
 
 
     def test_impossible(self):
     def test_impossible(self):
         self.b.serializer = 'pickle'
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(Impossible())
         x = self.b.prepare_exception(Impossible())
-        self.assertIsInstance(x, UnpickleableExceptionWrapper)
-        self.assertTrue(str(x))
+        assert isinstance(x, UnpickleableExceptionWrapper)
+        assert str(x)
         y = self.b.exception_to_python(x)
         y = self.b.exception_to_python(x)
-        self.assertEqual(y.__class__.__name__, 'Impossible')
+        assert y.__class__.__name__ == 'Impossible'
         if sys.version_info < (2, 5):
         if sys.version_info < (2, 5):
-            self.assertTrue(y.__class__.__module__)
+            assert y.__class__.__module__
         else:
         else:
-            self.assertEqual(y.__class__.__module__, 'foo.module')
+            assert y.__class__.__module__ == 'foo.module'
 
 
     def test_regular(self):
     def test_regular(self):
         self.b.serializer = 'pickle'
         self.b.serializer = 'pickle'
         x = self.b.prepare_exception(KeyError('baz'))
         x = self.b.prepare_exception(KeyError('baz'))
-        self.assertIsInstance(x, KeyError)
+        assert isinstance(x, KeyError)
         y = self.b.exception_to_python(x)
         y = self.b.exception_to_python(x)
-        self.assertIsInstance(y, KeyError)
+        assert isinstance(y, KeyError)
 
 
 
 
 class KVBackend(KeyValueStoreBackend):
 class KVBackend(KeyValueStoreBackend):
@@ -181,22 +180,22 @@ class DictBackend(BaseBackend):
         self._data.pop(group_id, None)
         self._data.pop(group_id, None)
 
 
 
 
-class test_BaseBackend_dict(AppCase):
+class test_BaseBackend_dict:
 
 
     def setup(self):
     def setup(self):
         self.b = DictBackend(app=self.app)
         self.b = DictBackend(app=self.app)
 
 
     def test_delete_group(self):
     def test_delete_group(self):
         self.b.delete_group('can-delete')
         self.b.delete_group('can-delete')
-        self.assertNotIn('can-delete', self.b._data)
+        assert 'can-delete' not in self.b._data
 
 
     def test_prepare_exception_json(self):
     def test_prepare_exception_json(self):
         x = DictBackend(self.app, serializer='json')
         x = DictBackend(self.app, serializer='json')
         e = x.prepare_exception(KeyError('foo'))
         e = x.prepare_exception(KeyError('foo'))
-        self.assertIn('exc_type', e)
+        assert 'exc_type' in e
         e = x.exception_to_python(e)
         e = x.exception_to_python(e)
-        self.assertEqual(e.__class__.__name__, 'KeyError')
-        self.assertEqual(str(e).strip('u'), "'foo'")
+        assert e.__class__.__name__ == 'KeyError'
+        assert str(e).strip('u') == "'foo'"
 
 
     def test_save_group(self):
     def test_save_group(self):
         b = BaseBackend(self.app)
         b = BaseBackend(self.app)
@@ -206,20 +205,20 @@ class test_BaseBackend_dict(AppCase):
 
 
     def test_add_to_chord_interface(self):
     def test_add_to_chord_interface(self):
         b = BaseBackend(self.app)
         b = BaseBackend(self.app)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             b.add_to_chord('group_id', 'sig')
             b.add_to_chord('group_id', 'sig')
 
 
     def test_forget_interface(self):
     def test_forget_interface(self):
         b = BaseBackend(self.app)
         b = BaseBackend(self.app)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             b.forget('foo')
             b.forget('foo')
 
 
     def test_restore_group(self):
     def test_restore_group(self):
-        self.assertIsNone(self.b.restore_group('missing'))
-        self.assertIsNone(self.b.restore_group('missing'))
-        self.assertEqual(self.b.restore_group('exists'), 'group')
-        self.assertEqual(self.b.restore_group('exists'), 'group')
-        self.assertEqual(self.b.restore_group('exists', cache=False), 'group')
+        assert self.b.restore_group('missing') is None
+        assert self.b.restore_group('missing') is None
+        assert self.b.restore_group('exists') == 'group'
+        assert self.b.restore_group('exists') == 'group'
+        assert self.b.restore_group('exists', cache=False) == 'group'
 
 
     def test_reload_group_result(self):
     def test_reload_group_result(self):
         self.b._cache = {}
         self.b._cache = {}
@@ -239,29 +238,29 @@ class test_BaseBackend_dict(AppCase):
             self.b.fail_from_current_stack('task_id')
             self.b.fail_from_current_stack('task_id')
             self.b.mark_as_failure.assert_called()
             self.b.mark_as_failure.assert_called()
             args = self.b.mark_as_failure.call_args[0]
             args = self.b.mark_as_failure.call_args[0]
-            self.assertEqual(args[0], 'task_id')
-            self.assertIs(args[1], exc)
-            self.assertTrue(args[2])
+            assert args[0] == 'task_id'
+            assert args[1] is exc
+            assert args[2]
 
 
     def test_prepare_value_serializes_group_result(self):
     def test_prepare_value_serializes_group_result(self):
         self.b.serializer = 'json'
         self.b.serializer = 'json'
         g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
         g = self.app.GroupResult('group_id', [self.app.AsyncResult('foo')])
         v = self.b.prepare_value(g)
         v = self.b.prepare_value(g)
-        self.assertIsInstance(v, (list, tuple))
-        self.assertEqual(result_from_tuple(v, app=self.app), g)
+        assert isinstance(v, (list, tuple))
+        assert result_from_tuple(v, app=self.app) == g
 
 
         v2 = self.b.prepare_value(g[0])
         v2 = self.b.prepare_value(g[0])
-        self.assertIsInstance(v2, (list, tuple))
-        self.assertEqual(result_from_tuple(v2, app=self.app), g[0])
+        assert isinstance(v2, (list, tuple))
+        assert result_from_tuple(v2, app=self.app) == g[0]
 
 
         self.b.serializer = 'pickle'
         self.b.serializer = 'pickle'
-        self.assertIsInstance(self.b.prepare_value(g), self.app.GroupResult)
+        assert isinstance(self.b.prepare_value(g), self.app.GroupResult)
 
 
     def test_is_cached(self):
     def test_is_cached(self):
         b = BaseBackend(app=self.app, max_cached_results=1)
         b = BaseBackend(app=self.app, max_cached_results=1)
         b._cache['foo'] = 1
         b._cache['foo'] = 1
-        self.assertTrue(b.is_cached('foo'))
-        self.assertFalse(b.is_cached('false'))
+        assert b.is_cached('foo')
+        assert not b.is_cached('false')
 
 
     def test_mark_as_done__chord(self):
     def test_mark_as_done__chord(self):
         b = BaseBackend(app=self.app)
         b = BaseBackend(app=self.app)
@@ -297,7 +296,7 @@ class test_BaseBackend_dict(AppCase):
         callback.options = {'link_error': []}
         callback.options = {'link_error': []}
         task = self.app.tasks[callback.task] = Mock()
         task = self.app.tasks[callback.task] = Mock()
         b.fail_from_current_stack = Mock()
         b.fail_from_current_stack = Mock()
-        group = self.patch('celery.group')
+        group = self.patching('celery.group')
         group.side_effect = exc
         group.side_effect = exc
         b.chord_error_from_stack(callback, exc=ValueError())
         b.chord_error_from_stack(callback, exc=ValueError())
         task.backend.fail_from_current_stack.assert_called_with(
         task.backend.fail_from_current_stack.assert_called_with(
@@ -305,15 +304,15 @@ class test_BaseBackend_dict(AppCase):
 
 
     def test_exception_to_python_when_None(self):
     def test_exception_to_python_when_None(self):
         b = BaseBackend(app=self.app)
         b = BaseBackend(app=self.app)
-        self.assertIsNone(b.exception_to_python(None))
+        assert b.exception_to_python(None) is None
 
 
     def test_wait_for__on_interval(self):
     def test_wait_for__on_interval(self):
-        self.patch('time.sleep')
+        self.patching('time.sleep')
         b = BaseBackend(app=self.app)
         b = BaseBackend(app=self.app)
         b._get_task_meta_for = Mock()
         b._get_task_meta_for = Mock()
         b._get_task_meta_for.return_value = {'status': states.PENDING}
         b._get_task_meta_for.return_value = {'status': states.PENDING}
         callback = Mock(name='callback')
         callback = Mock(name='callback')
-        with self.assertRaises(TimeoutError):
+        with pytest.raises(TimeoutError):
             b.wait_for(task_id='1', on_interval=callback, timeout=1)
             b.wait_for(task_id='1', on_interval=callback, timeout=1)
         callback.assert_called_with()
         callback.assert_called_with()
 
 
@@ -324,12 +323,12 @@ class test_BaseBackend_dict(AppCase):
         b = BaseBackend(app=self.app)
         b = BaseBackend(app=self.app)
         b._get_task_meta_for = Mock()
         b._get_task_meta_for = Mock()
         b._get_task_meta_for.return_value = {}
         b._get_task_meta_for.return_value = {}
-        self.assertIsNone(b.get_children('id'))
+        assert b.get_children('id') is None
         b._get_task_meta_for.return_value = {'children': 3}
         b._get_task_meta_for.return_value = {'children': 3}
-        self.assertEqual(b.get_children('id'), 3)
+        assert b.get_children('id') == 3
 
 
 
 
-class test_KeyValueStoreBackend(AppCase):
+class test_KeyValueStoreBackend:
 
 
     def setup(self):
     def setup(self):
         self.b = KVBackend(app=self.app)
         self.b = KVBackend(app=self.app)
@@ -341,15 +340,15 @@ class test_KeyValueStoreBackend(AppCase):
     def test_get_store_delete_result(self):
     def test_get_store_delete_result(self):
         tid = uuid()
         tid = uuid()
         self.b.mark_as_done(tid, 'Hello world')
         self.b.mark_as_done(tid, 'Hello world')
-        self.assertEqual(self.b.get_result(tid), 'Hello world')
-        self.assertEqual(self.b.get_state(tid), states.SUCCESS)
+        assert self.b.get_result(tid) == 'Hello world'
+        assert self.b.get_state(tid) == states.SUCCESS
         self.b.forget(tid)
         self.b.forget(tid)
-        self.assertEqual(self.b.get_state(tid), states.PENDING)
+        assert self.b.get_state(tid) == states.PENDING
 
 
     def test_strip_prefix(self):
     def test_strip_prefix(self):
         x = self.b.get_key_for_task('x1b34')
         x = self.b.get_key_for_task('x1b34')
-        self.assertEqual(self.b._strip_prefix(x), 'x1b34')
-        self.assertEqual(self.b._strip_prefix('x1b34'), 'x1b34')
+        assert self.b._strip_prefix(x) == 'x1b34'
+        assert self.b._strip_prefix('x1b34') == 'x1b34'
 
 
     def test_get_many(self):
     def test_get_many(self):
         for is_dict in True, False:
         for is_dict in True, False:
@@ -359,17 +358,17 @@ class test_KeyValueStoreBackend(AppCase):
                 self.b.mark_as_done(id, i)
                 self.b.mark_as_done(id, i)
             it = self.b.get_many(list(ids))
             it = self.b.get_many(list(ids))
             for i, (got_id, got_state) in enumerate(it):
             for i, (got_id, got_state) in enumerate(it):
-                self.assertEqual(got_state['result'], ids[got_id])
-            self.assertEqual(i, 9)
-            self.assertTrue(list(self.b.get_many(list(ids))))
+                assert got_state['result'] == ids[got_id]
+            assert i == 9
+            assert list(self.b.get_many(list(ids)))
 
 
             self.b._cache.clear()
             self.b._cache.clear()
             callback = Mock(name='callback')
             callback = Mock(name='callback')
             it = self.b.get_many(list(ids), on_message=callback)
             it = self.b.get_many(list(ids), on_message=callback)
             for i, (got_id, got_state) in enumerate(it):
             for i, (got_id, got_state) in enumerate(it):
-                self.assertEqual(got_state['result'], ids[got_id])
-            self.assertEqual(i, 9)
-            self.assertTrue(list(self.b.get_many(list(ids))))
+                assert got_state['result'] == ids[got_id]
+            assert i == 9
+            assert list(self.b.get_many(list(ids)))
             callback.assert_has_calls([
             callback.assert_has_calls([
                 call(ANY) for id in ids
                 call(ANY) for id in ids
             ])
             ])
@@ -377,7 +376,7 @@ class test_KeyValueStoreBackend(AppCase):
     def test_get_many_times_out(self):
     def test_get_many_times_out(self):
         tasks = [uuid() for _ in range(4)]
         tasks = [uuid() for _ in range(4)]
         self.b._cache[tasks[1]] = {'status': 'PENDING'}
         self.b._cache[tasks[1]] = {'status': 'PENDING'}
-        with self.assertRaises(self.b.TimeoutError):
+        with pytest.raises(self.b.TimeoutError):
             list(self.b.get_many(tasks, timeout=0.01, interval=0.01))
             list(self.b.get_many(tasks, timeout=0.01, interval=0.01))
 
 
     def test_chord_part_return_no_gid(self):
     def test_chord_part_return_no_gid(self):
@@ -390,9 +389,8 @@ class test_KeyValueStoreBackend(AppCase):
         self.b.get_key_for_chord.side_effect = AssertionError(
         self.b.get_key_for_chord.side_effect = AssertionError(
             'should not get here',
             'should not get here',
         )
         )
-        self.assertIsNone(
-            self.b.on_chord_part_return(task.request, state, result),
-        )
+        assert self.b.on_chord_part_return(
+            task.request, state, result) is None
 
 
     @patch('celery.backends.base.GroupResult')
     @patch('celery.backends.base.GroupResult')
     @patch('celery.backends.base.maybe_signature')
     @patch('celery.backends.base.maybe_signature')
@@ -429,14 +427,11 @@ class test_KeyValueStoreBackend(AppCase):
     def test_filter_ready(self):
     def test_filter_ready(self):
         self.b.decode_result = Mock()
         self.b.decode_result = Mock()
         self.b.decode_result.side_effect = pass1
         self.b.decode_result.side_effect = pass1
-        self.assertEqual(
-            len(list(self.b._filter_ready([
-                (1, {'status': states.RETRY}),
-                (2, {'status': states.FAILURE}),
-                (3, {'status': states.SUCCESS}),
-            ]))),
-            2,
-        )
+        assert len(list(self.b._filter_ready([
+            (1, {'status': states.RETRY}),
+            (2, {'status': states.FAILURE}),
+            (3, {'status': states.SUCCESS}),
+        ]))) == 2
 
 
     @contextmanager
     @contextmanager
     def _chord_part_context(self, b):
     def _chord_part_context(self, b):
@@ -484,8 +479,8 @@ class test_KeyValueStoreBackend(AppCase):
             self.b.fail_from_current_stack.assert_called()
             self.b.fail_from_current_stack.assert_called()
             args = self.b.fail_from_current_stack.call_args
             args = self.b.fail_from_current_stack.call_args
             exc = args[1]['exc']
             exc = args[1]['exc']
-            self.assertIsInstance(exc, ChordError)
-            self.assertIn('foo', str(exc))
+            assert isinstance(exc, ChordError)
+            assert 'foo' in str(exc)
 
 
     def test_chord_part_return_join_raises_task(self):
     def test_chord_part_return_join_raises_task(self):
         b = KVBackend(serializer='pickle', app=self.app)
         b = KVBackend(serializer='pickle', app=self.app)
@@ -498,8 +493,8 @@ class test_KeyValueStoreBackend(AppCase):
             b.fail_from_current_stack.assert_called()
             b.fail_from_current_stack.assert_called()
             args = b.fail_from_current_stack.call_args
             args = b.fail_from_current_stack.call_args
             exc = args[1]['exc']
             exc = args[1]['exc']
-            self.assertIsInstance(exc, ChordError)
-            self.assertIn('Dependency culprit raised', str(exc))
+            assert isinstance(exc, ChordError)
+            assert 'Dependency culprit raised' in str(exc)
 
 
     def test_restore_group_from_json(self):
     def test_restore_group_from_json(self):
         b = KVBackend(serializer='json', app=self.app)
         b = KVBackend(serializer='json', app=self.app)
@@ -509,7 +504,7 @@ class test_KeyValueStoreBackend(AppCase):
         )
         )
         b._save_group(g.id, g)
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         g2 = b._restore_group(g.id)['result']
-        self.assertEqual(g2, g)
+        assert g2 == g
 
 
     def test_restore_group_from_pickle(self):
     def test_restore_group_from_pickle(self):
         b = KVBackend(serializer='pickle', app=self.app)
         b = KVBackend(serializer='pickle', app=self.app)
@@ -519,7 +514,7 @@ class test_KeyValueStoreBackend(AppCase):
         )
         )
         b._save_group(g.id, g)
         b._save_group(g.id, g)
         g2 = b._restore_group(g.id)['result']
         g2 = b._restore_group(g.id)['result']
-        self.assertEqual(g2, g)
+        assert g2 == g
 
 
     def test_chord_apply_fallback(self):
     def test_chord_apply_fallback(self):
         self.b.implements_incr = False
         self.b.implements_incr = False
@@ -533,8 +528,8 @@ class test_KeyValueStoreBackend(AppCase):
         )
         )
 
 
     def test_get_missing_meta(self):
     def test_get_missing_meta(self):
-        self.assertIsNone(self.b.get_result('xxx-missing'))
-        self.assertEqual(self.b.get_state('xxx-missing'), states.PENDING)
+        assert self.b.get_result('xxx-missing') is None
+        assert self.b.get_state('xxx-missing') == states.PENDING
 
 
     def test_save_restore_delete_group(self):
     def test_save_restore_delete_group(self):
         tid = uuid()
         tid = uuid()
@@ -543,58 +538,58 @@ class test_KeyValueStoreBackend(AppCase):
         )
         )
         self.b.save_group(tid, tsr)
         self.b.save_group(tid, tsr)
         self.b.restore_group(tid)
         self.b.restore_group(tid)
-        self.assertEqual(self.b.restore_group(tid), tsr)
+        assert self.b.restore_group(tid) == tsr
         self.b.delete_group(tid)
         self.b.delete_group(tid)
-        self.assertIsNone(self.b.restore_group(tid))
+        assert self.b.restore_group(tid) is None
 
 
     def test_restore_missing_group(self):
     def test_restore_missing_group(self):
-        self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
+        assert self.b.restore_group('xxx-nonexistant') is None
 
 
 
 
-class test_KeyValueStoreBackend_interface(AppCase):
+class test_KeyValueStoreBackend_interface:
 
 
     def test_get(self):
     def test_get(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).get('a')
             KeyValueStoreBackend(self.app).get('a')
 
 
     def test_set(self):
     def test_set(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).set('a', 1)
             KeyValueStoreBackend(self.app).set('a', 1)
 
 
     def test_incr(self):
     def test_incr(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).incr('a')
             KeyValueStoreBackend(self.app).incr('a')
 
 
     def test_cleanup(self):
     def test_cleanup(self):
-        self.assertFalse(KeyValueStoreBackend(self.app).cleanup())
+        assert not KeyValueStoreBackend(self.app).cleanup()
 
 
     def test_delete(self):
     def test_delete(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).delete('a')
             KeyValueStoreBackend(self.app).delete('a')
 
 
     def test_mget(self):
     def test_mget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).mget(['a'])
             KeyValueStoreBackend(self.app).mget(['a'])
 
 
     def test_forget(self):
     def test_forget(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             KeyValueStoreBackend(self.app).forget('a')
             KeyValueStoreBackend(self.app).forget('a')
 
 
 
 
-class test_DisabledBackend(AppCase):
+class test_DisabledBackend:
 
 
     def test_store_result(self):
     def test_store_result(self):
         DisabledBackend(self.app).store_result()
         DisabledBackend(self.app).store_result()
 
 
     def test_is_disabled(self):
     def test_is_disabled(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             DisabledBackend(self.app).get_state('foo')
             DisabledBackend(self.app).get_state('foo')
 
 
     def test_as_uri(self):
     def test_as_uri(self):
-        self.assertEqual(DisabledBackend(self.app).as_uri(), 'disabled://')
+        assert DisabledBackend(self.app).as_uri() == 'disabled://'
 
 
 
 
-class test_as_uri(AppCase):
+class test_as_uri:
 
 
     def setup(self):
     def setup(self):
         self.b = BaseBackend(
         self.b = BaseBackend(
@@ -603,7 +598,7 @@ class test_as_uri(AppCase):
         )
         )
 
 
     def test_as_uri_include_password(self):
     def test_as_uri_include_password(self):
-        self.assertEqual(self.b.as_uri(True), self.b.url)
+        assert self.b.as_uri(True) == self.b.url
 
 
     def test_as_uri_exclude_password(self):
     def test_as_uri_exclude_password(self):
-        self.assertEqual(self.b.as_uri(), 'sch://uuuu:**@hostname.dom/')
+        assert self.b.as_uri() == 'sch://uuuu:**@hostname.dom/'

+ 35 - 37
celery/tests/backends/test_cache.py → t/unit/backends/test_cache.py

@@ -1,10 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
 import sys
 import sys
 import types
 import types
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
+from case import Mock, mock, patch, skip
 from kombu.utils.encoding import str_to_bytes, ensure_bytes
 from kombu.utils.encoding import str_to_bytes, ensure_bytes
 
 
 from celery import states
 from celery import states
@@ -13,8 +15,6 @@ from celery.backends.cache import CacheBackend, DummyClient, backends
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.five import items, bytes_if_py2, string, text_t
 from celery.five import items, bytes_if_py2, string, text_t
 
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 PY3 = sys.version_info[0] == 3
 PY3 = sys.version_info[0] == 3
 
 
 
 
@@ -24,7 +24,7 @@ class SomeClass(object):
         self.data = data
         self.data = data
 
 
 
 
-class test_CacheBackend(AppCase):
+class test_CacheBackend:
 
 
     def setup(self):
     def setup(self):
         self.app.conf.result_serializer = 'pickle'
         self.app.conf.result_serializer = 'pickle'
@@ -38,32 +38,32 @@ class test_CacheBackend(AppCase):
 
 
     def test_no_backend(self):
     def test_no_backend(self):
         self.app.conf.cache_backend = None
         self.app.conf.cache_backend = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CacheBackend(backend=None, app=self.app)
             CacheBackend(backend=None, app=self.app)
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
-        self.assertEqual(self.tb.get_state(self.tid), states.PENDING)
-        self.assertIsNone(self.tb.get_result(self.tid))
+        assert self.tb.get_state(self.tid) == states.PENDING
+        assert self.tb.get_result(self.tid) is None
 
 
         self.tb.mark_as_done(self.tid, 42)
         self.tb.mark_as_done(self.tid, 42)
-        self.assertEqual(self.tb.get_state(self.tid), states.SUCCESS)
-        self.assertEqual(self.tb.get_result(self.tid), 42)
+        assert self.tb.get_state(self.tid) == states.SUCCESS
+        assert self.tb.get_result(self.tid) == 42
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
         result = {'foo': 'baz', 'bar': SomeClass(12345)}
         self.tb.mark_as_done(self.tid, result)
         self.tb.mark_as_done(self.tid, result)
         # is serialized properly.
         # is serialized properly.
         rindb = self.tb.get_result(self.tid)
         rindb = self.tb.get_result(self.tid)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
         try:
         try:
             raise KeyError('foo')
             raise KeyError('foo')
         except KeyError as exception:
         except KeyError as exception:
             self.tb.mark_as_failure(self.tid, exception)
             self.tb.mark_as_failure(self.tid, exception)
-            self.assertEqual(self.tb.get_state(self.tid), states.FAILURE)
-            self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
+            assert self.tb.get_state(self.tid) == states.FAILURE
+            assert isinstance(self.tb.get_result(self.tid), KeyError)
 
 
     def test_apply_chord(self):
     def test_apply_chord(self):
         tb = CacheBackend(backend='memory://', app=self.app)
         tb = CacheBackend(backend='memory://', app=self.app)
@@ -99,48 +99,47 @@ class test_CacheBackend(AppCase):
         self.tb.set('foo', 1)
         self.tb.set('foo', 1)
         self.tb.set('bar', 2)
         self.tb.set('bar', 2)
 
 
-        self.assertDictEqual(self.tb.mget(['foo', 'bar']),
-                             {'foo': 1, 'bar': 2})
+        assert self.tb.mget(['foo', 'bar']) == {'foo': 1, 'bar': 2}
 
 
     def test_forget(self):
     def test_forget(self):
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
         self.tb.mark_as_done(self.tid, {'foo': 'bar'})
         x = self.app.AsyncResult(self.tid, backend=self.tb)
         x = self.app.AsyncResult(self.tid, backend=self.tb)
         x.forget()
         x.forget()
-        self.assertIsNone(x.result)
+        assert x.result is None
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         self.tb.process_cleanup()
         self.tb.process_cleanup()
 
 
     def test_expires_as_int(self):
     def test_expires_as_int(self):
         tb = CacheBackend(backend='memory://', expires=10, app=self.app)
         tb = CacheBackend(backend='memory://', expires=10, app=self.app)
-        self.assertEqual(tb.expires, 10)
+        assert tb.expires == 10
 
 
     def test_unknown_backend_raises_ImproperlyConfigured(self):
     def test_unknown_backend_raises_ImproperlyConfigured(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CacheBackend(backend='unknown://', app=self.app)
             CacheBackend(backend='unknown://', app=self.app)
 
 
     def test_as_uri_no_servers(self):
     def test_as_uri_no_servers(self):
-        self.assertEqual(self.tb.as_uri(), 'memory:///')
+        assert self.tb.as_uri() == 'memory:///'
 
 
     def test_as_uri_one_server(self):
     def test_as_uri_one_server(self):
         backend = 'memcache://127.0.0.1:11211/'
         backend = 'memcache://127.0.0.1:11211/'
         b = CacheBackend(backend=backend, app=self.app)
         b = CacheBackend(backend=backend, app=self.app)
-        self.assertEqual(b.as_uri(), backend)
+        assert b.as_uri() == backend
 
 
     def test_as_uri_multiple_servers(self):
     def test_as_uri_multiple_servers(self):
         backend = 'memcache://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         backend = 'memcache://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         b = CacheBackend(backend=backend, app=self.app)
         b = CacheBackend(backend=backend, app=self.app)
-        self.assertEqual(b.as_uri(), backend)
+        assert b.as_uri() == backend
 
 
-    @mock.stdouts
     @skip.unless_module('memcached', name='python-memcached')
     @skip.unless_module('memcached', name='python-memcached')
-    def test_regression_worker_startup_info(self, stdout, stderr):
+    def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
             'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
         )
         )
         worker = self.app.Worker()
         worker = self.app.Worker()
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
+        with mock.stdouts():
+            worker.on_start()
+            assert worker.startup_info()
 
 
 
 
 class MyMemcachedStringEncodingError(Exception):
 class MyMemcachedStringEncodingError(Exception):
@@ -190,15 +189,14 @@ class MockCacheMixin(object):
                 sys.modules['pylibmc'] = prev
                 sys.modules['pylibmc'] = prev
 
 
 
 
-class test_get_best_memcache(AppCase, MockCacheMixin):
+class test_get_best_memcache(MockCacheMixin):
 
 
     def test_pylibmc(self):
     def test_pylibmc(self):
         with self.mock_pylibmc():
         with self.mock_pylibmc():
             with mock.reset_modules('celery.backends.cache'):
             with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
-                self.assertEqual(cache.get_best_memcache()[0].__module__,
-                                 'pylibmc')
+                assert cache.get_best_memcache()[0].__module__ == 'pylibmc'
 
 
     def test_memcache(self):
     def test_memcache(self):
         with self.mock_memcache():
         with self.mock_memcache():
@@ -206,15 +204,15 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
                 with mock.mask_modules('pylibmc'):
                 with mock.mask_modules('pylibmc'):
                     from celery.backends import cache
                     from celery.backends import cache
                     cache._imp = [None]
                     cache._imp = [None]
-                    self.assertEqual(cache.get_best_memcache()[0]().__module__,
-                                     'memcache')
+                    assert (cache.get_best_memcache()[0]().__module__ ==
+                            'memcache')
 
 
     def test_no_implementations(self):
     def test_no_implementations(self):
         with mock.mask_modules('pylibmc', 'memcache'):
         with mock.mask_modules('pylibmc', 'memcache'):
             with mock.reset_modules('celery.backends.cache'):
             with mock.reset_modules('celery.backends.cache'):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
-                with self.assertRaises(ImproperlyConfigured):
+                with pytest.raises(ImproperlyConfigured):
                     cache.get_best_memcache()
                     cache.get_best_memcache()
 
 
     def test_cached(self):
     def test_cached(self):
@@ -223,17 +221,17 @@ class test_get_best_memcache(AppCase, MockCacheMixin):
                 from celery.backends import cache
                 from celery.backends import cache
                 cache._imp = [None]
                 cache._imp = [None]
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
                 cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
-                self.assertTrue(cache._imp[0])
+                assert cache._imp[0]
                 cache.get_best_memcache()[0]()
                 cache.get_best_memcache()[0]()
 
 
     def test_backends(self):
     def test_backends(self):
         from celery.backends.cache import backends
         from celery.backends.cache import backends
         with self.mock_memcache():
         with self.mock_memcache():
             for name, fun in items(backends):
             for name, fun in items(backends):
-                self.assertTrue(fun())
+                assert fun()
 
 
 
 
-class test_memcache_key(AppCase, MockCacheMixin):
+class test_memcache_key(MockCacheMixin):
 
 
     def test_memcache_unicode_key(self):
     def test_memcache_unicode_key(self):
         with self.mock_memcache():
         with self.mock_memcache():
@@ -244,7 +242,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     task_id, result = string(uuid()), 42
                     task_id, result = string(uuid()), 42
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, state=states.SUCCESS)
                     b.store_result(task_id, result, state=states.SUCCESS)
-                    self.assertEqual(b.get_result(task_id), result)
+                    assert b.get_result(task_id) == result
 
 
     def test_memcache_bytes_key(self):
     def test_memcache_bytes_key(self):
         with self.mock_memcache():
         with self.mock_memcache():
@@ -255,7 +253,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                     task_id, result = str_to_bytes(uuid()), 42
                     task_id, result = str_to_bytes(uuid()), 42
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b = cache.CacheBackend(backend='memcache', app=self.app)
                     b.store_result(task_id, result, state=states.SUCCESS)
                     b.store_result(task_id, result, state=states.SUCCESS)
-                    self.assertEqual(b.get_result(task_id), result)
+                    assert b.get_result(task_id) == result
 
 
     def test_pylibmc_unicode_key(self):
     def test_pylibmc_unicode_key(self):
         with mock.reset_modules('celery.backends.cache'):
         with mock.reset_modules('celery.backends.cache'):
@@ -265,7 +263,7 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 task_id, result = string(uuid()), 42
                 task_id, result = string(uuid()), 42
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, state=states.SUCCESS)
                 b.store_result(task_id, result, state=states.SUCCESS)
-                self.assertEqual(b.get_result(task_id), result)
+                assert b.get_result(task_id) == result
 
 
     def test_pylibmc_bytes_key(self):
     def test_pylibmc_bytes_key(self):
         with mock.reset_modules('celery.backends.cache'):
         with mock.reset_modules('celery.backends.cache'):
@@ -275,4 +273,4 @@ class test_memcache_key(AppCase, MockCacheMixin):
                 task_id, result = str_to_bytes(uuid()), 42
                 task_id, result = str_to_bytes(uuid()), 42
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b = cache.CacheBackend(backend='memcache', app=self.app)
                 b.store_result(task_id, result, state=states.SUCCESS)
                 b.store_result(task_id, result, state=states.SUCCESS)
-                self.assertEqual(b.get_result(task_id), result)
+                assert b.get_result(task_id) == result

+ 18 - 15
celery/tests/backends/test_cassandra.py → t/unit/backends/test_cassandra.py

@@ -1,18 +1,21 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from pickle import loads, dumps
 from pickle import loads, dumps
 from datetime import datetime
 from datetime import datetime
 
 
+from case import Mock, mock
+
 from celery import states
 from celery import states
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
-from celery.tests.case import AppCase, Mock, depends_on_current_app, mock
 
 
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 CASSANDRA_MODULES = ['cassandra', 'cassandra.auth', 'cassandra.cluster']
 
 
 
 
 @mock.module(*CASSANDRA_MODULES)
 @mock.module(*CASSANDRA_MODULES)
-class test_CassandraBackend(AppCase):
+class test_CassandraBackend:
 
 
     def setup(self):
     def setup(self):
         self.app.conf.update(
         self.app.conf.update(
@@ -27,7 +30,7 @@ class test_CassandraBackend(AppCase):
         from celery.backends import cassandra as mod
         from celery.backends import cassandra as mod
         prev, mod.cassandra = mod.cassandra, None
         prev, mod.cassandra = mod.cassandra, None
         try:
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 mod.CassandraBackend(app=self.app)
                 mod.CassandraBackend(app=self.app)
         finally:
         finally:
             mod.cassandra = prev
             mod.cassandra = prev
@@ -48,16 +51,16 @@ class test_CassandraBackend(AppCase):
         mod.CassandraBackend(app=self.app)
         mod.CassandraBackend(app=self.app)
 
 
         # no servers raises ImproperlyConfigured
         # no servers raises ImproperlyConfigured
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             self.app.conf.cassandra_servers = None
             self.app.conf.cassandra_servers = None
             mod.CassandraBackend(
             mod.CassandraBackend(
                 app=self.app, keyspace='b', column_family='c',
                 app=self.app, keyspace='b', column_family='c',
             )
             )
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self, *modules):
     def test_reduce(self, *modules):
         from celery.backends.cassandra import CassandraBackend
         from celery.backends.cassandra import CassandraBackend
-        self.assertTrue(loads(dumps(CassandraBackend(app=self.app))))
+        assert loads(dumps(CassandraBackend(app=self.app)))
 
 
     def test_get_task_meta_for(self, *modules):
     def test_get_task_meta_for(self, *modules):
         from celery.backends import cassandra as mod
         from celery.backends import cassandra as mod
@@ -72,11 +75,11 @@ class test_CassandraBackend(AppCase):
         ]
         ]
         x.decode = Mock()
         x.decode = Mock()
         meta = x._get_task_meta_for('task_id')
         meta = x._get_task_meta_for('task_id')
-        self.assertEqual(meta['status'], states.SUCCESS)
+        assert meta['status'] == states.SUCCESS
 
 
         x._session.execute.return_value = []
         x._session.execute.return_value = []
         meta = x._get_task_meta_for('task_id')
         meta = x._get_task_meta_for('task_id')
-        self.assertEqual(meta['status'], states.PENDING)
+        assert meta['status'] == states.PENDING
 
 
     def test_store_result(self, *modules):
     def test_store_result(self, *modules):
         from celery.backends import cassandra as mod
         from celery.backends import cassandra as mod
@@ -93,8 +96,8 @@ class test_CassandraBackend(AppCase):
         x = mod.CassandraBackend(app=self.app)
         x = mod.CassandraBackend(app=self.app)
         x.process_cleanup()
         x.process_cleanup()
 
 
-        self.assertIsNone(x._connection)
-        self.assertIsNone(x._session)
+        assert x._connection is None
+        assert x._session is None
 
 
     def test_timeouting_cluster(self):
     def test_timeouting_cluster(self):
         # Tests behavior when Cluster.connect raises
         # Tests behavior when Cluster.connect raises
@@ -121,10 +124,10 @@ class test_CassandraBackend(AppCase):
 
 
         x = mod.CassandraBackend(app=self.app)
         x = mod.CassandraBackend(app=self.app)
 
 
-        with self.assertRaises(OTOExc):
+        with pytest.raises(OTOExc):
             x._store_result('task_id', 'result', states.SUCCESS)
             x._store_result('task_id', 'result', states.SUCCESS)
-        self.assertIsNone(x._connection)
-        self.assertIsNone(x._session)
+        assert x._connection is None
+        assert x._session is None
 
 
         x.process_cleanup()  # shouldn't raise
         x.process_cleanup()  # shouldn't raise
 
 
@@ -156,7 +159,7 @@ class test_CassandraBackend(AppCase):
             x._store_result('task_id', 'result', states.SUCCESS)
             x._store_result('task_id', 'result', states.SUCCESS)
             x.process_cleanup()
             x.process_cleanup()
 
 
-        self.assertEquals(RAMHoggingCluster.objects_alive, 0)
+        assert RAMHoggingCluster.objects_alive == 0
 
 
     def test_auth_provider(self):
     def test_auth_provider(self):
         # Ensure valid auth_provider works properly, and invalid one raises
         # Ensure valid auth_provider works properly, and invalid one raises
@@ -181,5 +184,5 @@ class test_CassandraBackend(AppCase):
         self.app.conf.cassandra_auth_kwargs = {
         self.app.conf.cassandra_auth_kwargs = {
             'username': 'Jack'
             'username': 'Jack'
         }
         }
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             mod.CassandraBackend(app=self.app)
             mod.CassandraBackend(app=self.app)

+ 6 - 5
celery/tests/backends/test_consul.py → t/unit/backends/test_consul.py

@@ -1,25 +1,26 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from celery.tests.case import AppCase, Mock, skip
+from case import Mock, skip
+
 from celery.backends.consul import ConsulBackend
 from celery.backends.consul import ConsulBackend
 
 
 
 
 @skip.unless_module('consul')
 @skip.unless_module('consul')
-class test_ConsulBackend(AppCase):
+class test_ConsulBackend:
 
 
     def setup(self):
     def setup(self):
         self.backend = ConsulBackend(
         self.backend = ConsulBackend(
             app=self.app, url='consul://localhost:800')
             app=self.app, url='consul://localhost:800')
 
 
     def test_supports_autoexpire(self):
     def test_supports_autoexpire(self):
-        self.assertTrue(self.backend.supports_autoexpire)
+        assert self.backend.supports_autoexpire
 
 
     def test_consul_consistency(self):
     def test_consul_consistency(self):
-        self.assertEqual('consistent', self.backend.consistency)
+        assert self.backend.consistency == 'consistent'
 
 
     def test_get(self):
     def test_get(self):
         index = 100
         index = 100
         data = {'Key': 'test-consul-1', 'Value': 'mypayload'}
         data = {'Key': 'test-consul-1', 'Value': 'mypayload'}
         self.backend.client = Mock(name='c.client')
         self.backend.client = Mock(name='c.client')
         self.backend.client.kv.get.return_value = (index, data)
         self.backend.client.kv.get.return_value = (index, data)
-        self.assertEqual(self.backend.get(data['Key']), 'mypayload')
+        assert self.backend.get(data['Key']) == 'mypayload'

+ 24 - 22
celery/tests/backends/test_couchbase.py → t/unit/backends/test_couchbase.py

@@ -1,14 +1,16 @@
 """Tests for the CouchbaseBackend."""
 """Tests for the CouchbaseBackend."""
-
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from kombu.utils.encoding import str_t
 from kombu.utils.encoding import str_t
 
 
+from case import MagicMock, Mock, patch, sentinel, skip
+
 from celery.backends import couchbase as module
 from celery.backends import couchbase as module
 from celery.backends.couchbase import CouchbaseBackend
 from celery.backends.couchbase import CouchbaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
-from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 
 try:
 try:
     import couchbase
     import couchbase
@@ -19,7 +21,7 @@ COUCHBASE_BUCKET = 'celery_bucket'
 
 
 
 
 @skip.unless_module('couchbase')
 @skip.unless_module('couchbase')
-class test_CouchbaseBackend(AppCase):
+class test_CouchbaseBackend:
 
 
     def setup(self):
     def setup(self):
         self.backend = CouchbaseBackend(app=self.app)
         self.backend = CouchbaseBackend(app=self.app)
@@ -27,14 +29,14 @@ class test_CouchbaseBackend(AppCase):
     def test_init_no_couchbase(self):
     def test_init_no_couchbase(self):
         prev, module.Couchbase = module.Couchbase, None
         prev, module.Couchbase = module.Couchbase, None
         try:
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 CouchbaseBackend(app=self.app)
                 CouchbaseBackend(app=self.app)
         finally:
         finally:
             module.Couchbase = prev
             module.Couchbase = prev
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
         self.app.conf.couchbase_backend_settings = []
         self.app.conf.couchbase_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             CouchbaseBackend(app=self.app)
             CouchbaseBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
@@ -47,7 +49,7 @@ class test_CouchbaseBackend(AppCase):
 
 
             connection = self.backend._get_connection()
             connection = self.backend._get_connection()
 
 
-            self.assertEqual(sentinel._connection, connection)
+            assert sentinel._connection == connection
             mock_Connection.assert_not_called()
             mock_Connection.assert_not_called()
 
 
     def test_get(self):
     def test_get(self):
@@ -57,7 +59,7 @@ class test_CouchbaseBackend(AppCase):
         mocked_get = x._connection.get = Mock()
         mocked_get = x._connection.get = Mock()
         mocked_get.return_value.value = sentinel.retval
         mocked_get.return_value.value = sentinel.retval
         # should return None
         # should return None
-        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        assert x.get('1f3fab') == sentinel.retval
         x._connection.get.assert_called_once_with('1f3fab')
         x._connection.get.assert_called_once_with('1f3fab')
 
 
     def test_set(self):
     def test_set(self):
@@ -66,7 +68,7 @@ class test_CouchbaseBackend(AppCase):
         x._connection = MagicMock()
         x._connection = MagicMock()
         x._connection.set = MagicMock()
         x._connection.set = MagicMock()
         # should return None
         # should return None
-        self.assertIsNone(x.set(sentinel.key, sentinel.value))
+        assert x.set(sentinel.key, sentinel.value) is None
 
 
     def test_delete(self):
     def test_delete(self):
         self.app.conf.couchbase_backend_settings = {}
         self.app.conf.couchbase_backend_settings = {}
@@ -75,7 +77,7 @@ class test_CouchbaseBackend(AppCase):
         mocked_delete = x._connection.delete = Mock()
         mocked_delete = x._connection.delete = Mock()
         mocked_delete.return_value = None
         mocked_delete.return_value = None
         # should return None
         # should return None
-        self.assertIsNone(x.delete('1f3fab'))
+        assert x.delete('1f3fab') is None
         x._connection.delete.assert_called_once_with('1f3fab')
         x._connection.delete.assert_called_once_with('1f3fab')
 
 
     def test_config_params(self):
     def test_config_params(self):
@@ -87,27 +89,27 @@ class test_CouchbaseBackend(AppCase):
             'port': '1234',
             'port': '1234',
         }
         }
         x = CouchbaseBackend(app=self.app)
         x = CouchbaseBackend(app=self.app)
-        self.assertEqual(x.bucket, 'mycoolbucket')
-        self.assertEqual(x.host, ['here.host.com', 'there.host.com'],)
-        self.assertEqual(x.username, 'johndoe',)
-        self.assertEqual(x.password, 'mysecret')
-        self.assertEqual(x.port, 1234)
+        assert x.bucket == 'mycoolbucket'
+        assert x.host == ['here.host.com', 'there.host.com']
+        assert x.username == 'johndoe'
+        assert x.password == 'mysecret'
+        assert x.port == 1234
 
 
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
     def test_backend_by_url(self, url='couchbase://myhost/mycoolbucket'):
         from celery.backends.couchbase import CouchbaseBackend
         from celery.backends.couchbase import CouchbaseBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, CouchbaseBackend)
-        self.assertEqual(url_, url)
+        assert backend is CouchbaseBackend
+        assert url_ == url
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         url = 'couchbase://johndoe:mysecret@myhost:123/mycoolbucket'
         with self.Celery(backend=url) as app:
         with self.Celery(backend=url) as app:
             x = app.backend
             x = app.backend
-            self.assertEqual(x.bucket, 'mycoolbucket')
-            self.assertEqual(x.host, 'myhost')
-            self.assertEqual(x.username, 'johndoe')
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 123)
+            assert x.bucket == 'mycoolbucket'
+            assert x.host == 'myhost'
+            assert x.username == 'johndoe'
+            assert x.password == 'mysecret'
+            assert x.port == 123
 
 
     def test_correct_key_types(self):
     def test_correct_key_types(self):
         keys = [
         keys = [
@@ -119,4 +121,4 @@ class test_CouchbaseBackend(AppCase):
             self.backend.get_key_for_group('group_id', 'key'),
             self.backend.get_key_for_group('group_id', 'key'),
         ]
         ]
         for key in keys:
         for key in keys:
-            self.assertIsInstance(key, str_t)
+            assert isinstance(key, str_t)

+ 20 - 20
celery/tests/backends/test_couchdb.py → t/unit/backends/test_couchdb.py

@@ -1,10 +1,13 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import MagicMock, Mock, sentinel, skip
+
 from celery.backends import couchdb as module
 from celery.backends import couchdb as module
 from celery.backends.couchdb import CouchBackend
 from celery.backends.couchdb import CouchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 from celery import backends
 from celery import backends
-from celery.tests.case import AppCase, Mock, patch, sentinel, skip
 
 
 try:
 try:
     import pycouchdb
     import pycouchdb
@@ -15,28 +18,26 @@ COUCHDB_CONTAINER = 'celery_container'
 
 
 
 
 @skip.unless_module('pycouchdb')
 @skip.unless_module('pycouchdb')
-class test_CouchBackend(AppCase):
+class test_CouchBackend:
 
 
     def setup(self):
     def setup(self):
+        self.Server = self.patching('pycouchdb.Server')
         self.backend = CouchBackend(app=self.app)
         self.backend = CouchBackend(app=self.app)
 
 
     def test_init_no_pycouchdb(self):
     def test_init_no_pycouchdb(self):
         """test init no pycouchdb raises"""
         """test init no pycouchdb raises"""
         prev, module.pycouchdb = module.pycouchdb, None
         prev, module.pycouchdb = module.pycouchdb, None
         try:
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 CouchBackend(app=self.app)
                 CouchBackend(app=self.app)
         finally:
         finally:
             module.pycouchdb = prev
             module.pycouchdb = prev
 
 
     def test_get_container_exists(self):
     def test_get_container_exists(self):
-        with patch('pycouchdb.client.Database') as mock_Connection:
             self.backend._connection = sentinel._connection
             self.backend._connection = sentinel._connection
-
-            connection = self.backend._get_connection()
-
-            self.assertEqual(sentinel._connection, connection)
-            mock_Connection.assert_not_called()
+            connection = self.backend.connection
+            assert connection is sentinel._connection
+            self.Server.assert_not_called()
 
 
     def test_get(self):
     def test_get(self):
         """test_get
         """test_get
@@ -48,10 +49,9 @@ class test_CouchBackend(AppCase):
         """
         """
         x = CouchBackend(app=self.app)
         x = CouchBackend(app=self.app)
         x._connection = Mock()
         x._connection = Mock()
-        mocked_get = x._connection.get = Mock()
-        mocked_get.return_value = sentinel.retval
+        get = x._connection.get = MagicMock()
         # should return None
         # should return None
-        self.assertEqual(x.get('1f3fab'), sentinel.retval)
+        assert x.get('1f3fab') == get.return_value['value']
         x._connection.get.assert_called_once_with('1f3fab')
         x._connection.get.assert_called_once_with('1f3fab')
 
 
     def test_delete(self):
     def test_delete(self):
@@ -67,21 +67,21 @@ class test_CouchBackend(AppCase):
         mocked_delete = x._connection.delete = Mock()
         mocked_delete = x._connection.delete = Mock()
         mocked_delete.return_value = None
         mocked_delete.return_value = None
         # should return None
         # should return None
-        self.assertIsNone(x.delete('1f3fab'))
+        assert x.delete('1f3fab') is None
         x._connection.delete.assert_called_once_with('1f3fab')
         x._connection.delete.assert_called_once_with('1f3fab')
 
 
     def test_backend_by_url(self, url='couchdb://myhost/mycoolcontainer'):
     def test_backend_by_url(self, url='couchdb://myhost/mycoolcontainer'):
         from celery.backends.couchdb import CouchBackend
         from celery.backends.couchdb import CouchBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, CouchBackend)
-        self.assertEqual(url_, url)
+        assert backend is CouchBackend
+        assert url_ == url
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
         url = 'couchdb://johndoe:mysecret@myhost:123/mycoolcontainer'
         url = 'couchdb://johndoe:mysecret@myhost:123/mycoolcontainer'
         with self.Celery(backend=url) as app:
         with self.Celery(backend=url) as app:
             x = app.backend
             x = app.backend
-            self.assertEqual(x.container, 'mycoolcontainer')
-            self.assertEqual(x.host, 'myhost')
-            self.assertEqual(x.username, 'johndoe')
-            self.assertEqual(x.password, 'mysecret')
-            self.assertEqual(x.port, 123)
+            assert x.container == 'mycoolcontainer'
+            assert x.host == 'myhost'
+            assert x.username == 'johndoe'
+            assert x.password == 'mysecret'
+            assert x.port == 123

+ 47 - 49
celery/tests/backends/test_database.py → t/unit/backends/test_database.py

@@ -1,17 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-from datetime import datetime
+import pytest
 
 
+from datetime import datetime
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
+from case import Mock, patch, skip
+
 from celery import states
 from celery import states
 from celery import uuid
 from celery import uuid
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, patch, skip,
-)
-
 try:
 try:
     import sqlalchemy  # noqa
     import sqlalchemy  # noqa
 except ImportError:
 except ImportError:
@@ -33,7 +32,7 @@ class SomeClass(object):
 
 
 
 
 @skip.unless_module('sqlalchemy')
 @skip.unless_module('sqlalchemy')
-class test_session_cleanup(AppCase):
+class test_session_cleanup:
 
 
     def test_context(self):
     def test_context(self):
         session = Mock(name='session')
         session = Mock(name='session')
@@ -43,7 +42,7 @@ class test_session_cleanup(AppCase):
 
 
     def test_context_raises(self):
     def test_context_raises(self):
         session = Mock(name='session')
         session = Mock(name='session')
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             with session_cleanup(session):
             with session_cleanup(session):
                 raise KeyError()
                 raise KeyError()
         session.rollback.assert_called_with()
         session.rollback.assert_called_with()
@@ -53,7 +52,7 @@ class test_session_cleanup(AppCase):
 @skip.unless_module('sqlalchemy')
 @skip.unless_module('sqlalchemy')
 @skip.if_pypy()
 @skip.if_pypy()
 @skip.if_jython()
 @skip.if_jython()
-class test_DatabaseBackend(AppCase):
+class test_DatabaseBackend:
 
 
     def setup(self):
     def setup(self):
         self.uri = 'sqlite:///test.db'
         self.uri = 'sqlite:///test.db'
@@ -69,39 +68,38 @@ class test_DatabaseBackend(AppCase):
             calls[0] += 1
             calls[0] += 1
             raise DatabaseError(1, 2, 3)
             raise DatabaseError(1, 2, 3)
 
 
-        with self.assertRaises(DatabaseError):
+        with pytest.raises(DatabaseError):
             raises(max_retries=5)
             raises(max_retries=5)
-        self.assertEqual(calls[0], 5)
+        assert calls[0] == 5
 
 
     def test_missing_dburi_raises_ImproperlyConfigured(self):
     def test_missing_dburi_raises_ImproperlyConfigured(self):
         self.app.conf.sqlalchemy_dburi = None
         self.app.conf.sqlalchemy_dburi = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             DatabaseBackend(app=self.app)
             DatabaseBackend(app=self.app)
 
 
     def test_missing_task_id_is_PENDING(self):
     def test_missing_task_id_is_PENDING(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertEqual(tb.get_state('xxx-does-not-exist'), states.PENDING)
+        assert tb.get_state('xxx-does-not-exist') == states.PENDING
 
 
     def test_missing_task_meta_is_dict_with_pending(self):
     def test_missing_task_meta_is_dict_with_pending(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertDictContainsSubset({
-            'status': states.PENDING,
-            'task_id': 'xxx-does-not-exist-at-all',
-            'result': None,
-            'traceback': None
-        }, tb.get_task_meta('xxx-does-not-exist-at-all'))
+        meta = tb.get_task_meta('xxx-does-not-exist-at-all')
+        assert meta['status'] == states.PENDING
+        assert meta['task_id'] == 'xxx-does-not-exist-at-all'
+        assert meta['result'] is None
+        assert meta['traceback'] is None
 
 
     def test_mark_as_done(self):
     def test_mark_as_done(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid = uuid()
         tid = uuid()
 
 
-        self.assertEqual(tb.get_state(tid), states.PENDING)
-        self.assertIsNone(tb.get_result(tid))
+        assert tb.get_state(tid) == states.PENDING
+        assert tb.get_result(tid) is None
 
 
         tb.mark_as_done(tid, 42)
         tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_state(tid), states.SUCCESS)
-        self.assertEqual(tb.get_result(tid), 42)
+        assert tb.get_state(tid) == states.SUCCESS
+        assert tb.get_result(tid) == 42
 
 
     def test_is_pickled(self):
     def test_is_pickled(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -111,20 +109,20 @@ class test_DatabaseBackend(AppCase):
         tb.mark_as_done(tid2, result)
         tb.mark_as_done(tid2, result)
         # is serialized properly.
         # is serialized properly.
         rindb = tb.get_result(tid2)
         rindb = tb.get_result(tid2)
-        self.assertEqual(rindb.get('foo'), 'baz')
-        self.assertEqual(rindb.get('bar').data, 12345)
+        assert rindb.get('foo') == 'baz'
+        assert rindb.get('bar').data == 12345
 
 
     def test_mark_as_started(self):
     def test_mark_as_started(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_started(tid)
         tb.mark_as_started(tid)
-        self.assertEqual(tb.get_state(tid), states.STARTED)
+        assert tb.get_state(tid) == states.STARTED
 
 
     def test_mark_as_revoked(self):
     def test_mark_as_revoked(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
         tid = uuid()
         tid = uuid()
         tb.mark_as_revoked(tid)
         tb.mark_as_revoked(tid)
-        self.assertEqual(tb.get_state(tid), states.REVOKED)
+        assert tb.get_state(tid) == states.REVOKED
 
 
     def test_mark_as_retry(self):
     def test_mark_as_retry(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -135,9 +133,9 @@ class test_DatabaseBackend(AppCase):
             import traceback
             import traceback
             trace = '\n'.join(traceback.format_stack())
             trace = '\n'.join(traceback.format_stack())
             tb.mark_as_retry(tid, exception, traceback=trace)
             tb.mark_as_retry(tid, exception, traceback=trace)
-            self.assertEqual(tb.get_state(tid), states.RETRY)
-            self.assertIsInstance(tb.get_result(tid), KeyError)
-            self.assertEqual(tb.get_traceback(tid), trace)
+            assert tb.get_state(tid) == states.RETRY
+            assert isinstance(tb.get_result(tid), KeyError)
+            assert tb.get_traceback(tid) == trace
 
 
     def test_mark_as_failure(self):
     def test_mark_as_failure(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -149,9 +147,9 @@ class test_DatabaseBackend(AppCase):
             import traceback
             import traceback
             trace = '\n'.join(traceback.format_stack())
             trace = '\n'.join(traceback.format_stack())
             tb.mark_as_failure(tid3, exception, traceback=trace)
             tb.mark_as_failure(tid3, exception, traceback=trace)
-            self.assertEqual(tb.get_state(tid3), states.FAILURE)
-            self.assertIsInstance(tb.get_result(tid3), KeyError)
-            self.assertEqual(tb.get_traceback(tid3), trace)
+            assert tb.get_state(tid3) == states.FAILURE
+            assert isinstance(tb.get_result(tid3), KeyError)
+            assert tb.get_traceback(tid3) == trace
 
 
     def test_forget(self):
     def test_forget(self):
         tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
         tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
@@ -160,31 +158,31 @@ class test_DatabaseBackend(AppCase):
         tb.mark_as_done(tid, {'foo': 'bar'})
         tb.mark_as_done(tid, {'foo': 'bar'})
         x = self.app.AsyncResult(tid, backend=tb)
         x = self.app.AsyncResult(tid, backend=tb)
         x.forget()
         x.forget()
-        self.assertIsNone(x.result)
+        assert x.result is None
 
 
     def test_process_cleanup(self):
     def test_process_cleanup(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
         tb.process_cleanup()
         tb.process_cleanup()
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self):
     def test_reduce(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
-        self.assertTrue(loads(dumps(tb)))
+        assert loads(dumps(tb))
 
 
     def test_save__restore__delete_group(self):
     def test_save__restore__delete_group(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
 
 
         tid = uuid()
         tid = uuid()
         res = {'something': 'special'}
         res = {'something': 'special'}
-        self.assertEqual(tb.save_group(tid, res), res)
+        assert tb.save_group(tid, res) == res
 
 
         res2 = tb.restore_group(tid)
         res2 = tb.restore_group(tid)
-        self.assertEqual(res2, res)
+        assert res2 == res
 
 
         tb.delete_group(tid)
         tb.delete_group(tid)
-        self.assertIsNone(tb.restore_group(tid))
+        assert tb.restore_group(tid) is None
 
 
-        self.assertIsNone(tb.restore_group('xxx-nonexisting-id'))
+        assert tb.restore_group('xxx-nonexisting-id') is None
 
 
     def test_cleanup(self):
     def test_cleanup(self):
         tb = DatabaseBackend(self.uri, app=self.app)
         tb = DatabaseBackend(self.uri, app=self.app)
@@ -202,20 +200,20 @@ class test_DatabaseBackend(AppCase):
         tb.cleanup()
         tb.cleanup()
 
 
     def test_Task__repr__(self):
     def test_Task__repr__(self):
-        self.assertIn('foo', repr(Task('foo')))
+        assert 'foo' in repr(Task('foo'))
 
 
     def test_TaskSet__repr__(self):
     def test_TaskSet__repr__(self):
-        self.assertIn('foo', repr(TaskSet('foo', None)))
+        assert 'foo', repr(TaskSet('foo' in None))
 
 
 
 
 @skip.unless_module('sqlalchemy')
 @skip.unless_module('sqlalchemy')
-class test_SessionManager(AppCase):
+class test_SessionManager:
 
 
     def test_after_fork(self):
     def test_after_fork(self):
         s = SessionManager()
         s = SessionManager()
-        self.assertFalse(s.forked)
+        assert not s.forked
         s._after_fork()
         s._after_fork()
-        self.assertTrue(s.forked)
+        assert s.forked
 
 
     @patch('celery.backends.database.session.create_engine')
     @patch('celery.backends.database.session.create_engine')
     def test_get_engine_forked(self, create_engine):
     def test_get_engine_forked(self, create_engine):
@@ -223,9 +221,9 @@ class test_SessionManager(AppCase):
         s._after_fork()
         s._after_fork()
         engine = s.get_engine('dburi', foo=1)
         engine = s.get_engine('dburi', foo=1)
         create_engine.assert_called_with('dburi', foo=1)
         create_engine.assert_called_with('dburi', foo=1)
-        self.assertIs(engine, create_engine())
+        assert engine is create_engine()
         engine2 = s.get_engine('dburi', foo=1)
         engine2 = s.get_engine('dburi', foo=1)
-        self.assertIs(engine2, engine)
+        assert engine2 is engine
 
 
     @patch('celery.backends.database.session.sessionmaker')
     @patch('celery.backends.database.session.sessionmaker')
     def test_create_session_forked(self, sessionmaker):
     def test_create_session_forked(self, sessionmaker):
@@ -234,16 +232,16 @@ class test_SessionManager(AppCase):
         s._after_fork()
         s._after_fork()
         engine, session = s.create_session('dburi', short_lived_sessions=True)
         engine, session = s.create_session('dburi', short_lived_sessions=True)
         sessionmaker.assert_called_with(bind=s.get_engine())
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIs(session, sessionmaker())
+        assert session is sessionmaker()
         sessionmaker.return_value = Mock(name='new')
         sessionmaker.return_value = Mock(name='new')
         engine, session2 = s.create_session('dburi', short_lived_sessions=True)
         engine, session2 = s.create_session('dburi', short_lived_sessions=True)
         sessionmaker.assert_called_with(bind=s.get_engine())
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIsNot(session2, session)
+        assert session2 is not session
         sessionmaker.return_value = Mock(name='new2')
         sessionmaker.return_value = Mock(name='new2')
         engine, session3 = s.create_session(
         engine, session3 = s.create_session(
             'dburi', short_lived_sessions=False)
             'dburi', short_lived_sessions=False)
         sessionmaker.assert_called_with(bind=s.get_engine())
         sessionmaker.assert_called_with(bind=s.get_engine())
-        self.assertIs(session3, session2)
+        assert session3 is session2
 
 
     def test_coverage_madness(self):
     def test_coverage_madness(self):
         prev, session.register_after_fork = (
         prev, session.register_after_fork = (

+ 16 - 14
celery/tests/backends/test_elasticsearch.py → t/unit/backends/test_elasticsearch.py

@@ -1,15 +1,17 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import Mock, sentinel, skip
+
 from celery import backends
 from celery import backends
 from celery.backends import elasticsearch as module
 from celery.backends import elasticsearch as module
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.backends.elasticsearch import ElasticsearchBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
-from celery.tests.case import AppCase, Mock, sentinel, skip
-
 
 
 @skip.unless_module('elasticsearch')
 @skip.unless_module('elasticsearch')
-class test_ElasticsearchBackend(AppCase):
+class test_ElasticsearchBackend:
 
 
     def setup(self):
     def setup(self):
         self.backend = ElasticsearchBackend(app=self.app)
         self.backend = ElasticsearchBackend(app=self.app)
@@ -17,7 +19,7 @@ class test_ElasticsearchBackend(AppCase):
     def test_init_no_elasticsearch(self):
     def test_init_no_elasticsearch(self):
         prev, module.elasticsearch = module.elasticsearch, None
         prev, module.elasticsearch = module.elasticsearch, None
         try:
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 ElasticsearchBackend(app=self.app)
                 ElasticsearchBackend(app=self.app)
         finally:
         finally:
             module.elasticsearch = prev
             module.elasticsearch = prev
@@ -31,7 +33,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.get.return_value = r
         x._server.get.return_value = r
         dict_result = x.get(sentinel.task_id)
         dict_result = x.get(sentinel.task_id)
 
 
-        self.assertEqual(dict_result, sentinel.result)
+        assert dict_result == sentinel.result
         x._server.get.assert_called_once_with(
         x._server.get.assert_called_once_with(
             doc_type=x.doc_type,
             doc_type=x.doc_type,
             id=sentinel.task_id,
             id=sentinel.task_id,
@@ -45,7 +47,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.get.return_value = sentinel.result
         x._server.get.return_value = sentinel.result
         none_result = x.get(sentinel.task_id)
         none_result = x.get(sentinel.task_id)
 
 
-        self.assertEqual(none_result, None)
+        assert none_result is None
         x._server.get.assert_called_once_with(
         x._server.get.assert_called_once_with(
             doc_type=x.doc_type,
             doc_type=x.doc_type,
             id=sentinel.task_id,
             id=sentinel.task_id,
@@ -58,7 +60,7 @@ class test_ElasticsearchBackend(AppCase):
         x._server.delete = Mock()
         x._server.delete = Mock()
         x._server.delete.return_value = sentinel.result
         x._server.delete.return_value = sentinel.result
 
 
-        self.assertIsNone(x.delete(sentinel.task_id), sentinel.result)
+        assert x.delete(sentinel.task_id) is None
         x._server.delete.assert_called_once_with(
         x._server.delete.assert_called_once_with(
             doc_type=x.doc_type,
             doc_type=x.doc_type,
             id=sentinel.task_id,
             id=sentinel.task_id,
@@ -68,16 +70,16 @@ class test_ElasticsearchBackend(AppCase):
     def test_backend_by_url(self, url='elasticsearch://localhost:9200/index'):
     def test_backend_by_url(self, url='elasticsearch://localhost:9200/index'):
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
 
 
-        self.assertIs(backend, ElasticsearchBackend)
-        self.assertEqual(url_, url)
+        assert backend is ElasticsearchBackend
+        assert url_ == url
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
         url = 'elasticsearch://localhost:9200/index/doc_type'
         url = 'elasticsearch://localhost:9200/index/doc_type'
         with self.Celery(backend=url) as app:
         with self.Celery(backend=url) as app:
             x = app.backend
             x = app.backend
 
 
-            self.assertEqual(x.index, 'index')
-            self.assertEqual(x.doc_type, 'doc_type')
-            self.assertEqual(x.scheme, 'elasticsearch')
-            self.assertEqual(x.host, 'localhost')
-            self.assertEqual(x.port, 9200)
+            assert x.index == 'index'
+            assert x.doc_type == 'doc_type'
+            assert x.scheme == 'elasticsearch'
+            assert x.host == 'localhost'
+            assert x.port == 9200

+ 13 - 16
celery/tests/backends/test_filesystem.py → t/unit/backends/test_filesystem.py

@@ -2,54 +2,51 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import os
 import os
-import shutil
+import pytest
 import tempfile
 import tempfile
 
 
+from case import skip
+
 from celery import uuid
 from celery import uuid
 from celery import states
 from celery import states
 from celery.backends.filesystem import FilesystemBackend
 from celery.backends.filesystem import FilesystemBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
 
 
-from celery.tests.case import AppCase, skip
-
 
 
 @skip.if_win32()
 @skip.if_win32()
-class test_FilesystemBackend(AppCase):
+class test_FilesystemBackend:
 
 
     def setup(self):
     def setup(self):
         self.directory = tempfile.mkdtemp()
         self.directory = tempfile.mkdtemp()
         self.url = 'file://' + self.directory
         self.url = 'file://' + self.directory
         self.path = self.directory.encode('ascii')
         self.path = self.directory.encode('ascii')
 
 
-    def teardown(self):
-        shutil.rmtree(self.directory)
-
     def test_a_path_is_required(self):
     def test_a_path_is_required(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             FilesystemBackend(app=self.app)
             FilesystemBackend(app=self.app)
 
 
     def test_a_path_in_url(self):
     def test_a_path_in_url(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
-        self.assertEqual(tb.path, self.path)
+        assert tb.path == self.path
 
 
     def test_path_is_incorrect(self):
     def test_path_is_incorrect(self):
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             FilesystemBackend(app=self.app, url=self.url + '-incorrect')
             FilesystemBackend(app=self.app, url=self.url + '-incorrect')
 
 
     def test_missing_task_is_PENDING(self):
     def test_missing_task_is_PENDING(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
-        self.assertEqual(tb.get_state('xxx-does-not-exist'), states.PENDING)
+        assert tb.get_state('xxx-does-not-exist') == states.PENDING
 
 
     def test_mark_as_done_writes_file(self):
     def test_mark_as_done_writes_file(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb.mark_as_done(uuid(), 42)
         tb.mark_as_done(uuid(), 42)
-        self.assertEqual(len(os.listdir(self.directory)), 1)
+        assert len(os.listdir(self.directory)) == 1
 
 
     def test_done_task_is_SUCCESS(self):
     def test_done_task_is_SUCCESS(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, 42)
         tb.mark_as_done(tid, 42)
-        self.assertEqual(tb.get_state(tid), states.SUCCESS)
+        assert tb.get_state(tid) == states.SUCCESS
 
 
     def test_correct_result(self):
     def test_correct_result(self):
         data = {'foo': 'bar'}
         data = {'foo': 'bar'}
@@ -57,7 +54,7 @@ class test_FilesystemBackend(AppCase):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, data)
         tb.mark_as_done(tid, data)
-        self.assertEqual(tb.get_result(tid), data)
+        assert tb.get_result(tid) == data
 
 
     def test_get_many(self):
     def test_get_many(self):
         data = {uuid(): 'foo', uuid(): 'bar', uuid(): 'baz'}
         data = {uuid(): 'foo', uuid(): 'bar', uuid(): 'baz'}
@@ -67,11 +64,11 @@ class test_FilesystemBackend(AppCase):
             tb.mark_as_done(key, value)
             tb.mark_as_done(key, value)
 
 
         for key, result in tb.get_many(data.keys()):
         for key, result in tb.get_many(data.keys()):
-            self.assertEqual(result['result'], data[key])
+            assert result['result'] == data[key]
 
 
     def test_forget_deletes_file(self):
     def test_forget_deletes_file(self):
         tb = FilesystemBackend(app=self.app, url=self.url)
         tb = FilesystemBackend(app=self.app, url=self.url)
         tid = uuid()
         tid = uuid()
         tb.mark_as_done(tid, 42)
         tb.mark_as_done(tid, 42)
         tb.forget(tid)
         tb.forget(tid)
-        self.assertEqual(len(os.listdir(self.directory)), 0)
+        assert len(os.listdir(self.directory)) == 0

+ 80 - 92
celery/tests/backends/test_mongodb.py → t/unit/backends/test_mongodb.py

@@ -1,20 +1,18 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 import datetime
 import datetime
 
 
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
+from case import ANY, MagicMock, Mock, mock, patch, sentinel, skip
 from kombu.exceptions import EncodeError
 from kombu.exceptions import EncodeError
 
 
 from celery import uuid
 from celery import uuid
 from celery import states
 from celery import states
-from celery.backends import mongodb as module
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.backends.mongodb import InvalidDocument, MongoBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import (
-    ANY, AppCase, MagicMock, Mock,
-    mock, depends_on_current_app, patch, sentinel, skip,
-)
 
 
 COLLECTION = 'taskmeta_celery'
 COLLECTION = 'taskmeta_celery'
 TASK_ID = uuid()
 TASK_ID = uuid()
@@ -28,7 +26,7 @@ MONGODB_GROUP_COLLECTION = 'group_collection1'
 
 
 
 
 @skip.unless_module('pymongo')
 @skip.unless_module('pymongo')
-class test_MongoBackend(AppCase):
+class test_MongoBackend:
 
 
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
     default_url = 'mongodb://uuuu:pwpw@hostname.dom/database'
     replica_set_url = (
     replica_set_url = (
@@ -42,31 +40,20 @@ class test_MongoBackend(AppCase):
     )
     )
 
 
     def setup(self):
     def setup(self):
-        R = self._reset = {}
-        R['encode'], MongoBackend.encode = MongoBackend.encode, Mock()
-        R['decode'], MongoBackend.decode = MongoBackend.decode, Mock()
-        R['Binary'], module.Binary = module.Binary, Mock()
-        R['datetime'], datetime.datetime = datetime.datetime, Mock()
-
+        self.patching('celery.backends.mongodb.MongoBackend.encode')
+        self.patching('celery.backends.mongodb.MongoBackend.decode')
+        self.patching('celery.backends.mongodb.Binary')
+        self.patching('datetime.datetime')
         self.backend = MongoBackend(app=self.app, url=self.default_url)
         self.backend = MongoBackend(app=self.app, url=self.default_url)
 
 
-    def teardown(self):
-        MongoBackend.encode = self._reset['encode']
-        MongoBackend.decode = self._reset['decode']
-        module.Binary = self._reset['Binary']
-        datetime.datetime = self._reset['datetime']
-
-    def test_init_no_mongodb(self):
-        prev, module.pymongo = module.pymongo, None
-        try:
-            with self.assertRaises(ImproperlyConfigured):
-                MongoBackend(app=self.app)
-        finally:
-            module.pymongo = prev
+    def test_init_no_mongodb(self, patching):
+        patching('celery.backends.mongodb.pymongo', None)
+        with pytest.raises(ImproperlyConfigured):
+            MongoBackend(app=self.app)
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
         self.app.conf.mongodb_backend_settings = []
         self.app.conf.mongodb_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             MongoBackend(app=self.app)
             MongoBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
@@ -81,14 +68,14 @@ class test_MongoBackend(AppCase):
         # uri
         # uri
         uri = 'mongodb://localhost:27017'
         uri = 'mongodb://localhost:27017'
         mb = MongoBackend(app=self.app, url=uri)
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['localhost:27017'])
-        self.assertEqual(mb.options, mb._prepare_client_options())
-        self.assertEqual(mb.database_name, 'celery')
+        assert mb.mongo_host == ['localhost:27017']
+        assert mb.options == mb._prepare_client_options()
+        assert mb.database_name == 'celery'
 
 
         # uri with database name
         # uri with database name
         uri = 'mongodb://localhost:27017/celerydb'
         uri = 'mongodb://localhost:27017/celerydb'
         mb = MongoBackend(app=self.app, url=uri)
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.database_name, 'celerydb')
+        assert mb.database_name == 'celerydb'
 
 
         # uri with user, password, database name, replica set
         # uri with user, password, database name, replica set
         uri = ('mongodb://'
         uri = ('mongodb://'
@@ -98,15 +85,18 @@ class test_MongoBackend(AppCase):
                'mongo3.example.com:27017/'
                'mongo3.example.com:27017/'
                'celerydatabase?replicaSet=rs0')
                'celerydatabase?replicaSet=rs0')
         mb = MongoBackend(app=self.app, url=uri)
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['mongo1.example.com:27017',
-                                         'mongo2.example.com:27017',
-                                         'mongo3.example.com:27017'])
-        self.assertEqual(
-            mb.options, dict(mb._prepare_client_options(), replicaset='rs0'),
+        assert mb.mongo_host == [
+            'mongo1.example.com:27017',
+            'mongo2.example.com:27017',
+            'mongo3.example.com:27017',
+        ]
+        assert mb.options == dict(
+            mb._prepare_client_options(),
+            replicaset='rs0',
         )
         )
-        self.assertEqual(mb.user, 'celeryuser')
-        self.assertEqual(mb.password, 'celerypassword')
-        self.assertEqual(mb.database_name, 'celerydatabase')
+        assert mb.user == 'celeryuser'
+        assert mb.password == 'celerypassword'
+        assert mb.database_name == 'celerydatabase'
 
 
         # same uri, change some parameters in backend settings
         # same uri, change some parameters in backend settings
         self.app.conf.mongodb_backend_settings = {
         self.app.conf.mongodb_backend_settings = {
@@ -118,23 +108,26 @@ class test_MongoBackend(AppCase):
             },
             },
         }
         }
         mb = MongoBackend(app=self.app, url=uri)
         mb = MongoBackend(app=self.app, url=uri)
-        self.assertEqual(mb.mongo_host, ['mongo1.example.com:27017',
-                                         'mongo2.example.com:27017',
-                                         'mongo3.example.com:27017'])
-        self.assertEqual(
-            mb.options, dict(mb._prepare_client_options(),
-                             replicaset='rs1', socketKeepAlive=True),
+        assert mb.mongo_host == [
+            'mongo1.example.com:27017',
+            'mongo2.example.com:27017',
+            'mongo3.example.com:27017',
+        ]
+        assert mb.options == dict(
+            mb._prepare_client_options(),
+            replicaset='rs1',
+            socketKeepAlive=True,
         )
         )
-        self.assertEqual(mb.user, 'backenduser')
-        self.assertEqual(mb.password, 'celerypassword')
-        self.assertEqual(mb.database_name, 'another_db')
+        assert mb.user == 'backenduser'
+        assert mb.password == 'celerypassword'
+        assert mb.database_name == 'another_db'
 
 
         mb = MongoBackend(app=self.app, url='mongodb://')
         mb = MongoBackend(app=self.app, url='mongodb://')
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     def test_reduce(self):
     def test_reduce(self):
         x = MongoBackend(app=self.app)
         x = MongoBackend(app=self.app)
-        self.assertTrue(loads(dumps(x)))
+        assert loads(dumps(x))
 
 
     def test_get_connection_connection_exists(self):
     def test_get_connection_connection_exists(self):
         with patch('pymongo.MongoClient') as mock_Connection:
         with patch('pymongo.MongoClient') as mock_Connection:
@@ -142,7 +135,7 @@ class test_MongoBackend(AppCase):
 
 
             connection = self.backend._get_connection()
             connection = self.backend._get_connection()
 
 
-            self.assertEqual(sentinel._connection, connection)
+            assert sentinel._connection == connection
             mock_Connection.assert_not_called()
             mock_Connection.assert_not_called()
 
 
     def test_get_connection_no_connection_host(self):
     def test_get_connection_no_connection_host(self):
@@ -157,7 +150,7 @@ class test_MongoBackend(AppCase):
                 host='mongodb://localhost:27017',
                 host='mongodb://localhost:27017',
                 **self.backend._prepare_client_options()
                 **self.backend._prepare_client_options()
             )
             )
-            self.assertEqual(sentinel.connection, connection)
+            assert sentinel.connection == connection
 
 
     def test_get_connection_no_connection_mongodb_uri(self):
     def test_get_connection_no_connection_mongodb_uri(self):
         with patch('pymongo.MongoClient') as mock_Connection:
         with patch('pymongo.MongoClient') as mock_Connection:
@@ -171,7 +164,7 @@ class test_MongoBackend(AppCase):
             mock_Connection.assert_called_once_with(
             mock_Connection.assert_called_once_with(
                 host=mongodb_uri, **self.backend._prepare_client_options()
                 host=mongodb_uri, **self.backend._prepare_client_options()
             )
             )
-            self.assertEqual(sentinel.connection, connection)
+            assert sentinel.connection == connection
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_connection')
     @patch('celery.backends.mongodb.MongoBackend._get_connection')
     def test_get_database_no_existing(self, mock_get_connection):
     def test_get_database_no_existing(self, mock_get_connection):
@@ -186,8 +179,8 @@ class test_MongoBackend(AppCase):
 
 
         database = self.backend.database
         database = self.backend.database
 
 
-        self.assertTrue(database is mock_database)
-        self.assertTrue(self.backend.__dict__['database'] is mock_database)
+        assert database is mock_database
+        assert self.backend.__dict__['database'] is mock_database
         mock_database.authenticate.assert_called_once_with(
         mock_database.authenticate.assert_called_once_with(
             MONGODB_USER, MONGODB_PASSWORD)
             MONGODB_USER, MONGODB_PASSWORD)
 
 
@@ -204,9 +197,9 @@ class test_MongoBackend(AppCase):
 
 
         database = self.backend.database
         database = self.backend.database
 
 
-        self.assertTrue(database is mock_database)
+        assert database is mock_database
         mock_database.authenticate.assert_not_called()
         mock_database.authenticate.assert_not_called()
-        self.assertTrue(self.backend.__dict__['database'] is mock_database)
+        assert self.backend.__dict__['database'] is mock_database
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_store_result(self, mock_get_database):
     def test_store_result(self, mock_get_database):
@@ -224,16 +217,15 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         mock_collection.save.assert_called_once_with(ANY)
         mock_collection.save.assert_called_once_with(ANY)
-        self.assertEqual(sentinel.result, ret_val)
+        assert sentinel.result == ret_val
 
 
         mock_collection.save.side_effect = InvalidDocument()
         mock_collection.save.side_effect = InvalidDocument()
-        with self.assertRaises(EncodeError):
+        with pytest.raises(EncodeError):
             self.backend._store_result(
             self.backend._store_result(
                 sentinel.task_id, sentinel.result, sentinel.status)
                 sentinel.task_id, sentinel.result, sentinel.status)
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_get_task_meta_for(self, mock_get_database):
     def test_get_task_meta_for(self, mock_get_database):
-        datetime.datetime = self._reset['datetime']
         self.backend.taskmeta_collection = MONGODB_COLLECTION
         self.backend.taskmeta_collection = MONGODB_COLLECTION
 
 
         mock_database = MagicMock(spec=['__getitem__', '__setitem__'])
         mock_database = MagicMock(spec=['__getitem__', '__setitem__'])
@@ -247,11 +239,10 @@ class test_MongoBackend(AppCase):
 
 
         mock_get_database.assert_called_once_with()
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
-        self.assertEqual(
-            list(sorted(['status', 'task_id', 'date_done', 'traceback',
-                         'result', 'children'])),
-            list(sorted(ret_val.keys())),
-        )
+        assert list(sorted([
+            'status', 'task_id', 'date_done',
+            'traceback', 'result', 'children',
+        ])) == list(sorted(ret_val.keys()))
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_get_task_meta_for_no_result(self, mock_get_database):
     def test_get_task_meta_for_no_result(self, mock_get_database):
@@ -268,7 +259,7 @@ class test_MongoBackend(AppCase):
 
 
         mock_get_database.assert_called_once_with()
         mock_get_database.assert_called_once_with()
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
         mock_database.__getitem__.assert_called_once_with(MONGODB_COLLECTION)
-        self.assertEqual({'status': states.PENDING, 'result': None}, ret_val)
+        assert {'status': states.PENDING, 'result': None} == ret_val
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_save_group(self, mock_get_database):
     def test_save_group(self, mock_get_database):
@@ -288,7 +279,7 @@ class test_MongoBackend(AppCase):
             MONGODB_GROUP_COLLECTION,
             MONGODB_GROUP_COLLECTION,
         )
         )
         mock_collection.save.assert_called_once_with(ANY)
         mock_collection.save.assert_called_once_with(ANY)
-        self.assertEqual(res, ret_val)
+        assert res == ret_val
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_restore_group(self, mock_get_database):
     def test_restore_group(self, mock_get_database):
@@ -311,10 +302,8 @@ class test_MongoBackend(AppCase):
         mock_get_database.assert_called_once_with()
         mock_get_database.assert_called_once_with()
         mock_collection.find_one.assert_called_once_with(
         mock_collection.find_one.assert_called_once_with(
             {'_id': sentinel.taskset_id})
             {'_id': sentinel.taskset_id})
-        self.assertItemsEqual(
-            ['date_done', 'result', 'task_id'],
-            list(ret_val.keys()),
-        )
+        assert (sorted(['date_done', 'result', 'task_id']) ==
+                sorted(list(ret_val.keys())))
 
 
         mock_collection.find_one.return_value = None
         mock_collection.find_one.return_value = None
         self.backend._restore_group(sentinel.taskset_id)
         self.backend._restore_group(sentinel.taskset_id)
@@ -355,7 +344,6 @@ class test_MongoBackend(AppCase):
 
 
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     @patch('celery.backends.mongodb.MongoBackend._get_database')
     def test_cleanup(self, mock_get_database):
     def test_cleanup(self, mock_get_database):
-        datetime.datetime = self._reset['datetime']
         self.backend.taskmeta_collection = MONGODB_COLLECTION
         self.backend.taskmeta_collection = MONGODB_COLLECTION
         self.backend.groupmeta_collection = MONGODB_GROUP_COLLECTION
         self.backend.groupmeta_collection = MONGODB_GROUP_COLLECTION
 
 
@@ -381,56 +369,56 @@ class test_MongoBackend(AppCase):
         db.authenticate.return_value = False
         db.authenticate.return_value = False
         x.user = 'jerry'
         x.user = 'jerry'
         x.password = 'cere4l'
         x.password = 'cere4l'
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             x._get_database()
             x._get_database()
         db.authenticate.assert_called_with('jerry', 'cere4l')
         db.authenticate.assert_called_with('jerry', 'cere4l')
 
 
     def test_prepare_client_options(self):
     def test_prepare_client_options(self):
         with patch('pymongo.version_tuple', new=(3, 0, 3)):
         with patch('pymongo.version_tuple', new=(3, 0, 3)):
             options = self.backend._prepare_client_options()
             options = self.backend._prepare_client_options()
-            self.assertDictEqual(options, {
+            assert options == {
                 'maxPoolSize': self.backend.max_pool_size
                 'maxPoolSize': self.backend.max_pool_size
-            })
+            }
 
 
     def test_as_uri_include_password(self):
     def test_as_uri_include_password(self):
-        self.assertEqual(self.backend.as_uri(True), self.default_url)
+        assert self.backend.as_uri(True) == self.default_url
 
 
     def test_as_uri_exclude_password(self):
     def test_as_uri_exclude_password(self):
-        self.assertEqual(self.backend.as_uri(), self.sanitized_default_url)
+        assert self.backend.as_uri() == self.sanitized_default_url
 
 
     def test_as_uri_include_password_replica_set(self):
     def test_as_uri_include_password_replica_set(self):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
-        self.assertEqual(backend.as_uri(True), self.replica_set_url)
+        assert backend.as_uri(True) == self.replica_set_url
 
 
     def test_as_uri_exclude_password_replica_set(self):
     def test_as_uri_exclude_password_replica_set(self):
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
         backend = MongoBackend(app=self.app, url=self.replica_set_url)
-        self.assertEqual(backend.as_uri(), self.sanitized_replica_set_url)
+        assert backend.as_uri() == self.sanitized_replica_set_url
 
 
-    @mock.stdouts
-    def test_regression_worker_startup_info(self, stdout, stderr):
+    def test_regression_worker_startup_info(self):
         self.app.conf.result_backend = (
         self.app.conf.result_backend = (
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             'mongodb://user:password@host0.com:43437,host1.com:43437'
             '/work4us?replicaSet=rs&ssl=true'
             '/work4us?replicaSet=rs&ssl=true'
         )
         )
         worker = self.app.Worker()
         worker = self.app.Worker()
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
+        with mock.stdouts():
+            worker.on_start()
+            assert worker.startup_info()
 
 
 
 
 @skip.unless_module('pymongo')
 @skip.unless_module('pymongo')
-class test_MongoBackend_no_mock(AppCase):
+class test_MongoBackend_no_mock:
 
 
-    def test_encode_decode(self):
-        backend = MongoBackend(app=self.app)
+    def test_encode_decode(self, app):
+        backend = MongoBackend(app=app)
         data = {'foo': 1}
         data = {'foo': 1}
-        self.assertTrue(backend.decode(backend.encode(data)))
+        assert backend.decode(backend.encode(data))
         backend.serializer = 'bson'
         backend.serializer = 'bson'
-        self.assertEquals(backend.encode(data), data)
-        self.assertEquals(backend.decode(data), data)
+        assert backend.encode(data) == data
+        assert backend.decode(data) == data
 
 
-    def test_de(self):
-        backend = MongoBackend(app=self.app)
+    def test_de(self, app):
+        backend = MongoBackend(app=app)
         data = {'foo': 1}
         data = {'foo': 1}
-        self.assertTrue(backend.encode(data))
+        assert backend.encode(data)
         backend.serializer = 'bson'
         backend.serializer = 'bson'
-        self.assertEquals(backend.encode(data), data)
+        assert backend.encode(data) == data

+ 50 - 61
celery/tests/backends/test_redis.py → t/unit/backends/test_redis.py

@@ -1,21 +1,22 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from datetime import timedelta
 from datetime import timedelta
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 from pickle import loads, dumps
 from pickle import loads, dumps
 
 
+from case import ANY, ContextMock, Mock, mock, call, patch, skip
+
 from celery import signature
 from celery import signature
 from celery import states
 from celery import states
 from celery import uuid
 from celery import uuid
 from celery.canvas import Signature
 from celery.canvas import Signature
-from celery.exceptions import ChordError, ImproperlyConfigured
-from celery.utils.collections import AttributeDict
-
-from celery.tests.case import (
-    ANY, AppCase, ContextMock, Mock, mock,
-    call, depends_on_current_app, patch, skip,
+from celery.exceptions import (
+    ChordError, CPendingDeprecationWarning, ImproperlyConfigured,
 )
 )
+from celery.utils.collections import AttributeDict
 
 
 
 
 def raise_on_second_call(mock, exc, *retval):
 def raise_on_second_call(mock, exc, *retval):
@@ -122,7 +123,7 @@ class redis(object):
             pass
             pass
 
 
 
 
-class test_RedisBackend(AppCase):
+class test_RedisBackend:
 
 
     def get_backend(self):
     def get_backend(self):
         from celery.backends.redis import RedisBackend
         from celery.backends.redis import RedisBackend
@@ -141,54 +142,52 @@ class test_RedisBackend(AppCase):
         self.E_LOST = self.get_E_LOST()
         self.E_LOST = self.get_E_LOST()
         self.b = self.Backend(app=self.app)
         self.b = self.Backend(app=self.app)
 
 
-    @depends_on_current_app
+    @pytest.mark.usefixtures('depends_on_current_app')
     @skip.unless_module('redis')
     @skip.unless_module('redis')
     def test_reduce(self):
     def test_reduce(self):
         from celery.backends.redis import RedisBackend
         from celery.backends.redis import RedisBackend
         x = RedisBackend(app=self.app)
         x = RedisBackend(app=self.app)
-        self.assertTrue(loads(dumps(x)))
+        assert loads(dumps(x))
 
 
     def test_no_redis(self):
     def test_no_redis(self):
         self.Backend.redis = None
         self.Backend.redis = None
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             self.Backend(app=self.app)
             self.Backend(app=self.app)
 
 
     def test_url(self):
     def test_url(self):
         x = self.Backend(
         x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
             'redis://:bosco@vandelay.com:123//1', app=self.app,
         )
         )
-        self.assertTrue(x.connparams)
-        self.assertEqual(x.connparams['host'], 'vandelay.com')
-        self.assertEqual(x.connparams['db'], 1)
-        self.assertEqual(x.connparams['port'], 123)
-        self.assertEqual(x.connparams['password'], 'bosco')
+        assert x.connparams
+        assert x.connparams['host'] == 'vandelay.com'
+        assert x.connparams['db'] == 1
+        assert x.connparams['port'] == 123
+        assert x.connparams['password'] == 'bosco'
 
 
     def test_socket_url(self):
     def test_socket_url(self):
         x = self.Backend(
         x = self.Backend(
             'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
             'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
         )
         )
-        self.assertTrue(x.connparams)
-        self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
-        self.assertIs(
-            x.connparams['connection_class'],
-            redis.UnixDomainSocketConnection,
-        )
-        self.assertNotIn('host', x.connparams)
-        self.assertNotIn('port', x.connparams)
-        self.assertEqual(x.connparams['db'], 3)
+        assert x.connparams
+        assert x.connparams['path'] == '/tmp/redis.sock'
+        assert (x.connparams['connection_class'] is
+                redis.UnixDomainSocketConnection)
+        assert 'host' not in x.connparams
+        assert 'port' not in x.connparams
+        assert x.connparams['db'] == 3
 
 
     def test_compat_propertie(self):
     def test_compat_propertie(self):
         x = self.Backend(
         x = self.Backend(
             'redis://:bosco@vandelay.com:123//1', app=self.app,
             'redis://:bosco@vandelay.com:123//1', app=self.app,
         )
         )
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.host, 'vandelay.com')
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.db, 1)
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.port, 123)
-        with self.assertPendingDeprecation():
-            self.assertEqual(x.password, 'bosco')
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.host == 'vandelay.com'
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.db == 1
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.port == 123
+        with pytest.warns(CPendingDeprecationWarning):
+            assert x.password == 'bosco'
 
 
     def test_conf_raises_KeyError(self):
     def test_conf_raises_KeyError(self):
         self.app.conf = AttributeDict({
         self.app.conf = AttributeDict({
@@ -203,17 +202,11 @@ class test_RedisBackend(AppCase):
     def test_on_connection_error(self, error):
     def test_on_connection_error(self, error):
         intervals = iter([10, 20, 30])
         intervals = iter([10, 20, 30])
         exc = KeyError()
         exc = KeyError()
-        self.assertEqual(
-            self.b.on_connection_error(None, exc, intervals, 1), 10,
-        )
+        assert self.b.on_connection_error(None, exc, intervals, 1) == 10
         error.assert_called_with(self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
         error.assert_called_with(self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
-        self.assertEqual(
-            self.b.on_connection_error(10, exc, intervals, 2), 20,
-        )
+        assert self.b.on_connection_error(10, exc, intervals, 2) == 20
         error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
         error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
-        self.assertEqual(
-            self.b.on_connection_error(10, exc, intervals, 3), 30,
-        )
+        assert self.b.on_connection_error(10, exc, intervals, 3) == 30
         error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
         error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
 
 
     def test_incr(self):
     def test_incr(self):
@@ -229,7 +222,6 @@ class test_RedisBackend(AppCase):
     def test_apply_chord(self):
     def test_apply_chord(self):
         header = Mock(name='header')
         header = Mock(name='header')
         header.results = [Mock(name='t1'), Mock(name='t2')]
         header.results = [Mock(name='t1'), Mock(name='t2')]
-        print(self.b.apply_chord,)
         self.b.apply_chord(
         self.b.apply_chord(
             header, (1, 2), 'gid', None,
             header, (1, 2), 'gid', None,
             options={'max_retries': 10},
             options={'max_retries': 10},
@@ -241,7 +233,7 @@ class test_RedisBackend(AppCase):
         decode = Mock(name='decode')
         decode = Mock(name='decode')
         exc = KeyError()
         exc = KeyError()
         tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
         tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
-        with self.assertRaises(ChordError):
+        with pytest.raises(ChordError):
             self.b._unpack_chord_result(tup, decode)
             self.b._unpack_chord_result(tup, decode)
         decode.assert_called_with(tup)
         decode.assert_called_with(tup)
         self.b.exception_to_python.assert_called_with(exc)
         self.b.exception_to_python.assert_called_with(exc)
@@ -250,27 +242,27 @@ class test_RedisBackend(AppCase):
         tup = decode.return_value = (2, 'id2', states.RETRY, exc)
         tup = decode.return_value = (2, 'id2', states.RETRY, exc)
         ret = self.b._unpack_chord_result(tup, decode)
         ret = self.b._unpack_chord_result(tup, decode)
         self.b.exception_to_python.assert_called_with(exc)
         self.b.exception_to_python.assert_called_with(exc)
-        self.assertIs(ret, self.b.exception_to_python())
+        assert ret is self.b.exception_to_python()
 
 
     def test_on_chord_part_return_no_gid_or_tid(self):
     def test_on_chord_part_return_no_gid_or_tid(self):
         request = Mock(name='request')
         request = Mock(name='request')
         request.id = request.group = None
         request.id = request.group = None
-        self.assertIsNone(self.b.on_chord_part_return(request, 'SUCCESS', 10))
+        assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
 
 
     def test_ConnectionPool(self):
     def test_ConnectionPool(self):
         self.b.redis = Mock(name='redis')
         self.b.redis = Mock(name='redis')
-        self.assertIsNone(self.b._ConnectionPool)
-        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
-        self.assertIs(self.b.ConnectionPool, self.b.redis.ConnectionPool)
+        assert self.b._ConnectionPool is None
+        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
+        assert self.b.ConnectionPool is self.b.redis.ConnectionPool
 
 
     def test_expires_defaults_to_config(self):
     def test_expires_defaults_to_config(self):
         self.app.conf.result_expires = 10
         self.app.conf.result_expires = 10
         b = self.Backend(expires=None, app=self.app)
         b = self.Backend(expires=None, app=self.app)
-        self.assertEqual(b.expires, 10)
+        assert b.expires == 10
 
 
     def test_expires_is_int(self):
     def test_expires_is_int(self):
         b = self.Backend(expires=48, app=self.app)
         b = self.Backend(expires=48, app=self.app)
-        self.assertEqual(b.expires, 48)
+        assert b.expires == 48
 
 
     def test_add_to_chord(self):
     def test_add_to_chord(self):
         b = self.Backend('redis://', app=self.app)
         b = self.Backend('redis://', app=self.app)
@@ -280,17 +272,14 @@ class test_RedisBackend(AppCase):
 
 
     def test_expires_is_None(self):
     def test_expires_is_None(self):
         b = self.Backend(expires=None, app=self.app)
         b = self.Backend(expires=None, app=self.app)
-        self.assertEqual(
-            b.expires,
-            self.app.conf.result_expires.total_seconds(),
-        )
+        assert b.expires == self.app.conf.result_expires.total_seconds()
 
 
     def test_expires_is_timedelta(self):
     def test_expires_is_timedelta(self):
         b = self.Backend(expires=timedelta(minutes=1), app=self.app)
         b = self.Backend(expires=timedelta(minutes=1), app=self.app)
-        self.assertEqual(b.expires, 60)
+        assert b.expires == 60
 
 
     def test_mget(self):
     def test_mget(self):
-        self.assertTrue(self.b.mget(['a', 'b', 'c']))
+        assert self.b.mget(['a', 'b', 'c'])
         self.b.client.mget.assert_called_with(['a', 'b', 'c'])
         self.b.client.mget.assert_called_with(['a', 'b', 'c'])
 
 
     def test_set_no_expire(self):
     def test_set_no_expire(self):
@@ -314,9 +303,9 @@ class test_RedisBackend(AppCase):
 
 
         for i in range(10):
         for i in range(10):
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
             self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
-            self.assertTrue(self.b.client.rpush.call_count)
+            assert self.b.client.rpush.call_count
             self.b.client.rpush.reset_mock()
             self.b.client.rpush.reset_mock()
-        self.assertTrue(self.b.client.lrange.call_count)
+        assert self.b.client.lrange.call_count
         jkey = self.b.get_key_for_group('group_id', '.j')
         jkey = self.b.get_key_for_group('group_id', '.j')
         tkey = self.b.get_key_for_group('group_id', '.t')
         tkey = self.b.get_key_for_group('group_id', '.t')
         self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
         self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
@@ -383,10 +372,10 @@ class test_RedisBackend(AppCase):
     def test_get_set_forget(self):
     def test_get_set_forget(self):
         tid = uuid()
         tid = uuid()
         self.b.store_result(tid, 42, states.SUCCESS)
         self.b.store_result(tid, 42, states.SUCCESS)
-        self.assertEqual(self.b.get_state(tid), states.SUCCESS)
-        self.assertEqual(self.b.get_result(tid), 42)
+        assert self.b.get_state(tid) == states.SUCCESS
+        assert self.b.get_result(tid) == 42
         self.b.forget(tid)
         self.b.forget(tid)
-        self.assertEqual(self.b.get_state(tid), states.PENDING)
+        assert self.b.get_state(tid) == states.PENDING
 
 
     def test_set_expires(self):
     def test_set_expires(self):
         self.b = self.Backend(expires=512, app=self.app)
         self.b = self.Backend(expires=512, app=self.app)

+ 21 - 20
celery/tests/backends/test_riak.py → t/unit/backends/test_riak.py

@@ -1,17 +1,19 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
-
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import MagicMock, Mock, patch, sentinel, skip
+
 from celery.backends import riak as module
 from celery.backends import riak as module
 from celery.backends.riak import RiakBackend
 from celery.backends.riak import RiakBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.exceptions import ImproperlyConfigured
-from celery.tests.case import AppCase, MagicMock, Mock, patch, sentinel, skip
 
 
 RIAK_BUCKET = 'riak_bucket'
 RIAK_BUCKET = 'riak_bucket'
 
 
 
 
 @skip.unless_module('riak')
 @skip.unless_module('riak')
-class test_RiakBackend(AppCase):
+class test_RiakBackend:
 
 
     def setup(self):
     def setup(self):
         self.app.conf.result_backend = 'riak://'
         self.app.conf.result_backend = 'riak://'
@@ -23,28 +25,27 @@ class test_RiakBackend(AppCase):
     def test_init_no_riak(self):
     def test_init_no_riak(self):
         prev, module.riak = module.riak, None
         prev, module.riak = module.riak, None
         try:
         try:
-            with self.assertRaises(ImproperlyConfigured):
+            with pytest.raises(ImproperlyConfigured):
                 RiakBackend(app=self.app)
                 RiakBackend(app=self.app)
         finally:
         finally:
             module.riak = prev
             module.riak = prev
 
 
     def test_init_no_settings(self):
     def test_init_no_settings(self):
         self.app.conf.riak_backend_settings = []
         self.app.conf.riak_backend_settings = []
-        with self.assertRaises(ImproperlyConfigured):
+        with pytest.raises(ImproperlyConfigured):
             RiakBackend(app=self.app)
             RiakBackend(app=self.app)
 
 
     def test_init_settings_is_None(self):
     def test_init_settings_is_None(self):
         self.app.conf.riak_backend_settings = None
         self.app.conf.riak_backend_settings = None
-        self.assertTrue(self.app.backend)
+        assert self.app.backend
 
 
     def test_get_client_client_exists(self):
     def test_get_client_client_exists(self):
         with patch('riak.client.RiakClient') as mock_connection:
         with patch('riak.client.RiakClient') as mock_connection:
             self.backend._client = sentinel._client
             self.backend._client = sentinel._client
-
             mocked_is_alive = self.backend._client.is_alive = Mock()
             mocked_is_alive = self.backend._client.is_alive = Mock()
             mocked_is_alive.return_value.value = True
             mocked_is_alive.return_value.value = True
             client = self.backend._get_client()
             client = self.backend._get_client()
-            self.assertEquals(sentinel._client, client)
+            assert sentinel._client == client
             mock_connection.assert_not_called()
             mock_connection.assert_not_called()
 
 
     def test_get(self):
     def test_get(self):
@@ -54,7 +55,7 @@ class test_RiakBackend(AppCase):
         mocked_get = self.backend._bucket.get = Mock(name='bucket.get')
         mocked_get = self.backend._bucket.get = Mock(name='bucket.get')
         mocked_get.return_value.data = sentinel.retval
         mocked_get.return_value.data = sentinel.retval
         # should return None
         # should return None
-        self.assertEqual(self.backend.get('1f3fab'), sentinel.retval)
+        assert self.backend.get('1f3fab') == sentinel.retval
         self.backend._bucket.get.assert_called_once_with('1f3fab')
         self.backend._bucket.get.assert_called_once_with('1f3fab')
 
 
     def test_set(self):
     def test_set(self):
@@ -63,7 +64,7 @@ class test_RiakBackend(AppCase):
         self.backend._bucket = MagicMock()
         self.backend._bucket = MagicMock()
         self.backend._bucket.set = MagicMock()
         self.backend._bucket.set = MagicMock()
         # should return None
         # should return None
-        self.assertIsNone(self.backend.set(sentinel.key, sentinel.value))
+        assert self.backend.set(sentinel.key, sentinel.value) is None
 
 
     def test_delete(self):
     def test_delete(self):
         self.app.conf.couchbase_backend_settings = {}
         self.app.conf.couchbase_backend_settings = {}
@@ -73,7 +74,7 @@ class test_RiakBackend(AppCase):
         mocked_delete = self.backend._client.delete = Mock('client.delete')
         mocked_delete = self.backend._client.delete = Mock('client.delete')
         mocked_delete.return_value = None
         mocked_delete.return_value = None
         # should return None
         # should return None
-        self.assertIsNone(self.backend.delete('1f3fab'))
+        assert self.backend.delete('1f3fab') is None
         self.backend._bucket.delete.assert_called_once_with('1f3fab')
         self.backend._bucket.delete.assert_called_once_with('1f3fab')
 
 
     def test_config_params(self):
     def test_config_params(self):
@@ -82,22 +83,22 @@ class test_RiakBackend(AppCase):
             'host': 'there.host.com',
             'host': 'there.host.com',
             'port': '1234',
             'port': '1234',
         }
         }
-        self.assertEqual(self.backend.bucket_name, 'mycoolbucket')
-        self.assertEqual(self.backend.host, 'there.host.com')
-        self.assertEqual(self.backend.port, 1234)
+        assert self.backend.bucket_name == 'mycoolbucket'
+        assert self.backend.host == 'there.host.com'
+        assert self.backend.port == 1234
 
 
     def test_backend_by_url(self, url='riak://myhost/mycoolbucket'):
     def test_backend_by_url(self, url='riak://myhost/mycoolbucket'):
         from celery import backends
         from celery import backends
         from celery.backends.riak import RiakBackend
         from celery.backends.riak import RiakBackend
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
         backend, url_ = backends.get_backend_by_url(url, self.app.loader)
-        self.assertIs(backend, RiakBackend)
-        self.assertEqual(url_, url)
+        assert backend is RiakBackend
+        assert url_ == url
 
 
     def test_backend_params_by_url(self):
     def test_backend_params_by_url(self):
         self.app.conf.result_backend = 'riak://myhost:123/mycoolbucket'
         self.app.conf.result_backend = 'riak://myhost:123/mycoolbucket'
-        self.assertEqual(self.backend.bucket_name, 'mycoolbucket')
-        self.assertEqual(self.backend.host, 'myhost')
-        self.assertEqual(self.backend.port, 123)
+        assert self.backend.bucket_name == 'mycoolbucket'
+        assert self.backend.host == 'myhost'
+        assert self.backend.port == 123
 
 
     def test_non_ASCII_bucket_raises(self):
     def test_non_ASCII_bucket_raises(self):
         self.app.conf.riak_backend_settings = {
         self.app.conf.riak_backend_settings = {
@@ -105,5 +106,5 @@ class test_RiakBackend(AppCase):
             'host': 'there.host.com',
             'host': 'there.host.com',
             'port': '1234',
             'port': '1234',
         }
         }
-        with self.assertRaises(ValueError):
+        with pytest.raises(ValueError):
             RiakBackend(app=self.app)
             RiakBackend(app=self.app)

+ 20 - 22
celery/tests/backends/test_rpc.py → t/unit/backends/test_rpc.py

@@ -1,12 +1,14 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import Mock, patch
+
 from celery.backends.rpc import RPCBackend
 from celery.backends.rpc import RPCBackend
 from celery._state import _task_stack
 from celery._state import _task_stack
 
 
-from celery.tests.case import AppCase, Mock, patch
-
 
 
-class test_RPCBackend(AppCase):
+class test_RPCBackend:
 
 
     def setup(self):
     def setup(self):
         self.b = RPCBackend(app=self.app)
         self.b = RPCBackend(app=self.app)
@@ -14,8 +16,8 @@ class test_RPCBackend(AppCase):
     def test_oid(self):
     def test_oid(self):
         oid = self.b.oid
         oid = self.b.oid
         oid2 = self.b.oid
         oid2 = self.b.oid
-        self.assertEqual(oid, oid2)
-        self.assertEqual(oid, self.app.oid)
+        assert oid == oid2
+        assert oid == self.app.oid
 
 
     def test_interface(self):
     def test_interface(self):
         self.b.on_reply_declare('task_id')
         self.b.on_reply_declare('task_id')
@@ -24,38 +26,34 @@ class test_RPCBackend(AppCase):
         req = Mock(name='request')
         req = Mock(name='request')
         req.reply_to = 'reply_to'
         req.reply_to = 'reply_to'
         req.correlation_id = 'corid'
         req.correlation_id = 'corid'
-        self.assertTupleEqual(
-            self.b.destination_for('task_id', req),
-            ('reply_to', 'corid'),
-        )
+        assert self.b.destination_for('task_id', req) == ('reply_to', 'corid')
         task = Mock()
         task = Mock()
         _task_stack.push(task)
         _task_stack.push(task)
         try:
         try:
             task.request.reply_to = 'reply_to'
             task.request.reply_to = 'reply_to'
             task.request.correlation_id = 'corid'
             task.request.correlation_id = 'corid'
-            self.assertTupleEqual(
-                self.b.destination_for('task_id', None),
-                ('reply_to', 'corid'),
+            assert self.b.destination_for('task_id', None) == (
+                'reply_to', 'corid',
             )
             )
         finally:
         finally:
             _task_stack.pop()
             _task_stack.pop()
 
 
-        with self.assertRaises(RuntimeError):
+        with pytest.raises(RuntimeError):
             self.b.destination_for('task_id', None)
             self.b.destination_for('task_id', None)
 
 
     def test_rkey(self):
     def test_rkey(self):
-        self.assertEqual(self.b.rkey('id1'), 'id1')
+        assert self.b.rkey('id1') == 'id1'
 
 
     def test_binding(self):
     def test_binding(self):
         queue = self.b.binding
         queue = self.b.binding
-        self.assertEqual(queue.name, self.b.oid)
-        self.assertEqual(queue.exchange, self.b.exchange)
-        self.assertEqual(queue.routing_key, self.b.oid)
-        self.assertFalse(queue.durable)
-        self.assertTrue(queue.auto_delete)
+        assert queue.name == self.b.oid
+        assert queue.exchange == self.b.exchange
+        assert queue.routing_key == self.b.oid
+        assert not queue.durable
+        assert queue.auto_delete
 
 
     def test_create_binding(self):
     def test_create_binding(self):
-        self.assertEqual(self.b._create_binding('id'), self.b.binding)
+        assert self.b._create_binding('id') == self.b.binding
 
 
     def test_on_task_call(self):
     def test_on_task_call(self):
         with patch('celery.backends.rpc.maybe_declare') as md:
         with patch('celery.backends.rpc.maybe_declare') as md:
@@ -68,5 +66,5 @@ class test_RPCBackend(AppCase):
 
 
     def test_create_exchange(self):
     def test_create_exchange(self):
         ex = self.b._create_exchange('name')
         ex = self.b._create_exchange('name')
-        self.assertIsInstance(ex, self.b.Exchange)
-        self.assertEqual(ex.name, '')
+        assert isinstance(ex, self.b.Exchange)
+        assert ex.name == ''

+ 0 - 0
celery/tests/concurrency/__init__.py → t/unit/bin/__init__.py


+ 0 - 0
celery/tests/bin/celery.py → t/unit/bin/celery.py


+ 0 - 0
celery/tests/bin/proj/__init__.py → t/unit/bin/proj/__init__.py


+ 0 - 0
celery/tests/bin/proj/app.py → t/unit/bin/proj/app.py


+ 29 - 32
celery/tests/bin/test_amqp.py → t/unit/bin/test_amqp.py

@@ -1,5 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import Mock, patch
+
 from celery.bin.amqp import (
 from celery.bin.amqp import (
     AMQPAdmin,
     AMQPAdmin,
     AMQShell,
     AMQShell,
@@ -9,10 +13,8 @@ from celery.bin.amqp import (
 )
 )
 from celery.five import WhateverIO
 from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, patch
-
 
 
-class test_AMQShell(AppCase):
+class test_AMQShell:
 
 
     def setup(self):
     def setup(self):
         self.fh = WhateverIO()
         self.fh = WhateverIO()
@@ -24,54 +26,54 @@ class test_AMQShell(AppCase):
 
 
     def test_queue_declare(self):
     def test_queue_declare(self):
         self.shell.onecmd('queue.declare foo')
         self.shell.onecmd('queue.declare foo')
-        self.assertIn('ok', self.fh.getvalue())
+        assert 'ok' in self.fh.getvalue()
 
 
     def test_missing_command(self):
     def test_missing_command(self):
         self.shell.onecmd('foo foo')
         self.shell.onecmd('foo foo')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
 
     def RV(self):
     def RV(self):
         raise Exception(self.fh.getvalue())
         raise Exception(self.fh.getvalue())
 
 
     def test_spec_format_response(self):
     def test_spec_format_response(self):
         spec = self.shell.amqp['exchange.declare']
         spec = self.shell.amqp['exchange.declare']
-        self.assertEqual(spec.format_response(None), 'ok.')
-        self.assertEqual(spec.format_response('NO'), 'NO')
+        assert spec.format_response(None) == 'ok.'
+        assert spec.format_response('NO') == 'NO'
 
 
     def test_missing_namespace(self):
     def test_missing_namespace(self):
         self.shell.onecmd('ns.cmd arg')
         self.shell.onecmd('ns.cmd arg')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
 
     def test_help(self):
     def test_help(self):
         self.shell.onecmd('help')
         self.shell.onecmd('help')
-        self.assertIn('Example:', self.fh.getvalue())
+        assert 'Example:' in self.fh.getvalue()
 
 
     def test_help_command(self):
     def test_help_command(self):
         self.shell.onecmd('help queue.declare')
         self.shell.onecmd('help queue.declare')
-        self.assertIn('passive:no', self.fh.getvalue())
+        assert 'passive:no' in self.fh.getvalue()
 
 
     def test_help_unknown_command(self):
     def test_help_unknown_command(self):
         self.shell.onecmd('help foo.baz')
         self.shell.onecmd('help foo.baz')
-        self.assertIn('unknown syntax', self.fh.getvalue())
+        assert 'unknown syntax' in self.fh.getvalue()
 
 
     def test_onecmd_error(self):
     def test_onecmd_error(self):
         self.shell.dispatch = Mock()
         self.shell.dispatch = Mock()
         self.shell.dispatch.side_effect = MemoryError()
         self.shell.dispatch.side_effect = MemoryError()
         self.shell.say = Mock()
         self.shell.say = Mock()
-        self.assertFalse(self.shell.needs_reconnect)
+        assert not self.shell.needs_reconnect
         self.shell.onecmd('hello')
         self.shell.onecmd('hello')
         self.shell.say.assert_called()
         self.shell.say.assert_called()
-        self.assertTrue(self.shell.needs_reconnect)
+        assert self.shell.needs_reconnect
 
 
     def test_exit(self):
     def test_exit(self):
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             self.shell.onecmd('exit')
             self.shell.onecmd('exit')
-        self.assertIn("don't leave!", self.fh.getvalue())
+        assert "don't leave!" in self.fh.getvalue()
 
 
     def test_note_silent(self):
     def test_note_silent(self):
         self.shell.silent = True
         self.shell.silent = True
         self.shell.note('foo bar')
         self.shell.note('foo bar')
-        self.assertNotIn('foo bar', self.fh.getvalue())
+        assert 'foo bar' not in self.fh.getvalue()
 
 
     def test_reconnect(self):
     def test_reconnect(self):
         self.shell.onecmd('queue.declare foo')
         self.shell.onecmd('queue.declare foo')
@@ -79,14 +81,9 @@ class test_AMQShell(AppCase):
         self.shell.onecmd('queue.delete foo')
         self.shell.onecmd('queue.delete foo')
 
 
     def test_completenames(self):
     def test_completenames(self):
-        self.assertEqual(
-            self.shell.completenames('queue.dec'),
-            ['queue.declare'],
-        )
-        self.assertEqual(
-            sorted(self.shell.completenames('declare')),
-            sorted(['queue.declare', 'exchange.declare']),
-        )
+        assert self.shell.completenames('queue.dec') == ['queue.declare']
+        assert (sorted(self.shell.completenames('declare')) ==
+                sorted(['queue.declare', 'exchange.declare']))
 
 
     def test_empty_line(self):
     def test_empty_line(self):
         self.shell.emptyline = Mock()
         self.shell.emptyline = Mock()
@@ -98,10 +95,10 @@ class test_AMQShell(AppCase):
 
 
     def test_respond(self):
     def test_respond(self):
         self.shell.respond({'foo': 'bar'})
         self.shell.respond({'foo': 'bar'})
-        self.assertIn('foo', self.fh.getvalue())
+        assert 'foo' in self.fh.getvalue()
 
 
     def test_prompt(self):
     def test_prompt(self):
-        self.assertTrue(self.shell.prompt)
+        assert self.shell.prompt
 
 
     def test_no_returns(self):
     def test_no_returns(self):
         self.shell.onecmd('queue.declare foo')
         self.shell.onecmd('queue.declare foo')
@@ -114,20 +111,20 @@ class test_AMQShell(AppCase):
         m.body = 'the quick brown fox'
         m.body = 'the quick brown fox'
         m.properties = {'a': 1}
         m.properties = {'a': 1}
         m.delivery_info = {'exchange': 'bar'}
         m.delivery_info = {'exchange': 'bar'}
-        self.assertTrue(dump_message(m))
+        assert dump_message(m)
 
 
     def test_dump_message_no_message(self):
     def test_dump_message_no_message(self):
-        self.assertIn('No messages in queue', dump_message(None))
+        assert 'No messages in queue' in dump_message(None)
 
 
     def test_note(self):
     def test_note(self):
         self.adm.silent = True
         self.adm.silent = True
         self.adm.note('FOO')
         self.adm.note('FOO')
-        self.assertNotIn('FOO', self.fh.getvalue())
+        assert 'FOO' not in self.fh.getvalue()
 
 
     def test_run(self):
     def test_run(self):
         a = self.create_adm('queue.declare', 'foo')
         a = self.create_adm('queue.declare', 'foo')
         a.run()
         a.run()
-        self.assertIn('ok', self.fh.getvalue())
+        assert 'ok' in self.fh.getvalue()
 
 
     def test_run_loop(self):
     def test_run_loop(self):
         a = self.create_adm()
         a = self.create_adm()
@@ -139,7 +136,7 @@ class test_AMQShell(AppCase):
 
 
         shell.cmdloop.side_effect = KeyboardInterrupt()
         shell.cmdloop.side_effect = KeyboardInterrupt()
         a.run()
         a.run()
-        self.assertIn('bibi', self.fh.getvalue())
+        assert 'bibi' in self.fh.getvalue()
 
 
     @patch('celery.bin.amqp.amqp')
     @patch('celery.bin.amqp.amqp')
     def test_main(self, Command):
     def test_main(self, Command):
@@ -151,4 +148,4 @@ class test_AMQShell(AppCase):
     def test_command(self, cls):
     def test_command(self, cls):
         x = amqp(app=self.app)
         x = amqp(app=self.app)
         x.run()
         x.run()
-        self.assertIs(cls.call_args[1]['app'], self.app)
+        assert cls.call_args[1]['app'] is self.app

+ 130 - 143
celery/tests/bin/test_base.py → t/unit/bin/test_base.py

@@ -1,6 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import os
 import os
+import pytest
+
+from case import Mock, mock, patch
 
 
 from celery.bin.base import (
 from celery.bin.base import (
     Command,
     Command,
@@ -11,10 +14,6 @@ from celery.bin.base import (
 from celery.five import bytes_if_py2
 from celery.five import bytes_if_py2
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import (
-    AppCase, Mock, depends_on_current_app, mock, patch,
-)
-
 
 
 class MyApp(object):
 class MyApp(object):
     user_options = {'preload': None}
     user_options = {'preload': None}
@@ -33,7 +32,7 @@ class MockCommand(Command):
         return args, kwargs
         return args, kwargs
 
 
 
 
-class test_Extensions(AppCase):
+class test_Extensions:
 
 
     def test_load(self):
     def test_load(self):
         with patch('pkg_resources.iter_entry_points') as iterep:
         with patch('pkg_resources.iter_entry_points') as iterep:
@@ -58,28 +57,28 @@ class test_Extensions(AppCase):
 
 
             with patch('celery.utils.imports.symbol_by_name') as symbyname:
             with patch('celery.utils.imports.symbol_by_name') as symbyname:
                 symbyname.side_effect = KeyError('foo')
                 symbyname.side_effect = KeyError('foo')
-                with self.assertRaises(KeyError):
+                with pytest.raises(KeyError):
                     e.load()
                     e.load()
 
 
 
 
-class test_HelpFormatter(AppCase):
+class test_HelpFormatter:
 
 
     def test_format_epilog(self):
     def test_format_epilog(self):
         f = HelpFormatter()
         f = HelpFormatter()
-        self.assertTrue(f.format_epilog('hello'))
-        self.assertFalse(f.format_epilog(''))
+        assert f.format_epilog('hello')
+        assert not f.format_epilog('')
 
 
     def test_format_description(self):
     def test_format_description(self):
         f = HelpFormatter()
         f = HelpFormatter()
-        self.assertTrue(f.format_description('hello'))
+        assert f.format_description('hello')
 
 
 
 
-class test_Command(AppCase):
+class test_Command:
 
 
     def test_get_options(self):
     def test_get_options(self):
         cmd = Command()
         cmd = Command()
         cmd.option_list = (1, 2, 3)
         cmd.option_list = (1, 2, 3)
-        self.assertTupleEqual(cmd.get_options(), (1, 2, 3))
+        assert cmd.get_options() == (1, 2, 3)
 
 
     def test_custom_description(self):
     def test_custom_description(self):
 
 
@@ -87,12 +86,12 @@ class test_Command(AppCase):
             description = 'foo'
             description = 'foo'
 
 
         c = C()
         c = C()
-        self.assertEqual(c.description, 'foo')
+        assert c.description == 'foo'
 
 
     def test_register_callbacks(self):
     def test_register_callbacks(self):
         c = Command(on_error=8, on_usage_error=9)
         c = Command(on_error=8, on_usage_error=9)
-        self.assertEqual(c.on_error, 8)
-        self.assertEqual(c.on_usage_error, 9)
+        assert c.on_error == 8
+        assert c.on_usage_error == 9
 
 
     def test_run_raises_UsageError(self):
     def test_run_raises_UsageError(self):
         cb = Mock()
         cb = Mock()
@@ -101,7 +100,7 @@ class test_Command(AppCase):
         c.run = Mock()
         c.run = Mock()
         exc = c.run.side_effect = c.UsageError('foo', status=3)
         exc = c.run.side_effect = c.UsageError('foo', status=3)
 
 
-        self.assertEqual(c(), exc.status)
+        assert c() == exc.status
         cb.assert_called_with(exc)
         cb.assert_called_with(exc)
         c.verify_args.assert_called_with(())
         c.verify_args.assert_called_with(())
 
 
@@ -119,238 +118,226 @@ class test_Command(AppCase):
             pass
             pass
         c.run = run
         c.run = run
 
 
-        with self.assertRaises(c.UsageError):
+        with pytest.raises(c.UsageError):
             c.verify_args((1,))
             c.verify_args((1,))
         c.verify_args((1, 2, 3))
         c.verify_args((1, 2, 3))
 
 
     def test_run_interface(self):
     def test_run_interface(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             Command().run()
             Command().run()
 
 
     @patch('sys.stdout')
     @patch('sys.stdout')
     def test_early_version(self, stdout):
     def test_early_version(self, stdout):
         cmd = Command()
         cmd = Command()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             cmd.early_version(['--version'])
             cmd.early_version(['--version'])
 
 
-    def test_execute_from_commandline(self):
-        cmd = MockCommand(app=self.app)
+    def test_execute_from_commandline(self, app):
+        cmd = MockCommand(app=app)
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
         args1, kwargs1 = cmd.execute_from_commandline()     # sys.argv
-        self.assertTupleEqual(args1, cmd.mock_args)
-        self.assertDictContainsSubset({'foo': 'bar'}, kwargs1)
-        self.assertTrue(kwargs1.get('prog_name'))
+        assert args1 == cmd.mock_args
+        assert kwargs1['foo'] == 'bar'
+        assert kwargs1.get('prog_name')
         args2, kwargs2 = cmd.execute_from_commandline(['foo'])   # pass list
         args2, kwargs2 = cmd.execute_from_commandline(['foo'])   # pass list
-        self.assertTupleEqual(args2, cmd.mock_args)
-        self.assertDictContainsSubset({'foo': 'bar', 'prog_name': 'foo'},
-                                      kwargs2)
-
-    @mock.stdouts
-    def test_with_bogus_args(self, _, stderr):
-        cmd = MockCommand(app=self.app)
-        cmd.supports_args = False
-        with self.assertRaises(SystemExit):
-            cmd.execute_from_commandline(argv=['--bogus'])
-        self.assertTrue(stderr.getvalue())
-        self.assertIn('Unrecognized', stderr.getvalue())
-
-    def test_with_custom_config_module(self):
+        assert args2 == cmd.mock_args
+        assert kwargs2['foo'] == 'bar'
+        assert kwargs2['prog_name'] == 'foo'
+
+    def test_with_bogus_args(self, app):
+        with mock.stdouts() as (_, stderr):
+            cmd = MockCommand(app=app)
+            cmd.supports_args = False
+            with pytest.raises(SystemExit):
+                cmd.execute_from_commandline(argv=['--bogus'])
+            assert stderr.getvalue()
+            assert 'Unrecognized' in stderr.getvalue()
+
+    def test_with_custom_config_module(self, app):
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)
         prev = os.environ.pop('CELERY_CONFIG_MODULE', None)
         try:
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--config=foo.bar.baz'])
             cmd.setup_app_from_commandline(['--config=foo.bar.baz'])
-            self.assertEqual(os.environ.get('CELERY_CONFIG_MODULE'),
-                             'foo.bar.baz')
+            assert os.environ.get('CELERY_CONFIG_MODULE') == 'foo.bar.baz'
         finally:
         finally:
             if prev:
             if prev:
                 os.environ['CELERY_CONFIG_MODULE'] = prev
                 os.environ['CELERY_CONFIG_MODULE'] = prev
             else:
             else:
                 os.environ.pop('CELERY_CONFIG_MODULE', None)
                 os.environ.pop('CELERY_CONFIG_MODULE', None)
 
 
-    def test_with_custom_broker(self):
+    def test_with_custom_broker(self, app):
         prev = os.environ.pop('CELERY_BROKER_URL', None)
         prev = os.environ.pop('CELERY_BROKER_URL', None)
         try:
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--broker=xyzza://'])
             cmd.setup_app_from_commandline(['--broker=xyzza://'])
-            self.assertEqual(
-                os.environ.get('CELERY_BROKER_URL'), 'xyzza://',
-            )
+            assert os.environ.get('CELERY_BROKER_URL') == 'xyzza://'
         finally:
         finally:
             if prev:
             if prev:
                 os.environ['CELERY_BROKER_URL'] = prev
                 os.environ['CELERY_BROKER_URL'] = prev
             else:
             else:
                 os.environ.pop('CELERY_BROKER_URL', None)
                 os.environ.pop('CELERY_BROKER_URL', None)
 
 
-    def test_with_custom_app(self):
-        cmd = MockCommand(app=self.app)
-        app = '.'.join([__name__, 'APP'])
-        cmd.setup_app_from_commandline(['--app=%s' % (app,),
+    def test_with_custom_app(self, app):
+        cmd = MockCommand(app=app)
+        appstr = '.'.join([__name__, 'APP'])
+        cmd.setup_app_from_commandline(['--app=%s' % (appstr,),
                                         '--loglevel=INFO'])
                                         '--loglevel=INFO'])
-        self.assertIs(cmd.app, APP)
-        cmd.setup_app_from_commandline(['-A', app,
+        assert cmd.app is APP
+        cmd.setup_app_from_commandline(['-A', appstr,
                                         '--loglevel=INFO'])
                                         '--loglevel=INFO'])
-        self.assertIs(cmd.app, APP)
+        assert cmd.app is APP
 
 
-    def test_setup_app_sets_quiet(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_sets_quiet(self, app):
+        cmd = MockCommand(app=app)
         cmd.setup_app_from_commandline(['-q'])
         cmd.setup_app_from_commandline(['-q'])
-        self.assertTrue(cmd.quiet)
-        cmd2 = MockCommand(app=self.app)
+        assert cmd.quiet
+        cmd2 = MockCommand(app=app)
         cmd2.setup_app_from_commandline(['--quiet'])
         cmd2.setup_app_from_commandline(['--quiet'])
-        self.assertTrue(cmd2.quiet)
+        assert cmd2.quiet
 
 
-    def test_setup_app_sets_chdir(self):
+    def test_setup_app_sets_chdir(self, app):
         with patch('os.chdir') as chdir:
         with patch('os.chdir') as chdir:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--workdir=/opt'])
             cmd.setup_app_from_commandline(['--workdir=/opt'])
             chdir.assert_called_with('/opt')
             chdir.assert_called_with('/opt')
 
 
-    def test_setup_app_sets_loader(self):
+    def test_setup_app_sets_loader(self, app):
         prev = os.environ.get('CELERY_LOADER')
         prev = os.environ.get('CELERY_LOADER')
         try:
         try:
-            cmd = MockCommand(app=self.app)
+            cmd = MockCommand(app=app)
             cmd.setup_app_from_commandline(['--loader=X.Y:Z'])
             cmd.setup_app_from_commandline(['--loader=X.Y:Z'])
-            self.assertEqual(os.environ['CELERY_LOADER'], 'X.Y:Z')
+            assert os.environ['CELERY_LOADER'] == 'X.Y:Z'
         finally:
         finally:
             if prev is not None:
             if prev is not None:
                 os.environ['CELERY_LOADER'] = prev
                 os.environ['CELERY_LOADER'] = prev
 
 
-    def test_setup_app_no_respect(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_no_respect(self, app):
+        cmd = MockCommand(app=app)
         cmd.respects_app_option = False
         cmd.respects_app_option = False
         with patch('celery.bin.base.Celery') as cp:
         with patch('celery.bin.base.Celery') as cp:
             cmd.setup_app_from_commandline(['--app=x.y:z'])
             cmd.setup_app_from_commandline(['--app=x.y:z'])
             cp.assert_called()
             cp.assert_called()
 
 
-    def test_setup_app_custom_app(self):
-        cmd = MockCommand(app=self.app)
+    def test_setup_app_custom_app(self, app):
+        cmd = MockCommand(app=app)
         app = cmd.app = Mock()
         app = cmd.app = Mock()
         app.user_options = {'preload': None}
         app.user_options = {'preload': None}
         cmd.setup_app_from_commandline([])
         cmd.setup_app_from_commandline([])
-        self.assertEqual(cmd.app, app)
-
-    def test_find_app_suspects(self):
-        cmd = MockCommand(app=self.app)
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj:hello'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.hello'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app:app'))
-        self.assertTrue(cmd.find_app('celery.tests.bin.proj.app.app'))
-        with self.assertRaises(AttributeError):
-            cmd.find_app('celery.tests.bin')
-
-        with self.assertRaises(AttributeError):
+        assert cmd.app == app
+
+    def test_find_app_suspects(self, app):
+        cmd = MockCommand(app=app)
+        assert cmd.find_app('t.unit.bin.proj.app')
+        assert cmd.find_app('t.unit.bin.proj')
+        assert cmd.find_app('t.unit.bin.proj:hello')
+        assert cmd.find_app('t.unit.bin.proj.hello')
+        assert cmd.find_app('t.unit.bin.proj.app:app')
+        assert cmd.find_app('t.unit.bin.proj.app.app')
+        with pytest.raises(AttributeError):
+            cmd.find_app('t.unit.bin')
+
+        with pytest.raises(AttributeError):
             cmd.find_app(__name__)
             cmd.find_app(__name__)
 
 
-    def test_ask(self):
+    def test_ask(self, app, patching):
         try:
         try:
-            input = self.patch('celery.bin.base.input')
+            input = patching('celery.bin.base.input')
         except AttributeError:
         except AttributeError:
-            input = self.patch('builtins.input')
-        cmd = MockCommand(app=self.app)
+            input = patching('builtins.input')
+        cmd = MockCommand(app=app)
         input.return_value = 'yes'
         input.return_value = 'yes'
-        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'yes')
+        assert cmd.ask('q', ('yes', 'no'), 'no') == 'yes'
         input.return_value = 'nop'
         input.return_value = 'nop'
-        self.assertEqual(cmd.ask('q', ('yes', 'no'), 'no'), 'no')
+        assert cmd.ask('q', ('yes', 'no'), 'no') == 'no'
 
 
-    def test_host_format(self):
-        cmd = MockCommand(app=self.app)
+    def test_host_format(self, app):
+        cmd = MockCommand(app=app)
         with patch('celery.utils.nodenames.gethostname') as hn:
         with patch('celery.utils.nodenames.gethostname') as hn:
             hn.return_value = 'blacktron.example.com'
             hn.return_value = 'blacktron.example.com'
-            self.assertEqual(cmd.host_format(''), '')
-            self.assertEqual(
-                cmd.host_format('celery@%h'),
-                'celery@blacktron.example.com',
-            )
-            self.assertEqual(
-                cmd.host_format('celery@%d'),
-                'celery@example.com',
-            )
-            self.assertEqual(
-                cmd.host_format('celery@%n'),
-                'celery@blacktron',
-            )
-
-    def test_say_chat_quiet(self):
-        cmd = MockCommand(app=self.app)
+            assert cmd.host_format('') == ''
+            assert (cmd.host_format('celery@%h') ==
+                    'celery@blacktron.example.com')
+            assert cmd.host_format('celery@%d') == 'celery@example.com'
+            assert cmd.host_format('celery@%n') == 'celery@blacktron'
+
+    def test_say_chat_quiet(self, app):
+        cmd = MockCommand(app=app)
         cmd.quiet = True
         cmd.quiet = True
-        self.assertIsNone(cmd.say_chat('<-', 'foo', 'foo'))
+        assert cmd.say_chat('<-', 'foo', 'foo') is None
 
 
-    def test_say_chat_show_body(self):
-        cmd = MockCommand(app=self.app)
+    def test_say_chat_show_body(self, app):
+        cmd = MockCommand(app=app)
         cmd.out = Mock()
         cmd.out = Mock()
         cmd.show_body = True
         cmd.show_body = True
         cmd.say_chat('->', 'foo', 'body')
         cmd.say_chat('->', 'foo', 'body')
         cmd.out.assert_called_with('body')
         cmd.out.assert_called_with('body')
 
 
-    def test_say_chat_no_body(self):
-        cmd = MockCommand(app=self.app)
+    def test_say_chat_no_body(self, app):
+        cmd = MockCommand(app=app)
         cmd.out = Mock()
         cmd.out = Mock()
         cmd.show_body = False
         cmd.show_body = False
         cmd.say_chat('->', 'foo', 'body')
         cmd.say_chat('->', 'foo', 'body')
 
 
-    @depends_on_current_app
-    def test_with_cmdline_config(self):
-        cmd = MockCommand(app=self.app)
+    @pytest.mark.usefixtures('depends_on_current_app')
+    def test_with_cmdline_config(self, app):
+        cmd = MockCommand(app=app)
         cmd.enable_config_from_cmdline = True
         cmd.enable_config_from_cmdline = True
         cmd.namespace = 'worker'
         cmd.namespace = 'worker'
         rest = cmd.setup_app_from_commandline(argv=[
         rest = cmd.setup_app_from_commandline(argv=[
             '--loglevel=INFO', '--',
             '--loglevel=INFO', '--',
             'broker.url=amqp://broker.example.com',
             'broker.url=amqp://broker.example.com',
             '.prefetch_multiplier=100'])
             '.prefetch_multiplier=100'])
-        self.assertEqual(cmd.app.conf.broker_url,
-                         'amqp://broker.example.com')
-        self.assertEqual(cmd.app.conf.worker_prefetch_multiplier, 100)
-        self.assertListEqual(rest, ['--loglevel=INFO'])
+        assert cmd.app.conf.broker_url == 'amqp://broker.example.com'
+        assert cmd.app.conf.worker_prefetch_multiplier == 100
+        assert rest == ['--loglevel=INFO']
 
 
         cmd.app = None
         cmd.app = None
         cmd.get_app = Mock(name='get_app')
         cmd.get_app = Mock(name='get_app')
-        cmd.get_app.return_value = self.app
-        self.app.user_options['preload'] = [
+        cmd.get_app.return_value = app
+        app.user_options['preload'] = [
             Option('--foo', action='store_true'),
             Option('--foo', action='store_true'),
         ]
         ]
         cmd.setup_app_from_commandline(argv=[
         cmd.setup_app_from_commandline(argv=[
             '--foo', '--loglevel=INFO', '--',
             '--foo', '--loglevel=INFO', '--',
             'broker.url=amqp://broker.example.com',
             'broker.url=amqp://broker.example.com',
             '.prefetch_multiplier=100'])
             '.prefetch_multiplier=100'])
-        self.assertIs(cmd.app, cmd.get_app())
+        assert cmd.app is cmd.get_app()
 
 
-    def test_preparse_options__required_short(self):
-        cmd = MockCommand(app=self.app)
-        with self.assertRaises(ValueError):
+    def test_preparse_options__required_short(self, app):
+        cmd = MockCommand(app=app)
+        with pytest.raises(ValueError):
             cmd.preparse_options(
             cmd.preparse_options(
                 ['a', '-f'], [Option('-f', action='store')])
                 ['a', '-f'], [Option('-f', action='store')])
 
 
-    def test_preparse_options__longopt_whitespace(self):
-        cmd = MockCommand(app=self.app)
+    def test_preparse_options__longopt_whitespace(self, app):
+        cmd = MockCommand(app=app)
         cmd.preparse_options(
         cmd.preparse_options(
             ['a', '--foo', 'val'], [Option('--foo', action='store')])
             ['a', '--foo', 'val'], [Option('--foo', action='store')])
 
 
-    def test_preparse_options__shortopt_store_true(self):
-        cmd = MockCommand(app=self.app)
+    def test_preparse_options__shortopt_store_true(self, app):
+        cmd = MockCommand(app=app)
         cmd.preparse_options(
         cmd.preparse_options(
             ['a', '--foo'], [Option('--foo', action='store_true')])
             ['a', '--foo'], [Option('--foo', action='store_true')])
 
 
-    def test_get_default_app(self):
-        self.patch('celery._state.get_current_app')
-        cmd = MockCommand(app=self.app)
+    def test_get_default_app(self, app, patching):
+        patching('celery._state.get_current_app')
+        cmd = MockCommand(app=app)
         from celery._state import get_current_app
         from celery._state import get_current_app
-        self.assertIs(cmd._get_default_app(), get_current_app())
+        assert cmd._get_default_app() is get_current_app()
 
 
-    def test_set_colored(self):
-        cmd = MockCommand(app=self.app)
+    def test_set_colored(self, app):
+        cmd = MockCommand(app=app)
         cmd.colored = 'foo'
         cmd.colored = 'foo'
-        self.assertEqual(cmd.colored, 'foo')
+        assert cmd.colored == 'foo'
 
 
-    def test_set_no_color(self):
-        cmd = MockCommand(app=self.app)
+    def test_set_no_color(self, app):
+        cmd = MockCommand(app=app)
         cmd.no_color = False
         cmd.no_color = False
         _ = cmd.colored  # noqa
         _ = cmd.colored  # noqa
         cmd.no_color = True
         cmd.no_color = True
-        self.assertFalse(cmd.colored.enabled)
+        assert not cmd.colored.enabled
 
 
-    def test_find_app(self):
-        cmd = MockCommand(app=self.app)
+    def test_find_app(self, app):
+        cmd = MockCommand(app=app)
         with patch('celery.utils.imports.symbol_by_name') as sbn:
         with patch('celery.utils.imports.symbol_by_name') as sbn:
             from types import ModuleType
             from types import ModuleType
             x = ModuleType(bytes_if_py2('proj'))
             x = ModuleType(bytes_if_py2('proj'))
@@ -365,13 +352,13 @@ class test_Command(AppCase):
                 return x
                 return x
             sbn.side_effect = on_sbn
             sbn.side_effect = on_sbn
             x.__path__ = [True]
             x.__path__ = [True]
-            self.assertEqual(cmd.find_app('proj'), 'quick brown fox')
+            assert cmd.find_app('proj') == 'quick brown fox'
 
 
     def test_parse_preload_options_shortopt(self):
     def test_parse_preload_options_shortopt(self):
         cmd = Command()
         cmd = Command()
         cmd.preload_options = (Option('-s', action='store', dest='silent'),)
         cmd.preload_options = (Option('-s', action='store', dest='silent'),)
         acc = cmd.parse_preload_options(['-s', 'yes'])
         acc = cmd.parse_preload_options(['-s', 'yes'])
-        self.assertEqual(acc.get('silent'), 'yes')
+        assert acc.get('silent') == 'yes'
 
 
     def test_parse_preload_options_with_equals_and_append(self):
     def test_parse_preload_options_with_equals_and_append(self):
         cmd = Command()
         cmd = Command()
@@ -379,7 +366,7 @@ class test_Command(AppCase):
         cmd.preload_options = (opt,)
         cmd.preload_options = (opt,)
         acc = cmd.parse_preload_options(['--zoom=1', '--zoom=2'])
         acc = cmd.parse_preload_options(['--zoom=1', '--zoom=2'])
 
 
-        self.assertEqual(acc, {'zoom': ['1', '2']})
+        assert acc, {'zoom': ['1' == '2']}
 
 
     def test_parse_preload_options_without_equals_and_append(self):
     def test_parse_preload_options_without_equals_and_append(self):
         cmd = Command()
         cmd = Command()
@@ -387,4 +374,4 @@ class test_Command(AppCase):
         cmd.preload_options = (opt,)
         cmd.preload_options = (opt,)
         acc = cmd.parse_preload_options(['--zoom', '1', '--zoom', '2'])
         acc = cmd.parse_preload_options(['--zoom', '1', '--zoom', '2'])
 
 
-        self.assertEqual(acc, {'zoom': ['1', '2']})
+        assert acc, {'zoom': ['1' == '2']}

+ 20 - 20
celery/tests/bin/test_beat.py → t/unit/bin/test_beat.py

@@ -1,15 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import logging
 import logging
+import pytest
 import sys
 import sys
 
 
+from case import Mock, mock, patch
+
 from celery import beat
 from celery import beat
 from celery import platforms
 from celery import platforms
 from celery.bin import beat as beat_bin
 from celery.bin import beat as beat_bin
 from celery.apps import beat as beatapp
 from celery.apps import beat as beatapp
 
 
-from celery.tests.case import AppCase, Mock, mock, patch
-
 
 
 def MockBeat(*args, **kwargs):
 def MockBeat(*args, **kwargs):
     class _Beat(beatapp.Beat):
     class _Beat(beatapp.Beat):
@@ -23,16 +24,16 @@ def MockBeat(*args, **kwargs):
     return b
     return b
 
 
 
 
-class test_Beat(AppCase):
+class test_Beat:
 
 
     def test_loglevel_string(self):
     def test_loglevel_string(self):
         b = beatapp.Beat(app=self.app, loglevel='DEBUG',
         b = beatapp.Beat(app=self.app, loglevel='DEBUG',
                          redirect_stdouts=False)
                          redirect_stdouts=False)
-        self.assertEqual(b.loglevel, logging.DEBUG)
+        assert b.loglevel == logging.DEBUG
 
 
         b2 = beatapp.Beat(app=self.app, loglevel=logging.DEBUG,
         b2 = beatapp.Beat(app=self.app, loglevel=logging.DEBUG,
                           redirect_stdouts=False)
                           redirect_stdouts=False)
-        self.assertEqual(b2.loglevel, logging.DEBUG)
+        assert b2.loglevel == logging.DEBUG
 
 
     def test_colorize(self):
     def test_colorize(self):
         self.app.log.setup = Mock()
         self.app.log.setup = Mock()
@@ -40,7 +41,7 @@ class test_Beat(AppCase):
                          redirect_stdouts=False)
                          redirect_stdouts=False)
         b.setup_logging()
         b.setup_logging()
         self.app.log.setup.assert_called()
         self.app.log.setup.assert_called()
-        self.assertEqual(self.app.log.setup.call_args[1]['colorize'], False)
+        assert not self.app.log.setup.call_args[1]['colorize']
 
 
     def test_init_loader(self):
     def test_init_loader(self):
         b = beatapp.Beat(app=self.app, redirect_stdouts=False)
         b = beatapp.Beat(app=self.app, redirect_stdouts=False)
@@ -78,7 +79,7 @@ class test_Beat(AppCase):
         clock.start = Mock(name='beat.Service().start')
         clock.start = Mock(name='beat.Service().start')
         clock.sync = Mock(name='beat.Service().sync')
         clock.sync = Mock(name='beat.Service().sync')
         handlers = self.psig(b.install_sync_handler, clock)
         handlers = self.psig(b.install_sync_handler, clock)
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             handlers['SIGINT']('SIGINT', object())
             handlers['SIGINT']('SIGINT', object())
         clock.sync.assert_called_with()
         clock.sync.assert_called_with()
 
 
@@ -93,40 +94,39 @@ class test_Beat(AppCase):
         b.redirect_stdouts = False
         b.redirect_stdouts = False
         b.app.log.already_setup = False
         b.app.log.already_setup = False
         b.setup_logging()
         b.setup_logging()
-        with self.assertRaises(AttributeError):
+        with pytest.raises(AttributeError):
             sys.stdout.logger
             sys.stdout.logger
 
 
     import sys
     import sys
     orig_stdout = sys.__stdout__
     orig_stdout = sys.__stdout__
 
 
     @patch('celery.apps.beat.logger')
     @patch('celery.apps.beat.logger')
-    @mock.restore_logging()
-    @mock.stdouts
-    def test_logs_errors(self, logger, stdout, stderr):
+    def test_logs_errors(self, logger):
         b = MockBeat(
         b = MockBeat(
             app=self.app, redirect_stdouts=False, socket_timeout=None,
             app=self.app, redirect_stdouts=False, socket_timeout=None,
         )
         )
         b.install_sync_handler = Mock('beat.install_sync_handler')
         b.install_sync_handler = Mock('beat.install_sync_handler')
         b.install_sync_handler.side_effect = RuntimeError('xxx')
         b.install_sync_handler.side_effect = RuntimeError('xxx')
-        with self.assertRaises(RuntimeError):
-            b.start_scheduler()
+        with mock.restore_logging():
+            with pytest.raises(RuntimeError):
+                b.start_scheduler()
         logger.critical.assert_called()
         logger.critical.assert_called()
 
 
     @patch('celery.platforms.create_pidlock')
     @patch('celery.platforms.create_pidlock')
-    @mock.stdouts
-    def test_using_pidfile(self, create_pidlock, stdout, stderr):
+    def test_using_pidfile(self, create_pidlock):
         b = MockBeat(app=self.app, pidfile='pidfilelockfilepid',
         b = MockBeat(app=self.app, pidfile='pidfilelockfilepid',
                      socket_timeout=None, redirect_stdouts=False)
                      socket_timeout=None, redirect_stdouts=False)
         b.install_sync_handler = Mock(name='beat.install_sync_handler')
         b.install_sync_handler = Mock(name='beat.install_sync_handler')
-        b.start_scheduler()
+        with mock.stdouts():
+            b.start_scheduler()
         create_pidlock.assert_called()
         create_pidlock.assert_called()
 
 
 
 
-class test_div(AppCase):
+class test_div:
 
 
     def setup(self):
     def setup(self):
-        self.Beat = self.app.Beat = self.patch('celery.apps.beat.Beat')
-        self.detached = self.patch('celery.bin.beat.detached')
+        self.Beat = self.app.Beat = self.patching('celery.apps.beat.Beat')
+        self.detached = self.patching('celery.bin.beat.detached')
         self.Beat.__name__ = 'Beat'
         self.Beat.__name__ = 'Beat'
 
 
     def test_main(self):
     def test_main(self):
@@ -144,4 +144,4 @@ class test_div(AppCase):
         cmd = beat_bin.beat()
         cmd = beat_bin.beat()
         cmd.app = self.app
         cmd.app = self.app
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
         options, args = cmd.parse_options('celery beat', ['-s', 'foo'])
-        self.assertEqual(options.schedule, 'foo')
+        assert options.schedule == 'foo'

+ 104 - 116
celery/tests/bin/test_celery.py → t/unit/bin/test_celery.py

@@ -1,9 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
 import sys
 import sys
 
 
 from datetime import datetime
 from datetime import datetime
 
 
+from case import Mock, patch
 from kombu.utils.json import dumps
 from kombu.utils.json import dumps
 
 
 from celery import __main__
 from celery import __main__
@@ -31,10 +33,8 @@ from celery.bin.celery import (
 from celery.five import WhateverIO
 from celery.five import WhateverIO
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 from celery.platforms import EX_FAILURE, EX_USAGE, EX_OK
 
 
-from celery.tests.case import AppCase, Mock, patch
 
 
-
-class test__main__(AppCase):
+class test__main__:
 
 
     def test_main(self):
     def test_main(self):
         with patch('celery.__main__.maybe_patch_concurrency') as mpc:
         with patch('celery.__main__.maybe_patch_concurrency') as mpc:
@@ -55,13 +55,13 @@ class test__main__(AppCase):
                     sys.argv = prev
                     sys.argv = prev
 
 
 
 
-class test_Command(AppCase):
+class test_Command:
 
 
     def test_Error_repr(self):
     def test_Error_repr(self):
         x = Error('something happened')
         x = Error('something happened')
-        self.assertIsNotNone(x.status)
-        self.assertTrue(x.reason)
-        self.assertTrue(str(x))
+        assert x.status is not None
+        assert x.reason
+        assert str(x)
 
 
     def setup(self):
     def setup(self):
         self.out = WhateverIO()
         self.out = WhateverIO()
@@ -83,58 +83,52 @@ class test_Command(AppCase):
             pass
             pass
 
 
         self.cmd.run = ok_run
         self.cmd.run = ok_run
-        self.assertEqual(self.cmd(), EX_OK)
+        assert self.cmd() == EX_OK
 
 
         def error_run():
         def error_run():
             raise Error('error', EX_FAILURE)
             raise Error('error', EX_FAILURE)
         self.cmd.run = error_run
         self.cmd.run = error_run
-        self.assertEqual(self.cmd(), EX_FAILURE)
+        assert self.cmd() == EX_FAILURE
 
 
     def test_run_from_argv(self):
     def test_run_from_argv(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             self.cmd.run_from_argv('prog', ['foo', 'bar'])
             self.cmd.run_from_argv('prog', ['foo', 'bar'])
 
 
     def test_pretty_list(self):
     def test_pretty_list(self):
-        self.assertEqual(self.cmd.pretty([])[1], '- empty -')
-        self.assertIn('bar', self.cmd.pretty(['foo', 'bar'])[1])
+        assert self.cmd.pretty([])[1] == '- empty -'
+        assert 'bar', self.cmd.pretty(['foo' in 'bar'][1])
 
 
-    def test_pretty_dict(self):
-        self.assertIn(
-            'OK',
-            str(self.cmd.pretty({'ok': 'the quick brown fox'})[0]),
-        )
-        self.assertIn(
-            'ERROR',
-            str(self.cmd.pretty({'error': 'the quick brown fox'})[0]),
-        )
+    def test_pretty_dict(self, text='the quick brown fox'):
+        assert 'OK' in str(self.cmd.pretty({'ok': text})[0])
+        assert 'ERROR' in str(self.cmd.pretty({'error': text})[0])
 
 
     def test_pretty(self):
     def test_pretty(self):
-        self.assertIn('OK', str(self.cmd.pretty('the quick brown')))
-        self.assertIn('OK', str(self.cmd.pretty(object())))
-        self.assertIn('OK', str(self.cmd.pretty({'foo': 'bar'})))
+        assert 'OK' in str(self.cmd.pretty('the quick brown'))
+        assert 'OK' in str(self.cmd.pretty(object()))
+        assert 'OK' in str(self.cmd.pretty({'foo': 'bar'}))
 
 
 
 
-class test_list(AppCase):
+class test_list:
 
 
     def test_list_bindings_no_support(self):
     def test_list_bindings_no_support(self):
         l = list_(app=self.app, stderr=WhateverIO())
         l = list_(app=self.app, stderr=WhateverIO())
         management = Mock()
         management = Mock()
         management.get_bindings.side_effect = NotImplementedError()
         management.get_bindings.side_effect = NotImplementedError()
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.list_bindings(management)
             l.list_bindings(management)
 
 
     def test_run(self):
     def test_run(self):
         l = list_(app=self.app, stderr=WhateverIO())
         l = list_(app=self.app, stderr=WhateverIO())
         l.run('bindings')
         l.run('bindings')
 
 
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.run(None)
             l.run(None)
 
 
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             l.run('foo')
             l.run('foo')
 
 
 
 
-class test_call(AppCase):
+class test_call:
 
 
     def setup(self):
     def setup(self):
 
 
@@ -152,22 +146,22 @@ class test_call(AppCase):
         a.run(self.add.name,
         a.run(self.add.name,
               args=dumps([4, 4]),
               args=dumps([4, 4]),
               kwargs=dumps({'x': 2, 'y': 2}))
               kwargs=dumps({'x': 2, 'y': 2}))
-        self.assertEqual(send_task.call_args[1]['args'], [4, 4])
-        self.assertEqual(send_task.call_args[1]['kwargs'], {'x': 2, 'y': 2})
+        assert send_task.call_args[1]['args'], [4 == 4]
+        assert send_task.call_args[1]['kwargs'] == {'x': 2, 'y': 2}
 
 
         a.run(self.add.name, expires=10, countdown=10)
         a.run(self.add.name, expires=10, countdown=10)
-        self.assertEqual(send_task.call_args[1]['expires'], 10)
-        self.assertEqual(send_task.call_args[1]['countdown'], 10)
+        assert send_task.call_args[1]['expires'] == 10
+        assert send_task.call_args[1]['countdown'] == 10
 
 
         now = datetime.now()
         now = datetime.now()
         iso = now.isoformat()
         iso = now.isoformat()
         a.run(self.add.name, expires=iso)
         a.run(self.add.name, expires=iso)
-        self.assertEqual(send_task.call_args[1]['expires'], now)
-        with self.assertRaises(ValueError):
+        assert send_task.call_args[1]['expires'] == now
+        with pytest.raises(ValueError):
             a.run(self.add.name, expires='foobaribazibar')
             a.run(self.add.name, expires='foobaribazibar')
 
 
 
 
-class test_purge(AppCase):
+class test_purge:
 
 
     def test_run(self):
     def test_run(self):
         out = WhateverIO()
         out = WhateverIO()
@@ -175,11 +169,11 @@ class test_purge(AppCase):
         a._purge = Mock(name='_purge')
         a._purge = Mock(name='_purge')
         a._purge.return_value = 0
         a._purge.return_value = 0
         a.run(force=True)
         a.run(force=True)
-        self.assertIn('No messages purged', out.getvalue())
+        assert 'No messages purged' in out.getvalue()
 
 
         a._purge.return_value = 100
         a._purge.return_value = 100
         a.run(force=True)
         a.run(force=True)
-        self.assertIn('100 messages', out.getvalue())
+        assert '100 messages' in out.getvalue()
 
 
         a.out = Mock(name='out')
         a.out = Mock(name='out')
         a.ask = Mock(name='ask')
         a.ask = Mock(name='ask')
@@ -189,7 +183,7 @@ class test_purge(AppCase):
         a.run(force=False)
         a.run(force=False)
 
 
 
 
-class test_result(AppCase):
+class test_result:
 
 
     def setup(self):
     def setup(self):
 
 
@@ -204,18 +198,18 @@ class test_result(AppCase):
             r = result(app=self.app, stdout=out)
             r = result(app=self.app, stdout=out)
             get.return_value = 'Jerry'
             get.return_value = 'Jerry'
             r.run('id')
             r.run('id')
-            self.assertIn('Jerry', out.getvalue())
+            assert 'Jerry' in out.getvalue()
 
 
             get.return_value = 'Elaine'
             get.return_value = 'Elaine'
             r.run('id', task=self.add.name)
             r.run('id', task=self.add.name)
-            self.assertIn('Elaine', out.getvalue())
+            assert 'Elaine' in out.getvalue()
 
 
             with patch('celery.result.AsyncResult.traceback') as tb:
             with patch('celery.result.AsyncResult.traceback') as tb:
                 r.run('id', task=self.add.name, traceback=True)
                 r.run('id', task=self.add.name, traceback=True)
-                self.assertIn(str(tb), out.getvalue())
+                assert str(tb) in out.getvalue()
 
 
 
 
-class test_status(AppCase):
+class test_status:
 
 
     @patch('celery.bin.celery.inspect')
     @patch('celery.bin.celery.inspect')
     def test_run(self, inspect_):
     def test_run(self, inspect_):
@@ -223,22 +217,22 @@ class test_status(AppCase):
         ins = inspect_.return_value = Mock()
         ins = inspect_.return_value = Mock()
         ins.run.return_value = []
         ins.run.return_value = []
         s = status(self.app, stdout=out, stderr=err)
         s = status(self.app, stdout=out, stderr=err)
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             s.run()
             s.run()
 
 
         ins.run.return_value = ['a', 'b', 'c']
         ins.run.return_value = ['a', 'b', 'c']
         s.run()
         s.run()
-        self.assertIn('3 nodes online', out.getvalue())
+        assert '3 nodes online' in out.getvalue()
         s.run(quiet=True)
         s.run(quiet=True)
 
 
 
 
-class test_migrate(AppCase):
+class test_migrate:
 
 
     @patch('celery.contrib.migrate.migrate_tasks')
     @patch('celery.contrib.migrate.migrate_tasks')
     def test_run(self, migrate_tasks):
     def test_run(self, migrate_tasks):
         out = WhateverIO()
         out = WhateverIO()
         m = migrate(app=self.app, stdout=out, stderr=WhateverIO())
         m = migrate(app=self.app, stdout=out, stderr=WhateverIO())
-        with self.assertRaises(TypeError):
+        with pytest.raises(TypeError):
             m.run()
             m.run()
         migrate_tasks.assert_not_called()
         migrate_tasks.assert_not_called()
 
 
@@ -249,61 +243,61 @@ class test_migrate(AppCase):
         state.count = 10
         state.count = 10
         state.strtotal = 30
         state.strtotal = 30
         m.on_migrate_task(state, {'task': 'tasks.add', 'id': 'ID'}, None)
         m.on_migrate_task(state, {'task': 'tasks.add', 'id': 'ID'}, None)
-        self.assertIn('10/30', out.getvalue())
+        assert '10/30' in out.getvalue()
 
 
 
 
-class test_report(AppCase):
+class test_report:
 
 
     def test_run(self):
     def test_run(self):
         out = WhateverIO()
         out = WhateverIO()
         r = report(app=self.app, stdout=out)
         r = report(app=self.app, stdout=out)
-        self.assertEqual(r.run(), EX_OK)
-        self.assertTrue(out.getvalue())
+        assert r.run() == EX_OK
+        assert out.getvalue()
 
 
 
 
-class test_help(AppCase):
+class test_help:
 
 
     def test_run(self):
     def test_run(self):
         out = WhateverIO()
         out = WhateverIO()
         h = help(app=self.app, stdout=out)
         h = help(app=self.app, stdout=out)
         h.parser = Mock()
         h.parser = Mock()
-        self.assertEqual(h.run(), EX_USAGE)
-        self.assertTrue(out.getvalue())
-        self.assertTrue(h.usage('help'))
+        assert h.run() == EX_USAGE
+        assert out.getvalue()
+        assert h.usage('help')
         h.parser.print_help.assert_called_with()
         h.parser.print_help.assert_called_with()
 
 
 
 
-class test_CeleryCommand(AppCase):
+class test_CeleryCommand:
 
 
     def test_execute_from_commandline(self):
     def test_execute_from_commandline(self):
         x = CeleryCommand(app=self.app)
         x = CeleryCommand(app=self.app)
         x.handle_argv = Mock()
         x.handle_argv = Mock()
         x.handle_argv.return_value = 1
         x.handle_argv.return_value = 1
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
             x.execute_from_commandline()
 
 
         x.handle_argv.return_value = True
         x.handle_argv.return_value = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
             x.execute_from_commandline()
 
 
         x.handle_argv.side_effect = KeyboardInterrupt()
         x.handle_argv.side_effect = KeyboardInterrupt()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline()
             x.execute_from_commandline()
 
 
         x.respects_app_option = True
         x.respects_app_option = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline(['celery', 'multi'])
             x.execute_from_commandline(['celery', 'multi'])
-        self.assertFalse(x.respects_app_option)
+        assert not x.respects_app_option
         x.respects_app_option = True
         x.respects_app_option = True
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             x.execute_from_commandline(['manage.py', 'celery', 'multi'])
             x.execute_from_commandline(['manage.py', 'celery', 'multi'])
-        self.assertFalse(x.respects_app_option)
+        assert not x.respects_app_option
 
 
     def test_with_pool_option(self):
     def test_with_pool_option(self):
         x = CeleryCommand(app=self.app)
         x = CeleryCommand(app=self.app)
-        self.assertIsNone(x.with_pool_option(['celery', 'events']))
-        self.assertTrue(x.with_pool_option(['celery', 'worker']))
-        self.assertTrue(x.with_pool_option(['manage.py', 'celery', 'worker']))
+        assert x.with_pool_option(['celery', 'events']) is None
+        assert x.with_pool_option(['celery', 'worker'])
+        assert x.with_pool_option(['manage.py', 'celery', 'worker'])
 
 
     def test_load_extensions_no_commands(self):
     def test_load_extensions_no_commands(self):
         with patch('celery.bin.celery.Extensions') as Ext:
         with patch('celery.bin.celery.Extensions') as Ext:
@@ -327,35 +321,33 @@ class test_CeleryCommand(AppCase):
                 mod.command_classes = prev
                 mod.command_classes = prev
 
 
     def test_determine_exit_status(self):
     def test_determine_exit_status(self):
-        self.assertEqual(determine_exit_status('true'), EX_OK)
-        self.assertEqual(determine_exit_status(''), EX_FAILURE)
+        assert determine_exit_status('true') == EX_OK
+        assert determine_exit_status('') == EX_FAILURE
 
 
     def test_relocate_args_from_start(self):
     def test_relocate_args_from_start(self):
         x = CeleryCommand(app=self.app)
         x = CeleryCommand(app=self.app)
-        self.assertEqual(x._relocate_args_from_start(None), [])
-        self.assertEqual(
-            x._relocate_args_from_start(
-                ['-l', 'debug', 'worker', '-c', '3', '--foo'],
-            ),
-            ['worker', '-c', '3', '--foo', '-l', 'debug'],
-        )
-        self.assertEqual(
-            x._relocate_args_from_start(
-                ['--pool=gevent', '-l', 'debug', 'worker', '--foo', '-c', '3'],
-            ),
-            ['worker', '--foo', '-c', '3', '--pool=gevent', '-l', 'debug'],
-        )
-        self.assertEqual(
-            x._relocate_args_from_start(['foo', '--foo=1']),
-            ['foo', '--foo=1'],
-        )
+        assert x._relocate_args_from_start(None) == []
+        relargs1 = x._relocate_args_from_start([
+            '-l', 'debug', 'worker', '-c', '3', '--foo',
+        ])
+        assert relargs1 == ['worker', '-c', '3', '--foo', '-l', 'debug']
+        relargs2 = x._relocate_args_from_start([
+            '--pool=gevent', '-l', 'debug', 'worker', '--foo', '-c', '3',
+        ])
+        assert relargs2 == [
+            'worker', '--foo', '-c', '3',
+            '--pool=gevent', '-l', 'debug',
+        ]
+        assert x._relocate_args_from_start(['foo', '--foo=1']) == [
+            'foo', '--foo=1',
+        ]
 
 
     def test_register_command(self):
     def test_register_command(self):
         prev, CeleryCommand.commands = dict(CeleryCommand.commands), {}
         prev, CeleryCommand.commands = dict(CeleryCommand.commands), {}
         try:
         try:
             fun = Mock(name='fun')
             fun = Mock(name='fun')
             CeleryCommand.register_command(fun, name='foo')
             CeleryCommand.register_command(fun, name='foo')
-            self.assertIs(CeleryCommand.commands['foo'], fun)
+            assert CeleryCommand.commands['foo'] is fun
         finally:
         finally:
             CeleryCommand.commands = prev
             CeleryCommand.commands = prev
 
 
@@ -411,46 +403,42 @@ class test_CeleryCommand(AppCase):
         main = Mock(name='__main__')
         main = Mock(name='__main__')
         main.__file__ = '/opt/foo.py'
         main.__file__ = '/opt/foo.py'
         with patch.dict(sys.modules, __main__=main):
         with patch.dict(sys.modules, __main__=main):
-            self.assertEqual(x.prepare_prog_name('__main__.py'), '/opt/foo.py')
-            self.assertEqual(x.prepare_prog_name('celery'), 'celery')
+            assert x.prepare_prog_name('__main__.py') == '/opt/foo.py'
+            assert x.prepare_prog_name('celery') == 'celery'
 
 
 
 
-class test_RemoteControl(AppCase):
+class test_RemoteControl:
 
 
     def test_call_interface(self):
     def test_call_interface(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             _RemoteControl(app=self.app).call()
             _RemoteControl(app=self.app).call()
 
 
 
 
-class test_inspect(AppCase):
+class test_inspect:
 
 
     def test_usage(self):
     def test_usage(self):
-        self.assertTrue(inspect(app=self.app).usage('foo'))
+        assert inspect(app=self.app).usage('foo')
 
 
     def test_command_info(self):
     def test_command_info(self):
         i = inspect(app=self.app)
         i = inspect(app=self.app)
-        self.assertTrue(i.get_command_info(
+        assert i.get_command_info(
             'ping', help=True, color=i.colored.red, app=self.app,
             'ping', help=True, color=i.colored.red, app=self.app,
-        ))
+        )
 
 
     def test_list_commands_color(self):
     def test_list_commands_color(self):
         i = inspect(app=self.app)
         i = inspect(app=self.app)
-        self.assertTrue(i.list_commands(
-            help=True, color=i.colored.red, app=self.app,
-        ))
-        self.assertTrue(i.list_commands(
-            help=False, color=None, app=self.app,
-        ))
+        assert i.list_commands(help=True, color=i.colored.red, app=self.app)
+        assert i.list_commands(help=False, color=None, app=self.app)
 
 
     def test_epilog(self):
     def test_epilog(self):
-        self.assertTrue(inspect(app=self.app).epilog)
+        assert inspect(app=self.app).epilog
 
 
     def test_do_call_method_sql_transport_type(self):
     def test_do_call_method_sql_transport_type(self):
         self.app.connection = Mock()
         self.app.connection = Mock()
         conn = self.app.connection.return_value = Mock(name='Connection')
         conn = self.app.connection.return_value = Mock(name='Connection')
         conn.transport.driver_type = 'sql'
         conn.transport.driver_type = 'sql'
         i = inspect(app=self.app)
         i = inspect(app=self.app)
-        with self.assertRaises(i.Error):
+        with pytest.raises(i.Error):
             i.do_call_method(['ping'])
             i.do_call_method(['ping'])
 
 
     def test_say_directions(self):
     def test_say_directions(self):
@@ -472,22 +460,22 @@ class test_inspect(AppCase):
     def test_run(self, real):
     def test_run(self, real):
         out = WhateverIO()
         out = WhateverIO()
         i = inspect(app=self.app, stdout=out)
         i = inspect(app=self.app, stdout=out)
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run()
             i.run()
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('help')
             i.run('help')
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('xyzzybaz')
             i.run('xyzzybaz')
 
 
         i.run('ping')
         i.run('ping')
         real.assert_called()
         real.assert_called()
         i.run('ping', destination='foo,bar')
         i.run('ping', destination='foo,bar')
-        self.assertEqual(real.call_args[1]['destination'], ['foo', 'bar'])
-        self.assertEqual(real.call_args[1]['timeout'], 0.2)
+        assert real.call_args[1]['destination'], ['foo' == 'bar']
+        assert real.call_args[1]['timeout'] == 0.2
         callback = real.call_args[1]['callback']
         callback = real.call_args[1]['callback']
 
 
         callback({'foo': {'ok': 'pong'}})
         callback({'foo': {'ok': 'pong'}})
-        self.assertIn('OK', out.getvalue())
+        assert 'OK' in out.getvalue()
 
 
         with patch('celery.bin.celery.dumps') as dumps:
         with patch('celery.bin.celery.dumps') as dumps:
             i.run('ping', json=True)
             i.run('ping', json=True)
@@ -495,17 +483,17 @@ class test_inspect(AppCase):
 
 
         instance = real.return_value = Mock()
         instance = real.return_value = Mock()
         instance._request.return_value = None
         instance._request.return_value = None
-        with self.assertRaises(Error):
+        with pytest.raises(Error):
             i.run('ping')
             i.run('ping')
 
 
         out.seek(0)
         out.seek(0)
         out.truncate()
         out.truncate()
         i.quiet = True
         i.quiet = True
         i.say_chat('<-', 'hello')
         i.say_chat('<-', 'hello')
-        self.assertFalse(out.getvalue())
+        assert not out.getvalue()
 
 
 
 
-class test_control(AppCase):
+class test_control:
 
 
     def control(self, patch_call, *args, **kwargs):
     def control(self, patch_call, *args, **kwargs):
         kwargs.setdefault('app', Mock(name='app'))
         kwargs.setdefault('app', Mock(name='app'))
@@ -521,10 +509,10 @@ class test_control(AppCase):
             'foo', arguments={'kw': 2}, reply=True)
             'foo', arguments={'kw': 2}, reply=True)
 
 
 
 
-class test_multi(AppCase):
+class test_multi:
 
 
     def test_get_options(self):
     def test_get_options(self):
-        self.assertIsNone(multi(app=self.app).get_options())
+        assert multi(app=self.app).get_options() is None
 
 
     def test_run_from_argv(self):
     def test_run_from_argv(self):
         with patch('celery.bin.multi.MultiTool') as MultiTool:
         with patch('celery.bin.multi.MultiTool') as MultiTool:
@@ -535,7 +523,7 @@ class test_multi(AppCase):
             )
             )
 
 
 
 
-class test_main(AppCase):
+class test_main:
 
 
     @patch('celery.bin.celery.CeleryCommand')
     @patch('celery.bin.celery.CeleryCommand')
     def test_main(self, Command):
     def test_main(self, Command):
@@ -551,11 +539,11 @@ class test_main(AppCase):
         cmd.execute_from_commandline.assert_called_with(None)
         cmd.execute_from_commandline.assert_called_with(None)
 
 
 
 
-class test_compat(AppCase):
+class test_compat:
 
 
     def test_compat_command_decorator(self):
     def test_compat_command_decorator(self):
         with patch('celery.bin.celery.CeleryCommand') as CC:
         with patch('celery.bin.celery.CeleryCommand') as CC:
-            self.assertEqual(command(), CC.register_command)
+            assert command() == CC.register_command
             fun = Mock(name='fun')
             fun = Mock(name='fun')
             command(fun)
             command(fun)
             CC.register_command.assert_called_with(fun)
             CC.register_command.assert_called_with(fun)

+ 27 - 21
celery/tests/bin/test_celeryd_detach.py → t/unit/bin/test_celeryd_detach.py

@@ -1,5 +1,9 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
+from case import Mock, mock, patch
+
 from celery.platforms import IS_WINDOWS
 from celery.platforms import IS_WINDOWS
 from celery.bin.celeryd_detach import (
 from celery.bin.celeryd_detach import (
     detach,
     detach,
@@ -7,11 +11,9 @@ from celery.bin.celeryd_detach import (
     main,
     main,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, mock, patch
-
 
 
 if not IS_WINDOWS:
 if not IS_WINDOWS:
-    class test_detached(AppCase):
+    class test_detached:
 
 
         @patch('celery.bin.celeryd_detach.detached')
         @patch('celery.bin.celeryd_detach.detached')
         @patch('os.execv')
         @patch('os.execv')
@@ -44,9 +46,9 @@ if not IS_WINDOWS:
             logger.critical.assert_called()
             logger.critical.assert_called()
             setup_logs.assert_called_with(
             setup_logs.assert_called_with(
                 'ERROR', '/var/log', hostname='foo@example.com')
                 'ERROR', '/var/log', hostname='foo@example.com')
-            self.assertEqual(r, 1)
+            assert r == 1
 
 
-            self.patch('celery.current_app')
+            self.patching('celery.current_app')
             from celery import current_app
             from celery import current_app
             r = detach(
             r = detach(
                 '/bin/boo', ['a', 'b', 'c'],
                 '/bin/boo', ['a', 'b', 'c'],
@@ -57,7 +59,7 @@ if not IS_WINDOWS:
             )
             )
 
 
 
 
-class test_PartialOptionParser(AppCase):
+class test_PartialOptionParser:
 
 
     def test_parser(self):
     def test_parser(self):
         x = detached_celeryd(self.app)
         x = detached_celeryd(self.app)
@@ -66,23 +68,23 @@ class test_PartialOptionParser(AppCase):
             '--logfile=foo', '--fake', '--enable',
             '--logfile=foo', '--fake', '--enable',
             'a', 'b', '-c1', '-d', '2',
             'a', 'b', '-c1', '-d', '2',
         ])
         ])
-        self.assertEqual(options.logfile, 'foo')
-        self.assertEqual(values, ['a', 'b'])
-        self.assertEqual(p.leftovers, ['--enable', '-c1', '-d', '2'])
+        assert options.logfile == 'foo'
+        assert values, ['a' == 'b']
+        assert p.leftovers, ['--enable', '-c1', '-d' == '2']
         options, values = p.parse_args([
         options, values = p.parse_args([
             '--fake', '--enable',
             '--fake', '--enable',
             '--pidfile=/var/pid/foo.pid',
             '--pidfile=/var/pid/foo.pid',
             'a', 'b', '-c1', '-d', '2',
             'a', 'b', '-c1', '-d', '2',
         ])
         ])
-        self.assertEqual(options.pidfile, '/var/pid/foo.pid')
+        assert options.pidfile == '/var/pid/foo.pid'
 
 
         with mock.stdouts():
         with mock.stdouts():
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--logfile'])
                 p.parse_args(['--logfile'])
             p.get_option('--logfile').nargs = 2
             p.get_option('--logfile').nargs = 2
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--logfile=a'])
                 p.parse_args(['--logfile=a'])
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 p.parse_args(['--fake=abc'])
                 p.parse_args(['--fake=abc'])
 
 
         assert p.get_option('--logfile').nargs == 2
         assert p.get_option('--logfile').nargs == 2
@@ -90,18 +92,22 @@ class test_PartialOptionParser(AppCase):
         p.get_option('--logfile').nargs = 1
         p.get_option('--logfile').nargs = 1
 
 
 
 
-class test_Command(AppCase):
-    argv = ['--foobar=10,2', '-c', '1',
-            '--logfile=/var/log', '-lDEBUG',
-            '--', '.disable_rate_limits=1']
+class test_Command:
+    argv = [
+        '--foobar=10,2', '-c', '1',
+        '--logfile=/var/log', '-lDEBUG',
+        '--', '.disable_rate_limits=1',
+    ]
 
 
     def test_parse_options(self):
     def test_parse_options(self):
         x = detached_celeryd(app=self.app)
         x = detached_celeryd(app=self.app)
         o, v, l = x.parse_options('cd', self.argv)
         o, v, l = x.parse_options('cd', self.argv)
-        self.assertEqual(o.logfile, '/var/log')
-        self.assertEqual(l, ['--foobar=10,2', '-c', '1',
-                             '-lDEBUG', '--logfile=/var/log',
-                             '--pidfile=celeryd.pid'])
+        assert o.logfile == '/var/log'
+        assert l == [
+            '--foobar=10,2', '-c', '1',
+            '-lDEBUG', '--logfile=/var/log',
+            '--pidfile=celeryd.pid',
+        ]
         x.parse_options('cd', [])  # no args
         x.parse_options('cd', [])  # no args
 
 
     @patch('sys.exit')
     @patch('sys.exit')

+ 7 - 7
celery/tests/bin/test_celeryevdump.py → t/unit/bin/test_celeryevdump.py

@@ -2,6 +2,8 @@ from __future__ import absolute_import, unicode_literals
 
 
 from time import time
 from time import time
 
 
+from case import Mock, patch
+
 from celery.events.dumper import (
 from celery.events.dumper import (
     humanize_type,
     humanize_type,
     Dumper,
     Dumper,
@@ -9,23 +11,21 @@ from celery.events.dumper import (
 )
 )
 from celery.five import WhateverIO
 from celery.five import WhateverIO
 
 
-from celery.tests.case import AppCase, Mock, patch
-
 
 
-class test_Dumper(AppCase):
+class test_Dumper:
 
 
     def setup(self):
     def setup(self):
         self.out = WhateverIO()
         self.out = WhateverIO()
         self.dumper = Dumper(out=self.out)
         self.dumper = Dumper(out=self.out)
 
 
     def test_humanize_type(self):
     def test_humanize_type(self):
-        self.assertEqual(humanize_type('worker-offline'), 'shutdown')
-        self.assertEqual(humanize_type('task-started'), 'task started')
+        assert humanize_type('worker-offline') == 'shutdown'
+        assert humanize_type('task-started') == 'task started'
 
 
     def test_format_task_event(self):
     def test_format_task_event(self):
         self.dumper.format_task_event(
         self.dumper.format_task_event(
             'worker@example.com', time(), 'task-started', 'tasks.add', {})
             'worker@example.com', time(), 'task-started', 'tasks.add', {})
-        self.assertTrue(self.out.getvalue())
+        assert self.out.getvalue()
 
 
     def test_on_event(self):
     def test_on_event(self):
         event = {
         event = {
@@ -37,7 +37,7 @@ class test_Dumper(AppCase):
             'kwargs': '{}',
             'kwargs': '{}',
         }
         }
         self.dumper.on_event(dict(event, type='task-received'))
         self.dumper.on_event(dict(event, type='task-received'))
-        self.assertTrue(self.out.getvalue())
+        assert self.out.getvalue()
         self.dumper.on_event(dict(event, type='task-revoked'))
         self.dumper.on_event(dict(event, type='task-revoked'))
         self.dumper.on_event(dict(event, type='worker-online'))
         self.dumper.on_event(dict(event, type='worker-online'))
 
 

+ 35 - 14
celery/tests/bin/test_events.py → t/unit/bin/test_events.py

@@ -1,8 +1,29 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import importlib
+
+from functools import wraps
+
+from case import patch, skip
+
 from celery.bin import events
 from celery.bin import events
 
 
-from celery.tests.case import AppCase, patch, _old_patch, skip
+
+def _old_patch(module, name, mocked):
+    module = importlib.import_module(module)
+
+    def _patch(fun):
+
+        @wraps(fun)
+        def __patched(*args, **kwargs):
+            prev = getattr(module, name)
+            setattr(module, name, mocked)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                setattr(module, name, prev)
+        return __patched
+    return _patch
 
 
 
 
 class MockCommand(object):
 class MockCommand(object):
@@ -17,7 +38,7 @@ def proctitle(prog, info=None):
 proctitle.last = ()
 proctitle.last = ()
 
 
 
 
-class test_events(AppCase):
+class test_events:
 
 
     def setup(self):
     def setup(self):
         self.ev = events.events(app=self.app)
         self.ev = events.events(app=self.app)
@@ -26,8 +47,8 @@ class test_events(AppCase):
                 lambda **kw: 'me dumper, you?')
                 lambda **kw: 'me dumper, you?')
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     def test_run_dump(self):
     def test_run_dump(self):
-        self.assertEqual(self.ev.run(dump=True), 'me dumper, you?')
-        self.assertIn('celery events:dump', proctitle.last[0])
+        assert self.ev.run(dump=True), 'me dumper == you?'
+        assert 'celery events:dump' in proctitle.last[0]
 
 
     @skip.unless_module('curses', import_errors=(ImportError, OSError))
     @skip.unless_module('curses', import_errors=(ImportError, OSError))
     def test_run_top(self):
     def test_run_top(self):
@@ -35,8 +56,8 @@ class test_events(AppCase):
                     lambda **kw: 'me top, you?')
                     lambda **kw: 'me top, you?')
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)
         @_old_patch('celery.bin.events', 'set_process_title', proctitle)
         def _inner():
         def _inner():
-            self.assertEqual(self.ev.run(), 'me top, you?')
-            self.assertIn('celery events:top', proctitle.last[0])
+            assert self.ev.run(), 'me top == you?'
+            assert 'celery events:top' in proctitle.last[0]
         return _inner()
         return _inner()
 
 
     @_old_patch('celery.events.snapshot', 'evcam',
     @_old_patch('celery.events.snapshot', 'evcam',
@@ -44,12 +65,12 @@ class test_events(AppCase):
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     @_old_patch('celery.bin.events', 'set_process_title', proctitle)
     def test_run_cam(self):
     def test_run_cam(self):
         a, kw = self.ev.run(camera='foo.bar.baz', logfile='logfile')
         a, kw = self.ev.run(camera='foo.bar.baz', logfile='logfile')
-        self.assertEqual(a[0], 'foo.bar.baz')
-        self.assertEqual(kw['freq'], 1.0)
-        self.assertIsNone(kw['maxrate'])
-        self.assertEqual(kw['loglevel'], 'INFO')
-        self.assertEqual(kw['logfile'], 'logfile')
-        self.assertIn('celery events:cam', proctitle.last[0])
+        assert a[0] == 'foo.bar.baz'
+        assert kw['freq'] == 1.0
+        assert kw['maxrate'] is None
+        assert kw['loglevel'] == 'INFO'
+        assert kw['logfile'] == 'logfile'
+        assert 'celery events:cam' in proctitle.last[0]
 
 
     @patch('celery.events.snapshot.evcam')
     @patch('celery.events.snapshot.evcam')
     @patch('celery.bin.events.detached')
     @patch('celery.bin.events.detached')
@@ -60,10 +81,10 @@ class test_events(AppCase):
         evcam.assert_called()
         evcam.assert_called()
 
 
     def test_get_options(self):
     def test_get_options(self):
-        self.assertFalse(self.ev.get_options())
+        assert not self.ev.get_options()
 
 
     @_old_patch('celery.bin.events', 'events', MockCommand)
     @_old_patch('celery.bin.events', 'events', MockCommand)
     def test_main(self):
     def test_main(self):
         MockCommand.executed = []
         MockCommand.executed = []
         events.main()
         events.main()
-        self.assertTrue(MockCommand.executed)
+        assert MockCommand.executed

+ 63 - 84
celery/tests/bin/test_multi.py → t/unit/bin/test_multi.py

@@ -1,19 +1,16 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
 import signal
 import signal
 import sys
 import sys
 
 
-from celery.bin.multi import (
-    main,
-    MultiTool,
-    __doc__ as doc,
-)
-from celery.five import WhateverIO
+from case import Mock, patch
 
 
-from celery.tests.case import AppCase, Mock, patch
+from celery.bin.multi import main, MultiTool, __doc__ as doc
+from celery.five import WhateverIO
 
 
 
 
-class test_MultiTool(AppCase):
+class test_MultiTool:
 
 
     def setup(self):
     def setup(self):
         self.fh = WhateverIO()
         self.fh = WhateverIO()
@@ -47,7 +44,7 @@ class test_MultiTool(AppCase):
     def assert_sig_argument(self, args, expected):
     def assert_sig_argument(self, args, expected):
         p = self.t.OptionParser(args)
         p = self.t.OptionParser(args)
         p.parse()
         p.parse()
-        self.assertEqual(self.t._find_sig_argument(p), expected)
+        assert self.t._find_sig_argument(p) == expected
 
 
     def test_execute_from_commandline(self):
     def test_execute_from_commandline(self):
         self.t.call_command = Mock(name='call_command')
         self.t.call_command = Mock(name='call_command')
@@ -55,51 +52,43 @@ class test_MultiTool(AppCase):
             'multi start --verbose 10 --foo'.split(),
             'multi start --verbose 10 --foo'.split(),
             cmd='X',
             cmd='X',
         )
         )
-        self.assertEqual(self.t.cmd, 'X')
-        self.assertEqual(self.t.prog_name, 'multi')
+        assert self.t.cmd == 'X'
+        assert self.t.prog_name == 'multi'
         self.t.call_command.assert_called_with('start', ['10', '--foo'])
         self.t.call_command.assert_called_with('start', ['10', '--foo'])
 
 
     def test_execute_from_commandline__arguments(self):
     def test_execute_from_commandline__arguments(self):
-        self.assertTrue(self.t.execute_from_commandline('multi'.split()))
-        self.assertTrue(self.t.execute_from_commandline('multi -bar'.split()))
+        assert self.t.execute_from_commandline('multi'.split())
+        assert self.t.execute_from_commandline('multi -bar'.split())
 
 
     def test_call_command(self):
     def test_call_command(self):
         cmd = self.t.commands['foo'] = Mock(name='foo')
         cmd = self.t.commands['foo'] = Mock(name='foo')
         self.t.retcode = 303
         self.t.retcode = 303
-        self.assertIs(
-            self.t.call_command('foo', ['1', '2', '--foo=3']),
-            cmd.return_value,
-        )
+        assert (self.t.call_command('foo', ['1', '2', '--foo=3']) is
+                cmd.return_value)
         cmd.assert_called_with('1', '2', '--foo=3')
         cmd.assert_called_with('1', '2', '--foo=3')
 
 
     def test_call_command__error(self):
     def test_call_command__error(self):
-        self.assertEqual(
-            self.t.call_command('asdqwewqe', ['1', '2']),
-            1,
-        )
+        assert self.t.call_command('asdqwewqe', ['1', '2']) == 1
         self.t.carp.assert_called()
         self.t.carp.assert_called()
 
 
     def test_handle_reserved_options(self):
     def test_handle_reserved_options(self):
-        self.assertListEqual(
-            self.t._handle_reserved_options(
-                ['a', '-q', 'b', '--no-color', 'c']),
-            ['a', 'b', 'c'],
-        )
+        assert self.t._handle_reserved_options(
+            ['a', '-q', 'b', '--no-color', 'c']) == ['a', 'b', 'c']
 
 
     def test_start(self):
     def test_start(self):
         self.cluster.start.return_value = [0, 0, 1, 0]
         self.cluster.start.return_value = [0, 0, 1, 0]
-        self.assertTrue(self.t.start('10', '-A', 'proj'))
+        assert self.t.start('10', '-A', 'proj')
         self.t.splash.assert_called_with()
         self.t.splash.assert_called_with()
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.cluster.start.assert_called_with()
         self.cluster.start.assert_called_with()
 
 
     def test_start__exitcodes(self):
     def test_start__exitcodes(self):
         self.cluster.start.return_value = [0, 0, 0]
         self.cluster.start.return_value = [0, 0, 0]
-        self.assertFalse(self.t.start('foo', 'bar', 'baz'))
+        assert not self.t.start('foo', 'bar', 'baz')
         self.cluster.start.assert_called_with()
         self.cluster.start.assert_called_with()
 
 
         self.cluster.start.return_value = [0, 1, 0]
         self.cluster.start.return_value = [0, 1, 0]
-        self.assertTrue(self.t.start('foo', 'bar', 'baz'))
+        assert self.t.start('foo', 'bar', 'baz')
 
 
     def test_stop(self):
     def test_stop(self):
         self.t.stop('10', '-A', 'proj', retry=3)
         self.t.stop('10', '-A', 'proj', retry=3)
@@ -130,17 +119,15 @@ class test_MultiTool(AppCase):
     def test_get(self):
     def test_get(self):
         node = self.cluster.find.return_value = Mock(name='node')
         node = self.cluster.find.return_value = Mock(name='node')
         node.argv = ['A', 'B', 'C']
         node.argv = ['A', 'B', 'C']
-        self.assertIs(
-            self.t.get('wanted', '10', '-A', 'proj'),
-            self.t.ok.return_value,
-        )
+        assert (self.t.get('wanted', '10', '-A', 'proj') is
+                self.t.ok.return_value)
         self.cluster.find.assert_called_with('wanted')
         self.cluster.find.assert_called_with('wanted')
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.t.cluster_from_argv.assert_called_with(('10', '-A', 'proj'))
         self.t.ok.assert_called_with(' '.join(node.argv))
         self.t.ok.assert_called_with(' '.join(node.argv))
 
 
     def test_get__KeyError(self):
     def test_get__KeyError(self):
         self.cluster.find.side_effect = KeyError()
         self.cluster.find.side_effect = KeyError()
-        self.assertTrue(self.t.get('wanted', '10', '-A', 'proj'))
+        assert self.t.get('wanted', '10', '-A', 'proj')
 
 
     def test_show(self):
     def test_show(self):
         nodes = self.t.cluster_from_argv.return_value = [
         nodes = self.t.cluster_from_argv.return_value = [
@@ -150,10 +137,7 @@ class test_MultiTool(AppCase):
         nodes[0].argv_with_executable = ['python', 'foo', 'bar']
         nodes[0].argv_with_executable = ['python', 'foo', 'bar']
         nodes[1].argv_with_executable = ['python', 'xuzzy', 'baz']
         nodes[1].argv_with_executable = ['python', 'xuzzy', 'baz']
 
 
-        self.assertIs(
-            self.t.show('10', '-A', 'proj'),
-            self.t.ok.return_value,
-        )
+        assert self.t.show('10', '-A', 'proj') is self.t.ok.return_value
         self.t.ok.assert_called_with(
         self.t.ok.assert_called_with(
             '\n'.join(' '.join(node.argv_with_executable) for node in nodes))
             '\n'.join(' '.join(node.argv_with_executable) for node in nodes))
 
 
@@ -169,7 +153,7 @@ class test_MultiTool(AppCase):
         node1.expander.return_value = 'A'
         node1.expander.return_value = 'A'
         node2.expander.return_value = 'B'
         node2.expander.return_value = 'B'
         nodes = self.t.cluster_from_argv.return_value = [node1, node2]
         nodes = self.t.cluster_from_argv.return_value = [node1, node2]
-        self.assertIs(self.t.expand('%p', '10'), self.t.ok.return_value)
+        assert self.t.expand('%p', '10') is self.t.ok.return_value
         self.t.cluster_from_argv.assert_called_with(('10',))
         self.t.cluster_from_argv.assert_called_with(('10',))
         for node in nodes:
         for node in nodes:
             node.expander.assert_called_with('%p')
             node.expander.assert_called_with('%p')
@@ -196,26 +180,23 @@ class test_MultiTool(AppCase):
     def test_Cluster(self):
     def test_Cluster(self):
         m = MultiTool()
         m = MultiTool()
         c = m.cluster_from_argv(['A', 'B', 'C'])
         c = m.cluster_from_argv(['A', 'B', 'C'])
-        self.assertIs(c.env, m.env)
-        self.assertEqual(c.cmd, 'celery worker')
-        self.assertEqual(c.on_stopping_preamble, m.on_stopping_preamble)
-        self.assertEqual(c.on_send_signal, m.on_send_signal)
-        self.assertEqual(c.on_still_waiting_for, m.on_still_waiting_for)
-        self.assertEqual(
-            c.on_still_waiting_progress,
-            m.on_still_waiting_progress,
-        )
-        self.assertEqual(c.on_still_waiting_end, m.on_still_waiting_end)
-        self.assertEqual(c.on_node_start, m.on_node_start)
-        self.assertEqual(c.on_node_restart, m.on_node_restart)
-        self.assertEqual(c.on_node_shutdown_ok, m.on_node_shutdown_ok)
-        self.assertEqual(c.on_node_status, m.on_node_status)
-        self.assertEqual(c.on_node_signal_dead, m.on_node_signal_dead)
-        self.assertEqual(c.on_node_signal, m.on_node_signal)
-        self.assertEqual(c.on_node_down, m.on_node_down)
-        self.assertEqual(c.on_child_spawn, m.on_child_spawn)
-        self.assertEqual(c.on_child_signalled, m.on_child_signalled)
-        self.assertEqual(c.on_child_failure, m.on_child_failure)
+        assert c.env is m.env
+        assert c.cmd == 'celery worker'
+        assert c.on_stopping_preamble == m.on_stopping_preamble
+        assert c.on_send_signal == m.on_send_signal
+        assert c.on_still_waiting_for == m.on_still_waiting_for
+        assert c.on_still_waiting_progress == m.on_still_waiting_progress
+        assert c.on_still_waiting_end == m.on_still_waiting_end
+        assert c.on_node_start == m.on_node_start
+        assert c.on_node_restart == m.on_node_restart
+        assert c.on_node_shutdown_ok == m.on_node_shutdown_ok
+        assert c.on_node_status == m.on_node_status
+        assert c.on_node_signal_dead == m.on_node_signal_dead
+        assert c.on_node_signal == m.on_node_signal
+        assert c.on_node_down == m.on_node_down
+        assert c.on_child_spawn == m.on_child_spawn
+        assert c.on_child_signalled == m.on_child_signalled
+        assert c.on_child_failure == m.on_child_failure
 
 
     def test_on_stopping_preamble(self):
     def test_on_stopping_preamble(self):
         self.t.on_stopping_preamble([])
         self.t.on_stopping_preamble([])
@@ -271,12 +252,12 @@ class test_MultiTool(AppCase):
         self.t.on_child_failure(Mock(), Mock())
         self.t.on_child_failure(Mock(), Mock())
 
 
     def test_constant_strings(self):
     def test_constant_strings(self):
-        self.assertTrue(self.t.OK)
-        self.assertTrue(self.t.DOWN)
-        self.assertTrue(self.t.FAILED)
+        assert self.t.OK
+        assert self.t.DOWN
+        assert self.t.FAILED
 
 
 
 
-class test_MultiTool_functional(AppCase):
+class test_MultiTool_functional:
 
 
     def setup(self):
     def setup(self):
         self.fh = WhateverIO()
         self.fh = WhateverIO()
@@ -285,12 +266,12 @@ class test_MultiTool_functional(AppCase):
 
 
     def test_note(self):
     def test_note(self):
         self.t.note('hello world')
         self.t.note('hello world')
-        self.assertEqual(self.fh.getvalue(), 'hello world\n')
+        assert self.fh.getvalue() == 'hello world\n'
 
 
     def test_note_quiet(self):
     def test_note_quiet(self):
         self.t.quiet = True
         self.t.quiet = True
         self.t.note('hello world')
         self.t.note('hello world')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
 
     def test_carp(self):
     def test_carp(self):
         self.t.say = Mock()
         self.t.say = Mock()
@@ -300,61 +281,59 @@ class test_MultiTool_functional(AppCase):
     def test_info(self):
     def test_info(self):
         self.t.verbose = True
         self.t.verbose = True
         self.t.info('hello info')
         self.t.info('hello info')
-        self.assertEqual(self.fh.getvalue(), 'hello info\n')
+        assert self.fh.getvalue() == 'hello info\n'
 
 
     def test_info_not_verbose(self):
     def test_info_not_verbose(self):
         self.t.verbose = False
         self.t.verbose = False
         self.t.info('hello info')
         self.t.info('hello info')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
 
     def test_error(self):
     def test_error(self):
         self.t.carp = Mock()
         self.t.carp = Mock()
         self.t.usage = Mock()
         self.t.usage = Mock()
-        self.assertEqual(self.t.error('foo'), 1)
+        assert self.t.error('foo') == 1
         self.t.carp.assert_called_with('foo')
         self.t.carp.assert_called_with('foo')
         self.t.usage.assert_called_with()
         self.t.usage.assert_called_with()
 
 
         self.t.carp = Mock()
         self.t.carp = Mock()
-        self.assertEqual(self.t.error(), 1)
+        assert self.t.error() == 1
         self.t.carp.assert_not_called()
         self.t.carp.assert_not_called()
 
 
     def test_nosplash(self):
     def test_nosplash(self):
         self.t.nosplash = True
         self.t.nosplash = True
         self.t.splash()
         self.t.splash()
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
 
 
     def test_splash(self):
     def test_splash(self):
         self.t.nosplash = False
         self.t.nosplash = False
         self.t.splash()
         self.t.splash()
-        self.assertIn('celery multi', self.fh.getvalue())
+        assert 'celery multi' in self.fh.getvalue()
 
 
     def test_usage(self):
     def test_usage(self):
         self.t.usage()
         self.t.usage()
-        self.assertTrue(self.fh.getvalue())
+        assert self.fh.getvalue()
 
 
     def test_help(self):
     def test_help(self):
         self.t.help([])
         self.t.help([])
-        self.assertIn(doc, self.fh.getvalue())
+        assert doc in self.fh.getvalue()
 
 
     def test_expand(self):
     def test_expand(self):
         self.t.expand('foo%n', 'ask', 'klask', 'dask')
         self.t.expand('foo%n', 'ask', 'klask', 'dask')
-        self.assertEqual(
-            self.fh.getvalue(), 'fooask\nfooklask\nfoodask\n',
-        )
+        assert self.fh.getvalue() == 'fooask\nfooklask\nfoodask\n'
 
 
     @patch('celery.apps.multi.gethostname')
     @patch('celery.apps.multi.gethostname')
     def test_get(self, gethostname):
     def test_get(self, gethostname):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
         self.t.get('xuzzy@e.com', 'foo', 'bar', 'baz')
         self.t.get('xuzzy@e.com', 'foo', 'bar', 'baz')
-        self.assertFalse(self.fh.getvalue())
+        assert not self.fh.getvalue()
         self.t.get('foo@e.com', 'foo', 'bar', 'baz')
         self.t.get('foo@e.com', 'foo', 'bar', 'baz')
-        self.assertTrue(self.fh.getvalue())
+        assert self.fh.getvalue()
 
 
     @patch('celery.apps.multi.gethostname')
     @patch('celery.apps.multi.gethostname')
     def test_names(self, gethostname):
     def test_names(self, gethostname):
         gethostname.return_value = 'e.com'
         gethostname.return_value = 'e.com'
         self.t.names('foo', 'bar', 'baz')
         self.t.names('foo', 'bar', 'baz')
-        self.assertIn('foo@e.com\nbar@e.com\nbaz@e.com', self.fh.getvalue())
+        assert 'foo@e.com\nbar@e.com\nbaz@e.com' in self.fh.getvalue()
 
 
     def test_execute_from_commandline(self):
     def test_execute_from_commandline(self):
         start = self.t.commands['start'] = Mock()
         start = self.t.commands['start'] = Mock()
@@ -379,14 +358,14 @@ class test_MultiTool_functional(AppCase):
             ['multi', 'start', 'foo',
             ['multi', 'start', 'foo',
              '--nosplash', '--quiet', '-q', '--verbose', '--no-color'],
              '--nosplash', '--quiet', '-q', '--verbose', '--no-color'],
         )
         )
-        self.assertTrue(self.t.nosplash)
-        self.assertTrue(self.t.quiet)
-        self.assertTrue(self.t.verbose)
-        self.assertTrue(self.t.no_color)
+        assert self.t.nosplash
+        assert self.t.quiet
+        assert self.t.verbose
+        assert self.t.no_color
 
 
     @patch('celery.bin.multi.MultiTool')
     @patch('celery.bin.multi.MultiTool')
     def test_main(self, MultiTool):
     def test_main(self, MultiTool):
         m = MultiTool.return_value = Mock()
         m = MultiTool.return_value = Mock()
-        with self.assertRaises(SystemExit):
+        with pytest.raises(SystemExit):
             main()
             main()
         m.execute_from_commandline.assert_called_with(sys.argv)
         m.execute_from_commandline.assert_called_with(sys.argv)

+ 235 - 239
celery/tests/bin/test_worker.py → t/unit/bin/test_worker.py

@@ -2,9 +2,11 @@ from __future__ import absolute_import, unicode_literals
 
 
 import logging
 import logging
 import os
 import os
+import pytest
 import sys
 import sys
 
 
 from billiard.process import current_process
 from billiard.process import current_process
+from case import Mock, mock, patch, skip
 from kombu import Exchange, Queue
 from kombu import Exchange, Queue
 
 
 from celery import platforms
 from celery import platforms
@@ -18,13 +20,10 @@ from celery.exceptions import (
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.platforms import EX_FAILURE, EX_OK
 from celery.worker import state
 from celery.worker import state
 
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
 
 
-
-class WorkerAppCase(AppCase):
-
-    def teardown(self):
-        trace.reset_worker_optimizations()
+@pytest.fixture(autouse=True)
+def reset_worker_optimizations(request):
+    request.addfinalizer(trace.reset_worker_optimizations)
 
 
 
 
 class Worker(cd.Worker):
 class Worker(cd.Worker):
@@ -34,34 +33,34 @@ class Worker(cd.Worker):
         self.on_start()
         self.on_start()
 
 
 
 
-class test_Worker(WorkerAppCase):
+class test_Worker:
     Worker = Worker
     Worker = Worker
 
 
-    @mock.stdouts
-    def test_queues_string(self, stdout, stderr):
-        w = self.app.Worker()
-        w.setup_queues('foo,bar,baz')
-        self.assertIn('foo', self.app.amqp.queues)
-
-    @mock.stdouts
-    def test_cpu_count(self, stdout, stderr):
-        with patch('celery.worker.cpu_count') as cpu_count:
-            cpu_count.side_effect = NotImplementedError()
-            w = self.app.Worker(concurrency=None)
-            self.assertEqual(w.concurrency, 2)
-        w = self.app.Worker(concurrency=5)
-        self.assertEqual(w.concurrency, 5)
-
-    @mock.stdouts
-    def test_windows_B_option(self, stdout, stderr):
-        self.app.IS_WINDOWS = True
-        with self.assertRaises(SystemExit):
-            worker(app=self.app).run(beat=True)
+    def test_queues_string(self):
+        with mock.stdouts():
+            w = self.app.Worker()
+            w.setup_queues('foo,bar,baz')
+            assert 'foo' in self.app.amqp.queues
+
+    def test_cpu_count(self):
+        with mock.stdouts():
+            with patch('celery.worker.cpu_count') as cpu_count:
+                cpu_count.side_effect = NotImplementedError()
+                w = self.app.Worker(concurrency=None)
+                assert w.concurrency == 2
+            w = self.app.Worker(concurrency=5)
+            assert w.concurrency == 5
+
+    def test_windows_B_option(self):
+        with mock.stdouts():
+            self.app.IS_WINDOWS = True
+            with pytest.raises(SystemExit):
+                worker(app=self.app).run(beat=True)
 
 
     def test_setup_concurrency_very_early(self):
     def test_setup_concurrency_very_early(self):
         x = worker()
         x = worker()
         x.run = Mock()
         x.run = Mock()
-        with self.assertRaises(ImportError):
+        with pytest.raises(ImportError):
             x.execute_from_commandline(['worker', '-P', 'xyzybox'])
             x.execute_from_commandline(['worker', '-P', 'xyzybox'])
 
 
     def test_run_from_argv_basic(self):
     def test_run_from_argv_basic(self):
@@ -80,15 +79,15 @@ class test_Worker(WorkerAppCase):
         with patch('celery.bin.worker.detached_celeryd') as detached:
         with patch('celery.bin.worker.detached_celeryd') as detached:
             x.maybe_detach([])
             x.maybe_detach([])
             detached.assert_not_called()
             detached.assert_not_called()
-            with self.assertRaises(SystemExit):
+            with pytest.raises(SystemExit):
                 x.maybe_detach(['--detach'])
                 x.maybe_detach(['--detach'])
             detached.assert_called()
             detached.assert_called()
 
 
-    @mock.stdouts
-    def test_invalid_loglevel_gives_error(self, stdout, stderr):
-        x = worker(app=self.app)
-        with self.assertRaises(SystemExit):
-            x.run(loglevel='GRIM_REAPER')
+    def test_invalid_loglevel_gives_error(self):
+        with mock.stdouts():
+            x = worker(app=self.app)
+            with pytest.raises(SystemExit):
+                x.run(loglevel='GRIM_REAPER')
 
 
     def test_no_loglevel(self):
     def test_no_loglevel(self):
         self.app.Worker = Mock()
         self.app.Worker = Mock()
@@ -96,25 +95,24 @@ class test_Worker(WorkerAppCase):
 
 
     def test_tasklist(self):
     def test_tasklist(self):
         worker = self.app.Worker()
         worker = self.app.Worker()
-        self.assertTrue(worker.app.tasks)
-        self.assertTrue(worker.app.finalized)
-        self.assertTrue(worker.tasklist(include_builtins=True))
+        assert worker.app.tasks
+        assert worker.app.finalized
+        assert worker.tasklist(include_builtins=True)
         worker.tasklist(include_builtins=False)
         worker.tasklist(include_builtins=False)
 
 
     def test_extra_info(self):
     def test_extra_info(self):
         worker = self.app.Worker()
         worker = self.app.Worker()
         worker.loglevel = logging.WARNING
         worker.loglevel = logging.WARNING
-        self.assertFalse(worker.extra_info())
+        assert not worker.extra_info()
         worker.loglevel = logging.INFO
         worker.loglevel = logging.INFO
-        self.assertTrue(worker.extra_info())
+        assert worker.extra_info()
 
 
-    @mock.stdouts
-    def test_loglevel_string(self, stdout, stderr):
-        worker = self.Worker(app=self.app, loglevel='INFO')
-        self.assertEqual(worker.loglevel, logging.INFO)
+    def test_loglevel_string(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app, loglevel='INFO')
+            assert worker.loglevel == logging.INFO
 
 
-    @mock.stdouts
-    def test_run_worker(self, stdout, stderr):
+    def test_run_worker(self, patching):
         handlers = {}
         handlers = {}
 
 
         class Signals(platforms.Signals):
         class Signals(platforms.Signals):
@@ -122,158 +120,156 @@ class test_Worker(WorkerAppCase):
             def __setitem__(self, sig, handler):
             def __setitem__(self, sig, handler):
                 handlers[sig] = handler
                 handlers[sig] = handler
 
 
-        p = platforms.signals
-        platforms.signals = Signals()
-        try:
+        patching.setattr('celery.platforms.signals', Signals())
+        with mock.stdouts():
             w = self.Worker(app=self.app)
             w = self.Worker(app=self.app)
             w._isatty = False
             w._isatty = False
             w.on_start()
             w.on_start()
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
             for sig in 'SIGINT', 'SIGHUP', 'SIGTERM':
-                self.assertIn(sig, handlers)
+                assert sig in handlers
 
 
             handlers.clear()
             handlers.clear()
             w = self.Worker(app=self.app)
             w = self.Worker(app=self.app)
             w._isatty = True
             w._isatty = True
             w.on_start()
             w.on_start()
             for sig in 'SIGINT', 'SIGTERM':
             for sig in 'SIGINT', 'SIGTERM':
-                self.assertIn(sig, handlers)
-            self.assertNotIn('SIGHUP', handlers)
-        finally:
-            platforms.signals = p
+                assert sig in handlers
+            assert 'SIGHUP' not in handlers
 
 
-    @mock.stdouts
-    def test_startup_info(self, stdout, stderr):
-        worker = self.Worker(app=self.app)
-        worker.on_start()
-        self.assertTrue(worker.startup_info())
-        worker.loglevel = logging.DEBUG
-        self.assertTrue(worker.startup_info())
-        worker.loglevel = logging.INFO
-        self.assertTrue(worker.startup_info())
-
-        prev_loader = self.app.loader
-        worker = self.Worker(app=self.app, queues='foo,bar,baz,xuzzy,do,re,mi')
-        with patch('celery.apps.worker.qualname') as qualname:
-            qualname.return_value = 'acme.backed_beans.Loader'
-            self.assertTrue(worker.startup_info())
-
-        with patch('celery.apps.worker.qualname') as qualname:
-            qualname.return_value = 'celery.loaders.Loader'
-            self.assertTrue(worker.startup_info())
-
-        from celery.loaders.app import AppLoader
-        self.app.loader = AppLoader(app=self.app)
-        self.assertTrue(worker.startup_info())
+    def test_startup_info(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert worker.startup_info()
+            worker.loglevel = logging.DEBUG
+            assert worker.startup_info()
+            worker.loglevel = logging.INFO
+            assert worker.startup_info()
+
+            prev_loader = self.app.loader
+            worker = self.Worker(
+                app=self.app,
+                queues='foo,bar,baz,xuzzy,do,re,mi',
+            )
+            with patch('celery.apps.worker.qualname') as qualname:
+                qualname.return_value = 'acme.backed_beans.Loader'
+                assert worker.startup_info()
+
+            with patch('celery.apps.worker.qualname') as qualname:
+                qualname.return_value = 'celery.loaders.Loader'
+                assert worker.startup_info()
+
+            from celery.loaders.app import AppLoader
+            self.app.loader = AppLoader(app=self.app)
+            assert worker.startup_info()
+
+            self.app.loader = prev_loader
+            worker.task_events = True
+            assert worker.startup_info()
+
+            # test when there are too few output lines
+            # to draft the ascii art onto
+            prev, cd.ARTLINES = cd.ARTLINES, ['the quick brown fox']
+            try:
+                assert worker.startup_info()
+            finally:
+                cd.ARTLINES = prev
 
 
-        self.app.loader = prev_loader
-        worker.task_events = True
-        self.assertTrue(worker.startup_info())
+    def test_run(self):
+        with mock.stdouts():
+            self.Worker(app=self.app).on_start()
+            self.Worker(app=self.app, purge=True).on_start()
+            worker = self.Worker(app=self.app)
+            worker.on_start()
 
 
-        # test when there are too few output lines
-        # to draft the ascii art onto
-        prev, cd.ARTLINES = cd.ARTLINES, ['the quick brown fox']
-        try:
-            self.assertTrue(worker.startup_info())
-        finally:
-            cd.ARTLINES = prev
-
-    @mock.stdouts
-    def test_run(self, stdout, stderr):
-        self.Worker(app=self.app).on_start()
-        self.Worker(app=self.app, purge=True).on_start()
-        worker = self.Worker(app=self.app)
-        worker.on_start()
-
-    @mock.stdouts
-    def test_purge_messages(self, stdout, stderr):
-        self.Worker(app=self.app).purge_messages()
-
-    @mock.stdouts
-    def test_init_queues(self, stdout, stderr):
-        app = self.app
-        c = app.conf
-        app.amqp.queues = app.amqp.Queues({
-            'celery': {'exchange': 'celery',
-                       'routing_key': 'celery'},
-            'video': {'exchange': 'video',
-                      'routing_key': 'video'},
-        })
-        worker = self.Worker(app=self.app)
-        worker.setup_queues(['video'])
-        self.assertIn('video', app.amqp.queues)
-        self.assertIn('video', app.amqp.queues.consume_from)
-        self.assertIn('celery', app.amqp.queues)
-        self.assertNotIn('celery', app.amqp.queues.consume_from)
-
-        c.task_create_missing_queues = False
-        del(app.amqp.queues)
-        with self.assertRaises(ImproperlyConfigured):
-            self.Worker(app=self.app).setup_queues(['image'])
-        del(app.amqp.queues)
-        c.task_create_missing_queues = True
-        worker = self.Worker(app=self.app)
-        worker.setup_queues(['image'])
-        self.assertIn('image', app.amqp.queues.consume_from)
-        self.assertEqual(
-            Queue('image', Exchange('image'), routing_key='image'),
-            app.amqp.queues['image'],
-        )
+    def test_purge_messages(self):
+        with mock.stdouts():
+            self.Worker(app=self.app).purge_messages()
+
+    def test_init_queues(self):
+        with mock.stdouts():
+            app = self.app
+            c = app.conf
+            app.amqp.queues = app.amqp.Queues({
+                'celery': {
+                    'exchange': 'celery',
+                    'routing_key': 'celery',
+                },
+                'video': {
+                    'exchange': 'video',
+                    'routing_key': 'video',
+                },
+            })
+            worker = self.Worker(app=self.app)
+            worker.setup_queues(['video'])
+            assert 'video' in app.amqp.queues
+            assert 'video' in app.amqp.queues.consume_from
+            assert 'celery' in app.amqp.queues
+            assert 'celery' not in app.amqp.queues.consume_from
+
+            c.task_create_missing_queues = False
+            del(app.amqp.queues)
+            with pytest.raises(ImproperlyConfigured):
+                self.Worker(app=self.app).setup_queues(['image'])
+            del(app.amqp.queues)
+            c.task_create_missing_queues = True
+            worker = self.Worker(app=self.app)
+            worker.setup_queues(['image'])
+            assert 'image' in app.amqp.queues.consume_from
+            assert app.amqp.queues['image'] == Queue(
+                'image', Exchange('image'),
+                routing_key='image',
+            )
 
 
     def test_include_argument(self):
     def test_include_argument(self):
         worker1 = self.Worker(app=self.app, include='os')
         worker1 = self.Worker(app=self.app, include='os')
-        self.assertListEqual(worker1.include, ['os'])
+        assert worker1.include == ['os']
         worker2 = self.Worker(app=self.app,
         worker2 = self.Worker(app=self.app,
                               include='os,sys')
                               include='os,sys')
-        self.assertListEqual(worker2.include, ['os', 'sys'])
+        assert worker2.include == ['os', 'sys']
         self.Worker(app=self.app, include=['os', 'sys'])
         self.Worker(app=self.app, include=['os', 'sys'])
 
 
-    @mock.stdouts
-    def test_unknown_loglevel(self, stdout, stderr):
-        with self.assertRaises(SystemExit):
-            worker(app=self.app).run(loglevel='ALIEN')
-        worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
-        self.assertEqual(worker1.loglevel, 0xFFFF)
+    def test_unknown_loglevel(self):
+        with mock.stdouts():
+            with pytest.raises(SystemExit):
+                worker(app=self.app).run(loglevel='ALIEN')
+            worker1 = self.Worker(app=self.app, loglevel=0xFFFF)
+            assert worker1.loglevel == 0xFFFF
 
 
     @patch('os._exit')
     @patch('os._exit')
     @skip.if_win32()
     @skip.if_win32()
-    @mock.stdouts
-    def test_warns_if_running_as_privileged_user(self, _exit, stdout, stderr):
-        with patch('os.getuid') as getuid:
+    def test_warns_if_running_as_privileged_user(self, _exit, patching):
+        getuid = patching('os.getuid')
+
+        with mock.stdouts() as (_, stderr):
             getuid.return_value = 0
             getuid.return_value = 0
             self.app.conf.accept_content = ['pickle']
             self.app.conf.accept_content = ['pickle']
             worker = self.Worker(app=self.app)
             worker = self.Worker(app=self.app)
             worker.on_start()
             worker.on_start()
             _exit.assert_called_with(1)
             _exit.assert_called_with(1)
-            from celery import platforms
-            platforms.C_FORCE_ROOT = True
-            try:
-                with self.assertWarnsRegex(
-                        RuntimeWarning,
-                        r'absolutely not recommended'):
-                    worker = self.Worker(app=self.app)
-                    worker.on_start()
-            finally:
-                platforms.C_FORCE_ROOT = False
+            patching.setattr('celery.platforms.C_FORCE_ROOT', True)
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert 'a very bad idea' in stderr.getvalue()
+            patching.setattr('celery.platforms.C_FORCE_ROOT', False)
             self.app.conf.accept_content = ['json']
             self.app.conf.accept_content = ['json']
-            with self.assertWarnsRegex(
-                    RuntimeWarning,
-                    r'absolutely not recommended'):
-                worker = self.Worker(app=self.app)
-                worker.on_start()
-
-    @mock.stdouts
-    def test_redirect_stdouts(self, stdout, stderr):
-        self.Worker(app=self.app, redirect_stdouts=False)
-        with self.assertRaises(AttributeError):
-            sys.stdout.logger
-
-    @mock.stdouts
-    def test_on_start_custom_logging(self, stdout, stderr):
-        self.app.log.redirect_stdouts = Mock()
-        worker = self.Worker(app=self.app, redirect_stoutds=True)
-        worker._custom_logging = True
-        worker.on_start()
-        self.app.log.redirect_stdouts.assert_not_called()
+            worker = self.Worker(app=self.app)
+            worker.on_start()
+            assert 'superuser' in stderr.getvalue()
+
+    def test_redirect_stdouts(self):
+        with mock.stdouts():
+            self.Worker(app=self.app, redirect_stdouts=False)
+            with pytest.raises(AttributeError):
+                sys.stdout.logger
+
+    def test_on_start_custom_logging(self):
+        with mock.stdouts():
+            self.app.log.redirect_stdouts = Mock()
+            worker = self.Worker(app=self.app, redirect_stoutds=True)
+            worker._custom_logging = True
+            worker.on_start()
+            self.app.log.redirect_stdouts.assert_not_called()
 
 
     def test_setup_logging_no_color(self):
     def test_setup_logging_no_color(self):
         worker = self.Worker(
         worker = self.Worker(
@@ -282,15 +278,15 @@ class test_Worker(WorkerAppCase):
         prev, self.app.log.setup = self.app.log.setup, Mock()
         prev, self.app.log.setup = self.app.log.setup, Mock()
         try:
         try:
             worker.setup_logging()
             worker.setup_logging()
-            self.assertFalse(self.app.log.setup.call_args[1]['colorize'])
+            assert not self.app.log.setup.call_args[1]['colorize']
         finally:
         finally:
             self.app.log.setup = prev
             self.app.log.setup = prev
 
 
-    @mock.stdouts
-    def test_startup_info_pool_is_str(self, stdout, stderr):
-        worker = self.Worker(app=self.app, redirect_stdouts=False)
-        worker.pool_cls = 'foo'
-        worker.startup_info()
+    def test_startup_info_pool_is_str(self):
+        with mock.stdouts():
+            worker = self.Worker(app=self.app, redirect_stdouts=False)
+            worker.pool_cls = 'foo'
+            worker.startup_info()
 
 
     def test_redirect_stdouts_already_handled(self):
     def test_redirect_stdouts_already_handled(self):
         logging_setup = [False]
         logging_setup = [False]
@@ -303,14 +299,13 @@ class test_Worker(WorkerAppCase):
             worker = self.Worker(app=self.app, redirect_stdouts=False)
             worker = self.Worker(app=self.app, redirect_stdouts=False)
             worker.app.log.already_setup = False
             worker.app.log.already_setup = False
             worker.setup_logging()
             worker.setup_logging()
-            self.assertTrue(logging_setup[0])
-            with self.assertRaises(AttributeError):
+            assert logging_setup[0]
+            with pytest.raises(AttributeError):
                 sys.stdout.logger
                 sys.stdout.logger
         finally:
         finally:
             signals.setup_logging.disconnect(on_logging_setup)
             signals.setup_logging.disconnect(on_logging_setup)
 
 
-    @mock.stdouts
-    def test_platform_tweaks_macOS(self, stdout, stderr):
+    def test_platform_tweaks_macOS(self):
 
 
         class macOSWorker(Worker):
         class macOSWorker(Worker):
             proxy_workaround_installed = False
             proxy_workaround_installed = False
@@ -318,27 +313,27 @@ class test_Worker(WorkerAppCase):
             def macOS_proxy_detection_workaround(self):
             def macOS_proxy_detection_workaround(self):
                 self.proxy_workaround_installed = True
                 self.proxy_workaround_installed = True
 
 
-        worker = macOSWorker(app=self.app, redirect_stdouts=False)
+        with mock.stdouts():
+            worker = macOSWorker(app=self.app, redirect_stdouts=False)
 
 
-        def install_HUP_nosupport(controller):
-            controller.hup_not_supported_installed = True
+            def install_HUP_nosupport(controller):
+                controller.hup_not_supported_installed = True
 
 
-        class Controller(object):
-            pass
+            class Controller(object):
+                pass
 
 
-        prev = cd.install_HUP_not_supported_handler
-        cd.install_HUP_not_supported_handler = install_HUP_nosupport
-        try:
-            worker.app.IS_macOS = True
-            controller = Controller()
-            worker.install_platform_tweaks(controller)
-            self.assertTrue(controller.hup_not_supported_installed)
-            self.assertTrue(worker.proxy_workaround_installed)
-        finally:
-            cd.install_HUP_not_supported_handler = prev
+            prev = cd.install_HUP_not_supported_handler
+            cd.install_HUP_not_supported_handler = install_HUP_nosupport
+            try:
+                worker.app.IS_macOS = True
+                controller = Controller()
+                worker.install_platform_tweaks(controller)
+                assert controller.hup_not_supported_installed
+                assert worker.proxy_workaround_installed
+            finally:
+                cd.install_HUP_not_supported_handler = prev
 
 
-    @mock.stdouts
-    def test_general_platform_tweaks(self, stdout, stderr):
+    def test_general_platform_tweaks(self):
 
 
         restart_worker_handler_installed = [False]
         restart_worker_handler_installed = [False]
 
 
@@ -348,33 +343,34 @@ class test_Worker(WorkerAppCase):
         class Controller(object):
         class Controller(object):
             pass
             pass
 
 
-        prev = cd.install_worker_restart_handler
-        cd.install_worker_restart_handler = install_worker_restart_handler
-        try:
-            worker = self.Worker(app=self.app)
-            worker.app.IS_macOS = False
-            worker.install_platform_tweaks(Controller())
-            self.assertTrue(restart_worker_handler_installed[0])
-        finally:
-            cd.install_worker_restart_handler = prev
+        with mock.stdouts():
+            prev = cd.install_worker_restart_handler
+            cd.install_worker_restart_handler = install_worker_restart_handler
+            try:
+                worker = self.Worker(app=self.app)
+                worker.app.IS_macOS = False
+                worker.install_platform_tweaks(Controller())
+                assert restart_worker_handler_installed[0]
+            finally:
+                cd.install_worker_restart_handler = prev
 
 
-    @mock.stdouts
-    def test_on_consumer_ready(self, stdout, stderr):
+    def test_on_consumer_ready(self):
         worker_ready_sent = [False]
         worker_ready_sent = [False]
 
 
         @signals.worker_ready.connect
         @signals.worker_ready.connect
         def on_worker_ready(**kwargs):
         def on_worker_ready(**kwargs):
             worker_ready_sent[0] = True
             worker_ready_sent[0] = True
 
 
-        self.Worker(app=self.app).on_consumer_ready(object())
-        self.assertTrue(worker_ready_sent[0])
+        with mock.stdouts():
+            self.Worker(app=self.app).on_consumer_ready(object())
+            assert worker_ready_sent[0]
 
 
 
 
 @mock.stdouts
 @mock.stdouts
-class test_funs(WorkerAppCase):
+class test_funs:
 
 
     def test_active_thread_count(self):
     def test_active_thread_count(self):
-        self.assertTrue(cd.active_thread_count())
+        assert cd.active_thread_count()
 
 
     @skip.unless_module('setproctitle')
     @skip.unless_module('setproctitle')
     def test_set_process_status(self):
     def test_set_process_status(self):
@@ -382,16 +378,16 @@ class test_funs(WorkerAppCase):
         prev1, sys.argv = sys.argv, ['Arg0']
         prev1, sys.argv = sys.argv, ['Arg0']
         try:
         try:
             st = worker.set_process_status('Running')
             st = worker.set_process_status('Running')
-            self.assertIn('celeryd', st)
-            self.assertIn('xyzza', st)
-            self.assertIn('Running', st)
+            assert 'celeryd' in st
+            assert 'xyzza' in st
+            assert 'Running' in st
             prev2, sys.argv = sys.argv, ['Arg0', 'Arg1']
             prev2, sys.argv = sys.argv, ['Arg0', 'Arg1']
             try:
             try:
                 st = worker.set_process_status('Running')
                 st = worker.set_process_status('Running')
-                self.assertIn('celeryd', st)
-                self.assertIn('xyzza', st)
-                self.assertIn('Running', st)
-                self.assertIn('Arg1', st)
+                assert 'celeryd' in st
+                assert 'xyzza' in st
+                assert 'Running' in st
+                assert 'Arg1' in st
             finally:
             finally:
                 sys.argv = prev2
                 sys.argv = prev2
         finally:
         finally:
@@ -402,8 +398,8 @@ class test_funs(WorkerAppCase):
         cmd.app = self.app
         cmd.app = self.app
         opts, args = cmd.parse_options('worker', ['--concurrency=512',
         opts, args = cmd.parse_options('worker', ['--concurrency=512',
                                        '--heartbeat-interval=10'])
                                        '--heartbeat-interval=10'])
-        self.assertEqual(opts.concurrency, 512)
-        self.assertEqual(opts.heartbeat_interval, 10)
+        assert opts.concurrency == 512
+        assert opts.heartbeat_interval == 10
 
 
     def test_main(self):
     def test_main(self):
         p, cd.Worker = cd.Worker, Worker
         p, cd.Worker = cd.Worker, Worker
@@ -416,7 +412,7 @@ class test_funs(WorkerAppCase):
 
 
 
 
 @mock.stdouts
 @mock.stdouts
-class test_signal_handlers(WorkerAppCase):
+class test_signal_handlers:
 
 
     class _Worker(object):
     class _Worker(object):
         stopped = False
         stopped = False
@@ -459,16 +455,16 @@ class test_signal_handlers(WorkerAppCase):
             p, platforms.signals = platforms.signals, Signals()
             p, platforms.signals = platforms.signals, Signals()
             try:
             try:
                 handlers['SIGINT']('SIGINT', object())
                 handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_stop)
-                self.assertEqual(state.should_stop, EX_FAILURE)
+                assert state.should_stop
+                assert state.should_stop == EX_FAILURE
             finally:
             finally:
                 platforms.signals = p
                 platforms.signals = p
                 state.should_stop = None
                 state.should_stop = None
 
 
             try:
             try:
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_terminate)
-                self.assertEqual(state.should_terminate, EX_FAILURE)
+                assert state.should_terminate
+                assert state.should_terminate == EX_FAILURE
             finally:
             finally:
                 state.should_terminate = None
                 state.should_terminate = None
 
 
@@ -476,12 +472,12 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             c.return_value = 1
             p, platforms.signals = platforms.signals, Signals()
             p, platforms.signals = platforms.signals, Signals()
             try:
             try:
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
                     handlers['SIGINT']('SIGINT', object())
             finally:
             finally:
                 platforms.signals = p
                 platforms.signals = p
 
 
-            with self.assertRaises(WorkerTerminate):
+            with pytest.raises(WorkerTerminate):
                 next_handlers['SIGINT']('SIGINT', object())
                 next_handlers['SIGINT']('SIGINT', object())
 
 
     @skip.unless_module('multiprocessing')
     @skip.unless_module('multiprocessing')
@@ -494,7 +490,7 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_int_handler, worker)
                 handlers = self.psig(cd.install_worker_int_handler, worker)
                 handlers['SIGINT']('SIGINT', object())
                 handlers['SIGINT']('SIGINT', object())
-                self.assertTrue(state.should_stop)
+                assert state.should_stop
             finally:
             finally:
                 process.name = name
                 process.name = name
                 state.should_stop = None
                 state.should_stop = None
@@ -504,7 +500,7 @@ class test_signal_handlers(WorkerAppCase):
             try:
             try:
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_int_handler, worker)
                 handlers = self.psig(cd.install_worker_int_handler, worker)
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGINT']('SIGINT', object())
                     handlers['SIGINT']('SIGINT', object())
             finally:
             finally:
                 process.name = name
                 process.name = name
@@ -527,7 +523,7 @@ class test_signal_handlers(WorkerAppCase):
                     cd.install_worker_term_hard_handler, worker)
                     cd.install_worker_term_hard_handler, worker)
                 try:
                 try:
                     handlers['SIGQUIT']('SIGQUIT', object())
                     handlers['SIGQUIT']('SIGQUIT', object())
-                    self.assertTrue(state.should_terminate)
+                    assert state.should_terminate
                 finally:
                 finally:
                     state.should_terminate = None
                     state.should_terminate = None
             with patch('celery.apps.worker.active_thread_count') as c:
             with patch('celery.apps.worker.active_thread_count') as c:
@@ -536,7 +532,7 @@ class test_signal_handlers(WorkerAppCase):
                 handlers = self.psig(
                 handlers = self.psig(
                     cd.install_worker_term_hard_handler, worker)
                     cd.install_worker_term_hard_handler, worker)
                 try:
                 try:
-                    with self.assertRaises(WorkerTerminate):
+                    with pytest.raises(WorkerTerminate):
                         handlers['SIGQUIT']('SIGQUIT', object())
                         handlers['SIGQUIT']('SIGQUIT', object())
                 finally:
                 finally:
                     state.should_terminate = None
                     state.should_terminate = None
@@ -550,7 +546,7 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_term_handler, worker)
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
             try:
                 handlers['SIGTERM']('SIGTERM', object())
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertEqual(state.should_stop, EX_OK)
+                assert state.should_stop == EX_OK
             finally:
             finally:
                 state.should_stop = None
                 state.should_stop = None
 
 
@@ -560,7 +556,7 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_handler, worker)
             handlers = self.psig(cd.install_worker_term_handler, worker)
             try:
             try:
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
             finally:
             finally:
                 state.should_stop = None
                 state.should_stop = None
@@ -570,7 +566,7 @@ class test_signal_handlers(WorkerAppCase):
     @skip.if_jython()
     @skip.if_jython()
     def test_worker_cry_handler(self, stderr):
     def test_worker_cry_handler(self, stderr):
         handlers = self.psig(cd.install_cry_handler)
         handlers = self.psig(cd.install_cry_handler)
-        self.assertIsNone(handlers['SIGUSR1']('SIGUSR1', object()))
+        assert handlers['SIGUSR1']('SIGUSR1', object()) is None
         stderr.write.assert_called()
         stderr.write.assert_called()
 
 
     @skip.unless_module('multiprocessing')
     @skip.unless_module('multiprocessing')
@@ -583,12 +579,12 @@ class test_signal_handlers(WorkerAppCase):
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers['SIGTERM']('SIGTERM', object())
                 handlers['SIGTERM']('SIGTERM', object())
-                self.assertEqual(state.should_stop, EX_OK)
+                assert state.should_stop == EX_OK
             with patch('celery.apps.worker.active_thread_count') as c:
             with patch('celery.apps.worker.active_thread_count') as c:
                 c.return_value = 1
                 c.return_value = 1
                 worker = self._Worker()
                 worker = self._Worker()
                 handlers = self.psig(cd.install_worker_term_handler, worker)
                 handlers = self.psig(cd.install_worker_term_handler, worker)
-                with self.assertRaises(WorkerShutdown):
+                with pytest.raises(WorkerShutdown):
                     handlers['SIGTERM']('SIGTERM', object())
                     handlers['SIGTERM']('SIGTERM', object())
         finally:
         finally:
             process.name = name
             process.name = name
@@ -609,11 +605,11 @@ class test_signal_handlers(WorkerAppCase):
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers = self.psig(cd.install_worker_restart_handler, worker)
             handlers['SIGHUP']('SIGHUP', object())
             handlers['SIGHUP']('SIGHUP', object())
-            self.assertEqual(state.should_stop, EX_OK)
+            assert state.should_stop == EX_OK
             register.assert_called()
             register.assert_called()
             callback = register.call_args[0][0]
             callback = register.call_args[0][0]
             callback()
             callback()
-            self.assertTrue(argv)
+            assert argv
         finally:
         finally:
             os.execv = execv
             os.execv = execv
             state.should_stop = None
             state.should_stop = None
@@ -625,7 +621,7 @@ class test_signal_handlers(WorkerAppCase):
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
             try:
             try:
                 handlers['SIGQUIT']('SIGQUIT', object())
                 handlers['SIGQUIT']('SIGQUIT', object())
-                self.assertTrue(state.should_terminate)
+                assert state.should_terminate
             finally:
             finally:
                 state.should_terminate = None
                 state.should_terminate = None
 
 
@@ -634,5 +630,5 @@ class test_signal_handlers(WorkerAppCase):
             c.return_value = 1
             c.return_value = 1
             worker = self._Worker()
             worker = self._Worker()
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
             handlers = self.psig(cd.install_worker_term_hard_handler, worker)
-            with self.assertRaises(WorkerTerminate):
+            with pytest.raises(WorkerTerminate):
                 handlers['SIGQUIT']('SIGQUIT', object())
                 handlers['SIGQUIT']('SIGQUIT', object())

+ 0 - 0
celery/tests/contrib/__init__.py → t/unit/compat_modules/__init__.py


+ 16 - 14
celery/tests/compat_modules/test_compat.py → t/unit/compat_modules/test_compat.py

@@ -1,20 +1,22 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from datetime import timedelta
 from datetime import timedelta
 
 
+from celery.five import bytes_if_py2
 from celery.schedules import schedule
 from celery.schedules import schedule
 from celery.task import (
 from celery.task import (
     periodic_task,
     periodic_task,
     PeriodicTask
     PeriodicTask
 )
 )
 
 
-from celery.tests.case import AppCase, depends_on_current_app  # noqa
-
 
 
-@depends_on_current_app
-class test_periodic_tasks(AppCase):
+class test_periodic_tasks:
 
 
     def setup(self):
     def setup(self):
+        self.app.set_current()  # @depends_on_current_app
+
         @periodic_task(app=self.app, shared=False,
         @periodic_task(app=self.app, shared=False,
                        run_every=schedule(timedelta(hours=1), app=self.app))
                        run_every=schedule(timedelta(hours=1), app=self.app))
         def my_periodic():
         def my_periodic():
@@ -25,32 +27,32 @@ class test_periodic_tasks(AppCase):
         return self.app.now()
         return self.app.now()
 
 
     def test_must_have_run_every(self):
     def test_must_have_run_every(self):
-        with self.assertRaises(NotImplementedError):
-            type('Foo', (PeriodicTask,), {'__module__': __name__})
+        with pytest.raises(NotImplementedError):
+            type(bytes_if_py2('Foo'), (PeriodicTask,), {
+                '__module__': __name__,
+            })
 
 
     def test_remaining_estimate(self):
     def test_remaining_estimate(self):
         s = self.my_periodic.run_every
         s = self.my_periodic.run_every
-        self.assertIsInstance(
+        assert isinstance(
             s.remaining_estimate(s.maybe_make_aware(self.now())),
             s.remaining_estimate(s.maybe_make_aware(self.now())),
             timedelta)
             timedelta)
 
 
     def test_is_due_not_due(self):
     def test_is_due_not_due(self):
         due, remaining = self.my_periodic.run_every.is_due(self.now())
         due, remaining = self.my_periodic.run_every.is_due(self.now())
-        self.assertFalse(due)
+        assert not due
         # This assertion may fail if executed in the
         # This assertion may fail if executed in the
         # first minute of an hour, thus 59 instead of 60
         # first minute of an hour, thus 59 instead of 60
-        self.assertGreater(remaining, 59)
+        assert remaining > 59
 
 
     def test_is_due(self):
     def test_is_due(self):
         p = self.my_periodic
         p = self.my_periodic
         due, remaining = p.run_every.is_due(
         due, remaining = p.run_every.is_due(
             self.now() - p.run_every.run_every,
             self.now() - p.run_every.run_every,
         )
         )
-        self.assertTrue(due)
-        self.assertEqual(
-            remaining, p.run_every.run_every.total_seconds(),
-        )
+        assert due
+        assert remaining == p.run_every.run_every.total_seconds()
 
 
     def test_schedule_repr(self):
     def test_schedule_repr(self):
         p = self.my_periodic
         p = self.my_periodic
-        self.assertTrue(repr(p.run_every))
+        assert repr(p.run_every)

+ 13 - 14
celery/tests/compat_modules/test_compat_utils.py → t/unit/compat_modules/test_compat_utils.py

@@ -1,39 +1,38 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import celery
 import celery
+import pytest
 
 
 from celery.app.task import Task as ModernTask
 from celery.app.task import Task as ModernTask
 from celery.task.base import Task as CompatTask
 from celery.task.base import Task as CompatTask
 
 
-from celery.tests.case import AppCase, depends_on_current_app
 
 
-
-@depends_on_current_app
-class test_MagicModule(AppCase):
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_MagicModule:
 
 
     def test_class_property_set_without_type(self):
     def test_class_property_set_without_type(self):
-        self.assertTrue(ModernTask.__dict__['app'].__get__(CompatTask()))
+        assert ModernTask.__dict__['app'].__get__(CompatTask())
 
 
     def test_class_property_set_on_class(self):
     def test_class_property_set_on_class(self):
-        self.assertIs(ModernTask.__dict__['app'].__set__(None, None),
-                      ModernTask.__dict__['app'])
+        assert (ModernTask.__dict__['app'].__set__(None, None) is
+                ModernTask.__dict__['app'])
 
 
-    def test_class_property_set(self):
+    def test_class_property_set(self, app):
 
 
         class X(CompatTask):
         class X(CompatTask):
             pass
             pass
-        ModernTask.__dict__['app'].__set__(X(), self.app)
-        self.assertIs(X.app, self.app)
+        ModernTask.__dict__['app'].__set__(X(), app)
+        assert X.app is app
 
 
     def test_dir(self):
     def test_dir(self):
-        self.assertTrue(dir(celery.messaging))
+        assert dir(celery.messaging)
 
 
     def test_direct(self):
     def test_direct(self):
-        self.assertTrue(celery.task)
+        assert celery.task
 
 
     def test_app_attrs(self):
     def test_app_attrs(self):
-        self.assertEqual(celery.task.control.broadcast,
-                         celery.current_app.control.broadcast)
+        assert (celery.task.control.broadcast ==
+                celery.current_app.control.broadcast)
 
 
     def test_decorators_task(self):
     def test_decorators_task(self):
         @celery.decorators.task
         @celery.decorators.task

+ 37 - 0
t/unit/compat_modules/test_decorators.py

@@ -0,0 +1,37 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+import warnings
+
+from celery.task import base
+
+
+def add(x, y):
+    return x + y
+
+
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_decorators:
+
+    def test_task_alias(self):
+        from celery import task
+        assert task.__file__
+        assert task(add)
+
+    def setup(self):
+        with warnings.catch_warnings(record=True):
+            from celery import decorators
+            self.decorators = decorators
+
+    def assert_compat_decorator(self, decorator, type, **opts):
+        task = decorator(**opts)(add)
+        assert task(8, 8) == 16
+        assert isinstance(task, type)
+
+    def test_task(self):
+        self.assert_compat_decorator(self.decorators.task, base.BaseTask)
+
+    def test_periodic_task(self):
+        self.assert_compat_decorator(
+            self.decorators.periodic_task, base.BaseTask, run_every=1,
+        )

+ 4 - 3
celery/tests/compat_modules/test_messaging.py → t/unit/compat_modules/test_messaging.py

@@ -1,11 +1,12 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from celery import messaging
 from celery import messaging
-from celery.tests.case import AppCase, depends_on_current_app
 
 
 
 
-@depends_on_current_app
-class test_compat_messaging_module(AppCase):
+@pytest.mark.usefixtures('depends_on_current_app')
+class test_compat_messaging_module:
 
 
     def test_get_consume_set(self):
     def test_get_consume_set(self):
         conn = messaging.establish_connection()
         conn = messaging.establish_connection()

+ 0 - 0
celery/tests/events/__init__.py → t/unit/concurrency/__init__.py


+ 32 - 33
celery/tests/concurrency/test_concurrency.py → t/unit/concurrency/test_concurrency.py

@@ -1,15 +1,17 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 import os
 import os
+import pytest
 
 
 from itertools import count
 from itertools import count
 
 
+from case import Mock, patch
+
 from celery.concurrency.base import apply_target, BasePool
 from celery.concurrency.base import apply_target, BasePool
 from celery.exceptions import WorkerShutdown, WorkerTerminate
 from celery.exceptions import WorkerShutdown, WorkerTerminate
-from celery.tests.case import AppCase, Mock, patch
 
 
 
 
-class test_BasePool(AppCase):
+class test_BasePool:
 
 
     def test_apply_target(self):
     def test_apply_target(self):
 
 
@@ -29,14 +31,12 @@ class test_BasePool(AppCase):
                      callback=gen_callback('callback'),
                      callback=gen_callback('callback'),
                      accept_callback=gen_callback('accept_callback'))
                      accept_callback=gen_callback('accept_callback'))
 
 
-        self.assertDictContainsSubset(
-            {'target': (1, (8, 16)), 'callback': (2, (42,))},
-            scratch,
-        )
+        assert scratch['target'] == (1, (8, 16))
+        assert scratch['callback'] == (2, (42,))
         pa1 = scratch['accept_callback']
         pa1 = scratch['accept_callback']
-        self.assertEqual(0, pa1[0])
-        self.assertEqual(pa1[1][0], os.getpid())
-        self.assertTrue(pa1[1][1])
+        assert pa1[0] == 0
+        assert pa1[1][0] == os.getpid()
+        assert pa1[1][1]
 
 
         # No accept callback
         # No accept callback
         scratch.clear()
         scratch.clear()
@@ -44,32 +44,33 @@ class test_BasePool(AppCase):
                      args=(8, 16),
                      args=(8, 16),
                      callback=gen_callback('callback'),
                      callback=gen_callback('callback'),
                      accept_callback=None)
                      accept_callback=None)
-        self.assertDictEqual(scratch,
-                             {'target': (3, (8, 16)),
-                              'callback': (4, (42,))})
+        assert scratch == {
+            'target': (3, (8, 16)),
+            'callback': (4, (42,)),
+        }
 
 
     def test_apply_target__propagate(self):
     def test_apply_target__propagate(self):
         target = Mock(name='target')
         target = Mock(name='target')
         target.side_effect = KeyError()
         target.side_effect = KeyError()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target, propagate=(KeyError,))
             apply_target(target, propagate=(KeyError,))
 
 
     def test_apply_target__raises(self):
     def test_apply_target__raises(self):
         target = Mock(name='target')
         target = Mock(name='target')
         target.side_effect = KeyError()
         target.side_effect = KeyError()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target)
             apply_target(target)
 
 
     def test_apply_target__raises_WorkerShutdown(self):
     def test_apply_target__raises_WorkerShutdown(self):
         target = Mock(name='target')
         target = Mock(name='target')
         target.side_effect = WorkerShutdown()
         target.side_effect = WorkerShutdown()
-        with self.assertRaises(WorkerShutdown):
+        with pytest.raises(WorkerShutdown):
             apply_target(target)
             apply_target(target)
 
 
     def test_apply_target__raises_WorkerTerminate(self):
     def test_apply_target__raises_WorkerTerminate(self):
         target = Mock(name='target')
         target = Mock(name='target')
         target.side_effect = WorkerTerminate()
         target.side_effect = WorkerTerminate()
-        with self.assertRaises(WorkerTerminate):
+        with pytest.raises(WorkerTerminate):
             apply_target(target)
             apply_target(target)
 
 
     def test_apply_target__raises_BaseException(self):
     def test_apply_target__raises_BaseException(self):
@@ -85,7 +86,7 @@ class test_BasePool(AppCase):
         callback = Mock(name='callback')
         callback = Mock(name='callback')
         reraise.side_effect = KeyError()
         reraise.side_effect = KeyError()
         target.side_effect = BaseException()
         target.side_effect = BaseException()
-        with self.assertRaises(KeyError):
+        with pytest.raises(KeyError):
             apply_target(target, callback=callback)
             apply_target(target, callback=callback)
         callback.assert_not_called()
         callback.assert_not_called()
 
 
@@ -95,7 +96,7 @@ class test_BasePool(AppCase):
         x.apply_async(object)
         x.apply_async(object)
 
 
     def test_num_processes(self):
     def test_num_processes(self):
-        self.assertEqual(BasePool(7).num_processes, 7)
+        assert BasePool(7).num_processes == 7
 
 
     def test_interface_on_start(self):
     def test_interface_on_start(self):
         BasePool(10).on_start()
         BasePool(10).on_start()
@@ -107,22 +108,22 @@ class test_BasePool(AppCase):
         BasePool(10).on_apply()
         BasePool(10).on_apply()
 
 
     def test_interface_info(self):
     def test_interface_info(self):
-        self.assertDictEqual(BasePool(10).info, {
+        assert BasePool(10).info == {
             'max-concurrency': 10,
             'max-concurrency': 10,
-        })
+        }
 
 
     def test_interface_flush(self):
     def test_interface_flush(self):
-        self.assertIsNone(BasePool(10).flush())
+        assert BasePool(10).flush() is None
 
 
     def test_active(self):
     def test_active(self):
         p = BasePool(10)
         p = BasePool(10)
-        self.assertFalse(p.active)
+        assert not p.active
         p._state = p.RUN
         p._state = p.RUN
-        self.assertTrue(p.active)
+        assert p.active
 
 
     def test_restart(self):
     def test_restart(self):
         p = BasePool(10)
         p = BasePool(10)
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             p.restart()
             p.restart()
 
 
     def test_interface_on_terminate(self):
     def test_interface_on_terminate(self):
@@ -130,29 +131,27 @@ class test_BasePool(AppCase):
         p.on_terminate()
         p.on_terminate()
 
 
     def test_interface_terminate_job(self):
     def test_interface_terminate_job(self):
-        with self.assertRaises(NotImplementedError):
+        with pytest.raises(NotImplementedError):
             BasePool(10).terminate_job(101)
             BasePool(10).terminate_job(101)
 
 
     def test_interface_did_start_ok(self):
     def test_interface_did_start_ok(self):
-        self.assertTrue(BasePool(10).did_start_ok())
+        assert BasePool(10).did_start_ok()
 
 
     def test_interface_register_with_event_loop(self):
     def test_interface_register_with_event_loop(self):
-        self.assertIsNone(
-            BasePool(10).register_with_event_loop(Mock()),
-        )
+        assert BasePool(10).register_with_event_loop(Mock()) is None
 
 
     def test_interface_on_soft_timeout(self):
     def test_interface_on_soft_timeout(self):
-        self.assertIsNone(BasePool(10).on_soft_timeout(Mock()))
+        assert BasePool(10).on_soft_timeout(Mock()) is None
 
 
     def test_interface_on_hard_timeout(self):
     def test_interface_on_hard_timeout(self):
-        self.assertIsNone(BasePool(10).on_hard_timeout(Mock()))
+        assert BasePool(10).on_hard_timeout(Mock()) is None
 
 
     def test_interface_close(self):
     def test_interface_close(self):
         p = BasePool(10)
         p = BasePool(10)
         p.on_close = Mock()
         p.on_close = Mock()
         p.close()
         p.close()
-        self.assertEqual(p._state, p.CLOSE)
+        assert p._state == p.CLOSE
         p.on_close.assert_called_with()
         p.on_close.assert_called_with()
 
 
     def test_interface_no_close(self):
     def test_interface_no_close(self):
-        self.assertIsNone(BasePool(10).on_close())
+        assert BasePool(10).on_close() is None

+ 36 - 38
celery/tests/concurrency/test_eventlet.py → t/unit/concurrency/test_eventlet.py

@@ -1,25 +1,34 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
-import os
+import pytest
 import sys
 import sys
 
 
+from case import Mock, patch, skip
+
 from celery.concurrency.eventlet import (
 from celery.concurrency.eventlet import (
     apply_target,
     apply_target,
     Timer,
     Timer,
     TaskPool,
     TaskPool,
 )
 )
 
 
-from celery.tests.case import AppCase, Mock, patch, skip
+eventlet_modules = (
+    'eventlet',
+    'eventlet.debug',
+    'eventlet.greenthread',
+    'eventlet.greenpool',
+    'greenlet',
+)
 
 
 
 
 @skip.if_pypy()
 @skip.if_pypy()
-class EventletCase(AppCase):
+class EventletCase:
 
 
     def setup(self):
     def setup(self):
-        self.mock_modules(*eventlet_modules)
+        self.patching.modules(*eventlet_modules)
 
 
     def teardown(self):
     def teardown(self):
-        for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]:
+        for mod in [mod for mod in sys.modules
+                    if mod.startswith('eventlet')]:
             try:
             try:
                 del(sys.modules[mod])
                 del(sys.modules[mod])
             except KeyError:
             except KeyError:
@@ -36,45 +45,34 @@ class test_aaa_eventlet_patch(EventletCase):
 
 
     @patch('eventlet.debug.hub_blocking_detection', create=True)
     @patch('eventlet.debug.hub_blocking_detection', create=True)
     @patch('eventlet.monkey_patch', create=True)
     @patch('eventlet.monkey_patch', create=True)
-    def test_aaa_blockdetecet(self, monkey_patch, hub_blocking_detection):
-        os.environ['EVENTLET_NOBLOCK'] = '10.3'
-        try:
-            from celery import maybe_patch_concurrency
-            maybe_patch_concurrency(['x', '-P', 'eventlet'])
-            monkey_patch.assert_called_with()
-            hub_blocking_detection.assert_called_with(10.3, 10.3)
-        finally:
-            os.environ.pop('EVENTLET_NOBLOCK', None)
-
-
-eventlet_modules = (
-    'eventlet',
-    'eventlet.debug',
-    'eventlet.greenthread',
-    'eventlet.greenpool',
-    'greenlet',
-)
+    def test_aaa_blockdetecet(
+            self, monkey_patch, hub_blocking_detection, patching):
+        patching.setenv('EVENTLET_NOBLOCK', '10.3')
+        from celery import maybe_patch_concurrency
+        maybe_patch_concurrency(['x', '-P', 'eventlet'])
+        monkey_patch.assert_called_with()
+        hub_blocking_detection.assert_called_with(10.3, 10.3)
 
 
 
 
 class test_Timer(EventletCase):
 class test_Timer(EventletCase):
 
 
-    def setup(self):
-        EventletCase.setup(self)
-        self.spawn_after = self.patch('eventlet.greenthread.spawn_after')
-        self.GreenletExit = self.patch('greenlet.GreenletExit')
+    @pytest.fixture(autouse=True)
+    def setup_patches(self, patching):
+        self.spawn_after = patching('eventlet.greenthread.spawn_after')
+        self.GreenletExit = patching('greenlet.GreenletExit')
 
 
     def test_sched(self):
     def test_sched(self):
         x = Timer()
         x = Timer()
         x.GreenletExit = KeyError
         x.GreenletExit = KeyError
         entry = Mock()
         entry = Mock()
         g = x._enter(1, 0, entry)
         g = x._enter(1, 0, entry)
-        self.assertTrue(x.queue)
+        assert x.queue
 
 
         x._entry_exit(g, entry)
         x._entry_exit(g, entry)
         g.wait.side_effect = KeyError()
         g.wait.side_effect = KeyError()
         x._entry_exit(g, entry)
         x._entry_exit(g, entry)
         entry.cancel.assert_called_with()
         entry.cancel.assert_called_with()
-        self.assertFalse(x._queue)
+        assert not x._queue
 
 
         x._queue.add(g)
         x._queue.add(g)
         x.clear()
         x.clear()
@@ -94,10 +92,10 @@ class test_Timer(EventletCase):
 
 
 class test_TaskPool(EventletCase):
 class test_TaskPool(EventletCase):
 
 
-    def setup(self):
-        EventletCase.setup(self)
-        self.GreenPool = self.patch('eventlet.greenpool.GreenPool')
-        self.greenthread = self.patch('eventlet.greenthread')
+    @pytest.fixture(autouse=True)
+    def setup_patches(self, patching):
+        self.GreenPool = patching('eventlet.greenpool.GreenPool')
+        self.greenthread = patching('eventlet.greenthread')
 
 
     def test_pool(self):
     def test_pool(self):
         x = TaskPool()
         x = TaskPool()
@@ -106,7 +104,7 @@ class test_TaskPool(EventletCase):
         x.on_apply(Mock())
         x.on_apply(Mock())
         x._pool = None
         x._pool = None
         x.on_stop()
         x.on_stop()
-        self.assertTrue(x.getpid())
+        assert x.getpid()
 
 
     @patch('celery.concurrency.eventlet.base')
     @patch('celery.concurrency.eventlet.base')
     def test_apply_target(self, base):
     def test_apply_target(self, base):
@@ -117,21 +115,21 @@ class test_TaskPool(EventletCase):
         x = TaskPool(10)
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
         x._pool = Mock(name='_pool')
         x.grow(2)
         x.grow(2)
-        self.assertEqual(x.limit, 12)
+        assert x.limit == 12
         x._pool.resize.assert_called_with(12)
         x._pool.resize.assert_called_with(12)
 
 
     def test_shrink(self):
     def test_shrink(self):
         x = TaskPool(10)
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
         x._pool = Mock(name='_pool')
         x.shrink(2)
         x.shrink(2)
-        self.assertEqual(x.limit, 8)
+        assert x.limit == 8
         x._pool.resize.assert_called_with(8)
         x._pool.resize.assert_called_with(8)
 
 
     def test_get_info(self):
     def test_get_info(self):
         x = TaskPool(10)
         x = TaskPool(10)
         x._pool = Mock(name='_pool')
         x._pool = Mock(name='_pool')
-        self.assertDictEqual(x._get_info(), {
+        assert x._get_info() == {
             'max-concurrency': 10,
             'max-concurrency': 10,
             'free-threads': x._pool.free(),
             'free-threads': x._pool.free(),
             'running-threads': x._pool.running(),
             'running-threads': x._pool.running(),
-        })
+        }

+ 128 - 0
t/unit/concurrency/test_gevent.py

@@ -0,0 +1,128 @@
+from __future__ import absolute_import, unicode_literals
+
+from case import Mock, skip
+
+from celery.concurrency.gevent import (
+    Timer,
+    TaskPool,
+    apply_timeout,
+)
+
+gevent_modules = (
+    'gevent',
+    'gevent.monkey',
+    'gevent.greenlet',
+    'gevent.pool',
+    'greenlet',
+)
+
+
+@skip.if_pypy()
+class test_gevent_patch:
+
+    def test_is_patched(self):
+        self.patching.modules(*gevent_modules)
+        patch_all = self.patching('gevent.monkey.patch_all')
+        import gevent
+        gevent.version_info = (1, 0, 0)
+        from celery import maybe_patch_concurrency
+        maybe_patch_concurrency(['x', '-P', 'gevent'])
+        patch_all.assert_called()
+
+
+@skip.if_pypy()
+class test_Timer:
+
+    def setup(self):
+        self.patching.modules(*gevent_modules)
+        self.greenlet = self.patching('gevent.greenlet')
+        self.GreenletExit = self.patching('gevent.greenlet.GreenletExit')
+
+    def test_sched(self):
+        self.greenlet.Greenlet = object
+        x = Timer()
+        self.greenlet.Greenlet = Mock()
+        x._Greenlet.spawn_later = Mock()
+        x._GreenletExit = KeyError
+        entry = Mock()
+        g = x._enter(1, 0, entry)
+        assert x.queue
+
+        x._entry_exit(g)
+        g.kill.assert_called_with()
+        assert not x._queue
+
+        x._queue.add(g)
+        x.clear()
+        x._queue.add(g)
+        g.kill.side_effect = KeyError()
+        x.clear()
+
+        g = x._Greenlet()
+        g.cancel()
+
+
+@skip.if_pypy()
+class test_TaskPool:
+
+    def setup(self):
+        self.patching.modules(*gevent_modules)
+        self.spawn_raw = self.patching('gevent.spawn_raw')
+        self.Pool = self.patching('gevent.pool.Pool')
+
+    def test_pool(self):
+        x = TaskPool()
+        x.on_start()
+        x.on_stop()
+        x.on_apply(Mock())
+        x._pool = None
+        x.on_stop()
+
+        x._pool = Mock()
+        x._pool._semaphore.counter = 1
+        x._pool.size = 1
+        x.grow()
+        assert x._pool.size == 2
+        assert x._pool._semaphore.counter == 2
+        x.shrink()
+        assert x._pool.size, 1
+        assert x._pool._semaphore.counter == 1
+
+        x._pool = [4, 5, 6]
+        assert x.num_processes == 3
+
+
+@skip.if_pypy()
+class test_apply_timeout:
+
+    def test_apply_timeout(self):
+        self.patching.modules(*gevent_modules)
+
+        class Timeout(Exception):
+            value = None
+
+            def __init__(self, value):
+                self.__class__.value = value
+
+            def __enter__(self):
+                return self
+
+            def __exit__(self, *exc_info):
+                pass
+        timeout_callback = Mock(name='timeout_callback')
+        apply_target = Mock(name='apply_target')
+        apply_timeout(
+            Mock(), timeout=10, callback=Mock(name='callback'),
+            timeout_callback=timeout_callback,
+            apply_target=apply_target, Timeout=Timeout,
+        )
+        assert Timeout.value == 10
+        apply_target.assert_called()
+
+        apply_target.side_effect = Timeout(10)
+        apply_timeout(
+            Mock(), timeout=10, callback=Mock(),
+            timeout_callback=timeout_callback,
+            apply_target=apply_target, Timeout=Timeout,
+        )
+        timeout_callback.assert_called_with(False, 10)

+ 15 - 20
celery/tests/concurrency/test_pool.py → t/unit/concurrency/test_pool.py

@@ -3,9 +3,9 @@ from __future__ import absolute_import, unicode_literals
 import time
 import time
 import itertools
 import itertools
 
 
-from billiard.einfo import ExceptionInfo
+from case import skip
 
 
-from celery.tests.case import AppCase, skip
+from billiard.einfo import ExceptionInfo
 
 
 
 
 def do_something(i):
 def do_something(i):
@@ -24,7 +24,7 @@ def raise_something(i):
 
 
 
 
 @skip.unless_module('multiprocessing')
 @skip.unless_module('multiprocessing')
-class test_TaskPool(AppCase):
+class test_TaskPool:
 
 
     def setup(self):
     def setup(self):
         from celery.concurrency.prefork import TaskPool
         from celery.concurrency.prefork import TaskPool
@@ -32,8 +32,8 @@ class test_TaskPool(AppCase):
 
 
     def test_attrs(self):
     def test_attrs(self):
         p = self.TaskPool(2)
         p = self.TaskPool(2)
-        self.assertEqual(p.limit, 2)
-        self.assertIsNone(p._pool)
+        assert p.limit == 2
+        assert p._pool is None
 
 
     def x_apply(self):
     def x_apply(self):
         p = self.TaskPool(2)
         p = self.TaskPool(2)
@@ -52,28 +52,23 @@ class test_TaskPool(AppCase):
         res2 = p.apply_async(raise_something, args=[10], errback=myerrback)
         res2 = p.apply_async(raise_something, args=[10], errback=myerrback)
         res3 = p.apply_async(do_something, args=[20], callback=mycallback)
         res3 = p.apply_async(do_something, args=[20], callback=mycallback)
 
 
-        self.assertEqual(res.get(), 100)
+        assert res.get() == 100
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 100},
-                                      scratchpad.get(0))
+        assert scratchpad.get(0)['ret_value'] == 100
 
 
-        self.assertIsInstance(res2.get(), ExceptionInfo)
-        self.assertTrue(scratchpad.get(1))
+        assert isinstance(res2.get(), ExceptionInfo)
+        assert scratchpad.get(1)
         time.sleep(1)
         time.sleep(1)
-        self.assertIsInstance(scratchpad[1]['ret_value'],
-                              ExceptionInfo)
-        self.assertEqual(scratchpad[1]['ret_value'].exception.args,
-                         ('FOO EXCEPTION',))
+        assert isinstance(scratchpad[1]['ret_value'], ExceptionInfo)
+        assert scratchpad[1]['ret_value'].exception.args == ('FOO EXCEPTION',)
 
 
-        self.assertEqual(res3.get(), 400)
+        assert res3.get() == 400
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 400},
-                                      scratchpad.get(2))
+        assert scratchpad.get(2)['ret_value'] == 400
 
 
         res3 = p.apply_async(do_something, args=[30], callback=mycallback)
         res3 = p.apply_async(do_something, args=[30], callback=mycallback)
 
 
-        self.assertEqual(res3.get(), 900)
+        assert res3.get() == 900
         time.sleep(0.5)
         time.sleep(0.5)
-        self.assertDictContainsSubset({'ret_value': 900},
-                                      scratchpad.get(3))
+        assert scratchpad.get(3)['ret_value'] == 900
         p.stop()
         p.stop()

+ 42 - 52
celery/tests/concurrency/test_prefork.py → t/unit/concurrency/test_prefork.py

@@ -2,18 +2,19 @@ from __future__ import absolute_import, unicode_literals
 
 
 import errno
 import errno
 import os
 import os
+import pytest
 import socket
 import socket
 
 
 from itertools import cycle
 from itertools import cycle
 
 
+from case import Mock, mock, patch, skip
+
 from celery.app.defaults import DEFAULTS
 from celery.app.defaults import DEFAULTS
 from celery.five import range
 from celery.five import range
 from celery.utils.collections import AttributeDict
 from celery.utils.collections import AttributeDict
 from celery.utils.functional import noop
 from celery.utils.functional import noop
 from celery.utils.objects import Bunch
 from celery.utils.objects import Bunch
 
 
-from celery.tests.case import AppCase, Mock, mock, patch, skip
-
 try:
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import prefork as mp
     from celery.concurrency import asynpool
     from celery.concurrency import asynpool
@@ -53,7 +54,7 @@ class MockResult(object):
         return self.value
         return self.value
 
 
 
 
-class test_process_initializer(AppCase):
+class test_process_initializer:
 
 
     @patch('celery.platforms.signals')
     @patch('celery.platforms.signals')
     @patch('celery.platforms.set_mp_process_title')
     @patch('celery.platforms.set_mp_process_title')
@@ -78,9 +79,9 @@ class test_process_initializer(AppCase):
                 process_initializer(app, 'awesome.worker.com')
                 process_initializer(app, 'awesome.worker.com')
                 _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
                 _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
                 _signals.reset.assert_any_call(*WORKER_SIGRESET)
                 _signals.reset.assert_any_call(*WORKER_SIGRESET)
-                self.assertTrue(app.loader.init_worker.call_count)
+                assert app.loader.init_worker.call_count
                 on_worker_process_init.assert_called()
                 on_worker_process_init.assert_called()
-                self.assertIs(_tls.current_app, app)
+                assert _tls.current_app is app
                 set_mp_process_title.assert_called_with(
                 set_mp_process_title.assert_called_with(
                     'celeryd', hostname='awesome.worker.com',
                     'celeryd', hostname='awesome.worker.com',
                 )
                 )
@@ -101,7 +102,7 @@ class test_process_initializer(AppCase):
                     os.environ.pop('CELERY_LOG_FILE', None)
                     os.environ.pop('CELERY_LOG_FILE', None)
 
 
 
 
-class test_process_destructor(AppCase):
+class test_process_destructor:
 
 
     @patch('celery.concurrency.prefork.signals')
     @patch('celery.concurrency.prefork.signals')
     def test_process_destructor(self, signals):
     def test_process_destructor(self, signals):
@@ -181,13 +182,9 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
     Pool = BlockingPool = ExeMockPool
 
 
 
 
-@skip.unless_module('multiprocessing')
-class PoolCase(AppCase):
-    pass
-
-
 @skip.if_win32()
 @skip.if_win32()
-class test_AsynPool(PoolCase):
+@skip.unless_module('multiprocessing')
+class test_AsynPool:
 
 
     def test_gen_not_started(self):
     def test_gen_not_started(self):
 
 
@@ -195,11 +192,11 @@ class test_AsynPool(PoolCase):
             yield 1
             yield 1
             yield 2
             yield 2
         g = gen()
         g = gen()
-        self.assertTrue(asynpool.gen_not_started(g))
+        assert asynpool.gen_not_started(g)
         next(g)
         next(g)
-        self.assertFalse(asynpool.gen_not_started(g))
+        assert not asynpool.gen_not_started(g)
         list(g)
         list(g)
-        self.assertFalse(asynpool.gen_not_started(g))
+        assert not asynpool.gen_not_started(g)
 
 
     @patch('select.select', create=True)
     @patch('select.select', create=True)
     def test_select(self, __select):
     def test_select(self, __select):
@@ -208,15 +205,11 @@ class test_AsynPool(PoolCase):
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll = poller.return_value = Mock(name='poll.poll')
             poll.return_value = {3}, set(), 0
             poll.return_value = {3}, set(), 0
-            self.assertEqual(
-                asynpool._select({3}, poll=poll),
-                ({3}, set(), 0),
-            )
+            assert asynpool._select({3}, poll=poll) == ({3}, set(), 0)
 
 
             poll.return_value = {3}, set(), 0
             poll.return_value = {3}, set(), 0
-            self.assertEqual(
-                asynpool._select({3}, None, {3}, poll=poll),
-                ({3}, set(), 0),
+            assert asynpool._select({3}, None, {3}, poll=poll) == (
+                {3}, set(), 0,
             )
             )
 
 
             eintr = socket.error()
             eintr = socket.error()
@@ -224,11 +217,8 @@ class test_AsynPool(PoolCase):
             poll.side_effect = eintr
             poll.side_effect = eintr
 
 
             readers = {3}
             readers = {3}
-            self.assertEqual(
-                asynpool._select(readers, poll=poll),
-                (set(), set(), 1),
-            )
-            self.assertIn(3, readers)
+            assert asynpool._select(readers, poll=poll) == (set(), set(), 1)
+            assert 3 in readers
 
 
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll = poller.return_value = Mock(name='poll.poll')
@@ -236,16 +226,15 @@ class test_AsynPool(PoolCase):
             with patch('select.select') as selcheck:
             with patch('select.select') as selcheck:
                 selcheck.side_effect = ebadf
                 selcheck.side_effect = ebadf
                 readers = {3}
                 readers = {3}
-                self.assertEqual(
-                    asynpool._select(readers, poll=poll),
-                    (set(), set(), 1),
+                assert asynpool._select(readers, poll=poll) == (
+                    set(), set(), 1,
                 )
                 )
-                self.assertNotIn(3, readers)
+                assert 3 not in readers
 
 
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
             poll = poller.return_value = Mock(name='poll.poll')
             poll = poller.return_value = Mock(name='poll.poll')
             poll.side_effect = MemoryError()
             poll.side_effect = MemoryError()
-            with self.assertRaises(MemoryError):
+            with pytest.raises(MemoryError):
                 asynpool._select({1}, poll=poll)
                 asynpool._select({1}, poll=poll)
 
 
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
@@ -256,7 +245,7 @@ class test_AsynPool(PoolCase):
                     selcheck.side_effect = MemoryError()
                     selcheck.side_effect = MemoryError()
                     raise ebadf
                     raise ebadf
                 poll.side_effect = se
                 poll.side_effect = se
-                with self.assertRaises(MemoryError):
+                with pytest.raises(MemoryError):
                     asynpool._select({3}, poll=poll)
                     asynpool._select({3}, poll=poll)
 
 
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
@@ -268,7 +257,7 @@ class test_AsynPool(PoolCase):
                     selcheck.side_effect.errno = 1321
                     selcheck.side_effect.errno = 1321
                     raise ebadf
                     raise ebadf
                 poll.side_effect = se2
                 poll.side_effect = se2
-                with self.assertRaises(socket.error):
+                with pytest.raises(socket.error):
                     asynpool._select({3}, poll=poll)
                     asynpool._select({3}, poll=poll)
 
 
         with patch('select.poll', create=True) as poller:
         with patch('select.poll', create=True) as poller:
@@ -276,14 +265,14 @@ class test_AsynPool(PoolCase):
 
 
             poll.side_effect = socket.error()
             poll.side_effect = socket.error()
             poll.side_effect.errno = 34134
             poll.side_effect.errno = 34134
-            with self.assertRaises(socket.error):
+            with pytest.raises(socket.error):
                 asynpool._select({3}, poll=poll)
                 asynpool._select({3}, poll=poll)
 
 
     def test_promise(self):
     def test_promise(self):
         fun = Mock()
         fun = Mock()
         x = asynpool.promise(fun, (1,), {'foo': 1})
         x = asynpool.promise(fun, (1,), {'foo': 1})
         x()
         x()
-        self.assertTrue(x.ready)
+        assert x.ready
         fun.assert_called_with(1, foo=1)
         fun.assert_called_with(1, foo=1)
 
 
     def test_Worker(self):
     def test_Worker(self):
@@ -293,7 +282,8 @@ class test_AsynPool(PoolCase):
 
 
 
 
 @skip.if_win32()
 @skip.if_win32()
-class test_ResultHandler(PoolCase):
+@skip.unless_module('multiprocessing')
+class test_ResultHandler:
 
 
     def test_process_result(self):
     def test_process_result(self):
         x = asynpool.ResultHandler(
         x = asynpool.ResultHandler(
@@ -303,7 +293,7 @@ class test_ResultHandler(PoolCase):
             on_process_alive=Mock(),
             on_process_alive=Mock(),
             on_job_ready=Mock(),
             on_job_ready=Mock(),
         )
         )
-        self.assertTrue(x)
+        assert x
         hub = Mock(name='hub')
         hub = Mock(name='hub')
         recv = x._recv_message = Mock(name='recv_message')
         recv = x._recv_message = Mock(name='recv_message')
         recv.return_value = iter([])
         recv.return_value = iter([])
@@ -319,25 +309,25 @@ class test_ResultHandler(PoolCase):
         )
         )
 
 
 
 
-class test_TaskPool(PoolCase):
+class test_TaskPool:
 
 
     def test_start(self):
     def test_start(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()
-        self.assertTrue(pool._pool.started)
-        self.assertEqual(pool._pool._state, asynpool.RUN)
+        assert pool._pool.started
+        assert pool._pool._state == asynpool.RUN
 
 
         _pool = pool._pool
         _pool = pool._pool
         pool.stop()
         pool.stop()
-        self.assertTrue(_pool.closed)
-        self.assertTrue(_pool.joined)
+        assert _pool.closed
+        assert _pool.joined
         pool.stop()
         pool.stop()
 
 
         pool.start()
         pool.start()
         _pool = pool._pool
         _pool = pool._pool
         pool.terminate()
         pool.terminate()
         pool.terminate()
         pool.terminate()
-        self.assertTrue(_pool.terminated)
+        assert _pool.terminated
 
 
     def test_restart(self):
     def test_restart(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
@@ -349,7 +339,7 @@ class test_TaskPool(PoolCase):
     def test_did_start_ok(self):
     def test_did_start_ok(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool._pool = Mock(name='pool')
         pool._pool = Mock(name='pool')
-        self.assertIs(pool.did_start_ok(), pool._pool.did_start_ok())
+        assert pool.did_start_ok() is pool._pool.did_start_ok()
 
 
     def test_register_with_event_loop(self):
     def test_register_with_event_loop(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
@@ -380,11 +370,11 @@ class test_TaskPool(PoolCase):
     def test_grow_shrink(self):
     def test_grow_shrink(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()
-        self.assertEqual(pool._pool._processes, 10)
+        assert pool._pool._processes == 10
         pool.grow()
         pool.grow()
-        self.assertEqual(pool._pool._processes, 11)
+        assert pool._pool._processes == 11
         pool.shrink(2)
         pool.shrink(2)
-        self.assertEqual(pool._pool._processes, 9)
+        assert pool._pool._processes == 9
 
 
     def test_info(self):
     def test_info(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
@@ -400,11 +390,11 @@ class test_TaskPool(PoolCase):
                 return {}
                 return {}
         pool._pool = _Pool()
         pool._pool = _Pool()
         info = pool.info
         info = pool.info
-        self.assertEqual(info['max-concurrency'], pool.limit)
-        self.assertEqual(info['max-tasks-per-child'], 'N/A')
-        self.assertEqual(info['timeouts'], (5, 10))
+        assert info['max-concurrency'] == pool.limit
+        assert info['max-tasks-per-child'] == 'N/A'
+        assert info['timeouts'] == (5, 10)
 
 
     def test_num_processes(self):
     def test_num_processes(self):
         pool = TaskPool(7)
         pool = TaskPool(7)
         pool.start()
         pool.start()
-        self.assertEqual(pool.num_processes, 7)
+        assert pool.num_processes == 7

+ 2 - 3
celery/tests/concurrency/test_solo.py → t/unit/concurrency/test_solo.py

@@ -4,10 +4,9 @@ import operator
 
 
 from celery.concurrency import solo
 from celery.concurrency import solo
 from celery.utils.functional import noop
 from celery.utils.functional import noop
-from celery.tests.case import AppCase
 
 
 
 
-class test_solo_TaskPool(AppCase):
+class test_solo_TaskPool:
 
 
     def test_on_start(self):
     def test_on_start(self):
         x = solo.TaskPool()
         x = solo.TaskPool()
@@ -21,4 +20,4 @@ class test_solo_TaskPool(AppCase):
     def test_info(self):
     def test_info(self):
         x = solo.TaskPool()
         x = solo.TaskPool()
         x.on_start()
         x.on_start()
-        self.assertTrue(x.info)
+        assert x.info

+ 0 - 0
celery/tests/fixups/__init__.py → t/unit/contrib/__init__.py


+ 6 - 9
celery/tests/contrib/test_abortable.py → t/unit/contrib/test_abortable.py

@@ -1,13 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
 from celery.contrib.abortable import AbortableTask, AbortableAsyncResult
-from celery.tests.case import AppCase
 
 
 
 
-class test_AbortableTask(AppCase):
+class test_AbortableTask:
 
 
     def setup(self):
     def setup(self):
-
         @self.app.task(base=AbortableTask, shared=False)
         @self.app.task(base=AbortableTask, shared=False)
         def abortable():
         def abortable():
             return True
             return True
@@ -16,16 +14,15 @@ class test_AbortableTask(AppCase):
     def test_async_result_is_abortable(self):
     def test_async_result_is_abortable(self):
         result = self.abortable.apply_async()
         result = self.abortable.apply_async()
         tid = result.id
         tid = result.id
-        self.assertIsInstance(
-            self.abortable.AsyncResult(tid), AbortableAsyncResult,
-        )
+        assert isinstance(
+            self.abortable.AsyncResult(tid), AbortableAsyncResult)
 
 
     def test_is_not_aborted(self):
     def test_is_not_aborted(self):
         self.abortable.push_request()
         self.abortable.push_request()
         try:
         try:
             result = self.abortable.apply_async()
             result = self.abortable.apply_async()
             tid = result.id
             tid = result.id
-            self.assertFalse(self.abortable.is_aborted(task_id=tid))
+            assert not self.abortable.is_aborted(task_id=tid)
         finally:
         finally:
             self.abortable.pop_request()
             self.abortable.pop_request()
 
 
@@ -34,7 +31,7 @@ class test_AbortableTask(AppCase):
         self.abortable.push_request()
         self.abortable.push_request()
         try:
         try:
             self.abortable.request.id = 'foo'
             self.abortable.request.id = 'foo'
-            self.assertFalse(self.abortable.is_aborted())
+            assert not self.abortable.is_aborted()
         finally:
         finally:
             self.abortable.pop_request()
             self.abortable.pop_request()
 
 
@@ -44,6 +41,6 @@ class test_AbortableTask(AppCase):
             result = self.abortable.apply_async()
             result = self.abortable.apply_async()
             result.abort()
             result.abort()
             tid = result.id
             tid = result.id
-            self.assertTrue(self.abortable.is_aborted(task_id=tid))
+            assert self.abortable.is_aborted(task_id=tid)
         finally:
         finally:
             self.abortable.pop_request()
             self.abortable.pop_request()

+ 80 - 76
celery/tests/contrib/test_migrate.py → t/unit/contrib/test_migrate.py

@@ -1,8 +1,11 @@
 from __future__ import absolute_import, unicode_literals
 from __future__ import absolute_import, unicode_literals
 
 
+import pytest
+
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
 from amqp import ChannelError
 from amqp import ChannelError
+from case import Mock, mock, patch
 
 
 from kombu import Connection, Producer, Queue, Exchange
 from kombu import Connection, Producer, Queue, Exchange
 
 
@@ -26,7 +29,6 @@ from celery.contrib.migrate import (
     move,
     move,
 )
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Mock, mock, patch
 
 
 # hack to ignore error at shutdown
 # hack to ignore error at shutdown
 QoS.restore_at_shutdown = False
 QoS.restore_at_shutdown = False
@@ -52,22 +54,22 @@ def Message(body, exchange='exchange', routing_key='rkey',
     )
     )
 
 
 
 
-class test_State(AppCase):
+class test_State:
 
 
     def test_strtotal(self):
     def test_strtotal(self):
         x = State()
         x = State()
-        self.assertEqual(x.strtotal, '?')
+        assert x.strtotal == '?'
         x.total_apx = 100
         x.total_apx = 100
-        self.assertEqual(x.strtotal, '100')
+        assert x.strtotal == '100'
 
 
     def test_repr(self):
     def test_repr(self):
         x = State()
         x = State()
-        self.assertTrue(repr(x))
+        assert repr(x)
         x.filtered = 'foo'
         x.filtered = 'foo'
-        self.assertTrue(repr(x))
+        assert repr(x)
 
 
 
 
-class test_move(AppCase):
+class test_move:
 
 
     @contextmanager
     @contextmanager
     def move_context(self, **kwargs):
     def move_context(self, **kwargs):
@@ -113,7 +115,7 @@ class test_move(AppCase):
         with self.move_context(limit=1) as (callback, pred, republish):
         with self.move_context(limit=1) as (callback, pred, republish):
             pred.return_value = 'foo'
             pred.return_value = 'foo'
             body, message = self.msgpair()
             body, message = self.msgpair()
-            with self.assertRaises(StopFiltering):
+            with pytest.raises(StopFiltering):
                 callback(body, message)
                 callback(body, message)
             republish.assert_called()
             republish.assert_called()
 
 
@@ -127,7 +129,7 @@ class test_move(AppCase):
             cb.assert_called()
             cb.assert_called()
 
 
 
 
-class test_start_filter(AppCase):
+class test_start_filter:
 
 
     def test_start(self):
     def test_start(self):
         with patch('celery.contrib.migrate.eventloop') as evloop:
         with patch('celery.contrib.migrate.eventloop') as evloop:
@@ -174,11 +176,11 @@ class test_start_filter(AppCase):
                     callback(body, Message(body))
                     callback(body, Message(body))
                 except StopFiltering:
                 except StopFiltering:
                     stop_filtering_raised = True
                     stop_filtering_raised = True
-            self.assertTrue(state.count)
-            self.assertTrue(stop_filtering_raised)
+            assert state.count
+            assert stop_filtering_raised
 
 
 
 
-class test_filter_callback(AppCase):
+class test_filter_callback:
 
 
     def test_filter(self):
     def test_filter(self):
         callback = Mock()
         callback = Mock()
@@ -193,57 +195,59 @@ class test_filter_callback(AppCase):
         callback.assert_called_with(t1, message)
         callback.assert_called_with(t1, message)
 
 
 
 
-class test_utils(AppCase):
+def test_task_id_in():
+    assert task_id_in(['A'], {'id': 'A'}, Mock())
+    assert not task_id_in(['A'], {'id': 'B'}, Mock())
+
+
+def test_task_id_eq():
+    assert task_id_eq('A', {'id': 'A'}, Mock())
+    assert not task_id_eq('A', {'id': 'B'}, Mock())
 
 
-    def test_task_id_in(self):
-        self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
-        self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
 
 
-    def test_task_id_eq(self):
-        self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
-        self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
+def test_expand_dest():
+    assert expand_dest(None, 'foo', 'bar') == ('foo', 'bar')
+    assert expand_dest(('b', 'x'), 'foo', 'bar') == ('b', 'x')
 
 
-    def test_expand_dest(self):
-        self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
-        self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
 
 
-    def test_maybe_queue(self):
-        app = Mock()
-        app.amqp.queues = {'foo': 313}
-        self.assertEqual(_maybe_queue(app, 'foo'), 313)
-        self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
+def test_maybe_queue():
+    app = Mock()
+    app.amqp.queues = {'foo': 313}
+    assert _maybe_queue(app, 'foo') == 313
+    assert _maybe_queue(app, Queue('foo')) == Queue('foo')
 
 
-    @mock.stdouts
-    def test_filter_status(self, stdout, stderr):
+
+def test_filter_status():
+    with mock.stdouts() as (stdout, stderr):
         filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
         filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
-        self.assertTrue(stdout.getvalue())
-
-    def test_move_by_taskmap(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_by_taskmap({'add': Queue('foo')})
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertTrue(cb({'task': 'add'}, Mock()))
-
-    def test_move_by_idmap(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_by_idmap({'123f': Queue('foo')})
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertTrue(cb({'id': '123f'}, Mock()))
-
-    def test_move_task_by_id(self):
-        with patch('celery.contrib.migrate.move') as move:
-            move_task_by_id('123f', Queue('foo'))
-            move.assert_called()
-            cb = move.call_args[0][0]
-            self.assertEqual(
-                cb({'id': '123f'}, Mock()),
-                Queue('foo'),
-            )
-
-
-class test_migrate_task(AppCase):
+        assert stdout.getvalue()
+
+
+def test_move_by_taskmap():
+    with patch('celery.contrib.migrate.move') as move:
+        move_by_taskmap({'add': Queue('foo')})
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'task': 'add'}, Mock())
+
+
+def test_move_by_idmap():
+    with patch('celery.contrib.migrate.move') as move:
+        move_by_idmap({'123f': Queue('foo')})
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'id': '123f'}, Mock())
+
+
+def test_move_task_by_id():
+    with patch('celery.contrib.migrate.move') as move:
+        move_task_by_id('123f', Queue('foo'))
+        move.assert_called()
+        cb = move.call_args[0][0]
+        assert cb({'id': '123f'}, Mock()) == Queue('foo')
+
+
+class test_migrate_task:
 
 
     def test_removes_compression_header(self):
     def test_removes_compression_header(self):
         x = Message('foo', compression='zlib')
         x = Message('foo', compression='zlib')
@@ -251,18 +255,18 @@ class test_migrate_task(AppCase):
         migrate_task(producer, x.body, x)
         migrate_task(producer, x.body, x)
         producer.publish.assert_called()
         producer.publish.assert_called()
         args, kwargs = producer.publish.call_args
         args, kwargs = producer.publish.call_args
-        self.assertIsInstance(args[0], bytes_t)
-        self.assertNotIn('compression', kwargs['headers'])
-        self.assertEqual(kwargs['compression'], 'zlib')
-        self.assertEqual(kwargs['content_type'], 'application/json')
-        self.assertEqual(kwargs['content_encoding'], 'utf-8')
-        self.assertEqual(kwargs['exchange'], 'exchange')
-        self.assertEqual(kwargs['routing_key'], 'rkey')
+        assert isinstance(args[0], bytes_t)
+        assert 'compression' not in kwargs['headers']
+        assert kwargs['compression'] == 'zlib'
+        assert kwargs['content_type'] == 'application/json'
+        assert kwargs['content_encoding'] == 'utf-8'
+        assert kwargs['exchange'] == 'exchange'
+        assert kwargs['routing_key'] == 'rkey'
 
 
 
 
-class test_migrate_tasks(AppCase):
+class test_migrate_tasks:
 
 
-    def test_migrate(self, name='testcelery'):
+    def test_migrate(self, app, name='testcelery'):
         x = Connection('memory://foo')
         x = Connection('memory://foo')
         y = Connection('memory://foo')
         y = Connection('memory://foo')
         # use separate state
         # use separate state
@@ -275,25 +279,25 @@ class test_migrate_tasks(AppCase):
         Producer(x).publish('foo', exchange=name, routing_key=name)
         Producer(x).publish('foo', exchange=name, routing_key=name)
         Producer(x).publish('bar', exchange=name, routing_key=name)
         Producer(x).publish('bar', exchange=name, routing_key=name)
         Producer(x).publish('baz', exchange=name, routing_key=name)
         Producer(x).publish('baz', exchange=name, routing_key=name)
-        self.assertTrue(x.default_channel.queues)
-        self.assertFalse(y.default_channel.queues)
+        assert x.default_channel.queues
+        assert not y.default_channel.queues
 
 
-        migrate_tasks(x, y, accept=['text/plain'], app=self.app)
+        migrate_tasks(x, y, accept=['text/plain'], app=app)
 
 
         yq = q(y.default_channel)
         yq = q(y.default_channel)
-        self.assertEqual(yq.get().body, ensure_bytes('foo'))
-        self.assertEqual(yq.get().body, ensure_bytes('bar'))
-        self.assertEqual(yq.get().body, ensure_bytes('baz'))
+        assert yq.get().body == ensure_bytes('foo')
+        assert yq.get().body == ensure_bytes('bar')
+        assert yq.get().body == ensure_bytes('baz')
 
 
         Producer(x).publish('foo', exchange=name, routing_key=name)
         Producer(x).publish('foo', exchange=name, routing_key=name)
         callback = Mock()
         callback = Mock()
         migrate_tasks(x, y,
         migrate_tasks(x, y,
-                      callback=callback, accept=['text/plain'], app=self.app)
+                      callback=callback, accept=['text/plain'], app=app)
         callback.assert_called()
         callback.assert_called()
         migrate = Mock()
         migrate = Mock()
         Producer(x).publish('baz', exchange=name, routing_key=name)
         Producer(x).publish('baz', exchange=name, routing_key=name)
         migrate_tasks(x, y, callback=callback,
         migrate_tasks(x, y, callback=callback,
-                      migrate=migrate, accept=['text/plain'], app=self.app)
+                      migrate=migrate, accept=['text/plain'], app=app)
         migrate.assert_called()
         migrate.assert_called()
 
 
         with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
         with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
@@ -303,12 +307,12 @@ class test_migrate_tasks(AppCase):
                     raise ChannelError('some channel error')
                     raise ChannelError('some channel error')
                 return 0, 3, 0
                 return 0, 3, 0
             qd.side_effect = effect
             qd.side_effect = effect
-            migrate_tasks(x, y, app=self.app)
+            migrate_tasks(x, y, app=app)
 
 
         x = Connection('memory://')
         x = Connection('memory://')
         x.default_channel.queues = {}
         x.default_channel.queues = {}
         y.default_channel.queues = {}
         y.default_channel.queues = {}
         callback = Mock()
         callback = Mock()
         migrate_tasks(x, y,
         migrate_tasks(x, y,
-                      callback=callback, accept=['text/plain'], app=self.app)
+                      callback=callback, accept=['text/plain'], app=app)
         callback.assert_not_called()
         callback.assert_not_called()

+ 12 - 10
celery/tests/contrib/test_rdb.py → t/unit/contrib/test_rdb.py

@@ -2,6 +2,9 @@ from __future__ import absolute_import, unicode_literals
 
 
 import errno
 import errno
 import socket
 import socket
+import pytest
+
+from case import Mock, patch, skip
 
 
 from celery.contrib.rdb import (
 from celery.contrib.rdb import (
     Rdb,
     Rdb,
@@ -9,26 +12,25 @@ from celery.contrib.rdb import (
     set_trace,
     set_trace,
 )
 )
 from celery.five import WhateverIO
 from celery.five import WhateverIO
-from celery.tests.case import AppCase, Mock, patch, skip
 
 
 
 
 class SockErr(socket.error):
 class SockErr(socket.error):
     errno = None
     errno = None
 
 
 
 
-class test_Rdb(AppCase):
+class test_Rdb:
 
 
     @patch('celery.contrib.rdb.Rdb')
     @patch('celery.contrib.rdb.Rdb')
     def test_debugger(self, Rdb):
     def test_debugger(self, Rdb):
         x = debugger()
         x = debugger()
-        self.assertTrue(x)
-        self.assertIs(x, debugger())
+        assert x
+        assert x is debugger()
 
 
     @patch('celery.contrib.rdb.debugger')
     @patch('celery.contrib.rdb.debugger')
     @patch('celery.contrib.rdb._frame')
     @patch('celery.contrib.rdb._frame')
     def test_set_trace(self, _frame, debugger):
     def test_set_trace(self, _frame, debugger):
-        self.assertTrue(set_trace(Mock()))
-        self.assertTrue(set_trace())
+        assert set_trace(Mock())
+        assert set_trace()
         debugger.return_value.set_trace.assert_called()
         debugger.return_value.set_trace.assert_called()
 
 
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
     @patch('celery.contrib.rdb.Rdb.get_avail_port')
@@ -40,7 +42,7 @@ class test_Rdb(AppCase):
         out = WhateverIO()
         out = WhateverIO()
         with Rdb(out=out) as rdb:
         with Rdb(out=out) as rdb:
             get_avail_port.assert_called()
             get_avail_port.assert_called()
-            self.assertIn('helu', out.getvalue())
+            assert 'helu' in out.getvalue()
 
 
             # set_quit
             # set_quit
             with patch('sys.settrace') as settrace:
             with patch('sys.settrace') as settrace:
@@ -54,7 +56,7 @@ class test_Rdb(AppCase):
                     rdb.set_trace(Mock())
                     rdb.set_trace(Mock())
                     pset.side_effect = SockErr
                     pset.side_effect = SockErr
                     pset.side_effect.errno = errno.ENOENT
                     pset.side_effect.errno = errno.ENOENT
-                    with self.assertRaises(SockErr):
+                    with pytest.raises(SockErr):
                         rdb.set_trace()
                         rdb.set_trace()
 
 
             # _close_session
             # _close_session
@@ -90,11 +92,11 @@ class test_Rdb(AppCase):
 
 
         err = sock.return_value.bind.side_effect = SockErr()
         err = sock.return_value.bind.side_effect = SockErr()
         err.errno = errno.ENOENT
         err.errno = errno.ENOENT
-        with self.assertRaises(SockErr):
+        with pytest.raises(SockErr):
             with Rdb(out=out):
             with Rdb(out=out):
                 pass
                 pass
         err.errno = errno.EADDRINUSE
         err.errno = errno.EADDRINUSE
-        with self.assertRaises(Exception):
+        with pytest.raises(Exception):
             with Rdb(out=out):
             with Rdb(out=out):
                 pass
                 pass
         called = [0]
         called = [0]

+ 0 - 0
celery/tests/tasks/__init__.py → t/unit/events/__init__.py


Some files were not shown because too many files changed in this diff