Browse Source

getting rid of 'with' and replacing contextlib with a fallback

Jesper Noehr 15 years ago
parent
commit
8a63c291f2
4 changed files with 99 additions and 50 deletions
  1. 50 38
      celery/tests/test_log.py
  2. 4 4
      celery/tests/test_worker.py
  3. 7 7
      celery/tests/test_worker_job.py
  4. 38 1
      celery/tests/utils.py

+ 50 - 38
celery/tests/test_log.py

@@ -6,7 +6,11 @@ import logging
 import unittest
 from tempfile import mktemp
 from StringIO import StringIO
-from contextlib import contextmanager
+
+try:
+    from contextlib import contextmanager
+except ImportError:
+    from celery.tests.utils import fallback_contextmanager as contextmanager
 
 from carrot.utils import rpartition
 
@@ -31,9 +35,10 @@ class TestLog(unittest.TestCase):
 
     def _assertLog(self, logger, logmsg, loglevel=logging.ERROR):
 
-        with wrap_logger(logger, loglevel=loglevel) as sio:
-            logger.log(loglevel, logmsg)
-            return sio.getvalue().strip()
+        sio = wrap_logger(logger, loglevel=loglevel)
+        logger.log(loglevel, logmsg)
+        
+        return sio.getvalue().strip()
 
     def assertDidLogTrue(self, logger, logmsg, reason, loglevel=None):
         val = self._assertLog(logger, logmsg, loglevel=loglevel)
@@ -68,11 +73,11 @@ class TestLog(unittest.TestCase):
         from multiprocessing import get_logger
         l = get_logger()
         l.handlers = []
-        with override_stdouts() as outs:
-            stdout, stderr = outs
-            l = setup_logger(logfile=stderr, loglevel=logging.INFO)
-            l.info("The quick brown fox...")
-            self.assertTrue("The quick brown fox..." in stderr.getvalue())
+        outs = override_stdouts()
+        stdout, stderr = outs
+        l = setup_logger(logfile=stderr, loglevel=logging.INFO)
+        l.info("The quick brown fox...")
+        self.assertTrue("The quick brown fox..." in stderr.getvalue())
 
     def test_setup_logger_no_handlers_file(self):
         from multiprocessing import get_logger
@@ -83,44 +88,51 @@ class TestLog(unittest.TestCase):
         self.assertTrue(isinstance(l.handlers[0], logging.FileHandler))
 
     def test_emergency_error_stderr(self):
-        with override_stdouts() as outs:
-            stdout, stderr = outs
-            emergency_error(None, "The lazy dog crawls under the fast fox")
-            self.assertTrue("The lazy dog crawls under the fast fox" in \
-                                stderr.getvalue())
+        outs = override_stdouts()
+        stdout, stderr = outs
+        emergency_error(None, "The lazy dog crawls under the fast fox")
+        self.assertTrue("The lazy dog crawls under the fast fox" in \
+                            stderr.getvalue())
 
     def test_emergency_error_file(self):
         tempfile = mktemp(suffix="unittest", prefix="celery")
         emergency_error(tempfile, "Vandelay Industries")
-        with open(tempfile, "r") as tempfilefh:
-            self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
+        tempfilefh = open(tempfile, "r")
+        self.assertTrue("Vandelay Industries" in "".join(tempfilefh))
+        tempfilefh.close()
         os.unlink(tempfile)
 
     def test_redirect_stdouts(self):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
+        did_exc = None
         try:
-            with wrap_logger(logger) as sio:
-                redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
-                logger.error("foo")
-                self.assertTrue("foo" in sio.getvalue())
-        finally:
-            sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
+            sio = wrap_logger(logger)
+            redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
+            logger.error("foo")
+            self.assertTrue("foo" in sio.getvalue())
+        except Exception, e:
+            did_exc = e
+
+        sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
+        
+        if did_exc:
+            raise did_exc
 
     def test_logging_proxy(self):
         logger = setup_logger(loglevel=logging.ERROR, logfile=None)
-        with wrap_logger(logger) as sio:
-            p = LoggingProxy(logger)
-            p.close()
-            p.write("foo")
-            self.assertTrue("foo" not in sio.getvalue())
-            p.closed = False
-            p.write("foo")
-            self.assertTrue("foo" in sio.getvalue())
-            lines = ["baz", "xuzzy"]
-            p.writelines(lines)
-            for line in lines:
-                self.assertTrue(line in sio.getvalue())
-            p.flush()
-            p.close()
-            self.assertFalse(p.isatty())
-            self.assertTrue(p.fileno() is None)
+        sio = wrap_logger(logger)
+        p = LoggingProxy(logger)
+        p.close()
+        p.write("foo")
+        self.assertTrue("foo" not in sio.getvalue())
+        p.closed = False
+        p.write("foo")
+        self.assertTrue("foo" in sio.getvalue())
+        lines = ["baz", "xuzzy"]
+        p.writelines(lines)
+        for line in lines:
+            self.assertTrue(line in sio.getvalue())
+        p.flush()
+        p.close()
+        self.assertFalse(p.isatty())
+        self.assertTrue(p.fileno() is None)

+ 4 - 4
celery/tests/test_worker.py

@@ -222,10 +222,10 @@ class TestCarrotListener(unittest.TestCase):
         m = create_message(backend, unknown={"baz": "!!!"})
         l.event_dispatcher = MockEventDispatcher()
         l.control_dispatch = MockControlDispatch()
-        with catch_warnings(record=True) as log:
-                l.receive_message(m.decode(), m)
-                self.assertTrue(log)
-                self.assertTrue("unknown message" in log[0].message.args[0])
+        log catch_warnings(record=True)
+        l.receive_message(m.decode(), m)
+        self.assertTrue(log)
+        self.assertTrue("unknown message" in log[0].message.args[0])
 
     def test_receieve_message(self):
         l = CarrotListener(self.ready_queue, self.eta_schedule, self.logger,

+ 7 - 7
celery/tests/test_worker_job.py

@@ -226,13 +226,13 @@ class TestTaskWrapper(unittest.TestCase):
 
         WorkerTaskTrace.execute = _error_exec
         try:
-            with catch_warnings(record=True) as log:
-                res = execute_and_trace(mytask.name, gen_unique_id(),
-                                        [4], {})
-                self.assertTrue(isinstance(res, ExceptionInfo))
-                self.assertTrue(log)
-                self.assertTrue("Exception outside" in log[0].message.args[0])
-                self.assertTrue("KeyError" in log[0].message.args[0])
+            log catch_warnings(record=True)
+            res = execute_and_trace(mytask.name, gen_unique_id(),
+                                    [4], {})
+            self.assertTrue(isinstance(res, ExceptionInfo))
+            self.assertTrue(log)
+            self.assertTrue("Exception outside" in log[0].message.args[0])
+            self.assertTrue("KeyError" in log[0].message.args[0])
         finally:
             WorkerTaskTrace.execute = old_exec
 

+ 38 - 1
celery/tests/utils.py

@@ -5,7 +5,44 @@ import sys
 import __builtin__
 from StringIO import StringIO
 from functools import wraps
-from contextlib import contextmanager
+
+class GeneratorContextManager(object):
+    def __init__(self, gen):
+        self.gen = gen
+
+    def __enter__(self):
+        try:
+            return self.gen.next()
+        except StopIteration:
+            raise RuntimeError("generator didn't yield")
+
+    def __exit__(self, type, value, traceback):
+        if type is None:
+            try:
+                self.gen.next()
+            except StopIteration:
+                return
+            else:
+                raise RuntimeError("generator didn't stop")
+        else:
+            try:
+                self.gen.throw(type, value, traceback)
+                raise RuntimeError("generator didn't stop after throw()")
+            except StopIteration:
+                return True
+            except:
+                if sys.exc_info()[1] is not value:
+                    raise
+
+def fallback_contextmanager(func):
+    def helper(*args, **kwds):
+        return GeneratorContextManager(func(*args, **kwds))
+    return helper
+
+try:
+    from contextlib import contextmanager
+except ImportError:
+    contextmanager = fallback_contextmanager
 
 from celery.utils import noop