pyredis.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from datetime import timedelta
  2. from kombu.utils import cached_property
  3. from celery.backends.base import KeyValueStoreBackend
  4. from celery.exceptions import ImproperlyConfigured
  5. from celery.result import TaskSetResult
  6. from celery.task.sets import subtask
  7. from celery.utils import timeutils
  8. try:
  9. import redis
  10. from redis.exceptions import ConnectionError
  11. except ImportError:
  12. redis = None
  13. ConnectionError = None
  14. class RedisBackend(KeyValueStoreBackend):
  15. """Redis task result store."""
  16. #: redis-py client module.
  17. redis = redis
  18. #: default Redis server hostname (`localhost`).
  19. redis_host = "localhost"
  20. #: default Redis server port (6379)
  21. redis_port = 6379
  22. redis_db = 0
  23. #: default Redis password (:const:`None`)
  24. redis_password = None
  25. def __init__(self, redis_host=None, redis_port=None, redis_db=None,
  26. redis_password=None,
  27. expires=None, **kwargs):
  28. super(RedisBackend, self).__init__(**kwargs)
  29. if self.redis is None:
  30. raise ImproperlyConfigured(
  31. "You need to install the redis library in order to use "
  32. + "Redis result store backend.")
  33. self.redis_host = (redis_host or
  34. self.app.conf.get("REDIS_HOST") or
  35. self.redis_host)
  36. self.redis_port = (redis_port or
  37. self.app.conf.get("REDIS_PORT") or
  38. self.redis_port)
  39. self.redis_db = (redis_db or
  40. self.app.conf.get("REDIS_DB") or
  41. self.redis_db)
  42. self.redis_password = (redis_password or
  43. self.app.conf.get("REDIS_PASSWORD") or
  44. self.redis_password)
  45. self.expires = expires
  46. if self.expires is None:
  47. self.expires = self.app.conf.CELERY_TASK_RESULT_EXPIRES
  48. if isinstance(self.expires, timedelta):
  49. self.expires = timeutils.timedelta_seconds(self.expires)
  50. if self.expires is not None:
  51. self.expires = int(self.expires)
  52. self.redis_port = int(self.redis_port)
  53. def get(self, key):
  54. return self.client.get(key)
  55. def set(self, key, value):
  56. client = self.client
  57. client.set(key, value)
  58. if self.expires is not None:
  59. client.expire(key, self.expires)
  60. def delete(self, key):
  61. self.client.delete(key)
  62. def close(self):
  63. """Closes the Redis connection."""
  64. del(self.client)
  65. def process_cleanup(self):
  66. self.close()
  67. def on_chord_apply(self, setid, body):
  68. pass
  69. def on_chord_part_return(self, task, keyprefix="chord-unlock-%s"):
  70. setid = task.request.taskset
  71. key = keyprefix % setid
  72. deps = TaskSetResult.restore(setid, backend=task.backend)
  73. if self.client.incr(key) >= deps.total:
  74. subtask(task.request.chord).delay(deps.join())
  75. self.client.expire(key, 86400)
  76. @cached_property
  77. def client(self):
  78. return self.redis.Redis(host=self.redis_host,
  79. port=self.redis_port,
  80. db=self.redis_db,
  81. password=self.redis_password)
  82. @client.deleter
  83. def client(self, client):
  84. client.connection.disconnect()