test_redis.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. from __future__ import absolute_import, unicode_literals
  2. import pytest
  3. import ssl
  4. import random
  5. from datetime import timedelta
  6. from contextlib import contextmanager
  7. from pickle import loads, dumps
  8. from case import ANY, ContextMock, Mock, mock, call, patch, skip
  9. from celery import signature
  10. from celery import states
  11. from celery import uuid
  12. from celery.canvas import Signature
  13. from celery.exceptions import (
  14. ChordError, CPendingDeprecationWarning, ImproperlyConfigured,
  15. )
  16. from celery.utils.collections import AttributeDict
  17. def raise_on_second_call(mock, exc, *retval):
  18. def on_first_call(*args, **kwargs):
  19. mock.side_effect = exc
  20. return mock.return_value
  21. mock.side_effect = on_first_call
  22. if retval:
  23. mock.return_value, = retval
  24. class Connection(object):
  25. connected = True
  26. def disconnect(self):
  27. self.connected = False
  28. class Pipeline(object):
  29. def __init__(self, client):
  30. self.client = client
  31. self.steps = []
  32. def __getattr__(self, attr):
  33. def add_step(*args, **kwargs):
  34. self.steps.append((getattr(self.client, attr), args, kwargs))
  35. return self
  36. return add_step
  37. def __enter__(self):
  38. return self
  39. def __exit__(self, type, value, traceback):
  40. pass
  41. def execute(self):
  42. return [step(*a, **kw) for step, a, kw in self.steps]
  43. class Redis(mock.MockCallbacks):
  44. Connection = Connection
  45. Pipeline = Pipeline
  46. def __init__(self, host=None, port=None, db=None, password=None, **kw):
  47. self.host = host
  48. self.port = port
  49. self.db = db
  50. self.password = password
  51. self.keyspace = {}
  52. self.expiry = {}
  53. self.connection = self.Connection()
  54. def get(self, key):
  55. return self.keyspace.get(key)
  56. def setex(self, key, expires, value):
  57. self.set(key, value)
  58. self.expire(key, expires)
  59. def set(self, key, value):
  60. self.keyspace[key] = value
  61. def expire(self, key, expires):
  62. self.expiry[key] = expires
  63. return expires
  64. def delete(self, key):
  65. return bool(self.keyspace.pop(key, None))
  66. def pipeline(self):
  67. return self.Pipeline(self)
  68. def _get_list(self, key):
  69. try:
  70. return self.keyspace[key]
  71. except KeyError:
  72. l = self.keyspace[key] = []
  73. return l
  74. def rpush(self, key, value):
  75. self._get_list(key).append(value)
  76. def lrange(self, key, start, stop):
  77. return self._get_list(key)[start:stop]
  78. def llen(self, key):
  79. return len(self.keyspace.get(key) or [])
  80. class Sentinel(mock.MockCallbacks):
  81. def __init__(self, sentinels, min_other_sentinels=0, sentinel_kwargs=None,
  82. **connection_kwargs):
  83. self.sentinel_kwargs = sentinel_kwargs
  84. self.sentinels = [Redis(hostname, port, **self.sentinel_kwargs)
  85. for hostname, port in sentinels]
  86. self.min_other_sentinels = min_other_sentinels
  87. self.connection_kwargs = connection_kwargs
  88. def master_for(self, service_name, redis_class):
  89. return random.choice(self.sentinels)
  90. class redis(object):
  91. StrictRedis = Redis
  92. class ConnectionPool(object):
  93. def __init__(self, **kwargs):
  94. pass
  95. class UnixDomainSocketConnection(object):
  96. def __init__(self, **kwargs):
  97. pass
  98. class sentinel(object):
  99. Sentinel = Sentinel
  100. class test_RedisBackend:
  101. def get_backend(self):
  102. from celery.backends.redis import RedisBackend
  103. class _RedisBackend(RedisBackend):
  104. redis = redis
  105. return _RedisBackend
  106. def get_E_LOST(self):
  107. from celery.backends.redis import E_LOST
  108. return E_LOST
  109. def setup(self):
  110. self.Backend = self.get_backend()
  111. self.E_LOST = self.get_E_LOST()
  112. self.b = self.Backend(app=self.app)
  113. @pytest.mark.usefixtures('depends_on_current_app')
  114. @skip.unless_module('redis')
  115. def test_reduce(self):
  116. from celery.backends.redis import RedisBackend
  117. x = RedisBackend(app=self.app)
  118. assert loads(dumps(x))
  119. def test_no_redis(self):
  120. self.Backend.redis = None
  121. with pytest.raises(ImproperlyConfigured):
  122. self.Backend(app=self.app)
  123. def test_url(self):
  124. self.app.conf.redis_socket_timeout = 30.0
  125. self.app.conf.redis_socket_connect_timeout = 100.0
  126. x = self.Backend(
  127. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  128. )
  129. assert x.connparams
  130. assert x.connparams['host'] == 'vandelay.com'
  131. assert x.connparams['db'] == 1
  132. assert x.connparams['port'] == 123
  133. assert x.connparams['password'] == 'bosco'
  134. assert x.connparams['socket_timeout'] == 30.0
  135. assert x.connparams['socket_connect_timeout'] == 100.0
  136. def test_socket_url(self):
  137. self.app.conf.redis_socket_timeout = 30.0
  138. self.app.conf.redis_socket_connect_timeout = 100.0
  139. x = self.Backend(
  140. 'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
  141. )
  142. assert x.connparams
  143. assert x.connparams['path'] == '/tmp/redis.sock'
  144. assert (x.connparams['connection_class'] is
  145. redis.UnixDomainSocketConnection)
  146. assert 'host' not in x.connparams
  147. assert 'port' not in x.connparams
  148. assert x.connparams['socket_timeout'] == 30.0
  149. assert 'socket_connect_timeout' not in x.connparams
  150. assert x.connparams['db'] == 3
  151. @skip.unless_module('redis')
  152. def test_backend_ssl(self):
  153. self.app.conf.redis_backend_use_ssl = {
  154. 'ssl_cert_reqs': ssl.CERT_REQUIRED,
  155. 'ssl_ca_certs': '/path/to/ca.crt',
  156. 'ssl_certfile': '/path/to/client.crt',
  157. 'ssl_keyfile': '/path/to/client.key',
  158. }
  159. self.app.conf.redis_socket_timeout = 30.0
  160. self.app.conf.redis_socket_connect_timeout = 100.0
  161. x = self.Backend(
  162. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  163. )
  164. assert x.connparams
  165. assert x.connparams['host'] == 'vandelay.com'
  166. assert x.connparams['db'] == 1
  167. assert x.connparams['port'] == 123
  168. assert x.connparams['password'] == 'bosco'
  169. assert x.connparams['socket_timeout'] == 30.0
  170. assert x.connparams['socket_connect_timeout'] == 100.0
  171. assert x.connparams['ssl_cert_reqs'] == ssl.CERT_REQUIRED
  172. assert x.connparams['ssl_ca_certs'] == '/path/to/ca.crt'
  173. assert x.connparams['ssl_certfile'] == '/path/to/client.crt'
  174. assert x.connparams['ssl_keyfile'] == '/path/to/client.key'
  175. from redis.connection import SSLConnection
  176. assert x.connparams['connection_class'] is SSLConnection
  177. def test_compat_propertie(self):
  178. x = self.Backend(
  179. 'redis://:bosco@vandelay.com:123//1', app=self.app,
  180. )
  181. with pytest.warns(CPendingDeprecationWarning):
  182. assert x.host == 'vandelay.com'
  183. with pytest.warns(CPendingDeprecationWarning):
  184. assert x.db == 1
  185. with pytest.warns(CPendingDeprecationWarning):
  186. assert x.port == 123
  187. with pytest.warns(CPendingDeprecationWarning):
  188. assert x.password == 'bosco'
  189. def test_conf_raises_KeyError(self):
  190. self.app.conf = AttributeDict({
  191. 'result_serializer': 'json',
  192. 'result_cache_max': 1,
  193. 'result_expires': None,
  194. 'accept_content': ['json'],
  195. })
  196. self.Backend(app=self.app)
  197. @patch('celery.backends.redis.logger')
  198. def test_on_connection_error(self, logger):
  199. intervals = iter([10, 20, 30])
  200. exc = KeyError()
  201. assert self.b.on_connection_error(None, exc, intervals, 1) == 10
  202. logger.error.assert_called_with(
  203. self.E_LOST, 1, 'Inf', 'in 10.00 seconds')
  204. assert self.b.on_connection_error(10, exc, intervals, 2) == 20
  205. logger.error.assert_called_with(self.E_LOST, 2, 10, 'in 20.00 seconds')
  206. assert self.b.on_connection_error(10, exc, intervals, 3) == 30
  207. logger.error.assert_called_with(self.E_LOST, 3, 10, 'in 30.00 seconds')
  208. def test_incr(self):
  209. self.b.client = Mock(name='client')
  210. self.b.incr('foo')
  211. self.b.client.incr.assert_called_with('foo')
  212. def test_expire(self):
  213. self.b.client = Mock(name='client')
  214. self.b.expire('foo', 300)
  215. self.b.client.expire.assert_called_with('foo', 300)
  216. def test_apply_chord(self):
  217. header = Mock(name='header')
  218. header.results = [Mock(name='t1'), Mock(name='t2')]
  219. self.b.apply_chord(
  220. header, (1, 2), 'gid', None,
  221. options={'max_retries': 10},
  222. )
  223. header.assert_called_with(1, 2, max_retries=10, task_id='gid')
  224. def test_unpack_chord_result(self):
  225. self.b.exception_to_python = Mock(name='etp')
  226. decode = Mock(name='decode')
  227. exc = KeyError()
  228. tup = decode.return_value = (1, 'id1', states.FAILURE, exc)
  229. with pytest.raises(ChordError):
  230. self.b._unpack_chord_result(tup, decode)
  231. decode.assert_called_with(tup)
  232. self.b.exception_to_python.assert_called_with(exc)
  233. exc = ValueError()
  234. tup = decode.return_value = (2, 'id2', states.RETRY, exc)
  235. ret = self.b._unpack_chord_result(tup, decode)
  236. self.b.exception_to_python.assert_called_with(exc)
  237. assert ret is self.b.exception_to_python()
  238. def test_on_chord_part_return_no_gid_or_tid(self):
  239. request = Mock(name='request')
  240. request.id = request.group = None
  241. assert self.b.on_chord_part_return(request, 'SUCCESS', 10) is None
  242. def test_ConnectionPool(self):
  243. self.b.redis = Mock(name='redis')
  244. assert self.b._ConnectionPool is None
  245. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  246. assert self.b.ConnectionPool is self.b.redis.ConnectionPool
  247. def test_expires_defaults_to_config(self):
  248. self.app.conf.result_expires = 10
  249. b = self.Backend(expires=None, app=self.app)
  250. assert b.expires == 10
  251. def test_expires_is_int(self):
  252. b = self.Backend(expires=48, app=self.app)
  253. assert b.expires == 48
  254. def test_add_to_chord(self):
  255. b = self.Backend('redis://', app=self.app)
  256. gid = uuid()
  257. b.add_to_chord(gid, 'sig')
  258. b.client.incr.assert_called_with(b.get_key_for_group(gid, '.t'), 1)
  259. def test_expires_is_None(self):
  260. b = self.Backend(expires=None, app=self.app)
  261. assert b.expires == self.app.conf.result_expires.total_seconds()
  262. def test_expires_is_timedelta(self):
  263. b = self.Backend(expires=timedelta(minutes=1), app=self.app)
  264. assert b.expires == 60
  265. def test_mget(self):
  266. assert self.b.mget(['a', 'b', 'c'])
  267. self.b.client.mget.assert_called_with(['a', 'b', 'c'])
  268. def test_set_no_expire(self):
  269. self.b.expires = None
  270. self.b.set('foo', 'bar')
  271. def create_task(self):
  272. tid = uuid()
  273. task = Mock(name='task-{0}'.format(tid))
  274. task.name = 'foobarbaz'
  275. self.app.tasks['foobarbaz'] = task
  276. task.request.chord = signature(task)
  277. task.request.id = tid
  278. task.request.chord['chord_size'] = 10
  279. task.request.group = 'group_id'
  280. return task
  281. @patch('celery.result.GroupResult.restore')
  282. def test_on_chord_part_return(self, restore):
  283. tasks = [self.create_task() for i in range(10)]
  284. for i in range(10):
  285. self.b.on_chord_part_return(tasks[i].request, states.SUCCESS, i)
  286. assert self.b.client.rpush.call_count
  287. self.b.client.rpush.reset_mock()
  288. assert self.b.client.lrange.call_count
  289. jkey = self.b.get_key_for_group('group_id', '.j')
  290. tkey = self.b.get_key_for_group('group_id', '.t')
  291. self.b.client.delete.assert_has_calls([call(jkey), call(tkey)])
  292. self.b.client.expire.assert_has_calls([
  293. call(jkey, 86400), call(tkey, 86400),
  294. ])
  295. def test_on_chord_part_return__success(self):
  296. with self.chord_context(2) as (_, request, callback):
  297. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  298. callback.delay.assert_not_called()
  299. self.b.on_chord_part_return(request, states.SUCCESS, 20)
  300. callback.delay.assert_called_with([10, 20])
  301. def test_on_chord_part_return__callback_raises(self):
  302. with self.chord_context(1) as (_, request, callback):
  303. callback.delay.side_effect = KeyError(10)
  304. task = self.app._tasks['add'] = Mock(name='add_task')
  305. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  306. task.backend.fail_from_current_stack.assert_called_with(
  307. callback.id, exc=ANY,
  308. )
  309. def test_on_chord_part_return__ChordError(self):
  310. with self.chord_context(1) as (_, request, callback):
  311. self.b.client.pipeline = ContextMock()
  312. raise_on_second_call(self.b.client.pipeline, ChordError())
  313. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  314. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  315. task = self.app._tasks['add'] = Mock(name='add_task')
  316. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  317. task.backend.fail_from_current_stack.assert_called_with(
  318. callback.id, exc=ANY,
  319. )
  320. def test_on_chord_part_return__other_error(self):
  321. with self.chord_context(1) as (_, request, callback):
  322. self.b.client.pipeline = ContextMock()
  323. raise_on_second_call(self.b.client.pipeline, RuntimeError())
  324. self.b.client.pipeline.return_value.rpush().llen().get().expire(
  325. ).expire().execute.return_value = (1, 1, 0, 4, 5)
  326. task = self.app._tasks['add'] = Mock(name='add_task')
  327. self.b.on_chord_part_return(request, states.SUCCESS, 10)
  328. task.backend.fail_from_current_stack.assert_called_with(
  329. callback.id, exc=ANY,
  330. )
  331. @contextmanager
  332. def chord_context(self, size=1):
  333. with patch('celery.backends.redis.maybe_signature') as ms:
  334. tasks = [self.create_task() for i in range(size)]
  335. request = Mock(name='request')
  336. request.id = 'id1'
  337. request.group = 'gid1'
  338. callback = ms.return_value = Signature('add')
  339. callback.id = 'id1'
  340. callback['chord_size'] = size
  341. callback.delay = Mock(name='callback.delay')
  342. yield tasks, request, callback
  343. def test_process_cleanup(self):
  344. self.b.process_cleanup()
  345. def test_get_set_forget(self):
  346. tid = uuid()
  347. self.b.store_result(tid, 42, states.SUCCESS)
  348. assert self.b.get_state(tid) == states.SUCCESS
  349. assert self.b.get_result(tid) == 42
  350. self.b.forget(tid)
  351. assert self.b.get_state(tid) == states.PENDING
  352. def test_set_expires(self):
  353. self.b = self.Backend(expires=512, app=self.app)
  354. tid = uuid()
  355. key = self.b.get_key_for_task(tid)
  356. self.b.store_result(tid, 42, states.SUCCESS)
  357. self.b.client.expire.assert_called_with(
  358. key, 512,
  359. )
  360. class test_SentinelBackend:
  361. def get_backend(self):
  362. from celery.backends.redis import SentinelBackend
  363. class _SentinelBackend(SentinelBackend):
  364. redis = redis
  365. sentinel = sentinel
  366. return _SentinelBackend
  367. def get_E_LOST(self):
  368. from celery.backends.redis import E_LOST
  369. return E_LOST
  370. def setup(self):
  371. self.Backend = self.get_backend()
  372. self.E_LOST = self.get_E_LOST()
  373. self.b = self.Backend(app=self.app)
  374. @pytest.mark.usefixtures('depends_on_current_app')
  375. @skip.unless_module('redis')
  376. def test_reduce(self):
  377. from celery.backends.redis import SentinelBackend
  378. x = SentinelBackend(app=self.app)
  379. assert loads(dumps(x))
  380. def test_no_redis(self):
  381. self.Backend.redis = None
  382. with pytest.raises(ImproperlyConfigured):
  383. self.Backend(app=self.app)
  384. def test_url(self):
  385. self.app.conf.redis_socket_timeout = 30.0
  386. self.app.conf.redis_socket_connect_timeout = 100.0
  387. x = self.Backend(
  388. 'sentinel://:test@github.com:123/1;'
  389. 'sentinel://:test@github.com:124/1',
  390. app=self.app,
  391. )
  392. assert x.connparams
  393. assert "host" not in x.connparams
  394. assert x.connparams['db'] == 1
  395. assert "port" not in x.connparams
  396. assert x.connparams['password'] == "test"
  397. assert len(x.connparams['hosts']) == 2
  398. expected_hosts = ["github.com", "github.com"]
  399. found_hosts = [cp['host'] for cp in x.connparams['hosts']]
  400. assert found_hosts == expected_hosts
  401. expected_ports = [123, 124]
  402. found_ports = [cp['port'] for cp in x.connparams['hosts']]
  403. assert found_ports == expected_ports
  404. expected_passwords = ["test", "test"]
  405. found_passwords = [cp['password'] for cp in x.connparams['hosts']]
  406. assert found_passwords == expected_passwords
  407. expected_dbs = [1, 1]
  408. found_dbs = [cp['db'] for cp in x.connparams['hosts']]
  409. assert found_dbs == expected_dbs
  410. def test_get_sentinel_instance(self):
  411. x = self.Backend(
  412. 'sentinel://:test@github.com:123/1;'
  413. 'sentinel://:test@github.com:124/1',
  414. app=self.app,
  415. )
  416. sentinel_instance = x._get_sentinel_instance(**x.connparams)
  417. assert sentinel_instance.sentinel_kwargs == {}
  418. assert sentinel_instance.connection_kwargs['db'] == 1
  419. assert sentinel_instance.connection_kwargs['password'] == "test"
  420. assert len(sentinel_instance.sentinels) == 2
  421. def test_get_pool(self):
  422. x = self.Backend(
  423. 'sentinel://:test@github.com:123/1;'
  424. 'sentinel://:test@github.com:124/1',
  425. app=self.app,
  426. )
  427. pool = x._get_pool(**x.connparams)
  428. assert pool