Procházet zdrojové kódy

Use assertWarns and assertWarnsRegex

Ask Solem před 13 roky
rodič
revize
2a82e8518b

+ 2 - 2
celery/apps/worker.py

@@ -131,8 +131,8 @@ class Worker(configurated):
         self.redirect_stdouts_to_logger()
 
         if getattr(os, "getuid", None) and os.getuid() == 0:
-            warnings.warn(
-                "Running celeryd with superuser privileges is discouraged!")
+            warnings.warn(RuntimeWarning(
+                "Running celeryd with superuser privileges is discouraged!"))
 
         if self.discard:
             self.purge_messages()

+ 4 - 3
celery/execute/trace.py

@@ -22,7 +22,8 @@ import os
 import socket
 import sys
 import traceback
-import warnings
+
+from warnings import warn
 
 from .. import current_app
 from .. import states, signals
@@ -219,6 +220,6 @@ def report_internal_error(task, exc):
     _type, _value, _tb = sys.exc_info()
     _value = task.backend.prepare_exception(exc)
     exc_info = ExceptionInfo((_type, _value, _tb))
-    warnings.warn("Exception outside body: %s: %s\n%s" % tuple(
-        map(str, (exc.__class__, exc, exc_info.traceback))))
+    warn(RuntimeWarning(
+        "Exception raised outside body: %r:\n%s" % (exc, exc_info.traceback)))
     return exc_info

+ 14 - 25
celery/tests/test_app/test_loaders.py

@@ -3,18 +3,19 @@ from __future__ import with_statement
 
 import os
 import sys
-import warnings
 
 from celery import task
 from celery import loaders
 from celery.app import app_or_default
-from celery.exceptions import CPendingDeprecationWarning, ImproperlyConfigured
+from celery.exceptions import (
+        CPendingDeprecationWarning,
+        ImproperlyConfigured)
 from celery.loaders import base
 from celery.loaders import default
 from celery.loaders.app import AppLoader
 
+from celery.tests.utils import AppCase, Case
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest, AppCase
 
 
 class ObjectConfig(object):
@@ -68,25 +69,17 @@ class TestLoaders(AppCase):
                           default.Loader)
 
     def test_current_loader(self):
-        warnings.resetwarnings()
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                r'deprecation'):
             self.assertIs(loaders.current_loader(), self.app.loader)
-            warning = log[0].message
-
-            self.assertIsInstance(warning, CPendingDeprecationWarning)
-            self.assertIn("deprecation", warning.args[0])
 
     def test_load_settings(self):
-        warnings.resetwarnings()
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(CPendingDeprecationWarning,
+                r'deprecation'):
             self.assertIs(loaders.load_settings(), self.app.conf)
-            warning = log[0].message
-
-            self.assertIsInstance(warning, CPendingDeprecationWarning)
-            self.assertIn("deprecation", warning.args[0])
 
 
-class TestLoaderBase(unittest.TestCase):
+class TestLoaderBase(Case):
     message_options = {"subject": "Subject",
                        "body": "Body",
                        "sender": "x@x.com",
@@ -131,15 +124,11 @@ class TestLoaderBase(unittest.TestCase):
         MockMail.Mailer.raise_on_send = True
         opts = dict(self.message_options, **self.server_options)
 
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(MockMail.SendmailWarning, r'KeyError'):
             self.loader.mail_admins(fail_silently=True, **opts)
-            warning = log[0].message
-
-            self.assertIsInstance(warning, MockMail.SendmailWarning)
-            self.assertIn("KeyError", warning.args[0])
 
-            with self.assertRaises(KeyError):
-                self.loader.mail_admins(fail_silently=False, **opts)
+        with self.assertRaises(KeyError):
+            self.loader.mail_admins(fail_silently=False, **opts)
 
     def test_mail_admins(self):
         MockMail.Mailer.raise_on_send = False
@@ -159,7 +148,7 @@ class TestLoaderBase(unittest.TestCase):
             self.loader.cmdline_config_parser(["broker.port=foobar"])
 
 
-class TestDefaultLoader(unittest.TestCase):
+class TestDefaultLoader(Case):
 
     def test_wanted_module_item(self):
         l = default.Loader()
@@ -223,7 +212,7 @@ class TestDefaultLoader(unittest.TestCase):
         self.assertTrue(context_executed[0])
 
 
-class test_AppLoader(unittest.TestCase):
+class test_AppLoader(Case):
 
     def setUp(self):
         self.app = app_or_default()

+ 2 - 7
celery/tests/test_bin/test_celeryd.py

@@ -4,7 +4,6 @@ from __future__ import with_statement
 import logging
 import os
 import sys
-import warnings
 
 from functools import wraps
 try:
@@ -25,7 +24,6 @@ from celery.bin.celeryd import WorkerCommand, windows_main, \
                                main as celeryd_main
 from celery.exceptions import ImproperlyConfigured
 
-from celery.tests.compat import catch_warnings
 from celery.tests.utils import (AppCase, WhateverIO, mask_modules,
                                 reset_modules, skip_unless_module)
 
@@ -251,19 +249,16 @@ class test_Worker(AppCase):
         app = current_app
         if app.IS_WINDOWS:
             raise SkipTest("Not applicable on Windows")
-        warnings.resetwarnings()
 
         def getuid():
             return 0
 
         prev, os.getuid = os.getuid, getuid
         try:
-            with catch_warnings(record=True) as log:
+            with self.assertWarnsRegex(RuntimeWarning,
+                    r'superuser privileges is discouraged'):
                 worker = self.Worker()
                 worker.run()
-                self.assertTrue(log)
-                self.assertIn("superuser privileges is discouraged",
-                              log[0].message.args[0])
         finally:
             os.getuid = prev
 

+ 2 - 6
celery/tests/test_compat/test_decorators.py

@@ -1,23 +1,19 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import warnings
-
 from celery.task import base
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
+from celery.tests.utils import Case
 
 
 def add(x, y):
     return x + y
 
 
-class test_decorators(unittest.TestCase):
+class test_decorators(Case):
 
     def setUp(self):
-        warnings.resetwarnings()
-
         with catch_warnings(record=True):
             from celery import decorators
             self.decorators = decorators

+ 7 - 21
celery/tests/test_task/test_task_builtins.py

@@ -1,49 +1,35 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import warnings
-
 from celery.task import ping, PingTask, backend_cleanup
 from celery.exceptions import CDeprecationWarning
-from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
+from celery.tests.utils import Case
 
 
 def some_func(i):
     return i * i
 
 
-class test_deprecated(unittest.TestCase):
+class test_deprecated(Case):
 
     def test_ping(self):
-        warnings.resetwarnings()
-
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(CDeprecationWarning,
+                r'ping task has been deprecated'):
             prev = PingTask.app.conf.CELERY_ALWAYS_EAGER
             PingTask.app.conf.CELERY_ALWAYS_EAGER = True
             try:
-                pong = ping()
-                warning = log[0].message
-                self.assertEqual(pong, "pong")
-                self.assertIsInstance(warning, CDeprecationWarning)
-                self.assertIn("ping task has been deprecated",
-                              warning.args[0])
+                self.assertEqual(ping(), "pong")
             finally:
                 PingTask.app.conf.CELERY_ALWAYS_EAGER = prev
 
     def test_TaskSet_import_from_task_base(self):
-        warnings.resetwarnings()
-
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(CDeprecationWarning, r'is deprecated'):
             from celery.task.base import TaskSet, subtask
             TaskSet()
             subtask(PingTask)
-            for w in (log[0].message, log[1].message):
-                self.assertIsInstance(w, CDeprecationWarning)
-                self.assertIn("is deprecated", w.args[0])
 
 
-class test_backend_cleanup(unittest.TestCase):
+class test_backend_cleanup(Case):
 
     def test_run(self):
         backend_cleanup.apply()

+ 10 - 17
celery/tests/test_task/test_task_sets.py

@@ -2,15 +2,14 @@ from __future__ import absolute_import
 from __future__ import with_statement
 
 import anyjson
-import warnings
 
 from celery import registry
 from celery.app import app_or_default
+from celery.exceptions import CDeprecationWarning
 from celery.task import Task
 from celery.task.sets import subtask, TaskSet
 
-from celery.tests.utils import unittest
-from celery.tests.compat import catch_warnings
+from celery.tests.utils import Case
 
 
 class MockTask(Task):
@@ -28,7 +27,7 @@ class MockTask(Task):
         return (args, kwargs, options)
 
 
-class test_subtask(unittest.TestCase):
+class test_subtask(Case):
 
     def test_behaves_like_type(self):
         s = subtask("tasks.add", (2, 2), {"cache": True},
@@ -101,28 +100,22 @@ class test_subtask(unittest.TestCase):
         self.assertDictEqual(dict(cls(*args)), dict(s))
 
 
-class test_TaskSet(unittest.TestCase):
+class test_TaskSet(Case):
 
     def test_interface__compat(self):
-        warnings.resetwarnings()
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(CDeprecationWarning,
+                r'Using this invocation of TaskSet is deprecated'):
             ts = TaskSet(MockTask, [[(2, 2)], [(4, 4)], [(8, 8)]])
             self.assertListEqual(ts.tasks,
                                  [MockTask.subtask((i, i))
                                     for i in (2, 4, 8)])
-            self.assertIn("Using this invocation of TaskSet is deprecated",
-                          log[0].message.args[0])
-            log[:] = []
+        with self.assertWarnsRegex(CDeprecationWarning,
+                r'TaskSet.task is deprecated'):
             self.assertEqual(ts.task, registry.tasks[MockTask.name])
-            self.assertTrue(log)
-            self.assertIn("TaskSet.task is deprecated",
-                          log[0].message.args[0])
 
-            log[:] = []
+        with self.assertWarnsRegex(CDeprecationWarning,
+                r'TaskSet.task_name is deprecated'):
             self.assertEqual(ts.task_name, MockTask.name)
-            self.assertTrue(log)
-            self.assertIn("TaskSet.task_name is deprecated",
-                          log[0].message.args[0])
 
     def test_task_arg_can_be_iterable__compat(self):
         ts = TaskSet([MockTask.subtask((i, i))

+ 8 - 17
celery/tests/test_utils/test_timer2.py

@@ -3,18 +3,16 @@ from __future__ import with_statement
 
 import sys
 import time
-import warnings
 
 from kombu.tests.utils import redirect_stdouts
 from mock import Mock, patch
 
 import celery.utils.timer2 as timer2
 
-from celery.tests.utils import unittest, skip_if_quick
-from celery.tests.compat import catch_warnings
+from celery.tests.utils import Case, skip_if_quick
 
 
-class test_Entry(unittest.TestCase):
+class test_Entry(Case):
 
     def test_call(self):
         scratch = [None]
@@ -33,7 +31,7 @@ class test_Entry(unittest.TestCase):
         self.assertTrue(tref.cancelled)
 
 
-class test_Schedule(unittest.TestCase):
+class test_Schedule(Case):
 
     def test_handle_error(self):
         from datetime import datetime
@@ -65,7 +63,7 @@ class test_Schedule(unittest.TestCase):
         self.assertIsInstance(exc, OverflowError)
 
 
-class test_Timer(unittest.TestCase):
+class test_Timer(Case):
 
     @skip_if_quick
     def test_enter_after(self):
@@ -117,13 +115,10 @@ class test_Timer(unittest.TestCase):
 
         fun = Mock()
         fun.side_effect = ValueError()
-        warnings.resetwarnings()
 
-        with catch_warnings(record=True) as log:
+        with self.assertWarns(timer2.TimedFunctionFailed):
             t.apply_entry(fun)
             fun.assert_called_with()
-            self.assertTrue(log)
-            self.assertTrue(stderr.getvalue())
 
     @redirect_stdouts
     def test_apply_entry_error_not_handled(self, stdout, stderr):
@@ -132,13 +127,9 @@ class test_Timer(unittest.TestCase):
 
         fun = Mock()
         fun.side_effect = ValueError()
-        warnings.resetwarnings()
-
-        with catch_warnings(record=True) as log:
-            t.apply_entry(fun)
-            fun.assert_called_with()
-            self.assertFalse(log)
-            self.assertFalse(stderr.getvalue())
+        t.apply_entry(fun)
+        fun.assert_called_with()
+        self.assertFalse(stderr.getvalue())
 
     @patch("os._exit")
     def test_thread_crash(self, _exit):

+ 5 - 11
celery/tests/test_worker/__init__.py

@@ -26,9 +26,7 @@ from celery.worker.consumer import QoS, RUN, PREFETCH_COUNT_MAX, CLOSE
 from celery.utils.serialization import pickle
 from celery.utils.timer2 import Timer
 
-from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
-from celery.tests.utils import AppCase
+from celery.tests.utils import AppCase, Case
 
 
 class PlaceHolder(object):
@@ -96,7 +94,7 @@ def create_message(channel, **data):
                    delivery_info={"consumer_tag": "mock"})
 
 
-class test_QoS(unittest.TestCase):
+class test_QoS(Case):
 
     class _QoS(QoS):
         def __init__(self, value):
@@ -202,7 +200,7 @@ class test_QoS(unittest.TestCase):
         qos.set(qos.prev)
 
 
-class test_Consumer(unittest.TestCase):
+class test_Consumer(Case):
 
     def setUp(self):
         self.ready_queue = FastQueue()
@@ -281,10 +279,8 @@ class test_Consumer(unittest.TestCase):
         l.event_dispatcher = Mock()
         l.pidbox_node = MockNode()
 
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
             l.receive_message(m.decode(), m)
-            self.assertTrue(log)
-            self.assertIn("unknown message", log[0].message.args[0])
 
     @patch("celery.utils.timer2.to_timestamp")
     def test_receive_message_eta_OverflowError(self, to_timestamp):
@@ -557,10 +553,8 @@ class test_Consumer(unittest.TestCase):
         l.logger = Mock()
         m.ack = Mock()
         m.ack.side_effect = socket.error("foo")
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(RuntimeWarning, r'unknown message'):
             self.assertFalse(l.receive_message(m.decode(), m))
-            self.assertTrue(log)
-            self.assertIn("unknown message", log[0].message.args[0])
         with self.assertRaises(Empty):
             self.ready_queue.get_nowait()
         self.assertTrue(self.eta_schedule.empty())

+ 6 - 12
celery/tests/test_worker/test_worker_job.py

@@ -32,9 +32,7 @@ from celery.utils.encoding import from_utf8, default_encode
 from celery.worker.job import Request, TaskRequest, execute_and_trace
 from celery.worker.state import revoked
 
-from celery.tests.compat import catch_warnings
-from celery.tests.utils import unittest
-from celery.tests.utils import WhateverIO, wrap_logger
+from celery.tests.utils import Case, WhateverIO, wrap_logger
 
 
 scratch = {"ACK": False}
@@ -77,7 +75,7 @@ def mytask_raising(i, **kwargs):
     raise KeyError(i)
 
 
-class test_default_encode(unittest.TestCase):
+class test_default_encode(Case):
 
     def setUp(self):
         if sys.version_info >= (3, 0):
@@ -101,7 +99,7 @@ class test_default_encode(unittest.TestCase):
             sys.getfilesystemencoding = gfe
 
 
-class test_RetryTaskError(unittest.TestCase):
+class test_RetryTaskError(Case):
 
     def test_retry_task_error(self):
         try:
@@ -111,7 +109,7 @@ class test_RetryTaskError(unittest.TestCase):
             self.assertEqual(ret.exc, exc)
 
 
-class test_trace_task(unittest.TestCase):
+class test_trace_task(Case):
 
     def test_process_cleanup_fails(self):
         backend = mytask.backend
@@ -204,7 +202,7 @@ class MockEventDispatcher(object):
         self.sent.append(event)
 
 
-class test_TaskRequest(unittest.TestCase):
+class test_TaskRequest(Case):
 
     def test_task_wrapper_repr(self):
         tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -473,7 +471,6 @@ class test_TaskRequest(unittest.TestCase):
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        warnings.resetwarnings()
 
         def _error_exec(self, *args, **kwargs):
             raise KeyError("baz")
@@ -483,13 +480,10 @@ class test_TaskRequest(unittest.TestCase):
             raise KeyError("baz")
         raising.request = None
 
-        with catch_warnings(record=True) as log:
+        with self.assertWarnsRegex(RuntimeWarning, r'Exception raised outside'):
             res = execute_and_trace(raising.name, uuid(),
                                     [], {})
             self.assertIsInstance(res, ExceptionInfo)
-            self.assertTrue(log)
-            self.assertIn("Exception outside", log[0].message.args[0])
-            self.assertIn("AttributeError", log[0].message.args[0])
 
     def create_exception(self, exc):
         try:

+ 79 - 2
celery/tests/utils.py

@@ -9,8 +9,10 @@ except AttributeError:
 import importlib
 import logging
 import os
+import re
 import sys
 import time
+import warnings
 try:
     import __builtin__ as builtins
 except ImportError:  # py3k
@@ -19,7 +21,6 @@ except ImportError:  # py3k
 from functools import wraps
 from contextlib import contextmanager
 
-
 import mock
 from nose import SkipTest
 
@@ -27,6 +28,8 @@ from ..app import app_or_default
 from ..utils import noop
 from ..utils.compat import WhateverIO, LoggerAdapter
 
+from .compat import catch_warnings
+
 
 class Mock(mock.Mock):
 
@@ -54,7 +57,81 @@ def skip_unless_module(module):
     return _inner
 
 
-class AppCase(unittest.TestCase):
+# -- adds assertWarns from recent unittest2, not in Python 2.7.
+
+class _AssertRaisesBaseContext(object):
+
+    def __init__(self, expected, test_case, callable_obj=None,
+                 expected_regex=None):
+        self.expected = expected
+        self.failureException = test_case.failureException
+        self.obj_name = None
+        if isinstance(expected_regex, basestring):
+            expected_regex = re.compile(expected_regex)
+        self.expected_regex = expected_regex
+
+
+class _AssertWarnsContext(_AssertRaisesBaseContext):
+    """A context manager used to implement TestCase.assertWarns* methods."""
+
+    def __enter__(self):
+        # The __warningregistry__'s need to be in a pristine state for tests
+        # to work properly.
+        warnings.resetwarnings()
+        for v in sys.modules.values():
+            if getattr(v, '__warningregistry__', None):
+                v.__warningregistry__ = {}
+        self.warnings_manager = catch_warnings(record=True)
+        self.warnings = self.warnings_manager.__enter__()
+        warnings.simplefilter("always", self.expected)
+        return self
+
+    def __exit__(self, exc_type, exc_value, tb):
+        self.warnings_manager.__exit__(exc_type, exc_value, tb)
+        if exc_type is not None:
+            # let unexpected exceptions pass through
+            return
+        try:
+            exc_name = self.expected.__name__
+        except AttributeError:
+            exc_name = str(self.expected)
+        first_matching = None
+        for m in self.warnings:
+            w = m.message
+            if not isinstance(w, self.expected):
+                continue
+            if first_matching is None:
+                first_matching = w
+            if (self.expected_regex is not None and
+                not self.expected_regex.search(str(w))):
+                continue
+            # store warning for later retrieval
+            self.warning = w
+            self.filename = m.filename
+            self.lineno = m.lineno
+            return
+        # Now we simply try to choose a helpful failure message
+        if first_matching is not None:
+            raise self.failureException('%r does not match %r' %
+                     (self.expected_regex.pattern, str(first_matching)))
+        if self.obj_name:
+            raise self.failureException("%s not triggered by %s"
+                % (exc_name, self.obj_name))
+        else:
+            raise self.failureException("%s not triggered"
+                % exc_name )
+
+
+class Case(unittest.TestCase):
+
+    def assertWarns(self, expected_warning):
+        return _AssertWarnsContext(expected_warning, self, None)
+
+    def assertWarnsRegex(self, expected_warning, expected_regex):
+        return _AssertWarnsContext(expected_warning, self,
+                                   None, expected_regex)
+
+class AppCase(Case):
 
     def setUp(self):
         from ..app import current_app