Browse Source

Test related fixes

Ask Solem 8 years ago
parent
commit
b45c15dcb1
2 changed files with 26 additions and 26 deletions
  1. 16 20
      t/unit/conftest.py
  2. 10 6
      t/unit/worker/test_worker.py

+ 16 - 20
t/unit/conftest.py

@@ -14,14 +14,11 @@ from case import Mock
 from case.utils import decorator
 from kombu import Queue
 
+from celery.backends.cache import CacheBackend, DummyClient
 from celery.contrib.testing.app import Trap, TestApp
 from celery.contrib.testing.mocks import (
     TaskMessage, TaskMessage1, task_message_from_sig,
 )
-from celery.contrib.pytest import reset_cache_backend_state  # noqa
-from celery.contrib.pytest import depends_on_current_app  # noqa
-
-__all__ = ['app', 'reset_cache_backend_state', 'depends_on_current_app']
 
 try:
     WindowsError = WindowsError  # noqa
@@ -40,7 +37,9 @@ CASE_LOG_HANDLER_EFFECT = 'Test {0} modified handlers for the root logger'
 @pytest.fixture(scope='session')
 def celery_config():
     return {
-        #: Don't want log output when running suite.
+        'broker_url': 'memory://',
+        'result_backend': 'cache+memory://',
+
         'task_default_queue': 'testcelery',
         'task_default_exchange': 'testcelery',
         'task_default_routing_key': 'testcelery',
@@ -69,6 +68,18 @@ def use_celery_app_trap():
     return True
 
 
+@pytest.fixture(autouse=True)
+def reset_cache_backend_state(celery_app):
+    """Fixture that resets the internal state of the cache result backend."""
+    yield
+    backend = celery_app.__dict__.get('backend')
+    if backend is not None:
+        if isinstance(backend, CacheBackend):
+            if isinstance(backend.client, DummyClient):
+                backend.client.cache.clear()
+            backend._cache.clear()
+
+
 @decorator
 def assert_signal_called(signal, **expected):
     """Context that verifes signal is called before exiting."""
@@ -172,21 +183,6 @@ def test_cases_shortcuts(request, app, patching, celery_config):
         request.instance.app = None
 
 
-@pytest.fixture(autouse=True)
-def zzzz_test_cases_calls_setup_teardown(request):
-    if request.instance:
-        # we set the .patching attribute for every test class.
-        setup = getattr(request.instance, 'setup', None)
-        # we also call .setup() and .teardown() after every test method.
-        setup and setup()
-
-    yield
-
-    if request.instance:
-        teardown = getattr(request.instance, 'teardown', None)
-        teardown and teardown()
-
-
 @pytest.fixture(autouse=True)
 def sanity_no_shutdown_flags_set():
     yield

+ 10 - 6
t/unit/worker/test_worker.py

@@ -61,7 +61,8 @@ def find_step(obj, typ):
 
 def create_message(channel, **data):
     data.setdefault('id', uuid())
-    m = Message(channel, body=pickle.dumps(dict(**data)),
+    m = Message(body=pickle.dumps(dict(**data)),
+                channel=channel,
                 content_type='application/x-python-serialize',
                 content_encoding='binary',
                 delivery_info={'consumer_tag': 'mock'})
@@ -306,7 +307,7 @@ class test_Consumer(ConsumerCase):
                 raise socket.timeout(10)
 
         c = self.NoopConsumer()
-        c.connection = Connection()
+        c.connection = Connection(self.app.conf.broker_url)
         c.connection.obj = c
         c.qos = QoS(c.task_consumer.qos, 10)
         c.loop(*c.loop_args())
@@ -322,7 +323,7 @@ class test_Consumer(ConsumerCase):
 
         c = self.LoopConsumer()
         c.blueprint.state = RUN
-        conn = c.connection = Connection()
+        conn = c.connection = Connection(self.app.conf.broker_url)
         c.connection.obj = c
         c.qos = QoS(c.task_consumer.qos, 10)
         with pytest.raises(socket.error):
@@ -346,8 +347,9 @@ class test_Consumer(ConsumerCase):
 
         c = self.LoopConsumer()
         c.blueprint.state = RUN
-        c.connection = Connection()
+        c.connection = Connection(self.app.conf.broker_url)
         c.connection.obj = c
+        c.connection.get_heartbeat_interval = Mock(return_value=None)
         c.qos = QoS(c.task_consumer.qos, 10)
 
         c.loop(*c.loop_args())
@@ -651,7 +653,8 @@ class test_Consumer(ConsumerCase):
         init_callback = Mock(name='init_callback')
         c = self.NoopConsumer(init_callback=init_callback)
         c.qos = _QoS()
-        c.connection = Connection()
+        c.connection = Connection(self.app.conf.broker_url)
+        c.connection.get_heartbeat_interval = Mock(return_value=None)
         c.iterations = 0
 
         def raises_KeyError(*args, **kwargs):
@@ -670,7 +673,8 @@ class test_Consumer(ConsumerCase):
         init_callback.reset_mock()
         c = self.NoopConsumer(task_events=False, init_callback=init_callback)
         c.qos = _QoS()
-        c.connection = Connection()
+        c.connection = Connection(self.app.conf.broker_url)
+        c.connection.get_heartbeat_interval = Mock(return_value=None)
         c.loop = Mock(side_effect=socket.error('foo'))
         with pytest.raises(socket.error):
             c.start()