|  | @@ -16,6 +16,8 @@ from celery.utils.log import (
 | 
											
												
													
														|  |      ColorFormatter,
 |  |      ColorFormatter,
 | 
											
												
													
														|  |      logger as base_logger,
 |  |      logger as base_logger,
 | 
											
												
													
														|  |      get_task_logger,
 |  |      get_task_logger,
 | 
											
												
													
														|  | 
 |  | +    in_sighandler,
 | 
											
												
													
														|  | 
 |  | +    _patch_logger_class,
 | 
											
												
													
														|  |  )
 |  |  )
 | 
											
												
													
														|  |  from celery.tests.case import (
 |  |  from celery.tests.case import (
 | 
											
												
													
														|  |      AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 |  |      AppCase, Case, override_stdouts, wrap_logger, get_handlers,
 | 
											
										
											
												
													
														|  | @@ -61,6 +63,15 @@ class test_ColorFormatter(Case):
 | 
											
												
													
														|  |          if sys.version_info[0] == 2:
 |  |          if sys.version_info[0] == 2:
 | 
											
												
													
														|  |              self.assertTrue(safe_str.called)
 |  |              self.assertTrue(safe_str.called)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +    @patch('logging.Formatter.format')
 | 
											
												
													
														|  | 
 |  | +    def test_format_object(self, _format):
 | 
											
												
													
														|  | 
 |  | +        x = ColorFormatter(object())
 | 
											
												
													
														|  | 
 |  | +        x.use_color = True
 | 
											
												
													
														|  | 
 |  | +        record = Mock()
 | 
											
												
													
														|  | 
 |  | +        record.levelname = 'ERROR'
 | 
											
												
													
														|  | 
 |  | +        record.msg = object()
 | 
											
												
													
														|  | 
 |  | +        self.assertTrue(x.format(record))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      @patch('celery.utils.log.safe_str')
 |  |      @patch('celery.utils.log.safe_str')
 | 
											
												
													
														|  |      def test_format_raises(self, safe_str):
 |  |      def test_format_raises(self, safe_str):
 | 
											
												
													
														|  |          x = ColorFormatter('HELLO')
 |  |          x = ColorFormatter('HELLO')
 | 
											
										
											
												
													
														|  | @@ -231,6 +242,11 @@ class test_default_logger(AppCase):
 | 
											
												
													
														|  |              p.close()
 |  |              p.close()
 | 
											
												
													
														|  |              self.assertFalse(p.isatty())
 |  |              self.assertFalse(p.isatty())
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +            with override_stdouts() as (stdout, stderr):
 | 
											
												
													
														|  | 
 |  | +                with in_sighandler():
 | 
											
												
													
														|  | 
 |  | +                    p.write('foo')
 | 
											
												
													
														|  | 
 |  | +                    self.assertTrue(stderr.getvalue())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      def test_logging_proxy_recurse_protection(self):
 |  |      def test_logging_proxy_recurse_protection(self):
 | 
											
												
													
														|  |          logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
 |  |          logger = self.setup_logger(loglevel=logging.ERROR, logfile=None,
 | 
											
												
													
														|  |                                     root=False)
 |  |                                     root=False)
 | 
											
										
											
												
													
														|  | @@ -269,6 +285,18 @@ class test_task_logger(test_default_logger):
 | 
											
												
													
														|  |          return get_task_logger("test_task_logger")
 |  |          return get_task_logger("test_task_logger")
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +class test_patch_logger_cls(Case):
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def test_patches(self):
 | 
											
												
													
														|  | 
 |  | +        _patch_logger_class()
 | 
											
												
													
														|  | 
 |  | +        self.assertTrue(logging.getLoggerClass()._signal_safe)
 | 
											
												
													
														|  | 
 |  | +        _patch_logger_class()
 | 
											
												
													
														|  | 
 |  | +        self.assertTrue(logging.getLoggerClass()._signal_safe)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        with in_sighandler():
 | 
											
												
													
														|  | 
 |  | +            logging.getLoggerClass().log(get_logger('test'))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  class MockLogger(logging.Logger):
 |  |  class MockLogger(logging.Logger):
 | 
											
												
													
														|  |      _records = None
 |  |      _records = None
 | 
											
												
													
														|  |  
 |  |  
 |