|
@@ -8,18 +8,22 @@
|
|
|
"""
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
+import logging
|
|
|
+from contextlib import contextmanager
|
|
|
from functools import wraps
|
|
|
|
|
|
from celery import states
|
|
|
+from celery.backends.base import BaseBackend
|
|
|
from celery.exceptions import ImproperlyConfigured
|
|
|
from celery.five import range
|
|
|
from celery.utils.timeutils import maybe_timedelta
|
|
|
|
|
|
-from celery.backends.base import BaseBackend
|
|
|
-
|
|
|
-from .models import Task, TaskSet
|
|
|
+from .models import Task
|
|
|
+from .models import TaskSet
|
|
|
from .session import ResultSession
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
__all__ = ['DatabaseBackend']
|
|
|
|
|
|
|
|
@@ -33,7 +37,21 @@ def _sqlalchemy_installed():
|
|
|
return sqlalchemy
|
|
|
_sqlalchemy_installed()
|
|
|
|
|
|
-from sqlalchemy.exc import DatabaseError, OperationalError
|
|
|
+from sqlalchemy.exc import DatabaseError, OperationalError, ResourceClosedError
|
|
|
+from sqlalchemy.orm.exc import StaleDataError
|
|
|
+
|
|
|
+
|
|
|
+@contextmanager
|
|
|
+def session_cleanup(session):
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
|
|
|
+ session.rollback()
|
|
|
+ session.connection().invalidate()
|
|
|
+ session.close()
|
|
|
+ raise
|
|
|
+ else:
|
|
|
+ session.close()
|
|
|
|
|
|
|
|
|
def retry(fun):
|
|
@@ -45,7 +63,12 @@ def retry(fun):
|
|
|
for retries in range(max_retries):
|
|
|
try:
|
|
|
return fun(*args, **kwargs)
|
|
|
- except (DatabaseError, OperationalError):
|
|
|
+ except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
|
|
|
+ logger.critical(
|
|
|
+ "Failed operation %s. Retrying %s more times.",
|
|
|
+ fun.__name__, max_retries - retries - 1,
|
|
|
+ exc_info=True,
|
|
|
+ )
|
|
|
if retries + 1 >= max_retries:
|
|
|
raise
|
|
|
|
|
@@ -95,8 +118,9 @@ class DatabaseBackend(BaseBackend):
|
|
|
traceback=None, max_retries=3, **kwargs):
|
|
|
"""Store return value and status of an executed task."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
- task = session.query(Task).filter(Task.task_id == task_id).first()
|
|
|
+ with session_cleanup(session):
|
|
|
+ task = list(session.query(Task).filter(Task.task_id == task_id))
|
|
|
+ task = task and task[0]
|
|
|
if not task:
|
|
|
task = Task(task_id)
|
|
|
session.add(task)
|
|
@@ -106,83 +130,70 @@ class DatabaseBackend(BaseBackend):
|
|
|
task.traceback = traceback
|
|
|
session.commit()
|
|
|
return result
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
@retry
|
|
|
def _get_task_meta_for(self, task_id):
|
|
|
"""Get task metadata for a task by id."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
- task = session.query(Task).filter(Task.task_id == task_id).first()
|
|
|
- if task is None:
|
|
|
+ with session_cleanup(session):
|
|
|
+ task = list(session.query(Task).filter(Task.task_id == task_id))
|
|
|
+ task = task and task[0]
|
|
|
+ if not task:
|
|
|
task = Task(task_id)
|
|
|
task.status = states.PENDING
|
|
|
task.result = None
|
|
|
return task.to_dict()
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
@retry
|
|
|
def _save_group(self, group_id, result):
|
|
|
"""Store the result of an executed group."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
+ with session_cleanup(session):
|
|
|
group = TaskSet(group_id, result)
|
|
|
session.add(group)
|
|
|
session.flush()
|
|
|
session.commit()
|
|
|
return result
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
@retry
|
|
|
def _restore_group(self, group_id):
|
|
|
"""Get metadata for group by id."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
+ with session_cleanup(session):
|
|
|
group = session.query(TaskSet).filter(
|
|
|
TaskSet.taskset_id == group_id).first()
|
|
|
if group:
|
|
|
return group.to_dict()
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
@retry
|
|
|
def _delete_group(self, group_id):
|
|
|
"""Delete metadata for group by id."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
+ with session_cleanup(session):
|
|
|
session.query(TaskSet).filter(
|
|
|
TaskSet.taskset_id == group_id).delete()
|
|
|
session.flush()
|
|
|
session.commit()
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
@retry
|
|
|
def _forget(self, task_id):
|
|
|
"""Forget about result."""
|
|
|
session = self.ResultSession()
|
|
|
- try:
|
|
|
+ with session_cleanup(session):
|
|
|
session.query(Task).filter(Task.task_id == task_id).delete()
|
|
|
session.commit()
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
def cleanup(self):
|
|
|
"""Delete expired metadata."""
|
|
|
session = self.ResultSession()
|
|
|
expires = self.expires
|
|
|
now = self.app.now()
|
|
|
- try:
|
|
|
+ with session_cleanup(session):
|
|
|
session.query(Task).filter(
|
|
|
Task.date_done < (now - expires)).delete()
|
|
|
session.query(TaskSet).filter(
|
|
|
TaskSet.date_done < (now - expires)).delete()
|
|
|
session.commit()
|
|
|
- finally:
|
|
|
- session.close()
|
|
|
|
|
|
def __reduce__(self, args=(), kwargs={}):
|
|
|
kwargs.update(
|