dynamodb.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # -*- coding: utf-8 -*-
  2. """AWS DynamoDB result store backend."""
  3. from __future__ import absolute_import, unicode_literals
  4. from collections import namedtuple
  5. from time import time, sleep
  6. from kombu.utils.url import _parse_url as parse_url
  7. from celery.exceptions import ImproperlyConfigured
  8. from celery.utils.log import get_logger
  9. from celery.five import string
  10. from .base import KeyValueStoreBackend
  11. try:
  12. import boto3
  13. from botocore.exceptions import ClientError
  14. except ImportError: # pragma: no cover
  15. boto3 = ClientError = None # noqa
  16. __all__ = ['DynamoDBBackend']
  17. # Helper class that describes a DynamoDB attribute
  18. DynamoDBAttribute = namedtuple('DynamoDBAttribute', ('name', 'data_type'))
  19. logger = get_logger(__name__)
  20. class DynamoDBBackend(KeyValueStoreBackend):
  21. """AWS DynamoDB result backend.
  22. Raises:
  23. celery.exceptions.ImproperlyConfigured:
  24. if module :pypi:`boto3` is not available.
  25. """
  26. #: default DynamoDB table name (`default`)
  27. table_name = 'celery'
  28. #: Read Provisioned Throughput (`default`)
  29. read_capacity_units = 1
  30. #: Write Provisioned Throughput (`default`)
  31. write_capacity_units = 1
  32. #: AWS region (`default`)
  33. aws_region = None
  34. #: The endpoint URL that is passed to boto3 (local DynamoDB) (`default`)
  35. endpoint_url = None
  36. _key_field = DynamoDBAttribute(name='id', data_type='S')
  37. _value_field = DynamoDBAttribute(name='result', data_type='B')
  38. _timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
  39. _available_fields = None
  40. def __init__(self, url=None, table_name=None, *args, **kwargs):
  41. super(DynamoDBBackend, self).__init__(*args, **kwargs)
  42. self.url = url
  43. self.table_name = table_name or self.table_name
  44. if not boto3:
  45. raise ImproperlyConfigured(
  46. 'You need to install the boto3 library to use the '
  47. 'DynamoDB backend.')
  48. aws_credentials_given = False
  49. aws_access_key_id = None
  50. aws_secret_access_key = None
  51. if url is not None:
  52. scheme, region, port, username, password, table, query = \
  53. parse_url(url)
  54. aws_access_key_id = username
  55. aws_secret_access_key = password
  56. access_key_given = aws_access_key_id is not None
  57. secret_key_given = aws_secret_access_key is not None
  58. if access_key_given != secret_key_given:
  59. raise ImproperlyConfigured(
  60. 'You need to specify both the Access Key ID '
  61. 'and Secret.')
  62. aws_credentials_given = access_key_given
  63. if region == 'localhost':
  64. # We are using the downloadable, local version of DynamoDB
  65. self.endpoint_url = 'http://localhost:{}'.format(port)
  66. self.aws_region = 'us-east-1'
  67. logger.warning(
  68. 'Using local-only DynamoDB endpoint URL: {}'.format(
  69. self.endpoint_url
  70. )
  71. )
  72. else:
  73. self.aws_region = region
  74. self.read_capacity_units = int(
  75. query.get(
  76. 'read',
  77. self.read_capacity_units
  78. )
  79. )
  80. self.write_capacity_units = int(
  81. query.get(
  82. 'write',
  83. self.write_capacity_units
  84. )
  85. )
  86. self.table_name = table or self.table_name
  87. self._available_fields = (
  88. self._key_field,
  89. self._value_field,
  90. self._timestamp_field
  91. )
  92. self._client = None
  93. if aws_credentials_given:
  94. self._get_client(
  95. access_key_id=aws_access_key_id,
  96. secret_access_key=aws_secret_access_key
  97. )
  98. def _get_client(self, access_key_id=None, secret_access_key=None):
  99. """Get client connection."""
  100. if self._client is None:
  101. client_parameters = dict(
  102. region_name=self.aws_region
  103. )
  104. if access_key_id is not None:
  105. client_parameters.update(dict(
  106. aws_access_key_id=access_key_id,
  107. aws_secret_access_key=secret_access_key
  108. ))
  109. if self.endpoint_url is not None:
  110. client_parameters['endpoint_url'] = self.endpoint_url
  111. self._client = boto3.client(
  112. 'dynamodb',
  113. **client_parameters
  114. )
  115. self._get_or_create_table()
  116. return self._client
  117. def _get_table_schema(self):
  118. """Get the boto3 structure describing the DynamoDB table schema."""
  119. return dict(
  120. AttributeDefinitions=[
  121. {
  122. 'AttributeName': self._key_field.name,
  123. 'AttributeType': self._key_field.data_type
  124. }
  125. ],
  126. TableName=self.table_name,
  127. KeySchema=[
  128. {
  129. 'AttributeName': self._key_field.name,
  130. 'KeyType': 'HASH'
  131. }
  132. ],
  133. ProvisionedThroughput={
  134. 'ReadCapacityUnits': self.read_capacity_units,
  135. 'WriteCapacityUnits': self.write_capacity_units
  136. }
  137. )
  138. def _get_or_create_table(self):
  139. """Create table if not exists, otherwise return the description."""
  140. table_schema = self._get_table_schema()
  141. try:
  142. table_description = self._client.create_table(**table_schema)
  143. logger.info(
  144. 'DynamoDB Table {} did not exist, creating.'.format(
  145. self.table_name
  146. )
  147. )
  148. # In case we created the table, wait until it becomes available.
  149. self._wait_for_table_status('ACTIVE')
  150. logger.info(
  151. 'DynamoDB Table {} is now available.'.format(
  152. self.table_name
  153. )
  154. )
  155. return table_description
  156. except ClientError as e:
  157. error_code = e.response['Error'].get('Code', 'Unknown')
  158. # If table exists, do not fail, just return the description.
  159. if error_code == 'ResourceInUseException':
  160. return self._client.describe_table(
  161. TableName=self.table_name
  162. )
  163. else:
  164. raise e
  165. def _wait_for_table_status(self, expected='ACTIVE'):
  166. """Poll for the expected table status."""
  167. achieved_state = False
  168. while not achieved_state:
  169. table_description = self.client.describe_table(
  170. TableName=self.table_name
  171. )
  172. logger.debug(
  173. 'Waiting for DynamoDB table {} to become {}.'.format(
  174. self.table_name,
  175. expected
  176. )
  177. )
  178. current_status = table_description['Table']['TableStatus']
  179. achieved_state = current_status == expected
  180. sleep(1)
  181. def _prepare_get_request(self, key):
  182. """Construct the item retrieval request parameters."""
  183. return dict(
  184. TableName=self.table_name,
  185. Key={
  186. self._key_field.name: {
  187. self._key_field.data_type: key
  188. }
  189. }
  190. )
  191. def _prepare_put_request(self, key, value):
  192. """Construct the item creation request parameters."""
  193. return dict(
  194. TableName=self.table_name,
  195. Item={
  196. self._key_field.name: {
  197. self._key_field.data_type: key
  198. },
  199. self._value_field.name: {
  200. self._value_field.data_type: value
  201. },
  202. self._timestamp_field.name: {
  203. self._timestamp_field.data_type: str(time())
  204. }
  205. }
  206. )
  207. def _item_to_dict(self, raw_response):
  208. """Convert get_item() response to field-value pairs."""
  209. if 'Item' not in raw_response:
  210. return {}
  211. return {
  212. field.name: raw_response['Item'][field.name][field.data_type]
  213. for field in self._available_fields
  214. }
  215. @property
  216. def client(self):
  217. return self._get_client()
  218. def get(self, key):
  219. key = string(key)
  220. request_parameters = self._prepare_get_request(key)
  221. item_response = self.client.get_item(**request_parameters)
  222. item = self._item_to_dict(item_response)
  223. return item.get(self._value_field.name)
  224. def set(self, key, value):
  225. key = string(key)
  226. request_parameters = self._prepare_put_request(key, value)
  227. self.client.put_item(**request_parameters)
  228. def mget(self, keys):
  229. return [self.get(key) for key in keys]
  230. def delete(self, key):
  231. key = string(key)
  232. request_parameters = self._prepare_get_request(key)
  233. self.client.delete_item(**request_parameters)