Просмотр исходного кода

Implement Fake with statement for tests

Ask Solem 15 лет назад
Родитель
Сommit
8c38eba699
4 измененных файлов с 96 добавлено и 62 удалено
  1. 61 45
      celery/tests/test_log.py
  2. 9 6
      celery/tests/test_worker.py
  3. 11 9
      celery/tests/test_worker_job.py
  4. 15 2
      celery/tests/utils.py

+ 61 - 45
celery/tests/test_log.py

@@ -1,4 +1,4 @@
-from __future__ import with_statement, generators
+from __future__ import generators
 
 import os
 import sys
@@ -16,7 +16,7 @@ from carrot.utils import rpartition
 
 from celery.log import (setup_logger, emergency_error,
                         redirect_stdouts_to_logger, LoggingProxy)
-from celery.tests.utils import override_stdouts
+from celery.tests.utils import override_stdouts, execute_context
 
 
 @contextmanager
@@ -35,10 +35,12 @@ class TestLog(unittest.TestCase):
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
-        sio = wrap_logger(logger, loglevel=loglevel)
-        logger.log(loglevel, logmsg)
-        
-        return sio.getvalue().strip()
+        def with_wrap_logger(sio):
+            logger.log(loglevel, logmsg)
+            return sio.getvalue().strip()
+
+        context = wrap_logger(logger, loglevel=loglevel)
+        execute_context(context, with_wrap_logger)
 
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
@@ -73,11 +75,16 @@ class TestLog(unittest.TestCase):
         from multiprocessing import get_logger
         l = get_logger()
         l.handlers = []
-        outs = override_stdouts()
-        stdout, stderr = outs
-        l = setup_logger(logfile=stderr, loglevel=logging.INFO)
-        l.info("The quick brown fox...")
-        self.assertTrue("The quick brown fox..." in stderr.getvalue())
+
+        def with_override_stdouts(outs):
+            stdout, stderr = outs
+            l = setup_logger(logfile=stderr, loglevel=logging.INFO)
+            l.info("The quick brown fox...")
+            self.assertTrue("The quick brown fox..." in stderr.getvalue())
+
+        context = override_stdouts()
+        execute_context(context, with_override_stdouts)
+
 
     def test_setup_logger_no_handlers_file(self):
         from multiprocessing import get_logger
@@ -89,50 +96,59 @@ class TestLog(unittest.TestCase):
 
     def test_emergency_error_stderr(self):
         outs = override_stdouts()
-        stdout, stderr = outs
-        emergency_error(None, "The lazy dog crawls under the fast fox")
-        self.assertTrue("The lazy dog crawls under the fast fox" in \
-                            stderr.getvalue())
+
+        def with_override_stdouts(outs):
+            stdout, stderr = outs
+            emergency_error(None, "The lazy dog crawls under the fast fox")
+            self.assertTrue("The lazy dog crawls under the fast fox" in
+                                stderr.getvalue())
+
+        context = override_stdouts()
+        execute_context(context, with_override_stdouts)
 
     def test_emergency_error_file(self):
         tempfile = mktemp(suffix="unittest", prefix="celery")
         emergency_error(tempfile, "Vandelay Industries")
         tempfilefh = open(tempfile, "r")
-        self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
-        tempfilefh.close()
-        os.unlink(tempfile)
+        try:
+            self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
+        finally:
+            tempfilefh.close()
+            os.unlink(tempfile)
 
     def test_redirect_stdouts(self):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
-        did_exc = None
         try:
-            sio = wrap_logger(logger)
-            redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
-            logger.error("foo")
-            self.assertTrue("foo" in sio.getvalue())
-        except Exception, e:
-            did_exc = e
+            def with_wrap_logger(sio):
+                redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
+                logger.error("foo")
+                self.assertTrue("foo" in sio.getvalue())
+
+            context = wrap_logger(logger)
+            execute_context(context, with_wrap_logger)
+        finally:
+            sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
 
-        sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
-        
-        if did_exc:
-            raise did_exc
 
     def test_logging_proxy(self):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
-        sio = wrap_logger(logger)
-        p = LoggingProxy(logger)
-        p.close()
-        p.write("foo")
-        self.assertTrue("foo" not in sio.getvalue())
-        p.closed = False
-        p.write("foo")
-        self.assertTrue("foo" in sio.getvalue())
-        lines = ["baz", "xuzzy"]
-        p.writelines(lines)
-        for line in lines:
-            self.assertTrue(line in sio.getvalue())
-        p.flush()
-        p.close()
-        self.assertFalse(p.isatty())
-        self.assertTrue(p.fileno() is None)
+
+        def with_wrap_logger(sio):
+            p = LoggingProxy(logger)
+            p.close()
+            p.write("foo")
+            self.assertTrue("foo" not in sio.getvalue())
+            p.closed = False
+            p.write("foo")
+            self.assertTrue("foo" in sio.getvalue())
+            lines = ["baz", "xuzzy"]
+            p.writelines(lines)
+            for line in lines:
+                self.assertTrue(line in sio.getvalue())
+            p.flush()
+            p.close()
+            self.assertFalse(p.isatty())
+            self.assertTrue(p.fileno() is None)
+
+        context = wrap_logger(logger)
+        execute_context(context, with_wrap_logger)

+ 9 - 6
celery/tests/test_worker.py

@@ -1,5 +1,3 @@
-from __future__ import with_statement
-
 import unittest
 from Queue import Queue, Empty
 from datetime import datetime, timedelta
@@ -11,6 +9,7 @@ from billiard.serialization import pickle
 
 from celery import conf
 from celery.utils import gen_unique_id, noop
+from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
 from celery.worker import WorkController
 from celery.worker.listener import CarrotListener, RUN, CLOSE
@@ -222,10 +221,14 @@ class TestCarrotListener(unittest.TestCase):
         m = create_message(backend, unknown={"baz": "!!!"})
         l.event_dispatcher = MockEventDispatcher()
         l.control_dispatch = MockControlDispatch()
-        log = catch_warnings(record=True)
-        l.receive_message(m.decode(), m)
-        self.assertTrue(log)
-        self.assertTrue("unknown message" in log[0].message.args[0])
+
+        def with_catch_warnings(log):
+            l.receive_message(m.decode(), m)
+            self.assertTrue(log)
+            self.assertTrue("unknown message" in log[0].message.args[0])
+
+        context = catch_warnings(record=True)
+        execute_context(context, with_catch_warnings)
 
     def test_receieve_message(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,

+ 11 - 9
celery/tests/test_worker_job.py

@@ -1,6 +1,4 @@
 # -*- coding: utf-8 -*-
-from __future__ import with_statement
-
 import sys
 import logging
 import unittest
@@ -14,6 +12,7 @@ from celery import states
 from celery.log import setup_logger
 from celery.task.base import Task
 from celery.utils import gen_unique_id
+from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
 from celery.models import TaskMeta
 from celery.result import AsyncResult
@@ -226,13 +225,16 @@ class TestTaskWrapper(unittest.TestCase):
 
         WorkerTaskTrace.execute = _error_exec
         try:
-            log = catch_warnings(record=True)
-            res = execute_and_trace(mytask.name, gen_unique_id(),
-                                    [4], {})
-            self.assertTrue(isinstance(res, ExceptionInfo))
-            self.assertTrue(log)
-            self.assertTrue("Exception outside" in log[0].message.args[0])
-            self.assertTrue("KeyError" in log[0].message.args[0])
+            def with_catch_warnings(log):
+                res = execute_and_trace(mytask.name, gen_unique_id(),
+                                        [4], {})
+                self.assertTrue(isinstance(res, ExceptionInfo))
+                self.assertTrue(log)
+                self.assertTrue("Exception outside" in log[0].message.args[0])
+                self.assertTrue("KeyError" in log[0].message.args[0])
+
+            context = catch_warnings(record=True)
+            execute_context(context, with_catch_warnings)
         finally:
             WorkerTaskTrace.execute = old_exec
 

+ 15 - 2
celery/tests/utils.py

@@ -34,11 +34,24 @@ class GeneratorContextManager(object):
                 if sys.exc_info()[1] is not value:
                     raise
 
-def fallback_contextmanager(func):
+def fallback_contextmanager(fun):
     def helper(*args, **kwds):
-        return GeneratorContextManager(func(*args, **kwds))
+        return GeneratorContextManager(fun(*args, **kwds))
     return helper
 
+
+def execute_context(context, fun):
+    val = context.__enter__()
+    exc_info = (None, None, None)
+    retval = None
+    try:
+        retval = fun(val)
+    except:
+        exc_info = sys.exc_info()
+    context.__exit__(*exc_info)
+    return retval
+
+
 try:
     from contextlib import contextmanager
 except ImportError: