test_redis.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import pytest
  2. from datetime import timedelta
  3. from contextlib import contextmanager
  4. from pickle import loads, dumps
  5. from case import ANY, ContextMock, Mock, mock, call, patch, skip
  6. from celery import signature
  7. from celery import states
  8. from celery import uuid
  9. from celery.canvas import Signature
  10. from celery.exceptions import ChordError, ImproperlyConfigured
  11. from celery.utils.collections import AttributeDict
  12. def raise_on_second_call(mock, exc, *retval):
  13. def on_first_call(*args, **kwargs):
  14. mock.side_effect = exc
  15. return mock.return_value
  16. mock.side_effect = on_first_call
  17. if retval:
  18. mock.return_value, = retval
  19. class Connection:
  20. connected = True
  21. def disconnect(self):
  22. self.connected = False
  23. class Pipeline:
  24. def __init__(self, client):
  25. self.client = client
  26. self.steps = []
  27. def __getattr__(self, attr):
  28. def add_step(*args, **kwargs):
  29. self.steps.append((getattr(self.client, attr), args, kwargs))
  30. return self
  31. return add_step
  32. def __enter__(self):
  33. return self
  34. def __exit__(self, type, value, traceback):
  35. pass
  36. def execute(self):
  37. return [step(*a, **kw) for step, a, kw in self.steps]
  38. class Redis(mock.MockCallbacks):
  39. Connection = Connection
  40. Pipeline = Pipeline
  41. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  42. self.host = host
  43. self.port = port
  44. self.db = db
  45. self.password = password
  46. self.keyspace = {}
  47. self.expiry = {}
  48. self.connection = self.Connection()
  49. def get(self, key):
  50. return self.keyspace.get(key)
  51. def setex(self, key, expires, value):
  52. self.set(key, value)
  53. self.expire(key, expires)
  54. def set(self, key, value):
  55. self.keyspace[key] = value
  56. def expire(self, key, expires):
  57. self.expiry[key] = expires
  58. return expires
  59. def delete(self, key):
  60. return bool(self.keyspace.pop(key, None))
  61. def pipeline(self):
  62. return self.Pipeline(self)
  63. def _get_list(self, key):
  64. try:
  65. return self.keyspace[key]
  66. except KeyError:
  67. l = self.keyspace[key] = []
  68. return l
  69. def rpush(self, key, value):
  70. self._get_list(key).append(value)
  71. def lrange(self, key, start, stop):
  72. return self._get_list(key)[start:stop]
  73. def llen(self, key):
  74. return len(self.keyspace.get(key) or [])
  75. class redis:
  76. StrictRedis = Redis
  77. class ConnectionPool:
  78. def __init__(self, **kwargs):
  79. pass
  80. class UnixDomainSocketConnection:
  81. def __init__(self, **kwargs):
  82. pass
  83. class test_RedisBackend:
  84. def get_backend(self):
  85. from celery.backends.redis import RedisBackend
  86. class _RedisBackend(RedisBackend):
  87. redis = redis
  88. return _RedisBackend
  89. def get_E_LOST(self):
  90. from celery.backends.redis import E_LOST
  91. return E_LOST
  92. def setup(self):
  93. self.Backend = self.get_backend()
  94. self.E_LOST = self.get_E_LOST()
  95. self.b = self.Backend(app=self.app)
  96. @pytest.mark.usefixtures('depends_on_current_app')
  97. @skip.unless_module('redis')
  98. def test_reduce(self):
  99. from celery.backends.redis import RedisBackend
  100. x = RedisBackend(app=self.app)
  101. assert loads(dumps(x))
  102. def test_no_redis(self):
  103. self.Backend.redis = None
  104. with pytest.raises(ImproperlyConfigured):
  105. self.Backend(app=self.app)
  106. def test_url(self):
  107. self.app.conf.redis_socket_timeout = 30.0
  108. self.app.conf.redis_socket_connect_timeout = 100.0
  109. x = self.Backend(
  110. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  111. )
  112. assert x.connparams
  113. assert x.connparams['host'] == 'vandelay.com'
  114. assert x.connparams['db'] == 1
  115. assert x.connparams['port'] == 123
  116. assert x.connparams['password'] == 'bosco'
  117. assert x.connparams['socket_timeout'] == 30.0
  118. assert x.connparams['socket_connect_timeout'] == 100.0
  119. def test_socket_url(self):
  120. self.app.conf.redis_socket_timeout = 30.0
  121. self.app.conf.redis_socket_connect_timeout = 100.0
  122. x = self.Backend(
  123. 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
  124. )
  125. assert x.connparams
  126. assert x.connparams['path'] == '/tmp/redis.sock'
  127. assert (x.connparams['connection_class'] is
  128. redis.UnixDomainSocketConnection)
  129. assert 'host' not in x.connparams
  130. assert 'port' not in x.connparams
  131. assert x.connparams['socket_timeout'] == 30.0
  132. assert 'socket_connect_timeout' not in x.connparams
  133. assert x.connparams['db'] == 3
  134. def test_conf_raises_KeyError(self):
  135. self.app.conf = AttributeDict({
  136. 'result_serializer': 'json',
  137. 'result_cache_max': 1,
  138. 'result_expires': None,
  139. 'accept_content': ['json'],
  140. })
  141. self.Backend(app=self.app)
  142. @patch('celery.backends.redis.logger')
  143. def test_on_connection_error(self, logger):
  144. intervals = iter([10, 20, 30])
  145. exc = KeyError()
  146. assert self.b.on_connection_error(None, exc, intervals, 1) == 10
  147. logger.error.assert_called_with(
  148. self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
  149. assert self.b.on_connection_error(10, exc, intervals, 2) == 20
  150. logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
  151. assert self.b.on_connection_error(10, exc, intervals, 3) == 30
  152. logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
  153. def test_incr(self):
  154. self.b.client = Mock(name='client')
  155. self.b.incr('foo')
  156. self.b.client.incr.assert_called_with('foo')
  157. def test_expire(self):
  158. self.b.client = Mock(name='client')
  159. self.b.expire('foo', 300)
  160. self.b.client.expire.assert_called_with('foo', 300)
  161. def test_apply_chord(self):
  162. header = Mock(name='header')
  163. header.results = [Mock(name='t1'), Mock(name='t2')]
  164. self.b.apply_chord(
  165. header, (1, 2), 'gid', None,
  166. options={'max_retries': 10},
  167. )
  168. header.assert_called_with(1, 2, max_retries=10, task_id='gid')
  169. def test_unpack_chord_result(self):
  170. self.b.exception_to_python = Mock(name='etp')
  171. decode = Mock(name='decode')
  172. exc = KeyError()
  173. tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
  174. with pytest.raises(ChordError):
  175. self.b._unpack_chord_result(tup, decode)
  176. decode.assert_called_with(tup)
  177. self.b.exception_to_python.assert_called_with(exc)
  178. exc = ValueError()
  179. tup = decode.return_value = (2, 'id2', states.RETRY, exc)
  180. ret = self.b._unpack_chord_result(tup, decode)
  181. self.b.exception_to_python.assert_called_with(exc)
  182. assert ret is self.b.exception_to_python()
  183. def test_on_chord_part_return_no_gid_or_tid(self):
  184. request = Mock(name='request')
  185. request.id = request.group = None
  186. assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
  187. def test_ConnectionPool(self):
  188. self.b.redis = Mock(name='redis')
  189. assert self.b._ConnectionPool is None
  190. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  191. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  192. def test_expires_defaults_to_config(self):
  193. self.app.conf.result_expires = 10
  194. b = self.Backend(expires=None, app=self.app)
  195. assert b.expires == 10
  196. def test_expires_is_int(self):
  197. b = self.Backend(expires=48, app=self.app)
  198. assert b.expires == 48
  199. def test_add_to_chord(self):
  200. b = self.Backend('redis://', app=self.app)
  201. gid = uuid()
  202. b.add_to_chord(gid, 'sig')
  203. b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
  204. def test_expires_is_None(self):
  205. b = self.Backend(expires=None, app=self.app)
  206. assert b.expires == self.app.conf.result_expires.total_seconds()
  207. def test_expires_is_timedelta(self):
  208. b = self.Backend(expires=timedelta(minutes=1), app=self.app)
  209. assert b.expires == 60
  210. def test_mget(self):
  211. assert self.b.mget(['a', 'b', 'c'])
  212. self.b.client.mget.assert_called_with(['a', 'b', 'c'])
  213. def test_set_no_expire(self):
  214. self.b.expires = None
  215. self.b.set('foo', 'bar')
  216. def create_task(self):
  217. tid = uuid()
  218. task = Mock(name='task-{0}'.format(tid))
  219. task.name = 'foobarbaz'
  220. self.app.tasks['foobarbaz'] = task
  221. task.request.chord = signature(task)
  222. task.request.id = tid
  223. task.request.chord['chord_size'] = 10
  224. task.request.group = 'group_id'
  225. return task
  226. @patch('celery.result.GroupResult.restore')
  227. def test_on_chord_part_return(self, restore):
  228. tasks = [self.create_task() for i in range(10)]
  229. for i in range(10):
  230. self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
  231. assert self.b.client.rpush.call_count
  232. self.b.client.rpush.reset_mock()
  233. assert self.b.client.lrange.call_count
  234. jkey = self.b.get_key_for_group('group_id', '.j')
  235. tkey = self.b.get_key_for_group('group_id', '.t')
  236. self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
  237. self.b.client.expire.assert_has_calls([
  238. call(jkey, 86400), call(tkey, 86400),
  239. ])
  240. def test_on_chord_part_return__success(self):
  241. with self.chord_context(2) as (_, request, callback):
  242. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  243. callback.delay.assert_not_called()
  244. self.b.on_chord_part_return(request, states.SUCCESS, 20)
  245. callback.delay.assert_called_with([10, 20])
  246. def test_on_chord_part_return__callback_raises(self):
  247. with self.chord_context(1) as (_, request, callback):
  248. callback.delay.side_effect = KeyError(10)
  249. task = self.app._tasks['add'] = Mock(name='add_task')
  250. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  251. task.backend.fail_from_current_stack.assert_called_with(
  252. callback.id, exc=ANY,
  253. )
  254. def test_on_chord_part_return__ChordError(self):
  255. with self.chord_context(1) as (_, request, callback):
  256. self.b.client.pipeline = ContextMock()
  257. raise_on_second_call(self.b.client.pipeline, ChordError())
  258. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  259. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  260. task = self.app._tasks['add'] = Mock(name='add_task')
  261. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  262. task.backend.fail_from_current_stack.assert_called_with(
  263. callback.id, exc=ANY,
  264. )
  265. def test_on_chord_part_return__other_error(self):
  266. with self.chord_context(1) as (_, request, callback):
  267. self.b.client.pipeline = ContextMock()
  268. raise_on_second_call(self.b.client.pipeline, RuntimeError())
  269. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  270. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  271. task = self.app._tasks['add'] = Mock(name='add_task')
  272. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  273. task.backend.fail_from_current_stack.assert_called_with(
  274. callback.id, exc=ANY,
  275. )
  276. @contextmanager
  277. def chord_context(self, size=1):
  278. with patch('celery.backends.redis.maybe_signature') as ms:
  279. tasks = [self.create_task() for i in range(size)]
  280. request = Mock(name='request')
  281. request.id = 'id1'
  282. request.group = 'gid1'
  283. callback = ms.return_value = Signature('add')
  284. callback.id = 'id1'
  285. callback['chord_size'] = size
  286. callback.delay = Mock(name='callback.delay')
  287. yield tasks, request, callback
  288. def test_process_cleanup(self):
  289. self.b.process_cleanup()
  290. def test_get_set_forget(self):
  291. tid = uuid()
  292. self.b.store_result(tid, 42, states.SUCCESS)
  293. assert self.b.get_state(tid) == states.SUCCESS
  294. assert self.b.get_result(tid) == 42
  295. self.b.forget(tid)
  296. assert self.b.get_state(tid) == states.PENDING
  297. def test_set_expires(self):
  298. self.b = self.Backend(expires=512, app=self.app)
  299. tid = uuid()
  300. key = self.b.get_key_for_task(tid)
  301. self.b.store_result(tid, 42, states.SUCCESS)
  302. self.b.client.expire.assert_called_with(
  303. key, 512,
  304. )