Browse Source

Fixes for celery.contrib.testing

Ask Solem 8 years ago
parent
commit
250e3539cc

+ 4 - 1
.travis.yml

@@ -1,5 +1,5 @@
 language: python
-sudo: false
+sudo: required
 cache: false
 python:
     - '3.5'
@@ -41,3 +41,6 @@ notifications:
       - "chat.freenode.net#celery"
     on_success: change
     on_failure: change
+services:
+    - rabbitmq
+    - redis

+ 1 - 0
celery/app/log.py

@@ -66,6 +66,7 @@ class Logging(object):
 
     def setup(self, loglevel=None, logfile=None, redirect_stdouts=False,
               redirect_level='WARNING', colorize=None, hostname=None):
+        loglevel = mlevel(loglevel)
         handled = self.setup_logging_subsystem(
             loglevel, logfile, colorize=colorize, hostname=hostname,
         )

+ 0 - 1
celery/bootsteps.py

@@ -30,7 +30,6 @@ TERMINATE = 0x3
 
 logger = get_logger(__name__)
 
-
 def _pre(ns, fmt):
     return '| {0}: {1}'.format(ns.alias, fmt)
 

+ 90 - 24
celery/contrib/pytest.py

@@ -9,7 +9,7 @@ from contextlib import contextmanager
 from celery.backends.cache import CacheBackend, DummyClient
 
 from .testing import worker
-from .testing.app import TestApp, setup_default_app_trap
+from .testing.app import TestApp, setup_default_app
 
 NO_WORKER = os.environ.get('NO_WORKER')
 
@@ -18,62 +18,128 @@ NO_WORKER = os.environ.get('NO_WORKER')
 
 
 @contextmanager
-def _create_app(request, **config):
-    test_app = TestApp(set_as_current=False, config=config)
-    with setup_default_app_trap(test_app):
+def _create_app(request, enable_logging=False, use_trap=False, **config):
+    # type: (Any, **Any) -> Celery
+    """Utility context used to setup Celery app for pytest fixtures."""
+    test_app = TestApp(
+        set_as_current=False,
+        enable_logging=enable_logging,
+        config=config,
+    )
+    # request.module is not defined for session
+    _module = getattr(request, 'module', None)
+    _cls = getattr(request, 'cls', None)
+    _function = getattr(request, 'function', None)
+    with setup_default_app(test_app, use_trap=use_trap):
         is_not_contained = any([
-            not getattr(request.module, 'app_contained', True),
-            not getattr(request.cls, 'app_contained', True),
-            not getattr(request.function, 'app_contained', True)
+            not getattr(_module, 'app_contained', True),
+            not getattr(_cls, 'app_contained', True),
+            not getattr(_function, 'app_contained', True)
         ])
         if is_not_contained:
             test_app.set_current()
         yield test_app
 
+@pytest.fixture(scope='session')
+def use_celery_app_trap():
+    return False
+
 
 @pytest.fixture(scope='session')
-def celery_session_app(request):
-    with _create_app(request) as app:
+def celery_session_app(request,
+                       celery_config,
+                       celery_enable_logging,
+                       use_celery_app_trap):
+    # type: (Any) -> Celery
+    """Session Fixture: Return app for session fixtures."""
+    mark = request.node.get_marker('celery')
+    config = dict(celery_config, **mark.kwargs if mark else {})
+    with _create_app(request,
+                     enable_logging=celery_enable_logging,
+                     use_trap=use_celery_app_trap,
+                     **config) as app:
+        if not use_celery_app_trap:
+            app.set_default()
+            app.set_current()
         yield app
 
 
-@pytest.fixture
+@pytest.fixture(scope='session')
+def celery_session_worker(request, celery_session_app,
+                          celery_includes, celery_worker_pool):
+    # type: (Any, Celery, Sequence[str], str) -> WorkController
+    """Session Fixture: Start worker that lives throughout test suite."""
+    if not NO_WORKER:
+        for module in celery_includes:
+            celery_session_app.loader.import_task_module(module)
+        with worker.start_worker(celery_session_app,
+                                 pool=celery_worker_pool) as w:
+            yield w
+
+
+@pytest.fixture(scope='session')
+def celery_enable_logging():
+    return False
+
+
+@pytest.fixture(scope='session')
+def celery_includes():
+    return ()
+
+
+@pytest.fixture(scope='session')
+def celery_worker_pool():
+    return 'solo'
+
+
+@pytest.fixture(scope='session')
 def celery_config():
+    # type: () -> Mapping[str, Any]
+    """Redefine this fixture to configure the test Celery app.
+
+    The config returned by your fixture will then be used
+    to configure the :func:`celery_app` fixture.
+    """
     return {}
 
 
 @pytest.fixture()
-def celery_app(request, celery_config):
+def celery_app(request,
+               celery_config,
+               celery_enable_logging,
+               use_celery_app_trap):
     """Fixture creating a Celery application instance."""
     mark = request.node.get_marker('celery')
     config = dict(celery_config, **mark.kwargs if mark else {})
-    with _create_app(request, **config) as app:
+    with _create_app(request,
+                     enable_logging=celery_enable_logging,
+                     use_trap=use_celery_app_trap,
+                     **config) as app:
         yield app
 
 
 @pytest.fixture()
-def celery_worker(request, celery_app):
-    if not NO_WORKER:
-        worker.start_worker(celery_app)
-
-
-@pytest.fixture(scope='session')
-def celery_session_worker(request, celery_session_app):
+def celery_worker(request, celery_app, celery_includes, celery_worker_pool):
+    # type: (Any, Celery, Sequence[str], str) -> WorkController
+    """Fixture: Start worker in a thread, stop it when the test returns."""
     if not NO_WORKER:
-        worker.start_worker(celery_session_app)
+        for module in celery_includes:
+            celery_app.loader.import_task_module(module)
+        with worker.start_worker(celery_app, pool=celery_worker_pool) as w:
+            yield w
 
 
 @pytest.fixture()
-def depends_on_current_app(app):
+def depends_on_current_app(celery_app):
     """Fixture that sets app as current."""
-    app.set_current()
+    celery_app.set_current()
 
 
 @pytest.fixture(autouse=True)
-def reset_cache_backend_state(app):
+def reset_cache_backend_state(celery_app):
     """Fixture that resets the internal state of the cache result backend."""
     yield
-    backend = app.__dict__.get('backend')
+    backend = celery_app.__dict__.get('backend')
     if backend is not None:
         if isinstance(backend, CacheBackend):
             if isinstance(backend.client, DummyClient):

+ 20 - 11
celery/contrib/testing/app.py

@@ -1,12 +1,10 @@
 from __future__ import absolute_import, unicode_literals
 
-import os
 import weakref
 
 from contextlib import contextmanager
 from copy import deepcopy
 
-from kombu import Queue
 from kombu.utils.imports import symbol_by_name
 
 from celery import Celery
@@ -15,7 +13,7 @@ from celery import _state
 DEFAULT_TEST_CONFIG = {
     'worker_hijack_root_logger': False,
     'worker_log_color': False,
-    'accept_content': 'json',
+    'accept_content': {'json'},
     'enable_utc': True,
     'timezone': 'UTC',
     'broker_url': 'memory://',
@@ -23,7 +21,6 @@ DEFAULT_TEST_CONFIG = {
 }
 
 
-
 class Trap(object):
     """Trap that pretends to be an app but raises an exception instead.
 
@@ -43,7 +40,7 @@ class UnitLogging(symbol_by_name(Celery.log_cls)):
         self.already_setup = True
 
 
-def TestApp(name=None, config=None, set_as_current=False,
+def TestApp(name=None, config=None, enable_logging=False, set_as_current=False,
             log=UnitLogging, backend=None, broker=None, **kwargs):
     """App used for testing."""
     from . import tasks  # noqa
@@ -52,6 +49,7 @@ def TestApp(name=None, config=None, set_as_current=False,
         config.pop('broker_url', None)
     if backend is not None:
         config.pop('result_backend', None)
+    log = None if enable_logging else log
     test_app = Celery(
         name or 'celery.tests',
         set_as_current=set_as_current,
@@ -64,11 +62,7 @@ def TestApp(name=None, config=None, set_as_current=False,
 
 
 @contextmanager
-def setup_default_app_trap(app):
-    prev_current_app = _state.get_current_app()
-    prev_default_app = _state.default_app
-    prev_finalizers = set(_state._on_app_finalizers)
-    prev_apps = weakref.WeakSet(_state._apps)
+def set_trap(app):
     trap = Trap()
     prev_tls = _state._tls
     _state.set_default_app(trap)
@@ -78,8 +72,23 @@ def setup_default_app_trap(app):
     _state._tls = NonTLS()
 
     yield
-    _state.set_default_app(prev_default_app)
     _state._tls = prev_tls
+
+
+@contextmanager
+def setup_default_app(app, use_trap=False):
+    prev_current_app = _state.get_current_app()
+    prev_default_app = _state.default_app
+    prev_finalizers = set(_state._on_app_finalizers)
+    prev_apps = weakref.WeakSet(_state._apps)
+
+    if use_trap:
+        with set_trap(app):
+            yield
+    else:
+        yield
+
+    _state.set_default_app(prev_default_app)
     _state._tls.current_app = prev_current_app
     if app is not prev_current_app:
         app.close()

+ 15 - 6
celery/contrib/testing/manager.py

@@ -30,15 +30,15 @@ def humanize_seconds(secs, prefix='', sep='', now='now', **kwargs):
 
 
 class ManagerMixin(object):
-    ResultMissingError = AssertionError
 
-    def _init_manager(self, app,
-                      block_timeout=30 * 60, stdout=None, stderr=None):
+    def _init_manager(self,
+                      block_timeout=30 * 60, no_join=False,
+                      stdout=None, stderr=None):
         self.stdout = sys.stdout if stdout is None else stdout
         self.stderr = sys.stderr if stderr is None else stderr
         self.connerrors = self.app.connection().recoverable_connection_errors
         self.block_timeout = block_timeout
-        self.progress = None
+        self.no_join = no_join
 
     def remark(self, s, sep='-'):
         print('{0}{1}'.format(sep, s), file=self.stdout)
@@ -108,13 +108,15 @@ class ManagerMixin(object):
                 )
             except self.connerrors as exc:
                 self.remark('join: connection lost: {0!r}'.format(exc), '!')
-        raise self.TaskPredicate('Test failed: Missing task results')
+        raise AssertionError('Test failed: Missing task results')
 
     def inspect(self, timeout=1):
         return self.app.control.inspect(timeout=timeout)
 
     def query_tasks(self, ids, timeout=0.5):
-        for reply in items(self.inspect(timeout).query_task(ids) or []):
+        print('BROKER: %r' % (self.app.connection().as_uri(),))
+        for reply in items(self.inspect(timeout).query_task(*ids) or {}):
+            print('REPLY: %r' %( reply,))
             yield reply
 
     def query_task_states(self, ids, timeout=0.5):
@@ -161,3 +163,10 @@ class ManagerMixin(object):
         if not res:
             raise Sentinel()
         return res
+
+
+class Manager(ManagerMixin):
+
+    def __init__(self, app, **kwargs):
+        self.app = app
+        self._init_manager(**kwargs)

+ 69 - 23
celery/contrib/testing/worker.py

@@ -3,15 +3,17 @@ from __future__ import absolute_import, unicode_literals
 import os
 import threading
 
+from contextlib import contextmanager
+
 from celery import worker
 from celery.result import allow_join_result, _set_task_join_will_block
 from celery.utils.dispatch import Signal
+from celery.utils.nodenames import anon_nodename
 
 test_worker_starting = Signal(providing_args=[])
 test_worker_started = Signal(providing_args=['worker', 'consumer'])
 test_worker_stopped = Signal(providing_args=['worker'])
 
-NO_WORKER = os.environ.get('NO_WORKER')
 WORKER_LOGLEVEL = os.environ.get('WORKER_LOGLEVEL', 'error')
 
 
@@ -30,48 +32,92 @@ class TestWorkController(worker.WorkController):
         self._on_started.wait()
 
 
-def start_worker(app,
-                 concurrency=1,
-                 pool='solo',
-                 loglevel=WORKER_LOGLEVEL,
-                 logfile=None,
-                 WorkController=TestWorkController,
-                 perform_ping_check=True,
-                 ping_task_timeout=3.0,
-                 **kwargs):
-    test_worker_starting.send(sender=app)
-
-    setup_app_for_worker(app)
+@contextmanager
+def start_worker_thread(app,
+                        concurrency=1,
+                        pool='solo',
+                        loglevel=WORKER_LOGLEVEL,
+                        logfile=None,
+                        WorkController=TestWorkController,
+                        **kwargs):
+    setup_app_for_worker(app, loglevel, logfile)
+    print('BROKER: %r' % (app.conf.broker_url,))
+    assert 'celery.ping' in app.tasks
     worker = WorkController(
         app=app,
         concurrency=concurrency,
+        hostname=anon_nodename(),
         pool=pool,
         loglevel=loglevel,
         logfile=logfile,
         # not allowed to override TestWorkController.on_consumer_ready
         ready_callback=None,
+        without_heartbeat=True,
+        without_mingle=True,
+        without_gossip=True,
         **kwargs)
 
     t = threading.Thread(target=worker.start)
     t.start()
-
     worker.ensure_started()
-
-    if perform_ping_check:
-        from .tasks import ping
-        with allow_join_result():
-            assert ping.delay().get(ping_task_timeout=3) == 'pong'
+    print('WORKER STARTED')
     _set_task_join_will_block(False)
 
     yield worker
 
-    worker.stop()
+    print('STOPPING WORKER')
+    from celery.worker import state
+    state.should_terminate = 0
+    print('JOINING WORKER THREAD')
+    t.join(10)
+    state.should_terminate = None
+
+
+@contextmanager
+def start_worker_process(app,
+                         concurrency=1,
+                         pool='solo',
+                         loglevel=WORKER_LOGLEVEL,
+                         logfile=None,
+                         **kwargs):
+    from celery.apps.multi import Cluster, Node
+
+    app.set_current()
+    cluster = Cluster([Node('testworker1@%h')])
+    cluster.start()
+    yield
+    cluster.stopwait()
+
+
+@contextmanager
+def start_worker(app,
+                 concurrency=1,
+                 pool='solo',
+                 loglevel=WORKER_LOGLEVEL,
+                 logfile=None,
+                 perform_ping_check=True,
+                 ping_task_timeout=10.0,
+                 **kwargs):
+    test_worker_starting.send(sender=app)
+
+    with start_worker_thread(app,
+                             concurrency=concurrency,
+                             pool=pool,
+                             loglevel=loglevel,
+                             logfile=logfile,
+                             **kwargs) as worker:
+        if perform_ping_check:
+            from .tasks import ping
+            with allow_join_result():
+                assert ping.delay().get(timeout=ping_task_timeout) == 'pong'
+
+        yield worker
     test_worker_stopped.send(sender=app, worker=worker)
-    t.join()
 
 
-def setup_app_for_worker(app):
+def setup_app_for_worker(app, loglevel, logfile):
     app.finalize()
     app.set_current()
     app.set_default()
-    app.log.setup()
+    type(app.log)._setup = False
+    app.log.setup(loglevel=loglevel, logfile=logfile)

+ 1 - 0
celery/worker/control.py

@@ -110,6 +110,7 @@ def _wanted_config_key(key):
 )
 def query_task(state, ids, **kwargs):
     """Query for task information by id."""
+    print('GET IDS: %r' % (ids,))
     return {
         req.id: (_state_of_task(req), req.info())
         for req in _find_requests_by_id(maybe_list(ids))

+ 26 - 30
t/integration/conftest.py

@@ -2,46 +2,42 @@ from __future__ import absolute_import, unicode_literals
 
 import pytest
 
-from cyanide.suite import ManagerMixin
+from celery.contrib.testing.manager import Manager
 
+@pytest.fixture(scope='session')
+def celery_config():
+    return {
+        'broker_url': 'pyamqp://',
+        'result_backend': 'redis://',
+    }
 
-def _celerymark(app, redis_results=None, **kwargs):
-    if redis_results and not app.conf.result_backend.startswith('redis'):
-        pytest.skip('Test needs Redis result backend.')
 
+@pytest.fixture(scope='session')
+def celery_enable_logging():
+    return True
 
-@pytest.fixture
-def app(request):
-    from .app import app
-    app.finalize()
-    app.set_current()
-    mark = request.node.get_marker('celery')
-    mark = mark and mark.kwargs or {}
-    _celerymark(app, **mark)
-    yield app
 
+@pytest.fixture(scope='session')
+def celery_worker_pool():
+    return 'prefork'
 
-@pytest.fixture
-def manager(app):
-    with CeleryManager(app) as manager:
-        yield manager
 
+@pytest.fixture(scope='session')
+def celery_includes():
+    return {'t.integration.tasks'}
 
-class CeleryManager(ManagerMixin):
 
-    # we don't stop full suite when a task result is missing.
-    TaskPredicate = AssertionError
+@pytest.fixture
+def app(celery_app):
+    yield celery_app
 
-    def __init__(self, app, no_join=False, **kwargs):
-        self.app = app
-        self.no_join = no_join
-        self._init_manager(app, **kwargs)
 
-    def __enter__(self):
-        return self
+@pytest.fixture
+def manager(app, celery_session_worker):
+    return Manager(app)
 
-    def __exit__(self, *exc_info):
-        self.close()
 
-    def close(self):
-        pass
+@pytest.fixture(autouse=True)
+def ZZZZ_set_app_current(app):
+    app.set_current()
+    app.set_default()

+ 44 - 0
t/integration/tasks.py

@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+from time import sleep
+from celery import shared_task
+from celery.utils.log import get_task_logger
+
+logger = get_task_logger(__name__)
+
+
+@shared_task
+def add(x, y):
+    """Add two numbers."""
+    return x + y
+
+
+@shared_task
+def print_unicode(log_message='hå它 valmuefrø', print_message='hiöäüß'):
+    """Task that both logs and print strings containing funny characters."""
+    logger.warning(log_message)
+    print(print_message)
+
+
+@shared_task
+def sleeping(i, **_):
+    """Task sleeping for ``i`` seconds, and returning nothing."""
+    sleep(i)
+
+
+@shared_task(bind=True)
+def ids(self, i):
+    """Returns a tuple of ``root_id``, ``parent_id`` and
+    the argument passed as ``i``."""
+    return self.request.root_id, self.request.parent_id, i
+
+
+@shared_task(bind=True)
+def collect_ids(self, res, i):
+    """Used as a callback in a chain or group where the previous tasks
+    are :task:`ids`: returns a tuple of::
+
+        (previous_result, (root_id, parent_id, i))
+
+    """
+    return res, ids(i)

+ 1 - 4
t/integration/test_canvas.py

@@ -1,10 +1,7 @@
 from __future__ import absolute_import, unicode_literals
-
 import pytest
-
 from celery import chain, group, uuid
-
-from cyanide.tasks import add, collect_ids, ids
+from .tasks import add, collect_ids, ids
 
 
 class test_chain:

+ 1 - 1
t/integration/test_tasks.py

@@ -1,6 +1,6 @@
 from __future__ import absolute_import, unicode_literals
 from celery import group
-from cyanide.tasks import print_unicode, sleeping
+from .tasks import print_unicode, sleeping
 
 
 class test_tasks:

+ 5 - 2
t/unit/conftest.py

@@ -18,7 +18,6 @@ from celery.contrib.testing.app import Trap, TestApp
 from celery.contrib.testing.mocks import (
     TaskMessage, TaskMessage1, task_message_from_sig,
 )
-from celery.contrib.pytest import celery_app  # noqa
 from celery.contrib.pytest import reset_cache_backend_state  # noqa
 from celery.contrib.pytest import depends_on_current_app  # noqa
 
@@ -38,7 +37,7 @@ CASE_LOG_LEVEL_EFFECT = 'Test {0} modified the level of the root logger'
 CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
 
 
-@pytest.fixture
+@pytest.fixture(scope='session')
 def celery_config():
     return {
         #: Don't want log output when running suite.
@@ -65,6 +64,10 @@ def celery_config():
     }
 
 
+@pytest.fixture(scope='session')
+def use_celery_app_trap():
+    return True
+
 
 @decorator
 def assert_signal_called(signal, **expected):

+ 5 - 7
tox.ini

@@ -1,10 +1,6 @@
 [tox]
 envlist =
-    2.7
-    pypy
-    3.4
-    3.5
-    pypy3
+    {2.7,pypy,3.4,3.5,pypy3}-{unit,integration}
     flake8
     flakeplus
     apicheck
@@ -25,8 +21,10 @@ deps=
     flake8,flakeplus,pydocstyle: -r{toxinidir}/requirements/pkgutils.txt
 sitepackages = False
 recreate = False
-commands = pip install -U -r{toxinidir}/requirements/dev.txt
-           py.test -xv
+commands =
+    pip install -U -r{toxinidir}/requirements/dev.txt
+    unit: py.test -xv
+    integration: py.test -xv t/integration
 
 basepython =
     2.7: python2.7