Parcourir la source

Can only patch logger class once

Ask Solem il y a 11 ans
Parent
commit
4d3c574759
2 fichiers modifiés avec 11 ajouts et 29 suppressions
  1. 2 6
      celery/tests/app/test_log.py
  2. 9 23
      celery/utils/log.py

+ 2 - 6
celery/tests/app/test_log.py

@@ -16,7 +16,7 @@ from celery.utils.log import (
     task_logger,
     in_sighandler,
     logger_isa,
-    _patch_logger_class,
+    ensure_process_aware_logger,
 )
 from celery.tests.case import (
     AppCase, Mock, SkipTest,
@@ -342,11 +342,7 @@ class test_task_logger(test_default_logger):
 class test_patch_logger_cls(AppCase):
 
     def test_patches(self):
-        _patch_logger_class()
-        self.assertTrue(logging.getLoggerClass()._signal_safe)
-        _patch_logger_class()
-        self.assertTrue(logging.getLoggerClass()._signal_safe)
-
+        ensure_process_aware_logger()
         with in_sighandler():
             logging.getLoggerClass().log(get_logger('test'))
 

+ 9 - 23
celery/utils/log.py

@@ -245,10 +245,10 @@ class LoggingProxy(object):
         return False
 
 
-def ensure_process_aware_logger():
+def ensure_process_aware_logger(force=False):
     """Make sure process name is recorded when loggers are used."""
     global _process_aware
-    if not _process_aware:
+    if force or not _process_aware:
         logging._acquireLock()
         try:
             _process_aware = True
@@ -257,12 +257,18 @@ def ensure_process_aware_logger():
                 return
 
             class ProcessAwareLogger(Logger):
+                _signal_safe = True
                 _process_aware = True
 
                 def makeRecord(self, *args, **kwds):
                     record = Logger.makeRecord(self, *args, **kwds)
                     record.processName = current_process()._name
                     return record
+
+                def log(self, *args, **kwargs):
+                    if _in_sighandler:
+                        return
+                    return Logger.log(self, *args, **kwargs)
             logging.setLoggerClass(ProcessAwareLogger)
         finally:
             logging._releaseLock()
@@ -275,24 +281,4 @@ def get_multiprocessing_logger():
 def reset_multiprocessing_logger():
     if mputil and hasattr(mputil, '_logger'):
         mputil._logger = None
-
-
-def _patch_logger_class():
-    """Make sure loggers don't log while in a signal handler."""
-
-    logging._acquireLock()
-    try:
-        OldLoggerClass = logging.getLoggerClass()
-        if not getattr(OldLoggerClass, '_signal_safe', False):
-
-            class SigSafeLogger(OldLoggerClass):
-                _signal_safe = True
-
-                def log(self, *args, **kwargs):
-                    if _in_sighandler:
-                        return
-                    return OldLoggerClass.log(self, *args, **kwargs)
-            logging.setLoggerClass(SigSafeLogger)
-    finally:
-        logging._releaseLock()
-_patch_logger_class()
+ensure_process_aware_logger()