from __future__ import absolute_import, unicode_literals import pytest from datetime import datetime from pickle import loads, dumps from case import Mock, patch, skip from celery import states from celery import uuid from celery.exceptions import ImproperlyConfigured try: import sqlalchemy # noqa except ImportError: DatabaseBackend = Task = TaskSet = retry = None # noqa SessionManager = session_cleanup = None # noqa else: from celery.backends.database import ( DatabaseBackend, retry, session_cleanup, ) from celery.backends.database import session from celery.backends.database.session import SessionManager from celery.backends.database.models import Task, TaskSet class SomeClass(object): def __init__(self, data): self.data = data @skip.unless_module('sqlalchemy') class test_session_cleanup: def test_context(self): session = Mock(name='session') with session_cleanup(session): pass session.close.assert_called_with() def test_context_raises(self): session = Mock(name='session') with pytest.raises(KeyError): with session_cleanup(session): raise KeyError() session.rollback.assert_called_with() session.close.assert_called_with() @skip.unless_module('sqlalchemy') @skip.if_pypy() @skip.if_jython() class test_DatabaseBackend: def setup(self): self.uri = 'sqlite:///test.db' self.app.conf.result_serializer = 'pickle' def test_retry_helper(self): from celery.backends.database import DatabaseError calls = [0] @retry def raises(): calls[0] += 1 raise DatabaseError(1, 2, 3) with pytest.raises(DatabaseError): raises(max_retries=5) assert calls[0] == 5 def test_missing_dburi_raises_ImproperlyConfigured(self): self.app.conf.database_url = None with pytest.raises(ImproperlyConfigured): DatabaseBackend(app=self.app) def test_missing_task_id_is_PENDING(self): tb = DatabaseBackend(self.uri, app=self.app) assert tb.get_state('xxx-does-not-exist') == states.PENDING def test_missing_task_meta_is_dict_with_pending(self): tb = DatabaseBackend(self.uri, app=self.app) meta = tb.get_task_meta('xxx-does-not-exist-at-all') assert meta['status'] == states.PENDING assert meta['task_id'] == 'xxx-does-not-exist-at-all' assert meta['result'] is None assert meta['traceback'] is None def test_mark_as_done(self): tb = DatabaseBackend(self.uri, app=self.app) tid = uuid() assert tb.get_state(tid) == states.PENDING assert tb.get_result(tid) is None tb.mark_as_done(tid, 42) assert tb.get_state(tid) == states.SUCCESS assert tb.get_result(tid) == 42 def test_is_pickled(self): tb = DatabaseBackend(self.uri, app=self.app) tid2 = uuid() result = {'foo': 'baz', 'bar': SomeClass(12345)} tb.mark_as_done(tid2, result) # is serialized properly. rindb = tb.get_result(tid2) assert rindb.get('foo') == 'baz' assert rindb.get('bar').data == 12345 def test_mark_as_started(self): tb = DatabaseBackend(self.uri, app=self.app) tid = uuid() tb.mark_as_started(tid) assert tb.get_state(tid) == states.STARTED def test_mark_as_revoked(self): tb = DatabaseBackend(self.uri, app=self.app) tid = uuid() tb.mark_as_revoked(tid) assert tb.get_state(tid) == states.REVOKED def test_mark_as_retry(self): tb = DatabaseBackend(self.uri, app=self.app) tid = uuid() try: raise KeyError('foo') except KeyError as exception: import traceback trace = '\n'.join(traceback.format_stack()) tb.mark_as_retry(tid, exception, traceback=trace) assert tb.get_state(tid) == states.RETRY assert isinstance(tb.get_result(tid), KeyError) assert tb.get_traceback(tid) == trace def test_mark_as_failure(self): tb = DatabaseBackend(self.uri, app=self.app) tid3 = uuid() try: raise KeyError('foo') except KeyError as exception: import traceback trace = '\n'.join(traceback.format_stack()) tb.mark_as_failure(tid3, exception, traceback=trace) assert tb.get_state(tid3) == states.FAILURE assert isinstance(tb.get_result(tid3), KeyError) assert tb.get_traceback(tid3) == trace def test_forget(self): tb = DatabaseBackend(self.uri, backend='memory://', app=self.app) tid = uuid() tb.mark_as_done(tid, {'foo': 'bar'}) tb.mark_as_done(tid, {'foo': 'bar'}) x = self.app.AsyncResult(tid, backend=tb) x.forget() assert x.result is None def test_process_cleanup(self): tb = DatabaseBackend(self.uri, app=self.app) tb.process_cleanup() @pytest.mark.usefixtures('depends_on_current_app') def test_reduce(self): tb = DatabaseBackend(self.uri, app=self.app) assert loads(dumps(tb)) def test_save__restore__delete_group(self): tb = DatabaseBackend(self.uri, app=self.app) tid = uuid() res = {'something': 'special'} assert tb.save_group(tid, res) == res res2 = tb.restore_group(tid) assert res2 == res tb.delete_group(tid) assert tb.restore_group(tid) is None assert tb.restore_group('xxx-nonexisting-id') is None def test_cleanup(self): tb = DatabaseBackend(self.uri, app=self.app) for i in range(10): tb.mark_as_done(uuid(), 42) tb.save_group(uuid(), {'foo': 'bar'}) s = tb.ResultSession() for t in s.query(Task).all(): t.date_done = datetime.now() - tb.expires * 2 for t in s.query(TaskSet).all(): t.date_done = datetime.now() - tb.expires * 2 s.commit() s.close() tb.cleanup() def test_Task__repr__(self): assert 'foo' in repr(Task('foo')) def test_TaskSet__repr__(self): assert 'foo', repr(TaskSet('foo' in None)) @skip.unless_module('sqlalchemy') class test_SessionManager: def test_after_fork(self): s = SessionManager() assert not s.forked s._after_fork() assert s.forked @patch('celery.backends.database.session.create_engine') def test_get_engine_forked(self, create_engine): s = SessionManager() s._after_fork() engine = s.get_engine('dburi', foo=1) create_engine.assert_called_with('dburi', foo=1) assert engine is create_engine() engine2 = s.get_engine('dburi', foo=1) assert engine2 is engine @patch('celery.backends.database.session.sessionmaker') def test_create_session_forked(self, sessionmaker): s = SessionManager() s.get_engine = Mock(name='get_engine') s._after_fork() engine, session = s.create_session('dburi', short_lived_sessions=True) sessionmaker.assert_called_with(bind=s.get_engine()) assert session is sessionmaker() sessionmaker.return_value = Mock(name='new') engine, session2 = s.create_session('dburi', short_lived_sessions=True) sessionmaker.assert_called_with(bind=s.get_engine()) assert session2 is not session sessionmaker.return_value = Mock(name='new2') engine, session3 = s.create_session( 'dburi', short_lived_sessions=False) sessionmaker.assert_called_with(bind=s.get_engine()) assert session3 is session2 def test_coverage_madness(self): prev, session.register_after_fork = ( session.register_after_fork, None, ) try: SessionManager() finally: session.register_after_fork = prev