Explorar el Código

Make the database backend retry operations on ResourceClosedError and StaleDataError too. Make the operations close the connection if failure occurs (can't retry on broken connection). Fixes #1786.

Ionel Cristian Mărieș hace 11 años
padre
commit
f333abc091
Se han modificado 1 ficheros con 40 adiciones y 29 borrados
  1. 40 29
      celery/backends/database/__init__.py

+ 40 - 29
celery/backends/database/__init__.py

@@ -8,18 +8,22 @@
 """
 from __future__ import absolute_import
 
+import logging
+from contextlib import contextmanager
 from functools import wraps
 
 from celery import states
+from celery.backends.base import BaseBackend
 from celery.exceptions import ImproperlyConfigured
 from celery.five import range
 from celery.utils.timeutils import maybe_timedelta
 
-from celery.backends.base import BaseBackend
-
-from .models import Task, TaskSet
+from .models import Task
+from .models import TaskSet
 from .session import ResultSession
 
+logger = logging.getLogger(__name__)
+
 __all__ = ['DatabaseBackend']
 
 
@@ -33,7 +37,21 @@ def _sqlalchemy_installed():
     return sqlalchemy
 _sqlalchemy_installed()
 
-from sqlalchemy.exc import DatabaseError, OperationalError
+from sqlalchemy.exc import DatabaseError, OperationalError, ResourceClosedError
+from sqlalchemy.orm.exc import StaleDataError
+
+
+@contextmanager
+def session_cleanup(session):
+    try:
+        yield
+    except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
+        session.rollback()
+        session.connection().invalidate()
+        session.close()
+        raise
+    else:
+        session.close()
 
 
 def retry(fun):
@@ -45,7 +63,12 @@ def retry(fun):
         for retries in range(max_retries):
             try:
                 return fun(*args, **kwargs)
-            except (DatabaseError, OperationalError):
+            except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
+                logger.critical(
+                    "Failed operation %s. Retrying %s more times.",
+                    fun.__name__, max_retries - retries - 1,
+                    exc_info=True,
+                )
                 if retries + 1 >= max_retries:
                     raise
 
@@ -95,8 +118,9 @@ class DatabaseBackend(BaseBackend):
                       traceback=None, max_retries=3, **kwargs):
         """Store return value and status of an executed task."""
         session = self.ResultSession()
-        try:
-            task = session.query(Task).filter(Task.task_id == task_id).first()
+        with session_cleanup(session):
+            task = list(session.query(Task).filter(Task.task_id == task_id))
+            task = task and task[0]
             if not task:
                 task = Task(task_id)
                 session.add(task)
@@ -106,83 +130,70 @@ class DatabaseBackend(BaseBackend):
             task.traceback = traceback
             session.commit()
             return result
-        finally:
-            session.close()
 
     @retry
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
         session = self.ResultSession()
-        try:
-            task = session.query(Task).filter(Task.task_id == task_id).first()
-            if task is None:
+        with session_cleanup(session):
+            task = list(session.query(Task).filter(Task.task_id == task_id))
+            task = task and task[0]
+            if not task:
                 task = Task(task_id)
                 task.status = states.PENDING
                 task.result = None
             return task.to_dict()
-        finally:
-            session.close()
 
     @retry
     def _save_group(self, group_id, result):
         """Store the result of an executed group."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             group = TaskSet(group_id, result)
             session.add(group)
             session.flush()
             session.commit()
             return result
-        finally:
-            session.close()
 
     @retry
     def _restore_group(self, group_id):
         """Get metadata for group by id."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             group = session.query(TaskSet).filter(
                 TaskSet.taskset_id == group_id).first()
             if group:
                 return group.to_dict()
-        finally:
-            session.close()
 
     @retry
     def _delete_group(self, group_id):
         """Delete metadata for group by id."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             session.query(TaskSet).filter(
                 TaskSet.taskset_id == group_id).delete()
             session.flush()
             session.commit()
-        finally:
-            session.close()
 
     @retry
     def _forget(self, task_id):
         """Forget about result."""
         session = self.ResultSession()
-        try:
+        with session_cleanup(session):
             session.query(Task).filter(Task.task_id == task_id).delete()
             session.commit()
-        finally:
-            session.close()
 
     def cleanup(self):
         """Delete expired metadata."""
         session = self.ResultSession()
         expires = self.expires
         now = self.app.now()
-        try:
+        with session_cleanup(session):
             session.query(Task).filter(
                 Task.date_done < (now - expires)).delete()
             session.query(TaskSet).filter(
                 TaskSet.date_done < (now - expires)).delete()
             session.commit()
-        finally:
-            session.close()
 
     def __reduce__(self, args=(), kwargs={}):
         kwargs.update(