test_redis.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. import ssl
  4. from datetime import timedelta
  5. from contextlib import contextmanager
  6. from pickle import loads, dumps
  7. from case import ANY, ContextMock, Mock, mock, call, patch, skip
  8. from celery import signature
  9. from celery import states
  10. from celery import uuid
  11. from celery.canvas import Signature
  12. from celery.exceptions import (
  13. ChordError, CPendingDeprecationWarning, ImproperlyConfigured,
  14. )
  15. from celery.utils.collections import AttributeDict
  16. def raise_on_second_call(mock, exc, *retval):
  17. def on_first_call(*args, **kwargs):
  18. mock.side_effect = exc
  19. return mock.return_value
  20. mock.side_effect = on_first_call
  21. if retval:
  22. mock.return_value, = retval
  23. class Connection(object):
  24. connected = True
  25. def disconnect(self):
  26. self.connected = False
  27. class Pipeline(object):
  28. def __init__(self, client):
  29. self.client = client
  30. self.steps = []
  31. def __getattr__(self, attr):
  32. def add_step(*args, **kwargs):
  33. self.steps.append((getattr(self.client, attr), args, kwargs))
  34. return self
  35. return add_step
  36. def __enter__(self):
  37. return self
  38. def __exit__(self, type, value, traceback):
  39. pass
  40. def execute(self):
  41. return [step(*a, **kw) for step, a, kw in self.steps]
  42. class Redis(mock.MockCallbacks):
  43. Connection = Connection
  44. Pipeline = Pipeline
  45. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  46. self.host = host
  47. self.port = port
  48. self.db = db
  49. self.password = password
  50. self.keyspace = {}
  51. self.expiry = {}
  52. self.connection = self.Connection()
  53. def get(self, key):
  54. return self.keyspace.get(key)
  55. def setex(self, key, expires, value):
  56. self.set(key, value)
  57. self.expire(key, expires)
  58. def set(self, key, value):
  59. self.keyspace[key] = value
  60. def expire(self, key, expires):
  61. self.expiry[key] = expires
  62. return expires
  63. def delete(self, key):
  64. return bool(self.keyspace.pop(key, None))
  65. def pipeline(self):
  66. return self.Pipeline(self)
  67. def _get_list(self, key):
  68. try:
  69. return self.keyspace[key]
  70. except KeyError:
  71. l = self.keyspace[key] = []
  72. return l
  73. def rpush(self, key, value):
  74. self._get_list(key).append(value)
  75. def lrange(self, key, start, stop):
  76. return self._get_list(key)[start:stop]
  77. def llen(self, key):
  78. return len(self.keyspace.get(key) or [])
  79. class redis(object):
  80. StrictRedis = Redis
  81. class ConnectionPool(object):
  82. def __init__(self, **kwargs):
  83. pass
  84. class UnixDomainSocketConnection(object):
  85. def __init__(self, **kwargs):
  86. pass
  87. class test_RedisBackend:
  88. def get_backend(self):
  89. from celery.backends.redis import RedisBackend
  90. class _RedisBackend(RedisBackend):
  91. redis = redis
  92. return _RedisBackend
  93. def get_E_LOST(self):
  94. from celery.backends.redis import E_LOST
  95. return E_LOST
  96. def setup(self):
  97. self.Backend = self.get_backend()
  98. self.E_LOST = self.get_E_LOST()
  99. self.b = self.Backend(app=self.app)
  100. @pytest.mark.usefixtures('depends_on_current_app')
  101. @skip.unless_module('redis')
  102. def test_reduce(self):
  103. from celery.backends.redis import RedisBackend
  104. x = RedisBackend(app=self.app)
  105. assert loads(dumps(x))
  106. def test_no_redis(self):
  107. self.Backend.redis = None
  108. with pytest.raises(ImproperlyConfigured):
  109. self.Backend(app=self.app)
  110. def test_url(self):
  111. self.app.conf.redis_socket_timeout = 30.0
  112. self.app.conf.redis_socket_connect_timeout = 100.0
  113. x = self.Backend(
  114. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  115. )
  116. assert x.connparams
  117. assert x.connparams['host'] == 'vandelay.com'
  118. assert x.connparams['db'] == 1
  119. assert x.connparams['port'] == 123
  120. assert x.connparams['password'] == 'bosco'
  121. assert x.connparams['socket_timeout'] == 30.0
  122. assert x.connparams['socket_connect_timeout'] == 100.0
  123. def test_socket_url(self):
  124. self.app.conf.redis_socket_timeout = 30.0
  125. self.app.conf.redis_socket_connect_timeout = 100.0
  126. x = self.Backend(
  127. 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
  128. )
  129. assert x.connparams
  130. assert x.connparams['path'] == '/tmp/redis.sock'
  131. assert (x.connparams['connection_class'] is
  132. redis.UnixDomainSocketConnection)
  133. assert 'host' not in x.connparams
  134. assert 'port' not in x.connparams
  135. assert x.connparams['socket_timeout'] == 30.0
  136. assert 'socket_connect_timeout' not in x.connparams
  137. assert x.connparams['db'] == 3
  138. @skip.unless_module('redis')
  139. def test_backend_ssl(self):
  140. self.app.conf.redis_backend_use_ssl = {
  141. 'ssl_cert_reqs': ssl.CERT_REQUIRED,
  142. 'ssl_ca_certs': '/path/to/ca.crt',
  143. 'ssl_certfile': '/path/to/client.crt',
  144. 'ssl_keyfile': '/path/to/client.key',
  145. }
  146. self.app.conf.redis_socket_timeout = 30.0
  147. self.app.conf.redis_socket_connect_timeout = 100.0
  148. x = self.Backend(
  149. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  150. )
  151. assert x.connparams
  152. assert x.connparams['host'] == 'vandelay.com'
  153. assert x.connparams['db'] == 1
  154. assert x.connparams['port'] == 123
  155. assert x.connparams['password'] == 'bosco'
  156. assert x.connparams['socket_timeout'] == 30.0
  157. assert x.connparams['socket_connect_timeout'] == 100.0
  158. assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
  159. assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'
  160. assert x.connparams['ssl_certfile'] == '/path/to/client.crt'
  161. assert x.connparams['ssl_keyfile'] == '/path/to/client.key'
  162. from redis.connection import SSLConnection
  163. assert x.connparams['connection_class'] is SSLConnection
  164. def test_compat_propertie(self):
  165. x = self.Backend(
  166. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  167. )
  168. with pytest.warns(CPendingDeprecationWarning):
  169. assert x.host == 'vandelay.com'
  170. with pytest.warns(CPendingDeprecationWarning):
  171. assert x.db == 1
  172. with pytest.warns(CPendingDeprecationWarning):
  173. assert x.port == 123
  174. with pytest.warns(CPendingDeprecationWarning):
  175. assert x.password == 'bosco'
  176. def test_conf_raises_KeyError(self):
  177. self.app.conf = AttributeDict({
  178. 'result_serializer': 'json',
  179. 'result_cache_max': 1,
  180. 'result_expires': None,
  181. 'accept_content': ['json'],
  182. })
  183. self.Backend(app=self.app)
  184. @patch('celery.backends.redis.logger')
  185. def test_on_connection_error(self, logger):
  186. intervals = iter([10, 20, 30])
  187. exc = KeyError()
  188. assert self.b.on_connection_error(None, exc, intervals, 1) == 10
  189. logger.error.assert_called_with(
  190. self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
  191. assert self.b.on_connection_error(10, exc, intervals, 2) == 20
  192. logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
  193. assert self.b.on_connection_error(10, exc, intervals, 3) == 30
  194. logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
  195. def test_incr(self):
  196. self.b.client = Mock(name='client')
  197. self.b.incr('foo')
  198. self.b.client.incr.assert_called_with('foo')
  199. def test_expire(self):
  200. self.b.client = Mock(name='client')
  201. self.b.expire('foo', 300)
  202. self.b.client.expire.assert_called_with('foo', 300)
  203. def test_apply_chord(self):
  204. header = Mock(name='header')
  205. header.results = [Mock(name='t1'), Mock(name='t2')]
  206. self.b.apply_chord(
  207. header, (1, 2), 'gid', None,
  208. options={'max_retries': 10},
  209. )
  210. header.assert_called_with(1, 2, max_retries=10, task_id='gid')
  211. def test_unpack_chord_result(self):
  212. self.b.exception_to_python = Mock(name='etp')
  213. decode = Mock(name='decode')
  214. exc = KeyError()
  215. tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
  216. with pytest.raises(ChordError):
  217. self.b._unpack_chord_result(tup, decode)
  218. decode.assert_called_with(tup)
  219. self.b.exception_to_python.assert_called_with(exc)
  220. exc = ValueError()
  221. tup = decode.return_value = (2, 'id2', states.RETRY, exc)
  222. ret = self.b._unpack_chord_result(tup, decode)
  223. self.b.exception_to_python.assert_called_with(exc)
  224. assert ret is self.b.exception_to_python()
  225. def test_on_chord_part_return_no_gid_or_tid(self):
  226. request = Mock(name='request')
  227. request.id = request.group = None
  228. assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
  229. def test_ConnectionPool(self):
  230. self.b.redis = Mock(name='redis')
  231. assert self.b._ConnectionPool is None
  232. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  233. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  234. def test_expires_defaults_to_config(self):
  235. self.app.conf.result_expires = 10
  236. b = self.Backend(expires=None, app=self.app)
  237. assert b.expires == 10
  238. def test_expires_is_int(self):
  239. b = self.Backend(expires=48, app=self.app)
  240. assert b.expires == 48
  241. def test_add_to_chord(self):
  242. b = self.Backend('redis://', app=self.app)
  243. gid = uuid()
  244. b.add_to_chord(gid, 'sig')
  245. b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
  246. def test_expires_is_None(self):
  247. b = self.Backend(expires=None, app=self.app)
  248. assert b.expires == self.app.conf.result_expires.total_seconds()
  249. def test_expires_is_timedelta(self):
  250. b = self.Backend(expires=timedelta(minutes=1), app=self.app)
  251. assert b.expires == 60
  252. def test_mget(self):
  253. assert self.b.mget(['a', 'b', 'c'])
  254. self.b.client.mget.assert_called_with(['a', 'b', 'c'])
  255. def test_set_no_expire(self):
  256. self.b.expires = None
  257. self.b.set('foo', 'bar')
  258. def create_task(self):
  259. tid = uuid()
  260. task = Mock(name='task-{0}'.format(tid))
  261. task.name = 'foobarbaz'
  262. self.app.tasks['foobarbaz'] = task
  263. task.request.chord = signature(task)
  264. task.request.id = tid
  265. task.request.chord['chord_size'] = 10
  266. task.request.group = 'group_id'
  267. return task
  268. @patch('celery.result.GroupResult.restore')
  269. def test_on_chord_part_return(self, restore):
  270. tasks = [self.create_task() for i in range(10)]
  271. for i in range(10):
  272. self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
  273. assert self.b.client.rpush.call_count
  274. self.b.client.rpush.reset_mock()
  275. assert self.b.client.lrange.call_count
  276. jkey = self.b.get_key_for_group('group_id', '.j')
  277. tkey = self.b.get_key_for_group('group_id', '.t')
  278. self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
  279. self.b.client.expire.assert_has_calls([
  280. call(jkey, 86400), call(tkey, 86400),
  281. ])
  282. def test_on_chord_part_return__success(self):
  283. with self.chord_context(2) as (_, request, callback):
  284. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  285. callback.delay.assert_not_called()
  286. self.b.on_chord_part_return(request, states.SUCCESS, 20)
  287. callback.delay.assert_called_with([10, 20])
  288. def test_on_chord_part_return__callback_raises(self):
  289. with self.chord_context(1) as (_, request, callback):
  290. callback.delay.side_effect = KeyError(10)
  291. task = self.app._tasks['add'] = Mock(name='add_task')
  292. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  293. task.backend.fail_from_current_stack.assert_called_with(
  294. callback.id, exc=ANY,
  295. )
  296. def test_on_chord_part_return__ChordError(self):
  297. with self.chord_context(1) as (_, request, callback):
  298. self.b.client.pipeline = ContextMock()
  299. raise_on_second_call(self.b.client.pipeline, ChordError())
  300. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  301. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  302. task = self.app._tasks['add'] = Mock(name='add_task')
  303. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  304. task.backend.fail_from_current_stack.assert_called_with(
  305. callback.id, exc=ANY,
  306. )
  307. def test_on_chord_part_return__other_error(self):
  308. with self.chord_context(1) as (_, request, callback):
  309. self.b.client.pipeline = ContextMock()
  310. raise_on_second_call(self.b.client.pipeline, RuntimeError())
  311. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  312. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  313. task = self.app._tasks['add'] = Mock(name='add_task')
  314. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  315. task.backend.fail_from_current_stack.assert_called_with(
  316. callback.id, exc=ANY,
  317. )
  318. @contextmanager
  319. def chord_context(self, size=1):
  320. with patch('celery.backends.redis.maybe_signature') as ms:
  321. tasks = [self.create_task() for i in range(size)]
  322. request = Mock(name='request')
  323. request.id = 'id1'
  324. request.group = 'gid1'
  325. callback = ms.return_value = Signature('add')
  326. callback.id = 'id1'
  327. callback['chord_size'] = size
  328. callback.delay = Mock(name='callback.delay')
  329. yield tasks, request, callback
  330. def test_process_cleanup(self):
  331. self.b.process_cleanup()
  332. def test_get_set_forget(self):
  333. tid = uuid()
  334. self.b.store_result(tid, 42, states.SUCCESS)
  335. assert self.b.get_state(tid) == states.SUCCESS
  336. assert self.b.get_result(tid) == 42
  337. self.b.forget(tid)
  338. assert self.b.get_state(tid) == states.PENDING
  339. def test_set_expires(self):
  340. self.b = self.Backend(expires=512, app=self.app)
  341. tid = uuid()
  342. key = self.b.get_key_for_task(tid)
  343. self.b.store_result(tid, 42, states.SUCCESS)
  344. self.b.client.expire.assert_called_with(
  345. key, 512,
  346. )