Explorar el Código

Merge pull request #1899 from celery/more-reliable-sqlalchemy

Make the database backend more reliable
Ionel Cristian Mărieș hace 11 años
padre
commit
257173166a
Se han modificado 3 ficheros con 95 adiciones y 80 borrados
  1. 6 0
      Changelog
  2. 43 31
      celery/backends/database/__init__.py
  3. 46 49
      celery/backends/database/session.py

+ 6 - 0
Changelog

@@ -19,6 +19,12 @@ new in Celery 3.1.
 
     - Now depends on :ref:`Kombu 3.0.14 <kombu:version-3.0.14>`.
 
+- **Results**:
+
+    Reliability improvements to the SQLAlchemy database backend. Previously the
+    connection from the MainProcess was improperly shared with the workers.
+    (Issue #1786)
+
 - **Redis:** Important note about events (Issue #1882).
 
     There is a new transport option for Redis that enables monitors

+ 43 - 31
celery/backends/database/__init__.py

@@ -8,17 +8,21 @@
 """
 from __future__ import absolute_import
 
+import logging
+from contextlib import contextmanager
 from functools import wraps
 
 from celery import states
+from celery.backends.base import BaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.five import range
 from celery.utils.timeutils import maybe_timedelta
 
-from celery.backends.base import BaseBackend
+from .models import Task
+from .models import TaskSet
+from .session import SessionManager
 
-from .models import Task, TaskSet
-from .session import ResultSession
+logger = logging.getLogger(__name__)
 
 __all__ = ['DatabaseBackend']
 
@@ -33,7 +37,19 @@ def _sqlalchemy_installed():
     return sqlalchemy
 _sqlalchemy_installed()
 
-from sqlalchemy.exc import DatabaseError, OperationalError
+from sqlalchemy.exc import DatabaseError, OperationalError, ResourceClosedError, InvalidRequestError, IntegrityError
+from sqlalchemy.orm.exc import StaleDataError
+
+
+@contextmanager
+def session_cleanup(session):
+    try:
+        yield
+    except Exception:
+        session.rollback()
+        raise
+    finally:
+        session.close()
 
 
 def retry(fun):
@@ -45,7 +61,15 @@ def retry(fun):
         for retries in range(max_retries):
             try:
                 return fun(*args, **kwargs)
-            except (DatabaseError, OperationalError):
+            except (
+                DatabaseError, OperationalError, ResourceClosedError, StaleDataError, InvalidRequestError,
+                IntegrityError
+            ):
+                logger.warning(
+                    "Failed operation %s. Retrying %s more times.",
+                    fun.__name__, max_retries - retries - 1,
+                    exc_info=True,
+                )
                 if retries + 1 >= max_retries:
                     raise
 
@@ -83,8 +107,8 @@ class DatabaseBackend(BaseBackend):
                 'Missing connection string! Do you have '
                 'CELERY_RESULT_DBURI set to a real value?')
 
-    def ResultSession(self):
-        return ResultSession(
+    def ResultSession(self, session_manager=SessionManager()):
+        return session_manager.session_factory(
             dburi=self.dburi,
             short_lived_sessions=self.short_lived_sessions,
             **self.engine_options
@@ -95,8 +119,9 @@ class DatabaseBackend(BaseBackend):
                       traceback=None, max_retries=3, **kwargs):
         """Store return value and status of an executed task."""
         session = self.ResultSession()
-        try:
-            task = session.query(Task).filter(Task.task_id == task_id).first()
+        with session_cleanup(session):
+            task = list(session.query(Task).filter(Task.task_id == task_id))
+            task = task and task[0]
             if not task:
                 task = Task(task_id)
                 session.add(task)
@@ -106,83 +131,70 @@ class DatabaseBackend(BaseBackend):
             task.traceback = traceback
             session.commit()
             return result
-        finally:
-            session.close()
 
     @retry
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         session = self.ResultSession()
-        try:
-            task = session.query(Task).filter(Task.task_id == task_id).first()
-            if task is None:
+        with session_cleanup(session):
+            task = list(session.query(Task).filter(Task.task_id == task_id))
+            task = task and task[0]
+            if not task:
                 task = Task(task_id)
                 task.status = states.PENDING
                 task.result = None
             return task.to_dict()
-        finally:
-            session.close()
 
     @retry
     def _save_group(self, group_id, result):
         """Store the result of an executed group."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             group = TaskSet(group_id, result)
             session.add(group)
             session.flush()
             session.commit()
             return result
-        finally:
-            session.close()
 
     @retry
     def _restore_group(self, group_id):
         """Get metadata for group by id."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             group = session.query(TaskSet).filter(
                 TaskSet.taskset_id == group_id).first()
             if group:
                 return group.to_dict()
-        finally:
-            session.close()
 
     @retry
     def _delete_group(self, group_id):
         """Delete metadata for group by id."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             session.query(TaskSet).filter(
                 TaskSet.taskset_id == group_id).delete()
             session.flush()
             session.commit()
-        finally:
-            session.close()
 
     @retry
     def _forget(self, task_id):
         """Forget about result."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             session.query(Task).filter(Task.task_id == task_id).delete()
             session.commit()
-        finally:
-            session.close()
 
     def cleanup(self):
         """Delete expired metadata."""
         session = self.ResultSession()
         expires = self.expires
         now = self.app.now()
-        try:
+        with session_cleanup(session):
             session.query(Task).filter(
                 Task.date_done < (now - expires)).delete()
             session.query(TaskSet).filter(
                 TaskSet.date_done < (now - expires)).delete()
             session.commit()
-        finally:
-            session.close()
 
     def __reduce__(self, args=(), kwargs={}):
         kwargs.update(

+ 46 - 49
celery/backends/database/session.py

@@ -8,58 +8,55 @@
 """
 from __future__ import absolute_import
 
-from collections import defaultdict
-from multiprocessing.util import register_after_fork
+from billiard.util import register_after_fork
 
 from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import NullPool
 
 ResultModelBase = declarative_base()
 
-_SETUP = defaultdict(lambda: False)
-_ENGINES = {}
-_SESSIONS = {}
-
-__all__ = ['ResultSession', 'get_engine', 'create_session']
-
-
-class _after_fork(object):
-    registered = False
-
-    def __call__(self):
-        self.registered = False  # child must reregister
-        for engine in list(_ENGINES.values()):
-            engine.dispose()
-        _ENGINES.clear()
-        _SESSIONS.clear()
-after_fork = _after_fork()
-
-
-def get_engine(dburi, **kwargs):
-    try:
-        return _ENGINES[dburi]
-    except KeyError:
-        engine = _ENGINES[dburi] = create_engine(dburi, **kwargs)
-        after_fork.registered = True
-        register_after_fork(after_fork, after_fork)
-        return engine
-
-
-def create_session(dburi, short_lived_sessions=False, **kwargs):
-    engine = get_engine(dburi, **kwargs)
-    if short_lived_sessions or dburi not in _SESSIONS:
-        _SESSIONS[dburi] = sessionmaker(bind=engine)
-    return engine, _SESSIONS[dburi]
-
-
-def setup_results(engine):
-    if not _SETUP['results']:
-        ResultModelBase.metadata.create_all(engine)
-        _SETUP['results'] = True
-
-
-def ResultSession(dburi, **kwargs):
-    engine, session = create_session(dburi, **kwargs)
-    setup_results(engine)
-    return session()
+__all__ = ['SessionManager']
+
+
+class SessionManager(object):
+    def __init__(self):
+        self._engines = {}
+        self._sessions = {}
+        self.forked = False
+        self.prepared = False
+        register_after_fork(self, self._after_fork)
+
+    def _after_fork(self,):
+        self.forked = True
+
+    def get_engine(self, dburi, **kwargs):
+        if self.forked:
+            try:
+                return self._engines[dburi]
+            except KeyError:
+                engine = self._engines[dburi] = create_engine(dburi, **kwargs)
+                return engine
+        else:
+            kwargs['poolclass'] = NullPool
+            return create_engine(dburi, **kwargs)
+
+    def create_session(self, dburi, short_lived_sessions=False, **kwargs):
+        engine = self.get_engine(dburi, **kwargs)
+        if self.forked:
+            if short_lived_sessions or dburi not in self._sessions:
+                self._sessions[dburi] = sessionmaker(bind=engine)
+            return engine, self._sessions[dburi]
+        else:
+            return engine, sessionmaker(bind=engine)
+
+    def prepare_models(self, engine):
+        if not self.prepared:
+            ResultModelBase.metadata.create_all(engine)
+            self.prepared = True
+
+    def session_factory(self, dburi, **kwargs):
+        engine, session = self.create_session(dburi, **kwargs)
+        self.prepare_models(engine)
+        return session()