test_redis.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from __future__ import absolute_import
  2. from datetime import timedelta
  3. from mock import Mock, patch
  4. from nose import SkipTest
  5. from pickle import loads, dumps
  6. from kombu.utils import cached_property, uuid
  7. from celery import current_app
  8. from celery import states
  9. from celery.datastructures import AttributeDict
  10. from celery.exceptions import ImproperlyConfigured
  11. from celery.result import AsyncResult
  12. from celery.task import subtask
  13. from celery.utils.timeutils import timedelta_seconds
  14. from celery.tests.utils import Case
  15. class Redis(object):
  16. class Connection(object):
  17. connected = True
  18. def disconnect(self):
  19. self.connected = False
  20. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  21. self.host = host
  22. self.port = port
  23. self.db = db
  24. self.password = password
  25. self.connection = self.Connection()
  26. self.keyspace = {}
  27. self.expiry = {}
  28. def get(self, key):
  29. return self.keyspace.get(key)
  30. def setex(self, key, value, expires):
  31. self.set(key, value)
  32. self.expire(key, expires)
  33. def set(self, key, value):
  34. self.keyspace[key] = value
  35. def expire(self, key, expires):
  36. self.expiry[key] = expires
  37. def delete(self, key):
  38. self.keyspace.pop(key)
  39. def publish(self, key, value):
  40. pass
  41. class redis(object):
  42. Redis = Redis
  43. class ConnectionPool(object):
  44. def __init__(self, **kwargs):
  45. pass
  46. class test_RedisBackend(Case):
  47. def get_backend(self):
  48. from celery.backends import redis
  49. class RedisBackend(redis.RedisBackend):
  50. redis = redis
  51. return RedisBackend
  52. def setUp(self):
  53. self.Backend = self.get_backend()
  54. class MockBackend(self.Backend):
  55. @cached_property
  56. def client(self):
  57. return Mock()
  58. self.MockBackend = MockBackend
  59. def test_reduce(self):
  60. try:
  61. from celery.backends.redis import RedisBackend
  62. x = RedisBackend()
  63. self.assertTrue(loads(dumps(x)))
  64. except ImportError:
  65. raise SkipTest('redis not installed')
  66. def test_no_redis(self):
  67. self.MockBackend.redis = None
  68. with self.assertRaises(ImproperlyConfigured):
  69. self.MockBackend()
  70. def test_url(self):
  71. x = self.MockBackend('redis://foobar//1')
  72. self.assertEqual(x.host, 'foobar')
  73. self.assertEqual(x.db, '1')
  74. def test_conf_raises_KeyError(self):
  75. conf = AttributeDict({'CELERY_RESULT_SERIALIZER': 'json',
  76. 'CELERY_MAX_CACHED_RESULTS': 1,
  77. 'CELERY_TASK_RESULT_EXPIRES': None})
  78. prev, current_app.conf = current_app.conf, conf
  79. try:
  80. self.MockBackend()
  81. finally:
  82. current_app.conf = prev
  83. def test_expires_defaults_to_config(self):
  84. conf = current_app.conf
  85. prev = conf.CELERY_TASK_RESULT_EXPIRES
  86. conf.CELERY_TASK_RESULT_EXPIRES = 10
  87. try:
  88. b = self.Backend(expires=None)
  89. self.assertEqual(b.expires, 10)
  90. finally:
  91. conf.CELERY_TASK_RESULT_EXPIRES = prev
  92. def test_expires_is_int(self):
  93. b = self.Backend(expires=48)
  94. self.assertEqual(b.expires, 48)
  95. def test_expires_is_None(self):
  96. b = self.Backend(expires=None)
  97. self.assertEqual(b.expires, timedelta_seconds(
  98. current_app.conf.CELERY_TASK_RESULT_EXPIRES))
  99. def test_expires_is_timedelta(self):
  100. b = self.Backend(expires=timedelta(minutes=1))
  101. self.assertEqual(b.expires, 60)
  102. def test_on_chord_apply(self):
  103. self.Backend().on_chord_apply('group_id', {},
  104. result=map(AsyncResult, [1, 2, 3]))
  105. def test_mget(self):
  106. b = self.MockBackend()
  107. self.assertTrue(b.mget(['a', 'b', 'c']))
  108. b.client.mget.assert_called_with(['a', 'b', 'c'])
  109. def test_set_no_expire(self):
  110. b = self.MockBackend()
  111. b.expires = None
  112. b.set('foo', 'bar')
  113. @patch('celery.result.GroupResult')
  114. def test_on_chord_part_return(self, setresult):
  115. b = self.MockBackend()
  116. deps = Mock()
  117. deps.__len__ = Mock()
  118. deps.__len__.return_value = 10
  119. setresult.restore.return_value = deps
  120. b.client.incr.return_value = 1
  121. task = Mock()
  122. task.name = 'foobarbaz'
  123. try:
  124. current_app.tasks['foobarbaz'] = task
  125. task.request.chord = subtask(task)
  126. task.request.group = 'group_id'
  127. b.on_chord_part_return(task)
  128. self.assertTrue(b.client.incr.call_count)
  129. b.client.incr.return_value = len(deps)
  130. b.on_chord_part_return(task)
  131. deps.join.assert_called_with(propagate=False)
  132. deps.delete.assert_called_with()
  133. self.assertTrue(b.client.expire.call_count)
  134. finally:
  135. current_app.tasks.pop('foobarbaz')
  136. def test_process_cleanup(self):
  137. self.Backend().process_cleanup()
  138. def test_get_set_forget(self):
  139. b = self.Backend()
  140. tid = uuid()
  141. b.store_result(tid, 42, states.SUCCESS)
  142. self.assertEqual(b.get_status(tid), states.SUCCESS)
  143. self.assertEqual(b.get_result(tid), 42)
  144. b.forget(tid)
  145. self.assertEqual(b.get_status(tid), states.PENDING)
  146. def test_set_expires(self):
  147. b = self.Backend(expires=512)
  148. tid = uuid()
  149. key = b.get_key_for_task(tid)
  150. b.store_result(tid, 42, states.SUCCESS)
  151. self.assertEqual(b.client.expiry[key], 512)