test_tasks.py 17 KB


  1. from __future__ import absolute_import
  2. from datetime import datetime, timedelta
  3. from kombu import Queue
  4. from celery import Task
  5. from celery.exceptions import Retry
  6. from celery.five import items, range, string_t
  7. from celery.result import EagerResult
  8. from celery.utils import uuid
  9. from celery.utils.timeutils import parse_iso8601
  10. from celery.tests.case import AppCase, depends_on_current_app, patch
  11. def return_True(*args, **kwargs):
  12. # Task run functions can't be closures/lambdas, as they're pickled.
  13. return True
  14. def raise_exception(self, **kwargs):
  15. raise Exception('%s error' % self.__class__)
  16. class MockApplyTask(Task):
  17. abstract = True
  18. applied = 0
  19. def run(self, x, y):
  20. return x * y
  21. def apply_async(self, *args, **kwargs):
  22. self.applied += 1
  23. class TasksCase(AppCase):
  24. def setup(self):
  25. self.mytask = self.app.task(shared=False)(return_True)
  26. @self.app.task(bind=True, count=0, shared=False)
  27. def increment_counter(self, increment_by=1):
  28. self.count += increment_by or 1
  29. return self.count
  30. self.increment_counter = increment_counter
  31. @self.app.task(shared=False)
  32. def raising():
  33. raise KeyError('foo')
  34. self.raising = raising
  35. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  36. def retry_task(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
  37. self.iterations += 1
  38. rmax = self.max_retries if max_retries is None else max_retries
  39. assert repr(self.request)
  40. retries = self.request.retries
  41. if care and retries >= rmax:
  42. return arg1
  43. else:
  44. raise self.retry(countdown=0, max_retries=rmax)
  45. self.retry_task = retry_task
  46. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  47. def retry_task_noargs(self, **kwargs):
  48. self.iterations += 1
  49. if self.request.retries >= 3:
  50. return 42
  51. else:
  52. raise self.retry(countdown=0)
  53. self.retry_task_noargs = retry_task_noargs
  54. @self.app.task(bind=True, max_retries=3, iterations=0,
  55. base=MockApplyTask, shared=False)
  56. def retry_task_mockapply(self, arg1, arg2, kwarg=1):
  57. self.iterations += 1
  58. retries = self.request.retries
  59. if retries >= 3:
  60. return arg1
  61. raise self.retry(countdown=0)
  62. self.retry_task_mockapply = retry_task_mockapply
  63. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  64. def retry_task_customexc(self, arg1, arg2, kwarg=1, **kwargs):
  65. self.iterations += 1
  66. retries = self.request.retries
  67. if retries >= 3:
  68. return arg1 + kwarg
  69. else:
  70. try:
  71. raise MyCustomException('Elaine Marie Benes')
  72. except MyCustomException as exc:
  73. kwargs.update(kwarg=kwarg)
  74. raise self.retry(countdown=0, exc=exc)
  75. self.retry_task_customexc = retry_task_customexc
  76. @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
  77. shared=False)
  78. def autoretry_task_no_kwargs(self, a, b):
  79. self.iterations += 1
  80. return a/b
  81. self.autoretry_task_no_kwargs = autoretry_task_no_kwargs
  82. @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
  83. retry_kwargs={'max_retries': 5}, shared=False)
  84. def autoretry_task(self, a, b):
  85. self.iterations += 1
  86. return a/b
  87. self.autoretry_task = autoretry_task
  88. class MyCustomException(Exception):
  89. """Random custom exception."""
  90. class test_task_retries(TasksCase):
  91. def test_retry(self):
  92. self.retry_task.max_retries = 3
  93. self.retry_task.iterations = 0
  94. self.retry_task.apply([0xFF, 0xFFFF])
  95. self.assertEqual(self.retry_task.iterations, 4)
  96. self.retry_task.max_retries = 3
  97. self.retry_task.iterations = 0
  98. self.retry_task.apply([0xFF, 0xFFFF], {'max_retries': 10})
  99. self.assertEqual(self.retry_task.iterations, 11)
  100. def test_retry_no_args(self):
  101. self.retry_task_noargs.max_retries = 3
  102. self.retry_task_noargs.iterations = 0
  103. self.retry_task_noargs.apply(propagate=True).get()
  104. self.assertEqual(self.retry_task_noargs.iterations, 4)
  105. def test_signature_from_request__passes_headers(self):
  106. self.retry_task.push_request()
  107. self.retry_task.request.headers = {'custom': 10.1}
  108. sig = self.retry_task.signature_from_request()
  109. self.assertEqual(sig.options['headers']['custom'], 10.1)
  110. def test_signature_from_request__delivery_info(self):
  111. self.retry_task.push_request()
  112. self.retry_task.request.delivery_info = {
  113. 'exchange': 'testex',
  114. 'routing_key': 'testrk',
  115. }
  116. sig = self.retry_task.signature_from_request()
  117. self.assertEqual(sig.options['exchange'], 'testex')
  118. self.assertEqual(sig.options['routing_key'], 'testrk')
  119. def test_retry_kwargs_can_be_empty(self):
  120. self.retry_task_mockapply.push_request()
  121. try:
  122. with self.assertRaises(Retry):
  123. import sys
  124. try:
  125. sys.exc_clear()
  126. except AttributeError:
  127. pass
  128. self.retry_task_mockapply.retry(args=[4, 4], kwargs=None)
  129. finally:
  130. self.retry_task_mockapply.pop_request()
  131. def test_retry_not_eager(self):
  132. self.retry_task_mockapply.push_request()
  133. try:
  134. self.retry_task_mockapply.request.called_directly = False
  135. exc = Exception('baz')
  136. try:
  137. self.retry_task_mockapply.retry(
  138. args=[4, 4], kwargs={'task_retries': 0},
  139. exc=exc, throw=False,
  140. )
  141. self.assertTrue(self.retry_task_mockapply.applied)
  142. finally:
  143. self.retry_task_mockapply.applied = 0
  144. try:
  145. with self.assertRaises(Retry):
  146. self.retry_task_mockapply.retry(
  147. args=[4, 4], kwargs={'task_retries': 0},
  148. exc=exc, throw=True)
  149. self.assertTrue(self.retry_task_mockapply.applied)
  150. finally:
  151. self.retry_task_mockapply.applied = 0
  152. finally:
  153. self.retry_task_mockapply.pop_request()
  154. def test_retry_with_kwargs(self):
  155. self.retry_task_customexc.max_retries = 3
  156. self.retry_task_customexc.iterations = 0
  157. self.retry_task_customexc.apply([0xFF, 0xFFFF], {'kwarg': 0xF})
  158. self.assertEqual(self.retry_task_customexc.iterations, 4)
  159. def test_retry_with_custom_exception(self):
  160. self.retry_task_customexc.max_retries = 2
  161. self.retry_task_customexc.iterations = 0
  162. result = self.retry_task_customexc.apply(
  163. [0xFF, 0xFFFF], {'kwarg': 0xF},
  164. )
  165. with self.assertRaises(MyCustomException):
  166. result.get()
  167. self.assertEqual(self.retry_task_customexc.iterations, 3)
  168. def test_max_retries_exceeded(self):
  169. self.retry_task.max_retries = 2
  170. self.retry_task.iterations = 0
  171. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  172. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  173. result.get()
  174. self.assertEqual(self.retry_task.iterations, 3)
  175. self.retry_task.max_retries = 1
  176. self.retry_task.iterations = 0
  177. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  178. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  179. result.get()
  180. self.assertEqual(self.retry_task.iterations, 2)
  181. def test_autoretry_no_kwargs(self):
  182. self.autoretry_task_no_kwargs.max_retries = 3
  183. self.autoretry_task_no_kwargs.iterations = 0
  184. self.autoretry_task_no_kwargs.apply((1, 0))
  185. self.assertEqual(self.autoretry_task_no_kwargs.iterations, 4)
  186. def test_autoretry(self):
  187. self.autoretry_task.max_retries = 3
  188. self.autoretry_task.iterations = 0
  189. self.autoretry_task.apply((1, 0))
  190. self.assertEqual(self.autoretry_task.iterations, 6)
  191. class test_canvas_utils(TasksCase):
  192. def test_si(self):
  193. self.assertTrue(self.retry_task.si())
  194. self.assertTrue(self.retry_task.si().immutable)
  195. def test_chunks(self):
  196. self.assertTrue(self.retry_task.chunks(range(100), 10))
  197. def test_map(self):
  198. self.assertTrue(self.retry_task.map(range(100)))
  199. def test_starmap(self):
  200. self.assertTrue(self.retry_task.starmap(range(100)))
  201. def test_on_success(self):
  202. self.retry_task.on_success(1, 1, (), {})
  203. class test_tasks(TasksCase):
  204. def now(self):
  205. return self.app.now()
  206. @depends_on_current_app
  207. def test_unpickle_task(self):
  208. import pickle
  209. @self.app.task(shared=True)
  210. def xxx():
  211. pass
  212. self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
  213. def test_AsyncResult(self):
  214. task_id = uuid()
  215. result = self.retry_task.AsyncResult(task_id)
  216. self.assertEqual(result.backend, self.retry_task.backend)
  217. self.assertEqual(result.id, task_id)
  218. def assertNextTaskDataEqual(self, consumer, presult, task_name,
  219. test_eta=False, test_expires=False, **kwargs):
  220. next_task = consumer.queues[0].get(accept=['pickle', 'json'])
  221. task_data = next_task.decode()
  222. self.assertEqual(task_data['id'], presult.id)
  223. self.assertEqual(task_data['task'], task_name)
  224. task_kwargs = task_data.get('kwargs', {})
  225. if test_eta:
  226. self.assertIsInstance(task_data.get('eta'), string_t)
  227. to_datetime = parse_iso8601(task_data.get('eta'))
  228. self.assertIsInstance(to_datetime, datetime)
  229. if test_expires:
  230. self.assertIsInstance(task_data.get('expires'), string_t)
  231. to_datetime = parse_iso8601(task_data.get('expires'))
  232. self.assertIsInstance(to_datetime, datetime)
  233. for arg_name, arg_value in items(kwargs):
  234. self.assertEqual(task_kwargs.get(arg_name), arg_value)
  235. def test_incomplete_task_cls(self):
  236. class IncompleteTask(Task):
  237. app = self.app
  238. name = 'c.unittest.t.itask'
  239. with self.assertRaises(NotImplementedError):
  240. IncompleteTask().run()
  241. def test_task_kwargs_must_be_dictionary(self):
  242. with self.assertRaises(TypeError):
  243. self.increment_counter.apply_async([], 'str')
  244. def test_task_args_must_be_list(self):
  245. with self.assertRaises(ValueError):
  246. self.increment_counter.apply_async('s', {})
  247. def test_regular_task(self):
  248. self.assertIsInstance(self.mytask, Task)
  249. self.assertTrue(self.mytask.run())
  250. self.assertTrue(
  251. callable(self.mytask), 'Task class is callable()',
  252. )
  253. self.assertTrue(self.mytask(), 'Task class runs run() when called')
  254. with self.app.connection_or_acquire() as conn:
  255. consumer = self.app.amqp.TaskConsumer(conn)
  256. with self.assertRaises(NotImplementedError):
  257. consumer.receive('foo', 'foo')
  258. consumer.purge()
  259. self.assertIsNone(consumer.queues[0].get())
  260. self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
  261. # Without arguments.
  262. presult = self.mytask.delay()
  263. self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
  264. # With arguments.
  265. presult2 = self.mytask.apply_async(
  266. kwargs=dict(name='George Costanza'),
  267. )
  268. self.assertNextTaskDataEqual(
  269. consumer, presult2, self.mytask.name, name='George Costanza',
  270. )
  271. # send_task
  272. sresult = self.app.send_task(self.mytask.name,
  273. kwargs=dict(name='Elaine M. Benes'))
  274. self.assertNextTaskDataEqual(
  275. consumer, sresult, self.mytask.name, name='Elaine M. Benes',
  276. )
  277. # With eta.
  278. presult2 = self.mytask.apply_async(
  279. kwargs=dict(name='George Costanza'),
  280. eta=self.now() + timedelta(days=1),
  281. expires=self.now() + timedelta(days=2),
  282. )
  283. self.assertNextTaskDataEqual(
  284. consumer, presult2, self.mytask.name,
  285. name='George Costanza', test_eta=True, test_expires=True,
  286. )
  287. # With countdown.
  288. presult2 = self.mytask.apply_async(
  289. kwargs=dict(name='George Costanza'), countdown=10, expires=12,
  290. )
  291. self.assertNextTaskDataEqual(
  292. consumer, presult2, self.mytask.name,
  293. name='George Costanza', test_eta=True, test_expires=True,
  294. )
  295. # Discarding all tasks.
  296. consumer.purge()
  297. self.mytask.apply_async()
  298. self.assertEqual(consumer.purge(), 1)
  299. self.assertIsNone(consumer.queues[0].get())
  300. self.assertFalse(presult.successful())
  301. self.mytask.backend.mark_as_done(presult.id, result=None)
  302. self.assertTrue(presult.successful())
  303. def test_repr_v2_compat(self):
  304. self.mytask.__v2_compat__ = True
  305. self.assertIn('v2 compatible', repr(self.mytask))
  306. def test_apply_with_self(self):
  307. @self.app.task(__self__=42, shared=False)
  308. def tawself(self):
  309. return self
  310. self.assertEqual(tawself.apply().get(), 42)
  311. self.assertEqual(tawself(), 42)
  312. def test_context_get(self):
  313. self.mytask.push_request()
  314. try:
  315. request = self.mytask.request
  316. request.foo = 32
  317. self.assertEqual(request.get('foo'), 32)
  318. self.assertEqual(request.get('bar', 36), 36)
  319. request.clear()
  320. finally:
  321. self.mytask.pop_request()
  322. def test_annotate(self):
  323. with patch('celery.app.task.resolve_all_annotations') as anno:
  324. anno.return_value = [{'FOO': 'BAR'}]
  325. @self.app.task(shared=False)
  326. def task():
  327. pass
  328. task.annotate()
  329. self.assertEqual(task.FOO, 'BAR')
  330. def test_after_return(self):
  331. self.mytask.push_request()
  332. try:
  333. self.mytask.request.chord = self.mytask.s()
  334. self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
  335. self.mytask.request.clear()
  336. finally:
  337. self.mytask.pop_request()
  338. def test_update_state(self):
  339. @self.app.task(shared=False)
  340. def yyy():
  341. pass
  342. yyy.push_request()
  343. try:
  344. tid = uuid()
  345. yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
  346. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
  347. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  348. yyy.request.id = tid
  349. yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
  350. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
  351. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  352. finally:
  353. yyy.pop_request()
  354. def test_repr(self):
  355. @self.app.task(shared=False)
  356. def task_test_repr():
  357. pass
  358. self.assertIn('task_test_repr', repr(task_test_repr))
  359. def test_has___name__(self):
  360. @self.app.task(shared=False)
  361. def yyy2():
  362. pass
  363. self.assertTrue(yyy2.__name__)
  364. class test_apply_task(TasksCase):
  365. def test_apply_throw(self):
  366. with self.assertRaises(KeyError):
  367. self.raising.apply(throw=True)
  368. def test_apply_with_task_eager_propagates(self):
  369. self.app.conf.task_eager_propagates = True
  370. with self.assertRaises(KeyError):
  371. self.raising.apply()
  372. def test_apply(self):
  373. self.increment_counter.count = 0
  374. e = self.increment_counter.apply()
  375. self.assertIsInstance(e, EagerResult)
  376. self.assertEqual(e.get(), 1)
  377. e = self.increment_counter.apply(args=[1])
  378. self.assertEqual(e.get(), 2)
  379. e = self.increment_counter.apply(kwargs={'increment_by': 4})
  380. self.assertEqual(e.get(), 6)
  381. self.assertTrue(e.successful())
  382. self.assertTrue(e.ready())
  383. self.assertTrue(repr(e).startswith('<EagerResult:'))
  384. f = self.raising.apply()
  385. self.assertTrue(f.ready())
  386. self.assertFalse(f.successful())
  387. self.assertTrue(f.traceback)
  388. with self.assertRaises(KeyError):
  389. f.get()