test_base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from __future__ import absolute_import
  2. import sys
  3. import types
  4. from contextlib import contextmanager
  5. from mock import Mock, patch
  6. from nose import SkipTest
  7. from celery import current_app
  8. from celery.exceptions import ChordError
  9. from celery.five import items, range
  10. from celery.result import AsyncResult, GroupResult
  11. from celery.utils import serialization
  12. from celery.utils.serialization import subclass_exception
  13. from celery.utils.serialization import find_pickleable_exception as fnpe
  14. from celery.utils.serialization import UnpickleableExceptionWrapper
  15. from celery.utils.serialization import get_pickleable_exception as gpe
  16. from celery import states
  17. from celery.backends.base import (
  18. BaseBackend,
  19. KeyValueStoreBackend,
  20. DisabledBackend,
  21. )
  22. from celery.utils import uuid
  23. from celery.tests.case import AppCase, Case
  24. class wrapobject(object):
  25. def __init__(self, *args, **kwargs):
  26. self.args = args
  27. if sys.version_info >= (3, 0):
  28. Oldstyle = None
  29. else:
  30. Oldstyle = types.ClassType('Oldstyle', (), {})
  31. Unpickleable = subclass_exception('Unpickleable', KeyError, 'foo.module')
  32. Impossible = subclass_exception('Impossible', object, 'foo.module')
  33. Lookalike = subclass_exception('Lookalike', wrapobject, 'foo.module')
  34. b = BaseBackend()
  35. class test_serialization(Case):
  36. def test_create_exception_cls(self):
  37. self.assertTrue(serialization.create_exception_cls('FooError', 'm'))
  38. self.assertTrue(serialization.create_exception_cls('FooError', 'm',
  39. KeyError))
  40. class test_BaseBackend_interface(Case):
  41. def test__forget(self):
  42. with self.assertRaises(NotImplementedError):
  43. b._forget('SOMExx-N0Nex1stant-IDxx-')
  44. def test_forget(self):
  45. with self.assertRaises(NotImplementedError):
  46. b.forget('SOMExx-N0nex1stant-IDxx-')
  47. def test_on_chord_part_return(self):
  48. b.on_chord_part_return(None)
  49. def test_on_chord_apply(self, unlock='celery.chord_unlock'):
  50. p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
  51. try:
  52. b.on_chord_apply('dakj221', 'sdokqweok',
  53. result=[AsyncResult(x) for x in [1, 2, 3]])
  54. self.assertTrue(current_app.tasks[unlock].apply_async.call_count)
  55. finally:
  56. current_app.tasks[unlock] = p
  57. class test_exception_pickle(Case):
  58. def test_oldstyle(self):
  59. if Oldstyle is None:
  60. raise SkipTest('py3k does not support old style classes')
  61. self.assertTrue(fnpe(Oldstyle()))
  62. def test_BaseException(self):
  63. self.assertIsNone(fnpe(Exception()))
  64. def test_get_pickleable_exception(self):
  65. exc = Exception('foo')
  66. self.assertEqual(gpe(exc), exc)
  67. def test_unpickleable(self):
  68. self.assertIsInstance(fnpe(Unpickleable()), KeyError)
  69. self.assertIsNone(fnpe(Impossible()))
  70. class test_prepare_exception(Case):
  71. def test_unpickleable(self):
  72. x = b.prepare_exception(Unpickleable(1, 2, 'foo'))
  73. self.assertIsInstance(x, KeyError)
  74. y = b.exception_to_python(x)
  75. self.assertIsInstance(y, KeyError)
  76. def test_impossible(self):
  77. x = b.prepare_exception(Impossible())
  78. self.assertIsInstance(x, UnpickleableExceptionWrapper)
  79. self.assertTrue(str(x))
  80. y = b.exception_to_python(x)
  81. self.assertEqual(y.__class__.__name__, 'Impossible')
  82. if sys.version_info < (2, 5):
  83. self.assertTrue(y.__class__.__module__)
  84. else:
  85. self.assertEqual(y.__class__.__module__, 'foo.module')
  86. def test_regular(self):
  87. x = b.prepare_exception(KeyError('baz'))
  88. self.assertIsInstance(x, KeyError)
  89. y = b.exception_to_python(x)
  90. self.assertIsInstance(y, KeyError)
  91. class KVBackend(KeyValueStoreBackend):
  92. mget_returns_dict = False
  93. def __init__(self, *args, **kwargs):
  94. self.db = {}
  95. super(KVBackend, self).__init__()
  96. def get(self, key):
  97. return self.db.get(key)
  98. def set(self, key, value):
  99. self.db[key] = value
  100. def mget(self, keys):
  101. if self.mget_returns_dict:
  102. return dict((key, self.get(key)) for key in keys)
  103. else:
  104. return [self.get(k) for k in keys]
  105. def delete(self, key):
  106. self.db.pop(key, None)
  107. class DictBackend(BaseBackend):
  108. def __init__(self, *args, **kwargs):
  109. BaseBackend.__init__(self, *args, **kwargs)
  110. self._data = {'can-delete': {'result': 'foo'}}
  111. def _restore_group(self, group_id):
  112. if group_id == 'exists':
  113. return {'result': 'group'}
  114. def _get_task_meta_for(self, task_id):
  115. if task_id == 'task-exists':
  116. return {'result': 'task'}
  117. def _delete_group(self, group_id):
  118. self._data.pop(group_id, None)
  119. class test_BaseBackend_dict(Case):
  120. def setUp(self):
  121. self.b = DictBackend()
  122. def test_delete_group(self):
  123. self.b.delete_group('can-delete')
  124. self.assertNotIn('can-delete', self.b._data)
  125. def test_prepare_exception_json(self):
  126. x = DictBackend(serializer='json')
  127. e = x.prepare_exception(KeyError('foo'))
  128. self.assertIn('exc_type', e)
  129. e = x.exception_to_python(e)
  130. self.assertEqual(e.__class__.__name__, 'KeyError')
  131. self.assertEqual(str(e), "'foo'")
  132. def test_save_group(self):
  133. b = BaseBackend()
  134. b._save_group = Mock()
  135. b.save_group('foofoo', 'xxx')
  136. b._save_group.assert_called_with('foofoo', 'xxx')
  137. def test_forget_interface(self):
  138. b = BaseBackend()
  139. with self.assertRaises(NotImplementedError):
  140. b.forget('foo')
  141. def test_restore_group(self):
  142. self.assertIsNone(self.b.restore_group('missing'))
  143. self.assertIsNone(self.b.restore_group('missing'))
  144. self.assertEqual(self.b.restore_group('exists'), 'group')
  145. self.assertEqual(self.b.restore_group('exists'), 'group')
  146. self.assertEqual(self.b.restore_group('exists', cache=False), 'group')
  147. def test_reload_group_result(self):
  148. self.b._cache = {}
  149. self.b.reload_group_result('exists')
  150. self.b._cache['exists'] = {'result': 'group'}
  151. def test_reload_task_result(self):
  152. self.b._cache = {}
  153. self.b.reload_task_result('task-exists')
  154. self.b._cache['task-exists'] = {'result': 'task'}
  155. def test_fail_from_current_stack(self):
  156. self.b.mark_as_failure = Mock()
  157. try:
  158. raise KeyError('foo')
  159. except KeyError as exc:
  160. self.b.fail_from_current_stack('task_id')
  161. self.assertTrue(self.b.mark_as_failure.called)
  162. args = self.b.mark_as_failure.call_args[0]
  163. self.assertEqual(args[0], 'task_id')
  164. self.assertIs(args[1], exc)
  165. self.assertTrue(args[2])
  166. def test_prepare_value_serializes_group_result(self):
  167. g = GroupResult('group_id', [AsyncResult('foo')])
  168. self.assertIsInstance(self.b.prepare_value(g), (list, tuple))
  169. def test_is_cached(self):
  170. self.b._cache['foo'] = 1
  171. self.assertTrue(self.b.is_cached('foo'))
  172. self.assertFalse(self.b.is_cached('false'))
  173. class test_KeyValueStoreBackend(AppCase):
  174. def setup(self):
  175. self.b = KVBackend()
  176. def test_on_chord_part_return(self):
  177. assert not self.b.implements_incr
  178. self.b.on_chord_part_return(None)
  179. def test_get_store_delete_result(self):
  180. tid = uuid()
  181. self.b.mark_as_done(tid, 'Hello world')
  182. self.assertEqual(self.b.get_result(tid), 'Hello world')
  183. self.assertEqual(self.b.get_status(tid), states.SUCCESS)
  184. self.b.forget(tid)
  185. self.assertEqual(self.b.get_status(tid), states.PENDING)
  186. def test_strip_prefix(self):
  187. x = self.b.get_key_for_task('x1b34')
  188. self.assertEqual(self.b._strip_prefix(x), 'x1b34')
  189. self.assertEqual(self.b._strip_prefix('x1b34'), 'x1b34')
  190. def test_get_many(self):
  191. for is_dict in True, False:
  192. self.b.mget_returns_dict = is_dict
  193. ids = dict((uuid(), i) for i in range(10))
  194. for id, i in items(ids):
  195. self.b.mark_as_done(id, i)
  196. it = self.b.get_many(list(ids))
  197. for i, (got_id, got_state) in enumerate(it):
  198. self.assertEqual(got_state['result'], ids[got_id])
  199. self.assertEqual(i, 9)
  200. self.assertTrue(list(self.b.get_many(list(ids))))
  201. def test_get_many_times_out(self):
  202. tasks = [uuid() for _ in range(4)]
  203. self.b._cache[tasks[1]] = {'status': 'PENDING'}
  204. with self.assertRaises(self.b.TimeoutError):
  205. list(self.b.get_many(tasks, timeout=0.01, interval=0.01))
  206. def test_chord_part_return_no_gid(self):
  207. self.b.implements_incr = True
  208. task = Mock()
  209. task.request.group = None
  210. self.b.get_key_for_chord = Mock()
  211. self.b.get_key_for_chord.side_effect = AssertionError(
  212. 'should not get here',
  213. )
  214. self.assertIsNone(self.b.on_chord_part_return(task))
  215. @contextmanager
  216. def _chord_part_context(self, b):
  217. @self.app.task()
  218. def callback(result):
  219. pass
  220. b.implements_incr = True
  221. b.client = Mock()
  222. with patch('celery.result.GroupResult') as GR:
  223. deps = GR.restore.return_value = Mock()
  224. deps.__len__ = Mock()
  225. deps.__len__.return_value = 10
  226. b.incr = Mock()
  227. b.incr.return_value = 10
  228. b.expire = Mock()
  229. task = Mock()
  230. task.request.group = 'grid'
  231. cb = task.request.chord = callback.s()
  232. task.request.chord._freeze()
  233. callback.backend = b
  234. callback.backend.fail_from_current_stack = Mock()
  235. yield task, deps, cb
  236. def test_chord_part_return_propagate_set(self):
  237. with self._chord_part_context(self.b) as (task, deps, _):
  238. self.b.on_chord_part_return(task, propagate=True)
  239. self.assertFalse(self.b.expire.called)
  240. deps.delete.assert_called_with()
  241. deps.join_native.assert_called_with(propagate=True)
  242. def test_chord_part_return_propagate_default(self):
  243. with self._chord_part_context(self.b) as (task, deps, _):
  244. self.b.on_chord_part_return(task, propagate=None)
  245. self.assertFalse(self.b.expire.called)
  246. deps.delete.assert_called_with()
  247. deps.join_native.assert_called_with(
  248. propagate=self.b.app.conf.CELERY_CHORD_PROPAGATES,
  249. )
  250. def test_chord_part_return_join_raises_internal(self):
  251. with self._chord_part_context(self.b) as (task, deps, callback):
  252. deps._failed_join_report = lambda: iter([])
  253. deps.join_native.side_effect = KeyError('foo')
  254. self.b.on_chord_part_return(task)
  255. self.assertTrue(self.b.fail_from_current_stack.called)
  256. args = self.b.fail_from_current_stack.call_args
  257. exc = args[1]['exc']
  258. self.assertIsInstance(exc, ChordError)
  259. self.assertIn('foo', str(exc))
  260. def test_chord_part_return_join_raises_task(self):
  261. with self._chord_part_context(self.b) as (task, deps, callback):
  262. deps._failed_join_report = lambda: iter([AsyncResult('culprit')])
  263. deps.join_native.side_effect = KeyError('foo')
  264. self.b.on_chord_part_return(task)
  265. self.assertTrue(self.b.fail_from_current_stack.called)
  266. args = self.b.fail_from_current_stack.call_args
  267. exc = args[1]['exc']
  268. self.assertIsInstance(exc, ChordError)
  269. self.assertIn('Dependency culprit raised', str(exc))
  270. def test_restore_group_from_json(self):
  271. b = KVBackend(serializer='json')
  272. g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
  273. b._save_group(g.id, g)
  274. g2 = b._restore_group(g.id)['result']
  275. self.assertEqual(g2, g)
  276. def test_restore_group_from_pickle(self):
  277. b = KVBackend(serializer='pickle')
  278. g = GroupResult('group_id', [AsyncResult('a'), AsyncResult('b')])
  279. b._save_group(g.id, g)
  280. g2 = b._restore_group(g.id)['result']
  281. self.assertEqual(g2, g)
  282. def test_chord_apply_fallback(self):
  283. self.b.implements_incr = False
  284. self.b.fallback_chord_unlock = Mock()
  285. self.b.on_chord_apply('group_id', 'body', 'result', foo=1)
  286. self.b.fallback_chord_unlock.assert_called_with(
  287. 'group_id', 'body', 'result', foo=1,
  288. )
  289. def test_get_missing_meta(self):
  290. self.assertIsNone(self.b.get_result('xxx-missing'))
  291. self.assertEqual(self.b.get_status('xxx-missing'), states.PENDING)
  292. def test_save_restore_delete_group(self):
  293. tid = uuid()
  294. tsr = GroupResult(tid, [AsyncResult(uuid()) for _ in range(10)])
  295. self.b.save_group(tid, tsr)
  296. self.b.restore_group(tid)
  297. self.assertEqual(self.b.restore_group(tid), tsr)
  298. self.b.delete_group(tid)
  299. self.assertIsNone(self.b.restore_group(tid))
  300. def test_restore_missing_group(self):
  301. self.assertIsNone(self.b.restore_group('xxx-nonexistant'))
  302. class test_KeyValueStoreBackend_interface(Case):
  303. def test_get(self):
  304. with self.assertRaises(NotImplementedError):
  305. KeyValueStoreBackend().get('a')
  306. def test_set(self):
  307. with self.assertRaises(NotImplementedError):
  308. KeyValueStoreBackend().set('a', 1)
  309. def test_incr(self):
  310. with self.assertRaises(NotImplementedError):
  311. KeyValueStoreBackend().incr('a')
  312. def test_cleanup(self):
  313. self.assertFalse(KeyValueStoreBackend().cleanup())
  314. def test_delete(self):
  315. with self.assertRaises(NotImplementedError):
  316. KeyValueStoreBackend().delete('a')
  317. def test_mget(self):
  318. with self.assertRaises(NotImplementedError):
  319. KeyValueStoreBackend().mget(['a'])
  320. def test_forget(self):
  321. with self.assertRaises(NotImplementedError):
  322. KeyValueStoreBackend().forget('a')
  323. class test_DisabledBackend(Case):
  324. def test_store_result(self):
  325. DisabledBackend().store_result()
  326. def test_is_disabled(self):
  327. with self.assertRaises(NotImplementedError):
  328. DisabledBackend().get_status('foo')