test_cache.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from __future__ import absolute_import
  2. import sys
  3. import types
  4. from contextlib import contextmanager
  5. from kombu.utils.encoding import str_to_bytes, ensure_bytes
  6. from celery import signature
  7. from celery import states
  8. from celery import group
  9. from celery.backends.cache import CacheBackend, DummyClient, backends
  10. from celery.exceptions import ImproperlyConfigured
  11. from celery.five import items, string, text_t
  12. from celery.utils import uuid
  13. from celery.tests.case import (
  14. AppCase, Mock, disable_stdouts, mask_modules, patch, reset_modules,
  15. )
  16. PY3 = sys.version_info[0] == 3
  17. class SomeClass(object):
  18. def __init__(self, data):
  19. self.data = data
  20. class test_CacheBackend(AppCase):
  21. def setup(self):
  22. self.tb = CacheBackend(backend='memory://', app=self.app)
  23. self.tid = uuid()
  24. self.old_get_best_memcached = backends['memcache']
  25. backends['memcache'] = lambda: (DummyClient, ensure_bytes)
  26. def teardown(self):
  27. backends['memcache'] = self.old_get_best_memcached
  28. def test_no_backend(self):
  29. self.app.conf.CELERY_CACHE_BACKEND = None
  30. with self.assertRaises(ImproperlyConfigured):
  31. CacheBackend(backend=None, app=self.app)
  32. def test_mark_as_done(self):
  33. self.assertEqual(self.tb.get_status(self.tid), states.PENDING)
  34. self.assertIsNone(self.tb.get_result(self.tid))
  35. self.tb.mark_as_done(self.tid, 42)
  36. self.assertEqual(self.tb.get_status(self.tid), states.SUCCESS)
  37. self.assertEqual(self.tb.get_result(self.tid), 42)
  38. def test_is_pickled(self):
  39. result = {'foo': 'baz', 'bar': SomeClass(12345)}
  40. self.tb.mark_as_done(self.tid, result)
  41. # is serialized properly.
  42. rindb = self.tb.get_result(self.tid)
  43. self.assertEqual(rindb.get('foo'), 'baz')
  44. self.assertEqual(rindb.get('bar').data, 12345)
  45. def test_mark_as_failure(self):
  46. try:
  47. raise KeyError('foo')
  48. except KeyError as exception:
  49. self.tb.mark_as_failure(self.tid, exception)
  50. self.assertEqual(self.tb.get_status(self.tid), states.FAILURE)
  51. self.assertIsInstance(self.tb.get_result(self.tid), KeyError)
  52. def test_apply_chord(self):
  53. tb = CacheBackend(backend='memory://', app=self.app)
  54. gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
  55. tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
  56. @patch('celery.result.GroupResult.restore')
  57. def test_on_chord_part_return(self, restore):
  58. tb = CacheBackend(backend='memory://', app=self.app)
  59. deps = Mock()
  60. deps.__len__ = Mock()
  61. deps.__len__.return_value = 2
  62. restore.return_value = deps
  63. task = Mock()
  64. task.name = 'foobarbaz'
  65. self.app.tasks['foobarbaz'] = task
  66. task.request.chord = signature(task)
  67. gid, res = uuid(), [self.app.AsyncResult(uuid()) for _ in range(3)]
  68. task.request.group = gid
  69. tb.apply_chord(group(app=self.app), (), gid, {}, result=res)
  70. self.assertFalse(deps.join_native.called)
  71. tb.on_chord_part_return(task, 'SUCCESS', 10)
  72. self.assertFalse(deps.join_native.called)
  73. tb.on_chord_part_return(task, 'SUCCESS', 10)
  74. deps.join_native.assert_called_with(propagate=True, timeout=3.0)
  75. deps.delete.assert_called_with()
  76. def test_mget(self):
  77. self.tb.set('foo', 1)
  78. self.tb.set('bar', 2)
  79. self.assertDictEqual(self.tb.mget(['foo', 'bar']),
  80. {'foo': 1, 'bar': 2})
  81. def test_forget(self):
  82. self.tb.mark_as_done(self.tid, {'foo': 'bar'})
  83. x = self.app.AsyncResult(self.tid, backend=self.tb)
  84. x.forget()
  85. self.assertIsNone(x.result)
  86. def test_process_cleanup(self):
  87. self.tb.process_cleanup()
  88. def test_expires_as_int(self):
  89. tb = CacheBackend(backend='memory://', expires=10, app=self.app)
  90. self.assertEqual(tb.expires, 10)
  91. def test_unknown_backend_raises_ImproperlyConfigured(self):
  92. with self.assertRaises(ImproperlyConfigured):
  93. CacheBackend(backend='unknown://', app=self.app)
  94. def test_as_uri_no_servers(self):
  95. self.assertEqual(self.tb.as_uri(), 'memory:///')
  96. def test_as_uri_one_server(self):
  97. backend = 'memcache://127.0.0.1:11211/'
  98. b = CacheBackend(backend=backend, app=self.app)
  99. self.assertEqual(b.as_uri(), backend)
  100. def test_as_uri_multiple_servers(self):
  101. backend = 'memcache://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
  102. b = CacheBackend(backend=backend, app=self.app)
  103. self.assertEqual(b.as_uri(), backend)
  104. @disable_stdouts
  105. def test_regression_worker_startup_info(self):
  106. self.app.conf.result_backend = (
  107. 'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/'
  108. )
  109. worker = self.app.Worker()
  110. worker.on_start()
  111. self.assertTrue(worker.startup_info())
  112. class MyMemcachedStringEncodingError(Exception):
  113. pass
  114. class MemcachedClient(DummyClient):
  115. def set(self, key, value, *args, **kwargs):
  116. if PY3:
  117. key_t, must_be, not_be, cod = bytes, 'string', 'bytes', 'decode'
  118. else:
  119. key_t, must_be, not_be, cod = text_t, 'bytes', 'string', 'encode'
  120. if isinstance(key, key_t):
  121. raise MyMemcachedStringEncodingError(
  122. 'Keys must be {0}, not {1}. Convert your '
  123. 'strings using mystring.{2}(charset)!'.format(
  124. must_be, not_be, cod))
  125. return super(MemcachedClient, self).set(key, value, *args, **kwargs)
  126. class MockCacheMixin(object):
  127. @contextmanager
  128. def mock_memcache(self):
  129. memcache = types.ModuleType('memcache')
  130. memcache.Client = MemcachedClient
  131. memcache.Client.__module__ = memcache.__name__
  132. prev, sys.modules['memcache'] = sys.modules.get('memcache'), memcache
  133. try:
  134. yield True
  135. finally:
  136. if prev is not None:
  137. sys.modules['memcache'] = prev
  138. @contextmanager
  139. def mock_pylibmc(self):
  140. pylibmc = types.ModuleType('pylibmc')
  141. pylibmc.Client = MemcachedClient
  142. pylibmc.Client.__module__ = pylibmc.__name__
  143. prev = sys.modules.get('pylibmc')
  144. sys.modules['pylibmc'] = pylibmc
  145. try:
  146. yield True
  147. finally:
  148. if prev is not None:
  149. sys.modules['pylibmc'] = prev
  150. class test_get_best_memcache(AppCase, MockCacheMixin):
  151. def test_pylibmc(self):
  152. with self.mock_pylibmc():
  153. with reset_modules('celery.backends.cache'):
  154. from celery.backends import cache
  155. cache._imp = [None]
  156. self.assertEqual(cache.get_best_memcache()[0].__module__,
  157. 'pylibmc')
  158. def test_memcache(self):
  159. with self.mock_memcache():
  160. with reset_modules('celery.backends.cache'):
  161. with mask_modules('pylibmc'):
  162. from celery.backends import cache
  163. cache._imp = [None]
  164. self.assertEqual(cache.get_best_memcache()[0]().__module__,
  165. 'memcache')
  166. def test_no_implementations(self):
  167. with mask_modules('pylibmc', 'memcache'):
  168. with reset_modules('celery.backends.cache'):
  169. from celery.backends import cache
  170. cache._imp = [None]
  171. with self.assertRaises(ImproperlyConfigured):
  172. cache.get_best_memcache()
  173. def test_cached(self):
  174. with self.mock_pylibmc():
  175. with reset_modules('celery.backends.cache'):
  176. from celery.backends import cache
  177. cache._imp = [None]
  178. cache.get_best_memcache()[0](behaviors={'foo': 'bar'})
  179. self.assertTrue(cache._imp[0])
  180. cache.get_best_memcache()[0]()
  181. def test_backends(self):
  182. from celery.backends.cache import backends
  183. with self.mock_memcache():
  184. for name, fun in items(backends):
  185. self.assertTrue(fun())
  186. class test_memcache_key(AppCase, MockCacheMixin):
  187. def test_memcache_unicode_key(self):
  188. with self.mock_memcache():
  189. with reset_modules('celery.backends.cache'):
  190. with mask_modules('pylibmc'):
  191. from celery.backends import cache
  192. cache._imp = [None]
  193. task_id, result = string(uuid()), 42
  194. b = cache.CacheBackend(backend='memcache', app=self.app)
  195. b.store_result(task_id, result, status=states.SUCCESS)
  196. self.assertEqual(b.get_result(task_id), result)
  197. def test_memcache_bytes_key(self):
  198. with self.mock_memcache():
  199. with reset_modules('celery.backends.cache'):
  200. with mask_modules('pylibmc'):
  201. from celery.backends import cache
  202. cache._imp = [None]
  203. task_id, result = str_to_bytes(uuid()), 42
  204. b = cache.CacheBackend(backend='memcache', app=self.app)
  205. b.store_result(task_id, result, status=states.SUCCESS)
  206. self.assertEqual(b.get_result(task_id), result)
  207. def test_pylibmc_unicode_key(self):
  208. with reset_modules('celery.backends.cache'):
  209. with self.mock_pylibmc():
  210. from celery.backends import cache
  211. cache._imp = [None]
  212. task_id, result = string(uuid()), 42
  213. b = cache.CacheBackend(backend='memcache', app=self.app)
  214. b.store_result(task_id, result, status=states.SUCCESS)
  215. self.assertEqual(b.get_result(task_id), result)
  216. def test_pylibmc_bytes_key(self):
  217. with reset_modules('celery.backends.cache'):
  218. with self.mock_pylibmc():
  219. from celery.backends import cache
  220. cache._imp = [None]
  221. task_id, result = str_to_bytes(uuid()), 42
  222. b = cache.CacheBackend(backend='memcache', app=self.app)
  223. b.store_result(task_id, result, status=states.SUCCESS)
  224. self.assertEqual(b.get_result(task_id), result)