test_cache.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from __future__ import absolute_import, unicode_literals
  2. import sys
  3. import types
  4. from contextlib import contextmanager
  5. import pytest
  6. from case import Mock, mock, patch, skip
  7. from kombu.utils.encoding import ensure_bytes, str_to_bytes
  8. from celery import signature, states, uuid
  9. from celery.backends.cache import CacheBackend, DummyClient, backends
  10. from celery.exceptions import ImproperlyConfigured
  11. from celery.five import PY3, bytes_if_py2, items, string, text_t
  12. class SomeClass(object):
  13. def __init__(self, data):
  14. self.data = data
  15. class test_CacheBackend:
  16. def setup(self):
  17. self.app.conf.result_serializer = 'pickle'
  18. self.tb = CacheBackend(backend='memory://', app=self.app)
  19. self.tid = uuid()
  20. self.old_get_best_memcached = backends['memcache']
  21. backends['memcache'] = lambda: (DummyClient, ensure_bytes)
  22. def teardown(self):
  23. backends['memcache'] = self.old_get_best_memcached
  24. def test_no_backend(self):
  25. self.app.conf.cache_backend = None
  26. with pytest.raises(ImproperlyConfigured):
  27. CacheBackend(backend=None, app=self.app)
  28. def test_mark_as_done(self):
  29. assert self.tb.get_state(self.tid) == states.PENDING
  30. assert self.tb.get_result(self.tid) is None
  31. self.tb.mark_as_done(self.tid, 42)
  32. assert self.tb.get_state(self.tid) == states.SUCCESS
  33. assert self.tb.get_result(self.tid) == 42
  34. def test_is_pickled(self):
  35. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  36. self.tb.mark_as_done(self.tid, result)
  37. # is serialized properly.
  38. rindb = self.tb.get_result(self.tid)
  39. assert rindb.get('foo') == 'baz'
  40. assert rindb.get('bar').data == 12345
  41. def test_mark_as_failure(self):
  42. try:
  43. raise KeyError('foo')
  44. except KeyError as exception:
  45. self.tb.mark_as_failure(self.tid, exception)
  46. assert self.tb.get_state(self.tid) == states.FAILURE
  47. assert isinstance(self.tb.get_result(self.tid), KeyError)
  48. def test_apply_chord(self):
  49. tb = CacheBackend(backend='memory://', app=self.app)
  50. result = self.app.GroupResult(
  51. uuid(),
  52. [self.app.AsyncResult(uuid()) for _ in range(3)],
  53. )
  54. tb.apply_chord(result, None)
  55. assert self.app.GroupResult.restore(result.id, backend=tb) == result
  56. @patch('celery.result.GroupResult.restore')
  57. def test_on_chord_part_return(self, restore):
  58. tb = CacheBackend(backend='memory://', app=self.app)
  59. deps = Mock()
  60. deps.__len__ = Mock()
  61. deps.__len__.return_value = 2
  62. restore.return_value = deps
  63. task = Mock()
  64. task.name = 'foobarbaz'
  65. self.app.tasks['foobarbaz'] = task
  66. task.request.chord = signature(task)
  67. result = self.app.GroupResult(
  68. uuid(),
  69. [self.app.AsyncResult(uuid()) for _ in range(3)],
  70. )
  71. task.request.group = result.id
  72. tb.apply_chord(result, None)
  73. deps.join_native.assert_not_called()
  74. tb.on_chord_part_return(task.request, 'SUCCESS', 10)
  75. deps.join_native.assert_not_called()
  76. tb.on_chord_part_return(task.request, 'SUCCESS', 10)
  77. deps.join_native.assert_called_with(propagate=True, timeout=3.0)
  78. deps.delete.assert_called_with()
  79. def test_mget(self):
  80. self.tb.set('foo', 1)
  81. self.tb.set('bar', 2)
  82. assert self.tb.mget(['foo', 'bar']) == {'foo': 1, 'bar': 2}
  83. def test_forget(self):
  84. self.tb.mark_as_done(self.tid, {'foo': 'bar'})
  85. x = self.app.AsyncResult(self.tid, backend=self.tb)
  86. x.forget()
  87. assert x.result is None
  88. def test_process_cleanup(self):
  89. self.tb.process_cleanup()
  90. def test_expires_as_int(self):
  91. tb = CacheBackend(backend='memory://', expires=10, app=self.app)
  92. assert tb.expires == 10
  93. def test_unknown_backend_raises_ImproperlyConfigured(self):
  94. with pytest.raises(ImproperlyConfigured):
  95. CacheBackend(backend='unknown://', app=self.app)
  96. def test_as_uri_no_servers(self):
  97. assert self.tb.as_uri() == 'memory:///'
  98. def test_as_uri_one_server(self):
  99. backend = 'memcache://127.0.0.1:11211/'
  100. b = CacheBackend(backend=backend, app=self.app)
  101. assert b.as_uri() == backend
  102. def test_as_uri_multiple_servers(self):
  103. backend = 'memcache://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
  104. b = CacheBackend(backend=backend, app=self.app)
  105. assert b.as_uri() == backend
  106. @skip.unless_module('memcached', name='python-memcached')
  107. def test_regression_worker_startup_info(self):
  108. self.app.conf.result_backend = (
  109. 'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
  110. )
  111. worker = self.app.Worker()
  112. with mock.stdouts():
  113. worker.on_start()
  114. assert worker.startup_info()
  115. class MyMemcachedStringEncodingError(Exception):
  116. pass
  117. class MemcachedClient(DummyClient):
  118. def set(self, key, value, *args, **kwargs):
  119. if PY3:
  120. key_t, must_be, not_be, cod = bytes, 'string', 'bytes', 'decode'
  121. else:
  122. key_t, must_be, not_be, cod = text_t, 'bytes', 'string', 'encode'
  123. if isinstance(key, key_t):
  124. raise MyMemcachedStringEncodingError(
  125. 'Keys must be {0}, not {1}. Convert your '
  126. 'strings using mystring.{2}(charset)!'.format(
  127. must_be, not_be, cod))
  128. return super(MemcachedClient, self).set(key, value, *args, **kwargs)
  129. class MockCacheMixin(object):
  130. @contextmanager
  131. def mock_memcache(self):
  132. memcache = types.ModuleType(bytes_if_py2('memcache'))
  133. memcache.Client = MemcachedClient
  134. memcache.Client.__module__ = memcache.__name__
  135. prev, sys.modules['memcache'] = sys.modules.get('memcache'), memcache
  136. try:
  137. yield True
  138. finally:
  139. if prev is not None:
  140. sys.modules['memcache'] = prev
  141. @contextmanager
  142. def mock_pylibmc(self):
  143. pylibmc = types.ModuleType(bytes_if_py2('pylibmc'))
  144. pylibmc.Client = MemcachedClient
  145. pylibmc.Client.__module__ = pylibmc.__name__
  146. prev = sys.modules.get('pylibmc')
  147. sys.modules['pylibmc'] = pylibmc
  148. try:
  149. yield True
  150. finally:
  151. if prev is not None:
  152. sys.modules['pylibmc'] = prev
  153. class test_get_best_memcache(MockCacheMixin):
  154. def test_pylibmc(self):
  155. with self.mock_pylibmc():
  156. with mock.reset_modules('celery.backends.cache'):
  157. from celery.backends import cache
  158. cache._imp = [None]
  159. assert cache.get_best_memcache()[0].__module__ == 'pylibmc'
  160. def test_memcache(self):
  161. with self.mock_memcache():
  162. with mock.reset_modules('celery.backends.cache'):
  163. with mock.mask_modules('pylibmc'):
  164. from celery.backends import cache
  165. cache._imp = [None]
  166. assert (cache.get_best_memcache()[0]().__module__ ==
  167. 'memcache')
  168. def test_no_implementations(self):
  169. with mock.mask_modules('pylibmc', 'memcache'):
  170. with mock.reset_modules('celery.backends.cache'):
  171. from celery.backends import cache
  172. cache._imp = [None]
  173. with pytest.raises(ImproperlyConfigured):
  174. cache.get_best_memcache()
  175. def test_cached(self):
  176. with self.mock_pylibmc():
  177. with mock.reset_modules('celery.backends.cache'):
  178. from celery.backends import cache
  179. cache._imp = [None]
  180. cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
  181. assert cache._imp[0]
  182. cache.get_best_memcache()[0]()
  183. def test_backends(self):
  184. from celery.backends.cache import backends
  185. with self.mock_memcache():
  186. for name, fun in items(backends):
  187. assert fun()
  188. class test_memcache_key(MockCacheMixin):
  189. def test_memcache_unicode_key(self):
  190. with self.mock_memcache():
  191. with mock.reset_modules('celery.backends.cache'):
  192. with mock.mask_modules('pylibmc'):
  193. from celery.backends import cache
  194. cache._imp = [None]
  195. task_id, result = string(uuid()), 42
  196. b = cache.CacheBackend(backend='memcache', app=self.app)
  197. b.store_result(task_id, result, state=states.SUCCESS)
  198. assert b.get_result(task_id) == result
  199. def test_memcache_bytes_key(self):
  200. with self.mock_memcache():
  201. with mock.reset_modules('celery.backends.cache'):
  202. with mock.mask_modules('pylibmc'):
  203. from celery.backends import cache
  204. cache._imp = [None]
  205. task_id, result = str_to_bytes(uuid()), 42
  206. b = cache.CacheBackend(backend='memcache', app=self.app)
  207. b.store_result(task_id, result, state=states.SUCCESS)
  208. assert b.get_result(task_id) == result
  209. def test_pylibmc_unicode_key(self):
  210. with mock.reset_modules('celery.backends.cache'):
  211. with self.mock_pylibmc():
  212. from celery.backends import cache
  213. cache._imp = [None]
  214. task_id, result = string(uuid()), 42
  215. b = cache.CacheBackend(backend='memcache', app=self.app)
  216. b.store_result(task_id, result, state=states.SUCCESS)
  217. assert b.get_result(task_id) == result
  218. def test_pylibmc_bytes_key(self):
  219. with mock.reset_modules('celery.backends.cache'):
  220. with self.mock_pylibmc():
  221. from celery.backends import cache
  222. cache._imp = [None]
  223. task_id, result = str_to_bytes(uuid()), 42
  224. b = cache.CacheBackend(backend='memcache', app=self.app)
  225. b.store_result(task_id, result, state=states.SUCCESS)
  226. assert b.get_result(task_id) == result