redis.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # -*- coding: utf-8 -*-
  2. """Redis result store backend."""
  3. from __future__ import absolute_import, unicode_literals
  4. from functools import partial
  5. from kombu.utils.functional import retry_over_time
  6. from kombu.utils.objects import cached_property
  7. from kombu.utils.url import _parse_url
  8. from celery import states
  9. from celery._state import task_join_will_block
  10. from celery.canvas import maybe_signature
  11. from celery.exceptions import ChordError, ImproperlyConfigured
  12. from celery.five import string_t
  13. from celery.utils import deprecated
  14. from celery.utils.functional import dictfilter
  15. from celery.utils.log import get_logger
  16. from celery.utils.time import humanize_seconds
  17. from . import async
  18. from . import base
  19. try:
  20. import redis
  21. from kombu.transport.redis import get_redis_error_classes
  22. except ImportError: # pragma: no cover
  23. redis = None # noqa
  24. get_redis_error_classes = None # noqa
  25. try:
  26. from redis import sentinel
  27. except ImportError:
  28. sentinel = None
  29. __all__ = ('RedisBackend', 'SentinelBackend')
  30. E_REDIS_MISSING = """
  31. You need to install the redis library in order to use \
  32. the Redis result store backend.
  33. """
  34. E_REDIS_SENTINEL_MISSING = """
  35. You need to install the redis library with support of \
  36. sentinel in order to use the Redis result store backend.
  37. """
  38. E_LOST = 'Connection to Redis lost: Retry (%s/%s) %s.'
  39. logger = get_logger(__name__)
  40. class ResultConsumer(async.BaseResultConsumer):
  41. _pubsub = None
  42. def __init__(self, *args, **kwargs):
  43. super(ResultConsumer, self).__init__(*args, **kwargs)
  44. self._get_key_for_task = self.backend.get_key_for_task
  45. self._decode_result = self.backend.decode_result
  46. self.subscribed_to = set()
  47. def start(self, initial_task_id, **kwargs):
  48. self._pubsub = self.backend.client.pubsub(
  49. ignore_subscribe_messages=True,
  50. )
  51. self._consume_from(initial_task_id)
  52. def on_wait_for_pending(self, result, **kwargs):
  53. for meta in result._iter_meta():
  54. if meta is not None:
  55. self.on_state_change(meta, None)
  56. def stop(self):
  57. if self._pubsub is not None:
  58. self._pubsub.close()
  59. def drain_events(self, timeout=None):
  60. m = self._pubsub.get_message(timeout=timeout)
  61. if m and m['type'] == 'message':
  62. self.on_state_change(self._decode_result(m['data']), m)
  63. def consume_from(self, task_id):
  64. if self._pubsub is None:
  65. return self.start(task_id)
  66. self._consume_from(task_id)
  67. def _consume_from(self, task_id):
  68. key = self._get_key_for_task(task_id)
  69. if key not in self.subscribed_to:
  70. self.subscribed_to.add(key)
  71. self._pubsub.subscribe(key)
  72. def cancel_for(self, task_id):
  73. if self._pubsub:
  74. key = self._get_key_for_task(task_id)
  75. self.subscribed_to.discard(key)
  76. self._pubsub.unsubscribe(key)
  77. class RedisBackend(base.BaseKeyValueStoreBackend, async.AsyncBackendMixin):
  78. """Redis task result store."""
  79. ResultConsumer = ResultConsumer
  80. #: :pypi:`redis` client module.
  81. redis = redis
  82. #: Maximum number of connections in the pool.
  83. max_connections = None
  84. supports_autoexpire = True
  85. supports_native_join = True
  86. def __init__(self, host=None, port=None, db=None, password=None,
  87. max_connections=None, url=None,
  88. connection_pool=None, **kwargs):
  89. super(RedisBackend, self).__init__(expires_type=int, **kwargs)
  90. _get = self.app.conf.get
  91. if self.redis is None:
  92. raise ImproperlyConfigured(E_REDIS_MISSING.strip())
  93. if host and '://' in host:
  94. url, host = host, None
  95. self.max_connections = (
  96. max_connections or
  97. _get('redis_max_connections') or
  98. self.max_connections)
  99. self._ConnectionPool = connection_pool
  100. socket_timeout = _get('redis_socket_timeout')
  101. socket_connect_timeout = _get('redis_socket_connect_timeout')
  102. self.connparams = {
  103. 'host': _get('redis_host') or 'localhost',
  104. 'port': _get('redis_port') or 6379,
  105. 'db': _get('redis_db') or 0,
  106. 'password': _get('redis_password'),
  107. 'max_connections': self.max_connections,
  108. 'socket_timeout': socket_timeout and float(socket_timeout),
  109. 'socket_connect_timeout':
  110. socket_connect_timeout and float(socket_connect_timeout),
  111. }
  112. # "redis_backend_use_ssl" must be a dict with the keys:
  113. # 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile'
  114. # (the same as "broker_use_ssl")
  115. ssl = _get('redis_backend_use_ssl')
  116. if ssl:
  117. self.connparams.update(ssl)
  118. self.connparams['connection_class'] = redis.SSLConnection
  119. if url:
  120. self.connparams = self._params_from_url(url, self.connparams)
  121. self.url = url
  122. self.connection_errors, self.channel_errors = (
  123. get_redis_error_classes() if get_redis_error_classes
  124. else ((), ()))
  125. self.result_consumer = self.ResultConsumer(
  126. self, self.app, self.accept,
  127. self._pending_results, self._pending_messages,
  128. )
  129. def _params_from_url(self, url, defaults):
  130. scheme, host, port, _, password, path, query = _parse_url(url)
  131. connparams = dict(
  132. defaults, **dictfilter({
  133. 'host': host, 'port': port, 'password': password,
  134. 'db': query.pop('virtual_host', None)})
  135. )
  136. if scheme == 'socket':
  137. # use 'path' as path to the socket… in this case
  138. # the database number should be given in 'query'
  139. connparams.update({
  140. 'connection_class': self.redis.UnixDomainSocketConnection,
  141. 'path': '/' + path,
  142. })
  143. # host+port are invalid options when using this connection type.
  144. connparams.pop('host', None)
  145. connparams.pop('port', None)
  146. connparams.pop('socket_connect_timeout')
  147. else:
  148. connparams['db'] = path
  149. # db may be string and start with / like in kombu.
  150. db = connparams.get('db') or 0
  151. db = db.strip('/') if isinstance(db, string_t) else db
  152. connparams['db'] = int(db)
  153. # Query parameters override other parameters
  154. connparams.update(query)
  155. return connparams
  156. def on_task_call(self, producer, task_id):
  157. if not task_join_will_block():
  158. self.result_consumer.consume_from(task_id)
  159. def get(self, key):
  160. return self.client.get(key)
  161. def mget(self, keys):
  162. return self.client.mget(keys)
  163. def ensure(self, fun, args, **policy):
  164. retry_policy = dict(self.retry_policy, **policy)
  165. max_retries = retry_policy.get('max_retries')
  166. return retry_over_time(
  167. fun, self.connection_errors, args, {},
  168. partial(self.on_connection_error, max_retries),
  169. **retry_policy)
  170. def on_connection_error(self, max_retries, exc, intervals, retries):
  171. tts = next(intervals)
  172. logger.error(
  173. E_LOST.strip(),
  174. retries, max_retries or 'Inf', humanize_seconds(tts, 'in '))
  175. return tts
  176. def set(self, key, value, **retry_policy):
  177. return self.ensure(self._set, (key, value), **retry_policy)
  178. def _set(self, key, value):
  179. with self.client.pipeline() as pipe:
  180. if self.expires:
  181. pipe.setex(key, self.expires, value)
  182. else:
  183. pipe.set(key, value)
  184. pipe.publish(key, value)
  185. pipe.execute()
  186. def delete(self, key):
  187. self.client.delete(key)
  188. def incr(self, key):
  189. return self.client.incr(key)
  190. def expire(self, key, value):
  191. return self.client.expire(key, value)
  192. def add_to_chord(self, group_id, result):
  193. self.client.incr(self.get_key_for_group(group_id, '.t'), 1)
  194. def _unpack_chord_result(self, tup, decode,
  195. EXCEPTION_STATES=states.EXCEPTION_STATES,
  196. PROPAGATE_STATES=states.PROPAGATE_STATES):
  197. _, tid, state, retval = decode(tup)
  198. if state in EXCEPTION_STATES:
  199. retval = self.exception_to_python(retval)
  200. if state in PROPAGATE_STATES:
  201. raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
  202. return retval
  203. def apply_chord(self, header, partial_args, group_id, body,
  204. result=None, options={}, **kwargs):
  205. # Overrides this to avoid calling GroupResult.save
  206. # pylint: disable=method-hidden
  207. # Note that KeyValueStoreBackend.__init__ sets self.apply_chord
  208. # if the implements_incr attr is set. Redis backend doesn't set
  209. # this flag.
  210. options['task_id'] = group_id
  211. return header(*partial_args, **options or {})
  212. def on_chord_part_return(self, request, state, result,
  213. propagate=None, **kwargs):
  214. app = self.app
  215. tid, gid = request.id, request.group
  216. if not gid or not tid:
  217. return
  218. client = self.client
  219. jkey = self.get_key_for_group(gid, '.j')
  220. tkey = self.get_key_for_group(gid, '.t')
  221. result = self.encode_result(result, state)
  222. with client.pipeline() as pipe:
  223. _, readycount, totaldiff, _, _ = pipe \
  224. .rpush(jkey, self.encode([1, tid, state, result])) \
  225. .llen(jkey) \
  226. .get(tkey) \
  227. .expire(jkey, self.expires) \
  228. .expire(tkey, self.expires) \
  229. .execute()
  230. totaldiff = int(totaldiff or 0)
  231. try:
  232. callback = maybe_signature(request.chord, app=app)
  233. total = callback['chord_size'] + totaldiff
  234. if readycount == total:
  235. decode, unpack = self.decode, self._unpack_chord_result
  236. with client.pipeline() as pipe:
  237. resl, _, _ = pipe \
  238. .lrange(jkey, 0, total) \
  239. .delete(jkey) \
  240. .delete(tkey) \
  241. .execute()
  242. try:
  243. callback.delay([unpack(tup, decode) for tup in resl])
  244. except Exception as exc: # pylint: disable=broad-except
  245. logger.exception(
  246. 'Chord callback for %r raised: %r', request.group, exc)
  247. return self.chord_error_from_stack(
  248. callback,
  249. ChordError('Callback error: {0!r}'.format(exc)),
  250. )
  251. except ChordError as exc:
  252. logger.exception('Chord %r raised: %r', request.group, exc)
  253. return self.chord_error_from_stack(callback, exc)
  254. except Exception as exc: # pylint: disable=broad-except
  255. logger.exception('Chord %r raised: %r', request.group, exc)
  256. return self.chord_error_from_stack(
  257. callback,
  258. ChordError('Join error: {0!r}'.format(exc)),
  259. )
  260. def _create_client(self, **params):
  261. return self._get_client()(
  262. connection_pool=self._get_pool(**params),
  263. )
  264. def _get_client(self):
  265. return self.redis.StrictRedis
  266. def _get_pool(self, **params):
  267. return self.ConnectionPool(**params)
  268. @property
  269. def ConnectionPool(self):
  270. if self._ConnectionPool is None:
  271. self._ConnectionPool = self.redis.ConnectionPool
  272. return self._ConnectionPool
  273. @cached_property
  274. def client(self):
  275. return self._create_client(**self.connparams)
  276. def __reduce__(self, args=(), kwargs={}):
  277. return super(RedisBackend, self).__reduce__(
  278. (self.url,), {'expires': self.expires},
  279. )
  280. @deprecated.Property(4.0, 5.0)
  281. def host(self):
  282. return self.connparams['host']
  283. @deprecated.Property(4.0, 5.0)
  284. def port(self):
  285. return self.connparams['port']
  286. @deprecated.Property(4.0, 5.0)
  287. def db(self):
  288. return self.connparams['db']
  289. @deprecated.Property(4.0, 5.0)
  290. def password(self):
  291. return self.connparams['password']
  292. class SentinelBackend(RedisBackend):
  293. """Redis sentinel task result store."""
  294. sentinel = sentinel
  295. def __init__(self, *args, **kwargs):
  296. if self.sentinel is None:
  297. raise ImproperlyConfigured(E_REDIS_SENTINEL_MISSING.strip())
  298. super(SentinelBackend, self).__init__(*args, **kwargs)
  299. def _params_from_url(self, url, defaults):
  300. # URL looks like sentinel://0.0.0.0:26347/3;sentinel://0.0.0.0:26348/3.
  301. chunks = url.split(";")
  302. connparams = dict(defaults, hosts=[])
  303. for chunk in chunks:
  304. data = super(SentinelBackend, self)._params_from_url(
  305. url=chunk, defaults=defaults)
  306. connparams['hosts'].append(data)
  307. for p in ("host", "port", "db", "password"):
  308. connparams.pop(p)
  309. # Adding db/password in connparams to connect to the correct instance
  310. for p in ("db", "password"):
  311. if connparams['hosts'] and p in connparams['hosts'][0]:
  312. connparams[p] = connparams['hosts'][0].get(p)
  313. return connparams
  314. def _get_sentinel_instance(self, **params):
  315. connparams = params.copy()
  316. hosts = connparams.pop("hosts")
  317. result_backend_transport_opts = self.app.conf.get(
  318. "result_backend_transport_options", {})
  319. min_other_sentinels = result_backend_transport_opts.get(
  320. "min_other_sentinels", 0)
  321. sentinel_kwargs = result_backend_transport_opts.get(
  322. "sentinel_kwargs", {})
  323. sentinel_instance = self.sentinel.Sentinel(
  324. [(cp['host'], cp['port']) for cp in hosts],
  325. min_other_sentinels=min_other_sentinels,
  326. sentinel_kwargs=sentinel_kwargs,
  327. **connparams)
  328. return sentinel_instance
  329. def _get_pool(self, **params):
  330. sentinel_instance = self._get_sentinel_instance(**params)
  331. result_backend_transport_opts = self.app.conf.get(
  332. "result_backend_transport_options", {})
  333. master_name = result_backend_transport_opts.get("master_name", None)
  334. return sentinel_instance.master_for(
  335. service_name=master_name,
  336. redis_class=self._get_client(),
  337. ).connection_pool