test_database.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import unittest2 as unittest
  2. from datetime import datetime
  3. from celery.exceptions import ImproperlyConfigured
  4. from celery import states
  5. from celery.app import app_or_default
  6. from celery.backends.database import DatabaseBackend
  7. from celery.db.models import Task, TaskSet
  8. from celery.result import AsyncResult
  9. from celery.utils import gen_unique_id
  10. class SomeClass(object):
  11. def __init__(self, data):
  12. self.data = data
  13. class test_DatabaseBackend(unittest.TestCase):
  14. def test_missing_dburi_raises_ImproperlyConfigured(self):
  15. conf = app_or_default().conf
  16. prev, conf.CELERY_RESULT_DBURI = conf.CELERY_RESULT_DBURI, None
  17. try:
  18. self.assertRaises(ImproperlyConfigured, DatabaseBackend)
  19. finally:
  20. conf.CELERY_RESULT_DBURI = prev
  21. def test_missing_task_id_is_PENDING(self):
  22. tb = DatabaseBackend()
  23. self.assertEqual(tb.get_status("xxx-does-not-exist"), states.PENDING)
  24. def test_mark_as_done(self):
  25. tb = DatabaseBackend()
  26. tid = gen_unique_id()
  27. self.assertEqual(tb.get_status(tid), states.PENDING)
  28. self.assertIsNone(tb.get_result(tid))
  29. tb.mark_as_done(tid, 42)
  30. self.assertEqual(tb.get_status(tid), states.SUCCESS)
  31. self.assertEqual(tb.get_result(tid), 42)
  32. def test_is_pickled(self):
  33. tb = DatabaseBackend()
  34. tid2 = gen_unique_id()
  35. result = {"foo": "baz", "bar": SomeClass(12345)}
  36. tb.mark_as_done(tid2, result)
  37. # is serialized properly.
  38. rindb = tb.get_result(tid2)
  39. self.assertEqual(rindb.get("foo"), "baz")
  40. self.assertEqual(rindb.get("bar").data, 12345)
  41. def test_mark_as_started(self):
  42. tb = DatabaseBackend()
  43. tid = gen_unique_id()
  44. tb.mark_as_started(tid)
  45. self.assertEqual(tb.get_status(tid), states.STARTED)
  46. def test_mark_as_revoked(self):
  47. tb = DatabaseBackend()
  48. tid = gen_unique_id()
  49. tb.mark_as_revoked(tid)
  50. self.assertEqual(tb.get_status(tid), states.REVOKED)
  51. def test_mark_as_retry(self):
  52. tb = DatabaseBackend()
  53. tid = gen_unique_id()
  54. try:
  55. raise KeyError("foo")
  56. except KeyError, exception:
  57. import traceback
  58. trace = "\n".join(traceback.format_stack())
  59. tb.mark_as_retry(tid, exception, traceback=trace)
  60. self.assertEqual(tb.get_status(tid), states.RETRY)
  61. self.assertIsInstance(tb.get_result(tid), KeyError)
  62. self.assertEqual(tb.get_traceback(tid), trace)
  63. def test_mark_as_failure(self):
  64. tb = DatabaseBackend()
  65. tid3 = gen_unique_id()
  66. try:
  67. raise KeyError("foo")
  68. except KeyError, exception:
  69. import traceback
  70. trace = "\n".join(traceback.format_stack())
  71. tb.mark_as_failure(tid3, exception, traceback=trace)
  72. self.assertEqual(tb.get_status(tid3), states.FAILURE)
  73. self.assertIsInstance(tb.get_result(tid3), KeyError)
  74. self.assertEqual(tb.get_traceback(tid3), trace)
  75. def test_forget(self):
  76. tb = DatabaseBackend(backend="memory://")
  77. tid = gen_unique_id()
  78. tb.mark_as_done(tid, {"foo": "bar"})
  79. x = AsyncResult(tid)
  80. x.forget()
  81. self.assertIsNone(x.result)
  82. def test_process_cleanup(self):
  83. tb = DatabaseBackend()
  84. tb.process_cleanup()
  85. def test_save___restore_taskset(self):
  86. tb = DatabaseBackend()
  87. tid = gen_unique_id()
  88. res = {u"something": "special"}
  89. self.assertEqual(tb.save_taskset(tid, res), res)
  90. res2 = tb.restore_taskset(tid)
  91. self.assertEqual(res2, res)
  92. self.assertIsNone(tb.restore_taskset("xxx-nonexisting-id"))
  93. def test_cleanup(self):
  94. tb = DatabaseBackend()
  95. for i in range(10):
  96. tb.mark_as_done(gen_unique_id(), 42)
  97. tb.save_taskset(gen_unique_id(), {"foo": "bar"})
  98. s = tb.ResultSession()
  99. for t in s.query(Task).all():
  100. t.date_done = datetime.now() - tb.result_expires * 2
  101. for t in s.query(TaskSet).all():
  102. t.date_done = datetime.now() - tb.result_expires * 2
  103. s.commit()
  104. s.close()
  105. tb.cleanup()
  106. s2 = tb.ResultSession()
  107. self.assertEqual(s2.query(Task).count(), 0)
  108. self.assertEqual(s2.query(TaskSet).count(), 0)
  109. def test_Task__repr__(self):
  110. self.assertIn("foo", repr(Task("foo")))
  111. def test_TaskSet__repr__(self):
  112. self.assertIn("foo", repr(TaskSet("foo", None)))