Browse Source

Fixes a bug with using kombu pools and after forkers and cleans up after forkers

Ask Solem 9 years ago
parent
commit
6dac87f0f6
3 changed files with 54 additions and 73 deletions
  1. 27 35
      celery/app/base.py
  2. 8 7
      celery/backends/database/session.py
  3. 19 31
      celery/tests/app/test_app.py

+ 27 - 35
celery/app/base.py

@@ -20,7 +20,7 @@ from amqp import starpromise
 from kombu import pools
 from kombu.clocks import LamportClock
 from kombu.common import oid_from
-from kombu.utils import cached_property, uuid
+from kombu.utils import cached_property, register_after_fork, uuid
 
 from celery import platforms
 from celery import signals
@@ -40,6 +40,7 @@ from celery.utils import gen_task_name
 from celery.utils.dispatch import Signal
 from celery.utils.functional import first, maybe_list, head_from_fun
 from celery.utils.imports import instantiate, symbol_by_name
+from celery.utils.log import get_logger
 from celery.utils.objects import FallbackContext, mro_lookup
 
 from .annotations import prepare as prepare_annotations
@@ -53,13 +54,10 @@ from .utils import (
 # Load all builtin tasks
 from . import builtins  # noqa
 
-try:
-    from billiard.util import register_after_fork
-except ImportError:  # pragma: no cover
-    register_after_fork = None
-
 __all__ = ['Celery']
 
+logger = get_logger(__name__)
+
 _EXECV = os.environ.get('FORKED_BY_MULTIPROCESSING')
 BUILTIN_FIXUPS = {
     'celery.fixups.django:fixup',
@@ -71,8 +69,6 @@ and as such the configuration could not be loaded.
 Please set this variable and make it point to
 a configuration module."""
 
-_after_fork_registered = False
-
 
 def app_has_custom(app, attr):
     return mro_lookup(app.__class__, attr, stop=(Celery, object),
@@ -85,30 +81,11 @@ def _unpickle_appattr(reverse_name, args):
     return get_current_app()._rgetattr(reverse_name)(*args)
 
 
-def _global_after_fork(obj):
-    # Previously every app would call:
-    #    `register_after_fork(app, app._after_fork)`
-    # but this created a leak as `register_after_fork` stores concrete object
-    # references and once registered an object cannot be removed without
-    # touching and iterating over the private afterfork registry list.
-    #
-    # See Issue #1949
-    from celery import _state
-    from multiprocessing import util as mputil
-    for app in _state._apps:
-        try:
-            app._after_fork(obj)
-        except Exception as exc:
-            if mputil._logger:
-                mputil._logger.info(
-                    'after forker raised exception: %r', exc, exc_info=1)
-
-
-def _ensure_after_fork():
-    global _after_fork_registered
-    _after_fork_registered = True
-    if register_after_fork is not None:
-        register_after_fork(_global_after_fork, _global_after_fork)
+def _after_fork_cleanup_app(app):
+    try:
+        app._after_fork()
+    except Exception as exc:
+        logger.info('after forker raised exception: %r', exc, exc_info=1)
 
 
 class PendingConfiguration(UserDict, AttributeDictMixin):
@@ -180,6 +157,7 @@ class Celery(object):
     _pool = None
     _conf = None
     builtin_fixups = BUILTIN_FIXUPS
+    _after_fork_registered = False
 
     #: Signal sent when app is loading configuration.
     on_configure = None
@@ -190,6 +168,9 @@ class Celery(object):
     #: Signal sent after app has been finalized.
     on_after_finalize = None
 
+    #: Signal sent by every new process after fork.
+    on_after_fork = None
+
     def __init__(self, main=None, loader=None, backend=None,
                  amqp=None, events=None, log=None, control=None,
                  set_as_current=True, tasks=None, broker=None, include=None,
@@ -254,6 +235,7 @@ class Celery(object):
             self.on_configure = Signal()
         self.on_after_configure = Signal()
         self.on_after_finalize = Signal()
+        self.on_after_fork = Signal()
 
         self.on_init()
         _register_app(self)
@@ -271,6 +253,12 @@ class Celery(object):
         """Makes this the default app for all threads."""
         set_default_app(self)
 
+    def _ensure_after_fork(self):
+        if not self._after_fork_registered:
+            self._after_fork_registered = True
+            if register_after_fork is not None:
+                register_after_fork(self, _after_fork_cleanup_app)
+
     def __enter__(self):
         return self
 
@@ -828,9 +816,13 @@ class Celery(object):
         self.on_after_configure.send(sender=self, source=self._conf)
         return self._conf
 
-    def _after_fork(self, obj_):
+    def _after_fork(self):
         self._pool = None
-        pools.reset()
+        try:
+            self.__dict__['amqp']._producer_pool = None
+        except (AttributeError, KeyError):
+            pass
+        self.on_after_fork.send(sender=self)
 
     def signature(self, *args, **kwargs):
         """Return a new :class:`~celery.canvas.Signature` bound to this app.
@@ -1007,7 +999,7 @@ class Celery(object):
 
         """
         if self._pool is None:
-            _ensure_after_fork()
+            self._ensure_after_fork()
             limit = self.conf.broker_pool_limit
             pools.set_limit(limit)
             self._pool = pools.connections[self.connection()]

+ 8 - 7
celery/backends/database/session.py

@@ -8,21 +8,22 @@
 """
 from __future__ import absolute_import
 
-try:
-    from billiard.util import register_after_fork
-except ImportError:  # pragma: no cover
-    register_after_fork = None
-
 from sqlalchemy import create_engine
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.pool import NullPool
 
+from kombu.utils import register_after_fork
+
 ResultModelBase = declarative_base()
 
 __all__ = ['SessionManager']
 
 
+def _after_fork_cleanup_session(session):
+    session._after_fork()
+
+
 class SessionManager(object):
 
     def __init__(self):
@@ -31,9 +32,9 @@ class SessionManager(object):
         self.forked = False
         self.prepared = False
         if register_after_fork is not None:
-            register_after_fork(self, self._after_fork)
+            register_after_fork(self, _after_fork_cleanup_session)
 
-    def _after_fork(self,):
+    def _after_fork(self):
         self.forked = True
 
     def get_engine(self, dburi, **kwargs):

+ 19 - 31
celery/tests/app/test_app.py

@@ -776,46 +776,34 @@ class test_App(AppCase):
             my_failover_strategy,
         )
 
-    @patch('kombu.pools.reset')
-    def test_after_fork(self, reset):
+    def test_after_fork(self):
         self.app._pool = Mock()
-        self.app._after_fork(self.app)
+        self.app.on_after_fork = Mock(name='on_after_fork')
+        self.app._after_fork()
         self.assertIsNone(self.app._pool)
-        reset.assert_called_with()
-        self.app._after_fork(self.app)
+        self.app.on_after_fork.send.assert_called_with(sender=self.app)
+        self.app._after_fork()
 
     def test_global_after_fork(self):
-        app = Mock(name='app')
-        prev, _state._apps = _state._apps, [app]
-        try:
-            obj = Mock(name='obj')
-            _appbase._global_after_fork(obj)
-            app._after_fork.assert_called_with(obj)
-        finally:
-            _state._apps = prev
-
-    @patch('multiprocessing.util', create=True)
-    def test_global_after_fork__raises(self, util):
-        app = Mock(name='app')
-        prev, _state._apps = _state._apps, [app]
-        try:
-            obj = Mock(name='obj')
-            exc = app._after_fork.side_effect = KeyError()
-            _appbase._global_after_fork(obj)
-            util._logger.info.assert_called_with(
-                'after forker raised exception: %r', exc, exc_info=1)
-            util._logger = None
-            _appbase._global_after_fork(obj)
-        finally:
-            _state._apps = prev
+        self.app._after_fork = Mock(name='_after_fork')
+        _appbase._after_fork_cleanup_app(self.app)
+        self.app._after_fork.assert_called_with()
+
+    @patch('celery.app.base.logger')
+    def test_after_fork_cleanup_app__raises(self, logger):
+        self.app._after_fork = Mock(name='_after_fork')
+        exc = self.app._after_fork.side_effect = KeyError()
+        _appbase._after_fork_cleanup_app(self.app)
+        logger.info.assert_called_with(
+            'after forker raised exception: %r', exc, exc_info=1)
 
     def test_ensure_after_fork__no_multiprocessing(self):
         prev, _appbase.register_after_fork = (
             _appbase.register_after_fork, None)
         try:
-            _appbase._after_fork_registered = False
-            _appbase._ensure_after_fork()
-            self.assertTrue(_appbase._after_fork_registered)
+            self.app._after_fork_registered = False
+            self.app._ensure_after_fork()
+            self.assertTrue(self.app._after_fork_registered)
         finally:
             _appbase.register_after_fork = prev