test_redis_unit.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from datetime import timedelta
  2. from mock import Mock, patch
  3. from kombu.utils import cached_property
  4. from celery import current_app
  5. from celery import states
  6. from celery.utils import gen_unique_id
  7. from celery.utils.timeutils import timedelta_seconds
  8. from celery.tests.utils import unittest
  9. class Redis(object):
  10. class Connection(object):
  11. connected = True
  12. def disconnect(self):
  13. self.connected = False
  14. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  15. self.host = host
  16. self.port = port
  17. self.db = db
  18. self.password = password
  19. self.connection = self.Connection()
  20. self.keyspace = {}
  21. self.expiry = {}
  22. def get(self, key):
  23. return self.keyspace.get(key)
  24. def set(self, key, value):
  25. self.keyspace[key] = value
  26. def expire(self, key, expires):
  27. self.expiry[key] = expires
  28. def delete(self, key):
  29. self.keyspace.pop(key)
  30. class redis(object):
  31. Redis = Redis
  32. class test_RedisBackend(unittest.TestCase):
  33. def get_backend(self):
  34. from celery.backends import redis
  35. class RedisBackend(redis.RedisBackend):
  36. redis = redis
  37. return RedisBackend
  38. def setUp(self):
  39. self.Backend = self.get_backend()
  40. class MockBackend(self.Backend):
  41. @cached_property
  42. def client(self):
  43. return Mock()
  44. self.MockBackend = MockBackend
  45. def test_expires_defaults_to_config(self):
  46. conf = current_app.conf
  47. prev = conf.CELERY_TASK_RESULT_EXPIRES
  48. conf.CELERY_TASK_RESULT_EXPIRES = 10
  49. try:
  50. b = self.Backend(expires=None)
  51. self.assertEqual(b.expires, 10)
  52. finally:
  53. conf.CELERY_TASK_RESULT_EXPIRES = prev
  54. def test_expires_is_int(self):
  55. b = self.Backend(expires=48)
  56. self.assertEqual(b.expires, 48)
  57. def test_expires_is_None(self):
  58. b = self.Backend(expires=None)
  59. self.assertEqual(b.expires, timedelta_seconds(
  60. current_app.conf.CELERY_TASK_RESULT_EXPIRES))
  61. def test_expires_is_timedelta(self):
  62. b = self.Backend(expires=timedelta(minutes=1))
  63. self.assertEqual(b.expires, 60)
  64. def test_on_chord_apply(self):
  65. self.Backend().on_chord_apply()
  66. def test_mget(self):
  67. b = self.MockBackend()
  68. self.assertTrue(b.mget(["a", "b", "c"]))
  69. b.client.mget.assert_called_with(["a", "b", "c"])
  70. def test_set_no_expire(self):
  71. b = self.MockBackend()
  72. b.expires = None
  73. b.set("foo", "bar")
  74. @patch("celery.result.TaskSetResult")
  75. def test_on_chord_part_return(self, setresult):
  76. from celery.registry import tasks
  77. from celery.task import subtask
  78. b = self.MockBackend()
  79. deps = Mock()
  80. deps.total = 10
  81. setresult.restore.return_value = deps
  82. b.client.incr.return_value = 1
  83. task = Mock()
  84. task.name = "foobarbaz"
  85. try:
  86. tasks["foobarbaz"] = task
  87. task.request.chord = subtask(task)
  88. b.on_chord_part_return(task)
  89. self.assertTrue(b.client.incr.call_count)
  90. b.client.incr.return_value = deps.total
  91. b.on_chord_part_return(task)
  92. deps.join.assert_called_with(propagate=False)
  93. deps.delete.assert_called_with()
  94. self.assertTrue(b.client.expire.call_count)
  95. finally:
  96. tasks.pop("foobarbaz")
  97. def test_process_cleanup(self):
  98. self.Backend().process_cleanup()
  99. def test_get_set_forget(self):
  100. b = self.Backend()
  101. uuid = gen_unique_id()
  102. b.store_result(uuid, 42, states.SUCCESS)
  103. self.assertEqual(b.get_status(uuid), states.SUCCESS)
  104. self.assertEqual(b.get_result(uuid), 42)
  105. b.forget(uuid)
  106. self.assertEqual(b.get_status(uuid), states.PENDING)
  107. def test_set_expires(self):
  108. b = self.Backend(expires=512)
  109. uuid = gen_unique_id()
  110. key = b.get_key_for_task(uuid)
  111. b.store_result(uuid, 42, states.SUCCESS)
  112. self.assertEqual(b.client.expiry[key], 512)