Browse Source

Testing tools (were now a pytest plugin!)

Ask Solem 8 years ago
parent
commit
32e2df9489

+ 81 - 0
celery/contrib/pytest.py

@@ -0,0 +1,81 @@
+"""Fixtures and testing utilities for :pypi:`py.test <pytest>`."""
+from __future__ import absolute_import, unicode_literals
+
+import os
+import pytest
+
+from contextlib import contextmanager
+
+from celery.backends.cache import CacheBackend, DummyClient
+
+from .testing import worker
+from .testing.app import TestApp, setup_default_app_trap
+
+NO_WORKER = os.environ.get('NO_WORKER')
+
+# pylint: disable=redefined-outer-name
+# Well, they're called fixtures....
+
+
+@contextmanager
+def _create_app(request, **config):
+    test_app = TestApp(set_as_current=False, config=config)
+    with setup_default_app_trap(test_app):
+        is_not_contained = any([
+            not getattr(request.module, 'app_contained', True),
+            not getattr(request.cls, 'app_contained', True),
+            not getattr(request.function, 'app_contained', True)
+        ])
+        if is_not_contained:
+            test_app.set_current()
+        yield test_app
+
+
+@pytest.fixture(scope='session')
+def celery_session_app(request):
+    with _create_app(request) as app:
+        yield app
+
+
+@pytest.fixture
+def celery_config():
+    return {}
+
+
+@pytest.fixture()
+def celery_app(request, celery_config):
+    """Fixture creating a Celery application instance."""
+    mark = request.node.get_marker('celery')
+    config = dict(celery_config, **mark.kwargs if mark else {})
+    with _create_app(request, **config) as app:
+        yield app
+
+
+@pytest.fixture()
+def celery_worker(request, celery_app):
+    if not NO_WORKER:
+        worker.start_worker(celery_app)
+
+
+@pytest.fixture(scope='session')
+def celery_session_worker(request, celery_session_app):
+    if not NO_WORKER:
+        worker.start_worker(celery_session_app)
+
+
+@pytest.fixture()
+def depends_on_current_app(app):
+    """Fixture that sets app as current."""
+    app.set_current()
+
+
+@pytest.fixture(autouse=True)
+def reset_cache_backend_state(app):
+    """Fixture that resets the internal state of the cache result backend."""
+    yield
+    backend = app.__dict__.get('backend')
+    if backend is not None:
+        if isinstance(backend, CacheBackend):
+            if isinstance(backend.client, DummyClient):
+                backend.client.cache.clear()
+            backend._cache.clear()

+ 0 - 0
celery/contrib/testing/__init__.py


+ 87 - 0
celery/contrib/testing/app.py

@@ -0,0 +1,87 @@
+from __future__ import absolute_import, unicode_literals
+
+import os
+import weakref
+
+from contextlib import contextmanager
+from copy import deepcopy
+
+from kombu import Queue
+from kombu.utils.imports import symbol_by_name
+
+from celery import Celery
+from celery import _state
+
+DEFAULT_TEST_CONFIG = {
+    'worker_hijack_root_logger': False,
+    'worker_log_color': False,
+    'accept_content': 'json',
+    'enable_utc': True,
+    'timezone': 'UTC',
+    'broker_url': 'memory://',
+    'result_backend': 'cache+memory://'
+}
+
+
+
+class Trap(object):
+    """Trap that pretends to be an app but raises an exception instead.
+
+    This to protect from code that does not properly pass app instances,
+    then falls back to the current_app.
+    """
+
+    def __getattr__(self, name):
+        raise RuntimeError('Test depends on current_app')
+
+
+class UnitLogging(symbol_by_name(Celery.log_cls)):
+    """Sets up logging for the test application."""
+
+    def __init__(self, *args, **kwargs):
+        super(UnitLogging, self).__init__(*args, **kwargs)
+        self.already_setup = True
+
+
+def TestApp(name=None, config=None, set_as_current=False,
+            log=UnitLogging, backend=None, broker=None, **kwargs):
+    """App used for testing."""
+    from . import tasks  # noqa
+    config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {})
+    if broker is not None:
+        config.pop('broker_url', None)
+    if backend is not None:
+        config.pop('result_backend', None)
+    test_app = Celery(
+        name or 'celery.tests',
+        set_as_current=set_as_current,
+        log=log,
+        broker=broker,
+        backend=backend,
+        **kwargs)
+    test_app.add_defaults(config)
+    return test_app
+
+
+@contextmanager
+def setup_default_app_trap(app):
+    prev_current_app = _state.get_current_app()
+    prev_default_app = _state.default_app
+    prev_finalizers = set(_state._on_app_finalizers)
+    prev_apps = weakref.WeakSet(_state._apps)
+    trap = Trap()
+    prev_tls = _state._tls
+    _state.set_default_app(trap)
+
+    class NonTLS(object):
+        current_app = trap
+    _state._tls = NonTLS()
+
+    yield
+    _state.set_default_app(prev_default_app)
+    _state._tls = prev_tls
+    _state._tls.current_app = prev_current_app
+    if app is not prev_current_app:
+        app.close()
+    _state._on_app_finalizers = prev_finalizers
+    _state._apps = prev_apps

+ 163 - 0
celery/contrib/testing/manager.py

@@ -0,0 +1,163 @@
+from __future__ import absolute_import, print_function, unicode_literals
+
+import socket
+import sys
+
+from collections import defaultdict
+from functools import partial
+from itertools import count
+
+from kombu.utils.functional import retry_over_time
+
+from celery.exceptions import TimeoutError
+from celery.five import items
+from celery.utils.text import truncate
+from celery.utils.time import humanize_seconds as _humanize_seconds
+
+E_STILL_WAITING = 'Still waiting for {0}.  Trying again {when}: {exc!r}'
+
+
+class Sentinel(Exception):
+    pass
+
+
+def humanize_seconds(secs, prefix='', sep='', now='now', **kwargs):
+    s = _humanize_seconds(secs, prefix, sep, now, **kwargs)
+    if s == now and secs > 0:
+        return '{prefix}{sep}{0:.2f} seconds'.format(
+            float(secs), prefix=prefix, sep=sep)
+    return s
+
+
+class ManagerMixin(object):
+    ResultMissingError = AssertionError
+
+    def _init_manager(self, app,
+                      block_timeout=30 * 60, stdout=None, stderr=None):
+        self.stdout = sys.stdout if stdout is None else stdout
+        self.stderr = sys.stderr if stderr is None else stderr
+        self.connerrors = self.app.connection().recoverable_connection_errors
+        self.block_timeout = block_timeout
+        self.progress = None
+
+    def remark(self, s, sep='-'):
+        print('{0}{1}'.format(sep, s), file=self.stdout)
+
+    def missing_results(self, r):
+        return [res.id for res in r if res.id not in res.backend._cache]
+
+    def wait_for(self, fun, catch,
+                 desc='thing', args=(), kwargs={}, errback=None,
+                 max_retries=10, interval_start=0.1, interval_step=0.5,
+                 interval_max=5.0, emit_warning=False, **options):
+        def on_error(exc, intervals, retries):
+            interval = next(intervals)
+            if emit_warning:
+                self.warn(E_STILL_WAITING.format(
+                    desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
+                ))
+            if errback:
+                errback(exc, interval, retries)
+            return interval
+
+        return self.retry_over_time(
+            fun, catch,
+            args=args, kwargs=kwargs,
+            errback=on_error, max_retries=max_retries,
+            interval_start=interval_start, interval_step=interval_step,
+            **options
+        )
+
+    def ensure_not_for_a_while(self, fun, catch,
+                               desc='thing', max_retries=20,
+                               interval_start=0.1, interval_step=0.02,
+                               interval_max=1.0, emit_warning=False,
+                               **options):
+        try:
+            return self.wait_for(
+                fun, catch, desc=desc, max_retries=max_retries,
+                interval_start=interval_start, interval_step=interval_step,
+                interval_max=interval_max, emit_warning=emit_warning,
+            )
+        except catch:
+            pass
+        else:
+            raise AssertionError('Should not have happened: {0}'.format(desc))
+
+    def retry_over_time(self, *args, **kwargs):
+        return retry_over_time(*args, **kwargs)
+
+    def join(self, r, propagate=False, max_retries=10, **kwargs):
+        if self.no_join:
+            return
+        received = []
+
+        def on_result(task_id, value):
+            received.append(task_id)
+
+        for i in range(max_retries) if max_retries else count(0):
+            received[:] = []
+            try:
+                return r.get(callback=on_result, propagate=propagate, **kwargs)
+            except (socket.timeout, TimeoutError) as exc:
+                waiting_for = self.missing_results(r)
+                self.remark(
+                    'Still waiting for {0}/{1}: [{2}]: {3!r}'.format(
+                        len(r) - len(received), len(r),
+                        truncate(', '.join(waiting_for)), exc), '!',
+                )
+            except self.connerrors as exc:
+                self.remark('join: connection lost: {0!r}'.format(exc), '!')
+        raise self.TaskPredicate('Test failed: Missing task results')
+
+    def inspect(self, timeout=1):
+        return self.app.control.inspect(timeout=timeout)
+
+    def query_tasks(self, ids, timeout=0.5):
+        for reply in items(self.inspect(timeout).query_task(ids) or []):
+            yield reply
+
+    def query_task_states(self, ids, timeout=0.5):
+        states = defaultdict(set)
+        for hostname, reply in self.query_tasks(ids, timeout=timeout):
+            for task_id, (state, _) in items(reply):
+                states[state].add(task_id)
+        return states
+
+    def assert_accepted(self, ids, interval=0.5,
+                        desc='waiting for tasks to be accepted', **policy):
+        return self.assert_task_worker_state(
+            self.is_accepted, ids, interval=interval, desc=desc, **policy
+        )
+
+    def assert_received(self, ids, interval=0.5,
+                        desc='waiting for tasks to be received', **policy):
+        return self.assert_task_worker_state(
+            self.is_accepted, ids, interval=interval, desc=desc, **policy
+        )
+
+    def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
+        return self.wait_for(
+            partial(self.true_or_raise, fun, ids, timeout=interval),
+            (Sentinel,), **policy
+        )
+
+    def is_received(self, ids, **kwargs):
+        return self._ids_matches_state(
+            ['reserved', 'active', 'ready'], ids, **kwargs)
+
+    def is_accepted(self, ids, **kwargs):
+        return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
+
+    def _ids_matches_state(self, expected_states, ids, timeout=0.5):
+        states = self.query_task_states(ids, timeout=timeout)
+        return all(
+            any(t in s for s in [states[k] for k in expected_states])
+            for t in ids
+        )
+
+    def true_or_raise(self, fun, *args, **kwargs):
+        res = fun(*args, **kwargs)
+        if not res:
+            raise Sentinel()
+        return res

+ 86 - 0
celery/contrib/testing/mocks.py

@@ -0,0 +1,86 @@
+from __future__ import absolute_import, unicode_literals
+
+import numbers
+
+from datetime import datetime, timedelta
+
+try:
+    from case import Mock
+except ImportError:
+    try:
+        from unittest.mock import Mock
+    except ImportError:
+        from mock import Mock
+
+
+def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
+                errbacks=None, chain=None, shadow=None, utc=None, **options):
+    """Create task message in protocol 2 format."""
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {
+        'id': id,
+        'task': name,
+        'shadow': shadow,
+    }
+    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
+    message.headers.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        (args, kwargs, embed), serializer='json',
+    )
+    message.payload = (args, kwargs, embed)
+    return message
+
+
+def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
+                 errbacks=None, chain=None, **options):
+    """Create task message in protocol 1 format."""
+    from celery import uuid
+    from kombu.serialization import dumps
+    id = id or uuid()
+    message = Mock(name='TaskMessage-{0}'.format(id))
+    message.headers = {}
+    message.payload = {
+        'task': name,
+        'id': id,
+        'args': args,
+        'kwargs': kwargs,
+        'callbacks': callbacks,
+        'errbacks': errbacks,
+    }
+    message.payload.update(options)
+    message.content_type, message.content_encoding, message.body = dumps(
+        message.payload,
+    )
+    return message
+
+
+def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
+    """Create task message from :class:`celery.Signature`."""
+    sig.freeze()
+    callbacks = sig.options.pop('link', None)
+    errbacks = sig.options.pop('link_error', None)
+    countdown = sig.options.pop('countdown', None)
+    if countdown:
+        eta = app.now() + timedelta(seconds=countdown)
+    else:
+        eta = sig.options.pop('eta', None)
+    if eta and isinstance(eta, datetime):
+        eta = eta.isoformat()
+    expires = sig.options.pop('expires', None)
+    if expires and isinstance(expires, numbers.Real):
+        expires = app.now() + timedelta(seconds=expires)
+    if expires and isinstance(expires, datetime):
+        expires = expires.isoformat()
+    return TaskMessage(
+        sig.task, id=sig.id, args=sig.args,
+        kwargs=sig.kwargs,
+        callbacks=[dict(s) for s in callbacks] if callbacks else None,
+        errbacks=[dict(s) for s in errbacks] if errbacks else None,
+        eta=eta,
+        expires=expires,
+        utc=utc,
+        **sig.options
+    )

+ 8 - 0
celery/contrib/testing/tasks.py

@@ -0,0 +1,8 @@
+from __future__ import absolute_import, unicode_literals
+
+from celery import shared_task
+
+
+@shared_task(name='celery.ping')
+def ping():
+    return 'pong'

+ 77 - 0
celery/contrib/testing/worker.py

@@ -0,0 +1,77 @@
+from __future__ import absolute_import, unicode_literals
+
+import os
+import threading
+
+from celery import worker
+from celery.result import allow_join_result, _set_task_join_will_block
+from celery.utils.dispatch import Signal
+
+test_worker_starting = Signal(providing_args=[])
+test_worker_started = Signal(providing_args=['worker', 'consumer'])
+test_worker_stopped = Signal(providing_args=['worker'])
+
+NO_WORKER = os.environ.get('NO_WORKER')
+WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
+
+
+class TestWorkController(worker.WorkController):
+
+    def __init__(self, *args, **kwargs):
+        self._on_started = threading.Event()
+        super(TestWorkController, self).__init__(*args, **kwargs)
+
+    def on_consumer_ready(self, consumer):
+        self._on_started.set()
+        test_worker_started.send(
+            sender=self.app, worker=self, consumer=consumer)
+
+    def ensure_started(self):
+        self._on_started.wait()
+
+
+def start_worker(app,
+                 concurrency=1,
+                 pool='solo',
+                 loglevel=WORKER_LOGLEVEL,
+                 logfile=None,
+                 WorkController=TestWorkController,
+                 perform_ping_check=True,
+                 ping_task_timeout=3.0,
+                 **kwargs):
+    test_worker_starting.send(sender=app)
+
+    setup_app_for_worker(app)
+    worker = WorkController(
+        app=app,
+        concurrency=concurrency,
+        pool=pool,
+        loglevel=loglevel,
+        logfile=logfile,
+        # not allowed to override TestWorkController.on_consumer_ready
+        ready_callback=None,
+        **kwargs)
+
+    t = threading.Thread(target=worker.start)
+    t.start()
+
+    worker.ensure_started()
+
+    if perform_ping_check:
+        from .tasks import ping
+        with allow_join_result():
+            assert ping.delay().get(ping_task_timeout=3) == 'pong'
+    _set_task_join_will_block(False)
+
+    yield worker
+
+    worker.stop()
+    test_worker_stopped.send(sender=app, worker=worker)
+    t.join()
+
+
+def setup_app_for_worker(app):
+    app.finalize()
+    app.set_current()
+    app.set_default()
+    app.log.setup()

+ 0 - 224
celery/utils/pytest.py

@@ -1,224 +0,0 @@
-"""Fixtures and testing utilities for :pypi:`py.test <pytest>`."""
-from __future__ import absolute_import, unicode_literals
-
-import numbers
-import os
-import pytest
-import weakref
-
-from copy import deepcopy
-from datetime import datetime, timedelta
-from functools import partial
-
-from case import Mock
-from case.utils import decorator
-from kombu import Queue
-from kombu.utils.imports import symbol_by_name
-
-from celery import Celery
-from celery.app import current_app
-from celery.backends.cache import CacheBackend, DummyClient
-
-# pylint: disable=redefined-outer-name
-# Well, they're called fixtures....
-
-
-CELERY_TEST_CONFIG = {
-    #: Don't want log output when running suite.
-    'worker_hijack_root_logger': False,
-    'worker_log_color': False,
-    'task_default_queue': 'testcelery',
-    'task_default_exchange': 'testcelery',
-    'task_default_routing_key': 'testcelery',
-    'task_queues': (
-        Queue('testcelery', routing_key='testcelery'),
-    ),
-    'accept_content': ('json', 'pickle'),
-    'enable_utc': True,
-    'timezone': 'UTC',
-
-    # Mongo results tests (only executed if installed and running)
-    'mongodb_backend_settings': {
-        'host': os.environ.get('MONGO_HOST') or 'localhost',
-        'port': os.environ.get('MONGO_PORT') or 27017,
-        'database': os.environ.get('MONGO_DB') or 'celery_unittests',
-        'taskmeta_collection': (os.environ.get('MONGO_TASKMETA_COLLECTION') or
-                                'taskmeta_collection'),
-        'user': os.environ.get('MONGO_USER'),
-        'password': os.environ.get('MONGO_PASSWORD'),
-    }
-}
-
-
-class Trap(object):
-    """Trap that pretends to be an app but raises an exception instead.
-
-    This to protect from code that does not properly pass app instances,
-    then falls back to the current_app.
-    """
-
-    def __getattr__(self, name):
-        raise RuntimeError('Test depends on current_app')
-
-
-class UnitLogging(symbol_by_name(Celery.log_cls)):
-    """Sets up logging for the test application."""
-
-    def __init__(self, *args, **kwargs):
-        super(UnitLogging, self).__init__(*args, **kwargs)
-        self.already_setup = True
-
-
-def TestApp(name=None, set_as_current=False, log=UnitLogging,
-            broker='memory://', backend=None, **kwargs):
-    """App used for testing."""
-    test_app = Celery(
-        name or 'celery.tests',
-        set_as_current=set_as_current,
-        log=log, broker=broker,
-        backend=backend or 'cache+memory://', **kwargs)
-    test_app.add_defaults(deepcopy(CELERY_TEST_CONFIG))
-    return test_app
-
-
-@pytest.fixture(autouse=True)
-def app(request):
-    """Fixture creating a Celery application instance."""
-    from celery import _state
-    mark = request.node.get_marker('celery')
-    mark = mark and mark.kwargs or {}
-
-    prev_current_app = current_app()
-    prev_default_app = _state.default_app
-    prev_finalizers = set(_state._on_app_finalizers)
-    prev_apps = weakref.WeakSet(_state._apps)
-    trap = Trap()
-    prev_tls = _state._tls
-    _state.set_default_app(trap)
-
-    class NonTLS(object):
-        current_app = trap
-    _state._tls = NonTLS()
-
-    test_app = TestApp(set_as_current=False, backend=mark.get('backend'))
-    is_not_contained = any([
-        not getattr(request.module, 'app_contained', True),
-        not getattr(request.cls, 'app_contained', True),
-        not getattr(request.function, 'app_contained', True)
-    ])
-    if is_not_contained:
-        test_app.set_current()
-
-    yield test_app
-
-    _state.set_default_app(prev_default_app)
-    _state._tls = prev_tls
-    _state._tls.current_app = prev_current_app
-    if test_app is not prev_current_app:
-        test_app.close()
-    _state._on_app_finalizers = prev_finalizers
-    _state._apps = prev_apps
-
-
-@pytest.fixture()
-def depends_on_current_app(app):
-    """Fixture that sets app as current."""
-    app.set_current()
-
-
-@pytest.fixture(autouse=True)
-def reset_cache_backend_state(app):
-    """Fixture that resets the internal state of the cache result backend."""
-    yield
-    backend = app.__dict__.get('backend')
-    if backend is not None:
-        if isinstance(backend, CacheBackend):
-            if isinstance(backend.client, DummyClient):
-                backend.client.cache.clear()
-            backend._cache.clear()
-
-
-@decorator
-def assert_signal_called(signal, **expected):
-    """Context that verifes signal is called before exiting."""
-    handler = Mock()
-    call_handler = partial(handler)
-    signal.connect(call_handler)
-    try:
-        yield handler
-    finally:
-        signal.disconnect(call_handler)
-    handler.assert_called_with(signal=signal, **expected)
-
-
-def TaskMessage(name, id=None, args=(), kwargs={}, callbacks=None,
-                errbacks=None, chain=None, shadow=None, utc=None, **options):
-    """Create task message in protocol 2 format."""
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {
-        'id': id,
-        'task': name,
-        'shadow': shadow,
-    }
-    embed = {'callbacks': callbacks, 'errbacks': errbacks, 'chain': chain}
-    message.headers.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        (args, kwargs, embed), serializer='json',
-    )
-    message.payload = (args, kwargs, embed)
-    return message
-
-
-def TaskMessage1(name, id=None, args=(), kwargs={}, callbacks=None,
-                 errbacks=None, chain=None, **options):
-    """Create task message in protocol 1 format."""
-    from celery import uuid
-    from kombu.serialization import dumps
-    id = id or uuid()
-    message = Mock(name='TaskMessage-{0}'.format(id))
-    message.headers = {}
-    message.payload = {
-        'task': name,
-        'id': id,
-        'args': args,
-        'kwargs': kwargs,
-        'callbacks': callbacks,
-        'errbacks': errbacks,
-    }
-    message.payload.update(options)
-    message.content_type, message.content_encoding, message.body = dumps(
-        message.payload,
-    )
-    return message
-
-
-def task_message_from_sig(app, sig, utc=True, TaskMessage=TaskMessage):
-    """Create task message from :class:`celery.Signature`."""
-    sig.freeze()
-    callbacks = sig.options.pop('link', None)
-    errbacks = sig.options.pop('link_error', None)
-    countdown = sig.options.pop('countdown', None)
-    if countdown:
-        eta = app.now() + timedelta(seconds=countdown)
-    else:
-        eta = sig.options.pop('eta', None)
-    if eta and isinstance(eta, datetime):
-        eta = eta.isoformat()
-    expires = sig.options.pop('expires', None)
-    if expires and isinstance(expires, numbers.Real):
-        expires = app.now() + timedelta(seconds=expires)
-    if expires and isinstance(expires, datetime):
-        expires = expires.isoformat()
-    return TaskMessage(
-        sig.task, id=sig.id, args=sig.args,
-        kwargs=sig.kwargs,
-        callbacks=[dict(s) for s in callbacks] if callbacks else None,
-        errbacks=[dict(s) for s in errbacks] if errbacks else None,
-        eta=eta,
-        expires=expires,
-        utc=utc,
-        **sig.options
-    )

+ 2 - 2
docs/reference/celery.utils.pytest.rst → docs/reference/celery.contrib.pytest.rst

@@ -8,9 +8,9 @@
 API Reference
 =============
 
-.. currentmodule:: celery.utils.pytest
+.. currentmodule:: celery.contrib.pytest
 
-.. automodule:: celery.utils.pytest
+.. automodule:: celery.contrib.pytest
     :members:
     :undoc-members:
 

+ 1 - 1
docs/reference/index.rst

@@ -28,7 +28,6 @@
     celery.signals
     celery.security
     celery.utils.debug
-    celery.utils.pytest
     celery.exceptions
     celery.loaders
     celery.loaders.app
@@ -37,6 +36,7 @@
     celery.states
     celery.contrib.abortable
     celery.contrib.migrate
+    celery.contrib.pytest
     celery.contrib.sphinx
     celery.contrib.rdb
     celery.events

+ 4 - 1
setup.py

@@ -209,6 +209,9 @@ setuptools.setup(
     entry_points={
         'console_scripts': [
             'celery = celery.__main__:main',
-        ]
+        ],
+        'pytest11': [
+            'celery = celery.contrib.pytest',
+        ],
     },
 )

+ 1 - 1
t/integration/test_canvas.py

@@ -63,7 +63,7 @@ class test_group:
 
 class xxx_chord:
 
-    @pytest.mark.celery(redis_results=1)
+    @pytest.mark.celery(result_backend='redis://')
     def test_parent_ids(self, manager):
         self.assert_parentids_chord()
         self.assert_parentids_chord(uuid(), uuid())

+ 1 - 1
t/unit/backends/test_base.py

@@ -588,7 +588,7 @@ class test_DisabledBackend:
     def test_as_uri(self):
         assert DisabledBackend(self.app).as_uri() == 'disabled://'
 
-    @pytest.mark.celery(backend='disabled')
+    @pytest.mark.celery(result_backend='disabled')
     def test_chord_raises_error(self):
         from celery import chord
         with pytest.raises(NotImplementedError):

+ 57 - 8
t/unit/conftest.py

@@ -7,17 +7,20 @@ import sys
 import threading
 import warnings
 
+from functools import partial
 from importlib import import_module
 
 from case import Mock
+from case.utils import decorator
+from kombu import Queue
 
-from celery.utils.pytest import (
-    CELERY_TEST_CONFIG, Trap, TestApp,
-    assert_signal_called, TaskMessage, TaskMessage1, task_message_from_sig,
+from celery.contrib.testing.app import Trap, TestApp
+from celery.contrib.testing.mocks import (
+    TaskMessage, TaskMessage1, task_message_from_sig,
 )
-from celery.utils.pytest import app  # noqa
-from celery.utils.pytest import reset_cache_backend_state  # noqa
-from celery.utils.pytest import depends_on_current_app  # noqa
+from celery.contrib.pytest import celery_app  # noqa
+from celery.contrib.pytest import reset_cache_backend_state  # noqa
+from celery.contrib.pytest import depends_on_current_app  # noqa
 
 __all__ = ['app', 'reset_cache_backend_state', 'depends_on_current_app']
 
@@ -35,6 +38,52 @@ CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
 CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
 
 
+@pytest.fixture
+def celery_config():
+    return {
+        #: Don't want log output when running suite.
+        'task_default_queue': 'testcelery',
+        'task_default_exchange': 'testcelery',
+        'task_default_routing_key': 'testcelery',
+        'task_queues': (
+            Queue('testcelery', routing_key='testcelery'),
+        ),
+        'accept_content': ('json', 'pickle'),
+
+        # Mongo results tests (only executed if installed and running)
+        'mongodb_backend_settings': {
+            'host': os.environ.get('MONGO_HOST') or 'localhost',
+            'port': os.environ.get('MONGO_PORT') or 27017,
+            'database': os.environ.get('MONGO_DB') or 'celery_unittests',
+            'taskmeta_collection': (
+                os.environ.get('MONGO_TASKMETA_COLLECTION') or
+                'taskmeta_collection'
+            ),
+            'user': os.environ.get('MONGO_USER'),
+            'password': os.environ.get('MONGO_PASSWORD'),
+        }
+    }
+
+
+
+@decorator
+def assert_signal_called(signal, **expected):
+    """Context that verifes signal is called before exiting."""
+    handler = Mock()
+    call_handler = partial(handler)
+    signal.connect(call_handler)
+    try:
+        yield handler
+    finally:
+        signal.disconnect(call_handler)
+    handler.assert_called_with(signal=signal, **expected)
+
+
+@pytest.fixture
+def app(celery_app):
+    yield celery_app
+
+
 @pytest.fixture(autouse=True, scope='session')
 def AAA_disable_multiprocessing():
     # pytest-cov breaks if a multiprocessing.Process is started,
@@ -99,7 +148,7 @@ def AAA_reset_CELERY_LOADER_env():
 
 
 @pytest.fixture(autouse=True)
-def test_cases_shortcuts(request, app, patching):
+def test_cases_shortcuts(request, app, patching, celery_config):
     if request.instance:
         @app.task
         def add(x, y):
@@ -112,7 +161,7 @@ def test_cases_shortcuts(request, app, patching):
         request.instance.task_message_from_sig = task_message_from_sig
         request.instance.TaskMessage = TaskMessage
         request.instance.TaskMessage1 = TaskMessage1
-        request.instance.CELERY_TEST_CONFIG = dict(CELERY_TEST_CONFIG)
+        request.instance.CELERY_TEST_CONFIG = celery_config
         request.instance.add = add
         request.instance.patching = patching
     yield