test_redis.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. from __future__ import absolute_import
  2. from datetime import timedelta
  3. from pickle import loads, dumps
  4. from celery import signature
  5. from celery import states
  6. from celery import group
  7. from celery import uuid
  8. from celery.datastructures import AttributeDict
  9. from celery.exceptions import ImproperlyConfigured
  10. from celery.utils.timeutils import timedelta_seconds
  11. from celery.tests.case import (
  12. AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
  13. )
  14. class Connection(object):
  15. connected = True
  16. def disconnect(self):
  17. self.connected = False
  18. class Pipeline(object):
  19. def __init__(self, client):
  20. self.client = client
  21. self.steps = []
  22. def __getattr__(self, attr):
  23. def add_step(*args, **kwargs):
  24. self.steps.append((getattr(self.client, attr), args, kwargs))
  25. return self
  26. return add_step
  27. def __enter__(self):
  28. return self
  29. def __exit__(self, type, value, traceback):
  30. pass
  31. def execute(self):
  32. return [step(*a, **kw) for step, a, kw in self.steps]
  33. class Redis(MockCallbacks):
  34. Connection = Connection
  35. Pipeline = Pipeline
  36. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  37. self.host = host
  38. self.port = port
  39. self.db = db
  40. self.password = password
  41. self.keyspace = {}
  42. self.expiry = {}
  43. self.connection = self.Connection()
  44. def get(self, key):
  45. return self.keyspace.get(key)
  46. def setex(self, key, value, expires):
  47. self.set(key, value)
  48. self.expire(key, expires)
  49. def set(self, key, value):
  50. self.keyspace[key] = value
  51. def expire(self, key, expires):
  52. self.expiry[key] = expires
  53. return expires
  54. def delete(self, key):
  55. return bool(self.keyspace.pop(key, None))
  56. def pipeline(self):
  57. return self.Pipeline(self)
  58. def _get_list(self, key):
  59. try:
  60. return self.keyspace[key]
  61. except KeyError:
  62. l = self.keyspace[key] = []
  63. return l
  64. def rpush(self, key, value):
  65. self._get_list(key).append(value)
  66. def lrange(self, key, start, stop):
  67. return self._get_list(key)[start:stop]
  68. def llen(self, key):
  69. return len(self.keyspace.get(key) or [])
  70. class redis(object):
  71. Redis = Redis
  72. class ConnectionPool(object):
  73. def __init__(self, **kwargs):
  74. pass
  75. class UnixDomainSocketConnection(object):
  76. def __init__(self, **kwargs):
  77. pass
  78. class test_RedisBackend(AppCase):
  79. def get_backend(self):
  80. from celery.backends.redis import RedisBackend
  81. class _RedisBackend(RedisBackend):
  82. redis = redis
  83. return _RedisBackend
  84. def setup(self):
  85. self.Backend = self.get_backend()
  86. @depends_on_current_app
  87. def test_reduce(self):
  88. try:
  89. from celery.backends.redis import RedisBackend
  90. x = RedisBackend(app=self.app, new_join=True)
  91. self.assertTrue(loads(dumps(x)))
  92. except ImportError:
  93. raise SkipTest('redis not installed')
  94. def test_no_redis(self):
  95. self.Backend.redis = None
  96. with self.assertRaises(ImproperlyConfigured):
  97. self.Backend(app=self.app, new_join=True)
  98. def test_url(self):
  99. x = self.Backend(
  100. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  101. new_join=True,
  102. )
  103. self.assertTrue(x.connparams)
  104. self.assertEqual(x.connparams['host'], 'vandelay.com')
  105. self.assertEqual(x.connparams['db'], 1)
  106. self.assertEqual(x.connparams['port'], 123)
  107. self.assertEqual(x.connparams['password'], 'bosco')
  108. def test_socket_url(self):
  109. x = self.Backend(
  110. 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
  111. new_join=True,
  112. )
  113. self.assertTrue(x.connparams)
  114. self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
  115. self.assertIs(
  116. x.connparams['connection_class'],
  117. redis.UnixDomainSocketConnection,
  118. )
  119. self.assertNotIn('host', x.connparams)
  120. self.assertNotIn('port', x.connparams)
  121. self.assertEqual(x.connparams['db'], 3)
  122. def test_compat_propertie(self):
  123. x = self.Backend(
  124. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  125. new_join=True,
  126. )
  127. with self.assertPendingDeprecation():
  128. self.assertEqual(x.host, 'vandelay.com')
  129. with self.assertPendingDeprecation():
  130. self.assertEqual(x.db, 1)
  131. with self.assertPendingDeprecation():
  132. self.assertEqual(x.port, 123)
  133. with self.assertPendingDeprecation():
  134. self.assertEqual(x.password, 'bosco')
  135. def test_conf_raises_KeyError(self):
  136. self.app.conf = AttributeDict({
  137. 'CELERY_RESULT_SERIALIZER': 'json',
  138. 'CELERY_MAX_CACHED_RESULTS': 1,
  139. 'CELERY_ACCEPT_CONTENT': ['json'],
  140. 'CELERY_TASK_RESULT_EXPIRES': None,
  141. })
  142. self.Backend(app=self.app, new_join=True)
  143. def test_expires_defaults_to_config(self):
  144. self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
  145. b = self.Backend(expires=None, app=self.app, new_join=True)
  146. self.assertEqual(b.expires, 10)
  147. def test_expires_is_int(self):
  148. b = self.Backend(expires=48, app=self.app, new_join=True)
  149. self.assertEqual(b.expires, 48)
  150. def test_set_new_join_from_url_query(self):
  151. b = self.Backend('redis://?new_join=True;foobar=1', app=self.app)
  152. self.assertEqual(b.on_chord_part_return, b._new_chord_return)
  153. self.assertEqual(b.apply_chord, b._new_chord_apply)
  154. def test_default_is_old_join(self):
  155. b = self.Backend(app=self.app)
  156. self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
  157. self.assertNotEqual(b.apply_chord, b._new_chord_apply)
  158. def test_expires_is_None(self):
  159. b = self.Backend(expires=None, app=self.app, new_join=True)
  160. self.assertEqual(b.expires, timedelta_seconds(
  161. self.app.conf.CELERY_TASK_RESULT_EXPIRES))
  162. def test_expires_is_timedelta(self):
  163. b = self.Backend(
  164. expires=timedelta(minutes=1), app=self.app, new_join=1,
  165. )
  166. self.assertEqual(b.expires, 60)
  167. def test_apply_chord(self):
  168. self.Backend(app=self.app, new_join=True).apply_chord(
  169. group(app=self.app), (), 'group_id', {},
  170. result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
  171. )
  172. def test_mget(self):
  173. b = self.Backend(app=self.app, new_join=True)
  174. self.assertTrue(b.mget(['a', 'b', 'c']))
  175. b.client.mget.assert_called_with(['a', 'b', 'c'])
  176. def test_set_no_expire(self):
  177. b = self.Backend(app=self.app, new_join=True)
  178. b.expires = None
  179. b.set('foo', 'bar')
  180. @patch('celery.result.GroupResult.restore')
  181. def test_on_chord_part_return(self, restore):
  182. b = self.Backend(app=self.app, new_join=True)
  183. def create_task():
  184. tid = uuid()
  185. task = Mock(name='task-{0}'.format(tid))
  186. task.name = 'foobarbaz'
  187. self.app.tasks['foobarbaz'] = task
  188. task.request.chord = signature(task)
  189. task.request.id = tid
  190. task.request.chord['chord_size'] = 10
  191. task.request.group = 'group_id'
  192. return task
  193. tasks = [create_task() for i in range(10)]
  194. for i in range(10):
  195. b.on_chord_part_return(tasks[i], states.SUCCESS, i)
  196. self.assertTrue(b.client.rpush.call_count)
  197. b.client.rpush.reset_mock()
  198. self.assertTrue(b.client.lrange.call_count)
  199. gkey = b.get_key_for_group('group_id', '.j')
  200. b.client.delete.assert_called_with(gkey)
  201. b.client.expire.assert_called_with(gkey, 86400)
  202. def test_process_cleanup(self):
  203. self.Backend(app=self.app, new_join=True).process_cleanup()
  204. def test_get_set_forget(self):
  205. b = self.Backend(app=self.app, new_join=True)
  206. tid = uuid()
  207. b.store_result(tid, 42, states.SUCCESS)
  208. self.assertEqual(b.get_status(tid), states.SUCCESS)
  209. self.assertEqual(b.get_result(tid), 42)
  210. b.forget(tid)
  211. self.assertEqual(b.get_status(tid), states.PENDING)
  212. def test_set_expires(self):
  213. b = self.Backend(expires=512, app=self.app, new_join=True)
  214. tid = uuid()
  215. key = b.get_key_for_task(tid)
  216. b.store_result(tid, 42, states.SUCCESS)
  217. b.client.expire.assert_called_with(
  218. key, 512,
  219. )