瀏覽代碼

get_logger/get_task_logger now makes sure the right parent is set.

This doesn't fix #1404 but seems appropriate to include anyway.
Ask Solem 11 年之前
父節點
當前提交
3c18e76d00
共有 3 個文件被更改,包括 67 次插入3 次删除
  1. 0 1
      celery/app/__init__.py
  2. 49 0
      celery/tests/app/test_log.py
  3. 18 2
      celery/utils/log.py

+ 0 - 1
celery/app/__init__.py

@@ -57,7 +57,6 @@ push_current_task = _task_stack.push
 pop_current_task = _task_stack.pop
 
 
-
 def bugreport(app=None):
     return (app or current_app()).bugreport()
 

+ 49 - 0
celery/tests/app/test_log.py

@@ -16,7 +16,9 @@ from celery.utils.log import (
     ColorFormatter,
     logger as base_logger,
     get_task_logger,
+    task_logger,
     in_sighandler,
+    logger_isa,
     _patch_logger_class,
 )
 from celery.tests.case import (
@@ -43,6 +45,53 @@ class test_TaskFormatter(AppCase):
         self.assertEqual(record.task_id, '???')
 
 
+class test_logger_isa(AppCase):
+
+    def test_isa(self):
+        x = get_task_logger('Z1george')
+        self.assertTrue(logger_isa(x, task_logger))
+        prev_x, x.parent = x.parent, None
+        try:
+            self.assertFalse(logger_isa(x, task_logger))
+        finally:
+            x.parent = prev_x
+
+        y = get_task_logger('Z1elaine')
+        y.parent = x
+        self.assertTrue(logger_isa(y, task_logger))
+        self.assertTrue(logger_isa(y, x))
+        self.assertTrue(logger_isa(y, y))
+
+        z = get_task_logger('Z1jerry')
+        z.parent = y
+        self.assertTrue(logger_isa(z, task_logger))
+        self.assertTrue(logger_isa(z, y))
+        self.assertTrue(logger_isa(z, x))
+        self.assertTrue(logger_isa(z, z))
+
+    def test_recursive(self):
+        x = get_task_logger('X1foo')
+        prev, x.parent = x.parent, x
+        try:
+            with self.assertRaises(RuntimeError):
+                logger_isa(x, task_logger)
+        finally:
+            x.parent = prev
+
+        y = get_task_logger('X2foo')
+        z = get_task_logger('X2foo')
+        prev_y, y.parent = y.parent, z
+        try:
+            prev_z, z.parent = z.parent, y
+            try:
+                with self.assertRaises(RuntimeError):
+                    logger_isa(y, task_logger)
+            finally:
+                z.parent = prev_z
+        finally:
+            y.parent = prev_y
+
+
 class test_ColorFormatter(AppCase):
 
     @patch('celery.utils.log.safe_str')

+ 18 - 2
celery/utils/log.py

@@ -59,10 +59,26 @@ def in_sighandler():
         set_in_sighandler(False)
 
 
+def logger_isa(l, p):
+    this, seen = l, set()
+    while this:
+        if this == p:
+            return True
+        else:
+            if this in seen:
+                raise RuntimeError(
+                    'Logger {0!r} parents recursive'.format(l),
+                )
+            seen.add(this)
+            this = this.parent
+    return False
+
+
 def get_logger(name):
     l = _get_logger(name)
     if logging.root not in (l, l.parent) and l is not base_logger:
-        l.parent = base_logger
+        if not logger_isa(l, base_logger):
+            l.parent = base_logger
     return l
 task_logger = get_logger('celery.task')
 worker_logger = get_logger('celery.worker')
@@ -70,7 +86,7 @@ worker_logger = get_logger('celery.worker')
 
 def get_task_logger(name):
     logger = get_logger(name)
-    if logger.parent is logging.root:
+    if not logger_isa(logger, task_logger):
         logger.parent = task_logger
     return logger