浏览代码

Tests now passing again, override_stdouts now has to redirect __stdout__+__stderr__ as well now that emergency_error prints to __stderr__.

Ask Solem 15 年之前
父节点
当前提交
e8c8a82ac3
共有 5 个文件被更改,包括 20 次插入20 次删除
  1. 4 1
      celery/fields.py
  2. 1 1
      celery/tests/test_backends/test_database.py
  3. 3 3
      celery/tests/test_log.py
  4. 2 2
      celery/tests/test_monitoring.py
  5. 10 13
      celery/tests/utils.py

+ 4 - 1
celery/fields.py

@@ -7,7 +7,10 @@ Custom Django Model Fields.
 from copy import deepcopy
 from copy import deepcopy
 from base64 import b64encode, b64decode
 from base64 import b64encode, b64decode
 from zlib import compress, decompress
 from zlib import compress, decompress
-from celery.seralization import pickle
+try:
+    import cPickle as pickle
+except ImportError:
+    import pickle
 
 
 from django.db import models
 from django.db import models
 from django.utils.encoding import force_unicode
 from django.utils.encoding import force_unicode

+ 1 - 1
celery/tests/test_backends/test_database.py

@@ -41,7 +41,7 @@ class TestDatabaseBackend(unittest.TestCase):
 
 
         self.assertFalse(b.is_done(tid))
         self.assertFalse(b.is_done(tid))
         self.assertEquals(b.get_status(tid), "PENDING")
         self.assertEquals(b.get_status(tid), "PENDING")
-        self.assertEquals(b.get_result(tid), '')
+        self.assertTrue(b.get_result(tid) is None)
 
 
         b.mark_as_done(tid, 42)
         b.mark_as_done(tid, 42)
         self.assertTrue(b.is_done(tid))
         self.assertTrue(b.is_done(tid))

+ 3 - 3
celery/tests/test_log.py

@@ -6,7 +6,7 @@ import unittest
 import multiprocessing
 import multiprocessing
 from StringIO import StringIO
 from StringIO import StringIO
 from celery.log import setup_logger, emergency_error
 from celery.log import setup_logger, emergency_error
-from celery.tests.utils import OverrideStdout
+from celery.tests.utils import override_stdouts
 from tempfile import mktemp
 from tempfile import mktemp
 
 
 
 
@@ -56,7 +56,7 @@ class TestLog(unittest.TestCase):
         from multiprocessing import get_logger
         from multiprocessing import get_logger
         l = get_logger()
         l = get_logger()
         l.handlers = []
         l.handlers = []
-        with OverrideStdout() as outs:
+        with override_stdouts() as outs:
             stdout, stderr = outs
             stdout, stderr = outs
             l = setup_logger(logfile=stderr, loglevel=logging.INFO)
             l = setup_logger(logfile=stderr, loglevel=logging.INFO)
             l.info("The quick brown fox...")
             l.info("The quick brown fox...")
@@ -71,7 +71,7 @@ class TestLog(unittest.TestCase):
         self.assertTrue(isinstance(l.handlers[0], logging.FileHandler))
         self.assertTrue(isinstance(l.handlers[0], logging.FileHandler))
 
 
     def test_emergency_error_stderr(self):
     def test_emergency_error_stderr(self):
-        with OverrideStdout() as outs:
+        with override_stdouts() as outs:
             stdout, stderr = outs
             stdout, stderr = outs
             emergency_error(None, "The lazy dog crawls under the fast fox")
             emergency_error(None, "The lazy dog crawls under the fast fox")
             self.assertTrue("The lazy dog crawls under the fast fox" in \
             self.assertTrue("The lazy dog crawls under the fast fox" in \

+ 2 - 2
celery/tests/test_monitoring.py

@@ -4,7 +4,7 @@ import time
 from celery.monitoring import TaskTimerStats, Statistics, StatsCollector
 from celery.monitoring import TaskTimerStats, Statistics, StatsCollector
 from carrot.connection import DjangoBrokerConnection
 from carrot.connection import DjangoBrokerConnection
 from celery.messaging import StatsConsumer
 from celery.messaging import StatsConsumer
-from celery.tests.utils import OverrideStdout
+from celery.tests.utils import override_stdouts
 
 
 
 
 class PartialStatistics(Statistics):
 class PartialStatistics(Statistics):
@@ -86,7 +86,7 @@ class TestStatsCollector(unittest.TestCase):
         self.assertEquals(self.s.total_tasks_processed, 3)
         self.assertEquals(self.s.total_tasks_processed, 3)
 
 
         # Report
         # Report
-        with OverrideStdout() as outs:
+        with override_stdouts() as outs:
             stdout, stderr = outs
             stdout, stderr = outs
             self.s.report()
             self.s.report()
             self.assertTrue(
             self.assertTrue(

+ 10 - 13
celery/tests/utils.py

@@ -38,18 +38,15 @@ def mask_modules(*modnames):
     __builtin__.__import__ = realimport
     __builtin__.__import__ = realimport
 
 
 
 
-class OverrideStdout(object):
+@contextmanager
+def override_stdouts():
     """Override ``sys.stdout`` and ``sys.stderr`` with ``StringIO``."""
     """Override ``sys.stdout`` and ``sys.stderr`` with ``StringIO``."""
+    prev_out, prev_err = sys.stdout, sys.stderr
+    mystdout, mystderr = StringIO(), StringIO()
+    sys.stdout = sys.__stdout__ = mystdout
+    sys.stderr = sys.__stderr__ = mystderr
+
+    yield mystdout, mystderr
 
 
-    def __enter__(self):
-        mystdout = StringIO()
-        mystderr = StringIO()
-        sys.stdout = mystdout
-        sys.stderr = mystderr
-        return mystdout, mystderr
-
-    def __exit__(self, e_type, e_value, e_trace):
-        if e_type:
-            raise e_type(e_value)
-        sys.stdout = sys.__stdout__
-        sys.stderr = sys.__stderr__
+    sys.stdout = sys.__stdout__ = prev_out
+    sys.stderr = sys.__stderr__ = prev_err