|  | @@ -8,18 +8,22 @@
 | 
	
		
			
				|  |  |  """
 | 
	
		
			
				|  |  |  from __future__ import absolute_import
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  | +from contextlib import contextmanager
 | 
	
		
			
				|  |  |  from functools import wraps
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from celery import states
 | 
	
		
			
				|  |  | +from celery.backends.base import BaseBackend
 | 
	
		
			
				|  |  |  from celery.exceptions import ImproperlyConfigured
 | 
	
		
			
				|  |  |  from celery.five import range
 | 
	
		
			
				|  |  |  from celery.utils.timeutils import maybe_timedelta
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from celery.backends.base import BaseBackend
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -from .models import Task, TaskSet
 | 
	
		
			
				|  |  | +from .models import Task
 | 
	
		
			
				|  |  | +from .models import TaskSet
 | 
	
		
			
				|  |  |  from .session import ResultSession
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  __all__ = ['DatabaseBackend']
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -33,7 +37,21 @@ def _sqlalchemy_installed():
 | 
	
		
			
				|  |  |      return sqlalchemy
 | 
	
		
			
				|  |  |  _sqlalchemy_installed()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from sqlalchemy.exc import DatabaseError, OperationalError
 | 
	
		
			
				|  |  | +from sqlalchemy.exc import DatabaseError, OperationalError, ResourceClosedError
 | 
	
		
			
				|  |  | +from sqlalchemy.orm.exc import StaleDataError
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@contextmanager
 | 
	
		
			
				|  |  | +def session_cleanup(session):
 | 
	
		
			
				|  |  | +    try:
 | 
	
		
			
				|  |  | +        yield
 | 
	
		
			
				|  |  | +    except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
 | 
	
		
			
				|  |  | +        session.rollback()
 | 
	
		
			
				|  |  | +        session.connection().invalidate()
 | 
	
		
			
				|  |  | +        session.close()
 | 
	
		
			
				|  |  | +        raise
 | 
	
		
			
				|  |  | +    else:
 | 
	
		
			
				|  |  | +        session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  def retry(fun):
 | 
	
	
		
			
				|  | @@ -45,7 +63,12 @@ def retry(fun):
 | 
	
		
			
				|  |  |          for retries in range(max_retries):
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  |                  return fun(*args, **kwargs)
 | 
	
		
			
				|  |  | -            except (DatabaseError, OperationalError):
 | 
	
		
			
				|  |  | +            except (DatabaseError, OperationalError, ResourceClosedError, StaleDataError):
 | 
	
		
			
				|  |  | +                logger.critical(
 | 
	
		
			
				|  |  | +                    "Failed operation %s. Retrying %s more times.",
 | 
	
		
			
				|  |  | +                    fun.__name__, max_retries - retries - 1,
 | 
	
		
			
				|  |  | +                    exc_info=True,
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  |                  if retries + 1 >= max_retries:
 | 
	
		
			
				|  |  |                      raise
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -95,8 +118,9 @@ class DatabaseBackend(BaseBackend):
 | 
	
		
			
				|  |  |                        traceback=None, max_retries=3, **kwargs):
 | 
	
		
			
				|  |  |          """Store return value and status of an executed task."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            task = session.query(Task).filter(Task.task_id == task_id).first()
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  | +            task = list(session.query(Task).filter(Task.task_id == task_id))
 | 
	
		
			
				|  |  | +            task = task and task[0]
 | 
	
		
			
				|  |  |              if not task:
 | 
	
		
			
				|  |  |                  task = Task(task_id)
 | 
	
		
			
				|  |  |                  session.add(task)
 | 
	
	
		
			
				|  | @@ -106,83 +130,70 @@ class DatabaseBackend(BaseBackend):
 | 
	
		
			
				|  |  |              task.traceback = traceback
 | 
	
		
			
				|  |  |              session.commit()
 | 
	
		
			
				|  |  |              return result
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @retry
 | 
	
		
			
				|  |  |      def _get_task_meta_for(self, task_id):
 | 
	
		
			
				|  |  |          """Get task metadata for a task by id."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            task = session.query(Task).filter(Task.task_id == task_id).first()
 | 
	
		
			
				|  |  | -            if task is None:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  | +            task = list(session.query(Task).filter(Task.task_id == task_id))
 | 
	
		
			
				|  |  | +            task = task and task[0]
 | 
	
		
			
				|  |  | +            if not task:
 | 
	
		
			
				|  |  |                  task = Task(task_id)
 | 
	
		
			
				|  |  |                  task.status = states.PENDING
 | 
	
		
			
				|  |  |                  task.result = None
 | 
	
		
			
				|  |  |              return task.to_dict()
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @retry
 | 
	
		
			
				|  |  |      def _save_group(self, group_id, result):
 | 
	
		
			
				|  |  |          """Store the result of an executed group."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  |              group = TaskSet(group_id, result)
 | 
	
		
			
				|  |  |              session.add(group)
 | 
	
		
			
				|  |  |              session.flush()
 | 
	
		
			
				|  |  |              session.commit()
 | 
	
		
			
				|  |  |              return result
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @retry
 | 
	
		
			
				|  |  |      def _restore_group(self, group_id):
 | 
	
		
			
				|  |  |          """Get metadata for group by id."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  |              group = session.query(TaskSet).filter(
 | 
	
		
			
				|  |  |                  TaskSet.taskset_id == group_id).first()
 | 
	
		
			
				|  |  |              if group:
 | 
	
		
			
				|  |  |                  return group.to_dict()
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @retry
 | 
	
		
			
				|  |  |      def _delete_group(self, group_id):
 | 
	
		
			
				|  |  |          """Delete metadata for group by id."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  |              session.query(TaskSet).filter(
 | 
	
		
			
				|  |  |                  TaskSet.taskset_id == group_id).delete()
 | 
	
		
			
				|  |  |              session.flush()
 | 
	
		
			
				|  |  |              session.commit()
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @retry
 | 
	
		
			
				|  |  |      def _forget(self, task_id):
 | 
	
		
			
				|  |  |          """Forget about result."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  |              session.query(Task).filter(Task.task_id == task_id).delete()
 | 
	
		
			
				|  |  |              session.commit()
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cleanup(self):
 | 
	
		
			
				|  |  |          """Delete expired metadata."""
 | 
	
		
			
				|  |  |          session = self.ResultSession()
 | 
	
		
			
				|  |  |          expires = self.expires
 | 
	
		
			
				|  |  |          now = self.app.now()
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | +        with session_cleanup(session):
 | 
	
		
			
				|  |  |              session.query(Task).filter(
 | 
	
		
			
				|  |  |                  Task.date_done < (now - expires)).delete()
 | 
	
		
			
				|  |  |              session.query(TaskSet).filter(
 | 
	
		
			
				|  |  |                  TaskSet.date_done < (now - expires)).delete()
 | 
	
		
			
				|  |  |              session.commit()
 | 
	
		
			
				|  |  | -        finally:
 | 
	
		
			
				|  |  | -            session.close()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __reduce__(self, args=(), kwargs={}):
 | 
	
		
			
				|  |  |          kwargs.update(
 |