test_redis_unit.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from __future__ import absolute_import
  2. from datetime import timedelta
  3. from mock import Mock, patch
  4. from celery import current_app
  5. from celery import states
  6. from celery.result import AsyncResult
  7. from celery.registry import tasks
  8. from celery.task import subtask
  9. from celery.utils import cached_property, uuid
  10. from celery.utils.timeutils import timedelta_seconds
  11. from celery.tests.utils import unittest
  12. class Redis(object):
  13. class Connection(object):
  14. connected = True
  15. def disconnect(self):
  16. self.connected = False
  17. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  18. self.host = host
  19. self.port = port
  20. self.db = db
  21. self.password = password
  22. self.connection = self.Connection()
  23. self.keyspace = {}
  24. self.expiry = {}
  25. def get(self, key):
  26. return self.keyspace.get(key)
  27. def set(self, key, value):
  28. self.keyspace[key] = value
  29. def expire(self, key, expires):
  30. self.expiry[key] = expires
  31. def delete(self, key):
  32. self.keyspace.pop(key)
  33. class redis(object):
  34. Redis = Redis
  35. class test_RedisBackend(unittest.TestCase):
  36. def get_backend(self):
  37. from celery.backends import redis
  38. class RedisBackend(redis.RedisBackend):
  39. redis = redis
  40. return RedisBackend
  41. def setUp(self):
  42. self.Backend = self.get_backend()
  43. class MockBackend(self.Backend):
  44. @cached_property
  45. def client(self):
  46. return Mock()
  47. self.MockBackend = MockBackend
  48. def test_expires_defaults_to_config(self):
  49. conf = current_app.conf
  50. prev = conf.CELERY_TASK_RESULT_EXPIRES
  51. conf.CELERY_TASK_RESULT_EXPIRES = 10
  52. try:
  53. b = self.Backend(expires=None)
  54. self.assertEqual(b.expires, 10)
  55. finally:
  56. conf.CELERY_TASK_RESULT_EXPIRES = prev
  57. def test_expires_is_int(self):
  58. b = self.Backend(expires=48)
  59. self.assertEqual(b.expires, 48)
  60. def test_expires_is_None(self):
  61. b = self.Backend(expires=None)
  62. self.assertEqual(b.expires, timedelta_seconds(
  63. current_app.conf.CELERY_TASK_RESULT_EXPIRES))
  64. def test_expires_is_timedelta(self):
  65. b = self.Backend(expires=timedelta(minutes=1))
  66. self.assertEqual(b.expires, 60)
  67. def test_on_chord_apply(self):
  68. self.Backend().on_chord_apply("setid", {},
  69. result=map(AsyncResult, [1, 2, 3]))
  70. def test_mget(self):
  71. b = self.MockBackend()
  72. self.assertTrue(b.mget(["a", "b", "c"]))
  73. b.client.mget.assert_called_with(["a", "b", "c"])
  74. def test_set_no_expire(self):
  75. b = self.MockBackend()
  76. b.expires = None
  77. b.set("foo", "bar")
  78. @patch("celery.result.TaskSetResult")
  79. def test_on_chord_part_return(self, setresult):
  80. b = self.MockBackend()
  81. deps = Mock()
  82. deps.total = 10
  83. setresult.restore.return_value = deps
  84. b.client.incr.return_value = 1
  85. task = Mock()
  86. task.name = "foobarbaz"
  87. try:
  88. tasks["foobarbaz"] = task
  89. task.request.chord = subtask(task)
  90. task.request.taskset = "setid"
  91. b.on_chord_part_return(task)
  92. self.assertTrue(b.client.incr.call_count)
  93. b.client.incr.return_value = deps.total
  94. b.on_chord_part_return(task)
  95. deps.join.assert_called_with(propagate=False)
  96. deps.delete.assert_called_with()
  97. self.assertTrue(b.client.expire.call_count)
  98. finally:
  99. tasks.pop("foobarbaz")
  100. def test_process_cleanup(self):
  101. self.Backend().process_cleanup()
  102. def test_get_set_forget(self):
  103. b = self.Backend()
  104. tid = uuid()
  105. b.store_result(tid, 42, states.SUCCESS)
  106. self.assertEqual(b.get_status(tid), states.SUCCESS)
  107. self.assertEqual(b.get_result(tid), 42)
  108. b.forget(tid)
  109. self.assertEqual(b.get_status(tid), states.PENDING)
  110. def test_set_expires(self):
  111. b = self.Backend(expires=512)
  112. tid = uuid()
  113. key = b.get_key_for_task(tid)
  114. b.store_result(tid, 42, states.SUCCESS)
  115. self.assertEqual(b.client.expiry[key], 512)