test_database.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. from datetime import datetime
  4. from pickle import loads, dumps
  5. from case import Mock, patch, skip
  6. from celery import states
  7. from celery import uuid
  8. from celery.exceptions import ImproperlyConfigured
  9. try:
  10. import sqlalchemy # noqa
  11. except ImportError:
  12. DatabaseBackend = Task = TaskSet = retry = None # noqa
  13. SessionManager = session_cleanup = None # noqa
  14. else:
  15. from celery.backends.database import (
  16. DatabaseBackend, retry, session_cleanup,
  17. )
  18. from celery.backends.database import session
  19. from celery.backends.database.session import SessionManager
  20. from celery.backends.database.models import Task, TaskSet
  21. class SomeClass(object):
  22. def __init__(self, data):
  23. self.data = data
  24. @skip.unless_module('sqlalchemy')
  25. class test_session_cleanup:
  26. def test_context(self):
  27. session = Mock(name='session')
  28. with session_cleanup(session):
  29. pass
  30. session.close.assert_called_with()
  31. def test_context_raises(self):
  32. session = Mock(name='session')
  33. with pytest.raises(KeyError):
  34. with session_cleanup(session):
  35. raise KeyError()
  36. session.rollback.assert_called_with()
  37. session.close.assert_called_with()
  38. @skip.unless_module('sqlalchemy')
  39. @skip.if_pypy()
  40. @skip.if_jython()
  41. class test_DatabaseBackend:
  42. def setup(self):
  43. self.uri = 'sqlite:///test.db'
  44. self.app.conf.result_serializer = 'pickle'
  45. def test_retry_helper(self):
  46. from celery.backends.database import DatabaseError
  47. calls = [0]
  48. @retry
  49. def raises():
  50. calls[0] += 1
  51. raise DatabaseError(1, 2, 3)
  52. with pytest.raises(DatabaseError):
  53. raises(max_retries=5)
  54. assert calls[0] == 5
  55. def test_missing_dburi_raises_ImproperlyConfigured(self):
  56. self.app.conf.database_url = None
  57. with pytest.raises(ImproperlyConfigured):
  58. DatabaseBackend(app=self.app)
  59. def test_missing_task_id_is_PENDING(self):
  60. tb = DatabaseBackend(self.uri, app=self.app)
  61. assert tb.get_state('xxx-does-not-exist') == states.PENDING
  62. def test_missing_task_meta_is_dict_with_pending(self):
  63. tb = DatabaseBackend(self.uri, app=self.app)
  64. meta = tb.get_task_meta('xxx-does-not-exist-at-all')
  65. assert meta['status'] == states.PENDING
  66. assert meta['task_id'] == 'xxx-does-not-exist-at-all'
  67. assert meta['result'] is None
  68. assert meta['traceback'] is None
  69. def test_mark_as_done(self):
  70. tb = DatabaseBackend(self.uri, app=self.app)
  71. tid = uuid()
  72. assert tb.get_state(tid) == states.PENDING
  73. assert tb.get_result(tid) is None
  74. tb.mark_as_done(tid, 42)
  75. assert tb.get_state(tid) == states.SUCCESS
  76. assert tb.get_result(tid) == 42
  77. def test_is_pickled(self):
  78. tb = DatabaseBackend(self.uri, app=self.app)
  79. tid2 = uuid()
  80. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  81. tb.mark_as_done(tid2, result)
  82. # is serialized properly.
  83. rindb = tb.get_result(tid2)
  84. assert rindb.get('foo') == 'baz'
  85. assert rindb.get('bar').data == 12345
  86. def test_mark_as_started(self):
  87. tb = DatabaseBackend(self.uri, app=self.app)
  88. tid = uuid()
  89. tb.mark_as_started(tid)
  90. assert tb.get_state(tid) == states.STARTED
  91. def test_mark_as_revoked(self):
  92. tb = DatabaseBackend(self.uri, app=self.app)
  93. tid = uuid()
  94. tb.mark_as_revoked(tid)
  95. assert tb.get_state(tid) == states.REVOKED
  96. def test_mark_as_retry(self):
  97. tb = DatabaseBackend(self.uri, app=self.app)
  98. tid = uuid()
  99. try:
  100. raise KeyError('foo')
  101. except KeyError as exception:
  102. import traceback
  103. trace = '\n'.join(traceback.format_stack())
  104. tb.mark_as_retry(tid, exception, traceback=trace)
  105. assert tb.get_state(tid) == states.RETRY
  106. assert isinstance(tb.get_result(tid), KeyError)
  107. assert tb.get_traceback(tid) == trace
  108. def test_mark_as_failure(self):
  109. tb = DatabaseBackend(self.uri, app=self.app)
  110. tid3 = uuid()
  111. try:
  112. raise KeyError('foo')
  113. except KeyError as exception:
  114. import traceback
  115. trace = '\n'.join(traceback.format_stack())
  116. tb.mark_as_failure(tid3, exception, traceback=trace)
  117. assert tb.get_state(tid3) == states.FAILURE
  118. assert isinstance(tb.get_result(tid3), KeyError)
  119. assert tb.get_traceback(tid3) == trace
  120. def test_forget(self):
  121. tb = DatabaseBackend(self.uri, backend='memory://', app=self.app)
  122. tid = uuid()
  123. tb.mark_as_done(tid, {'foo': 'bar'})
  124. tb.mark_as_done(tid, {'foo': 'bar'})
  125. x = self.app.AsyncResult(tid, backend=tb)
  126. x.forget()
  127. assert x.result is None
  128. def test_process_cleanup(self):
  129. tb = DatabaseBackend(self.uri, app=self.app)
  130. tb.process_cleanup()
  131. @pytest.mark.usefixtures('depends_on_current_app')
  132. def test_reduce(self):
  133. tb = DatabaseBackend(self.uri, app=self.app)
  134. assert loads(dumps(tb))
  135. def test_save__restore__delete_group(self):
  136. tb = DatabaseBackend(self.uri, app=self.app)
  137. tid = uuid()
  138. res = {'something': 'special'}
  139. assert tb.save_group(tid, res) == res
  140. res2 = tb.restore_group(tid)
  141. assert res2 == res
  142. tb.delete_group(tid)
  143. assert tb.restore_group(tid) is None
  144. assert tb.restore_group('xxx-nonexisting-id') is None
  145. def test_cleanup(self):
  146. tb = DatabaseBackend(self.uri, app=self.app)
  147. for i in range(10):
  148. tb.mark_as_done(uuid(), 42)
  149. tb.save_group(uuid(), {'foo': 'bar'})
  150. s = tb.ResultSession()
  151. for t in s.query(Task).all():
  152. t.date_done = datetime.now() - tb.expires * 2
  153. for t in s.query(TaskSet).all():
  154. t.date_done = datetime.now() - tb.expires * 2
  155. s.commit()
  156. s.close()
  157. tb.cleanup()
  158. def test_Task__repr__(self):
  159. assert 'foo' in repr(Task('foo'))
  160. def test_TaskSet__repr__(self):
  161. assert 'foo', repr(TaskSet('foo' in None))
  162. @skip.unless_module('sqlalchemy')
  163. class test_SessionManager:
  164. def test_after_fork(self):
  165. s = SessionManager()
  166. assert not s.forked
  167. s._after_fork()
  168. assert s.forked
  169. @patch('celery.backends.database.session.create_engine')
  170. def test_get_engine_forked(self, create_engine):
  171. s = SessionManager()
  172. s._after_fork()
  173. engine = s.get_engine('dburi', foo=1)
  174. create_engine.assert_called_with('dburi', foo=1)
  175. assert engine is create_engine()
  176. engine2 = s.get_engine('dburi', foo=1)
  177. assert engine2 is engine
  178. @patch('celery.backends.database.session.sessionmaker')
  179. def test_create_session_forked(self, sessionmaker):
  180. s = SessionManager()
  181. s.get_engine = Mock(name='get_engine')
  182. s._after_fork()
  183. engine, session = s.create_session('dburi', short_lived_sessions=True)
  184. sessionmaker.assert_called_with(bind=s.get_engine())
  185. assert session is sessionmaker()
  186. sessionmaker.return_value = Mock(name='new')
  187. engine, session2 = s.create_session('dburi', short_lived_sessions=True)
  188. sessionmaker.assert_called_with(bind=s.get_engine())
  189. assert session2 is not session
  190. sessionmaker.return_value = Mock(name='new2')
  191. engine, session3 = s.create_session(
  192. 'dburi', short_lived_sessions=False)
  193. sessionmaker.assert_called_with(bind=s.get_engine())
  194. assert session3 is session2
  195. def test_coverage_madness(self):
  196. prev, session.register_after_fork = (
  197. session.register_after_fork, None,
  198. )
  199. try:
  200. SessionManager()
  201. finally:
  202. session.register_after_fork = prev