|
@@ -9,6 +9,7 @@
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
from collections import defaultdict
|
|
|
+from multiprocessing.util import register_after_fork
|
|
|
|
|
|
from sqlalchemy import create_engine
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
@@ -23,16 +24,32 @@ _SESSIONS = {}
|
|
|
__all__ = ['ResultSession', 'get_engine', 'create_session']
|
|
|
|
|
|
|
|
|
+class _after_fork(object):
|
|
|
+ registered = False
|
|
|
+
|
|
|
+ def __call__(self):
|
|
|
+ self.registered = False # child must reregister
|
|
|
+ for engine in list(_ENGINES.values()):
|
|
|
+ engine.dispose()
|
|
|
+ _ENGINES.clear()
|
|
|
+ _SESSIONS.clear()
|
|
|
+after_fork = _after_fork()
|
|
|
+
|
|
|
+
|
|
|
def get_engine(dburi, **kwargs):
|
|
|
- if dburi not in _ENGINES:
|
|
|
- _ENGINES[dburi] = create_engine(dburi, **kwargs)
|
|
|
- return _ENGINES[dburi]
|
|
|
+ try:
|
|
|
+ return _ENGINES[dburi]
|
|
|
+ except KeyError:
|
|
|
+ engine = _ENGINES[dburi] = create_engine(dburi, **kwargs)
|
|
|
+ after_fork.registered = True
|
|
|
+ register_after_fork(after_fork, after_fork)
|
|
|
+ return engine
|
|
|
|
|
|
|
|
|
def create_session(dburi, short_lived_sessions=False, **kwargs):
|
|
|
engine = get_engine(dburi, **kwargs)
|
|
|
if short_lived_sessions or dburi not in _SESSIONS:
|
|
|
- _SESSIONS[dburi] = sessionmaker(bind=engine)
|
|
|
+ session = _SESSIONS[dburi] = sessionmaker(bind=engine)
|
|
|
return engine, _SESSIONS[dburi]
|
|
|
|
|
|
|