test_tasks.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. from __future__ import absolute_import
  2. from datetime import datetime, timedelta
  3. from kombu import Queue
  4. from celery import Task
  5. from celery.exceptions import Retry
  6. from celery.five import items, range, string_t
  7. from celery.result import EagerResult
  8. from celery.utils import uuid
  9. from celery.utils.timeutils import parse_iso8601
  10. from celery.tests.case import AppCase, depends_on_current_app, patch
  11. def return_True(*args, **kwargs):
  12. # Task run functions can't be closures/lambdas, as they're pickled.
  13. return True
  14. def raise_exception(self, **kwargs):
  15. raise Exception('%s error' % self.__class__)
  16. class MockApplyTask(Task):
  17. abstract = True
  18. applied = 0
  19. def run(self, x, y):
  20. return x * y
  21. def apply_async(self, *args, **kwargs):
  22. self.applied += 1
  23. class TasksCase(AppCase):
  24. def setup(self):
  25. self.mytask = self.app.task(shared=False)(return_True)
  26. @self.app.task(bind=True, count=0, shared=False)
  27. def increment_counter(self, increment_by=1):
  28. self.count += increment_by or 1
  29. return self.count
  30. self.increment_counter = increment_counter
  31. @self.app.task(shared=False)
  32. def raising():
  33. raise KeyError('foo')
  34. self.raising = raising
  35. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  36. def retry_task(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
  37. self.iterations += 1
  38. rmax = self.max_retries if max_retries is None else max_retries
  39. assert repr(self.request)
  40. retries = self.request.retries
  41. if care and retries >= rmax:
  42. return arg1
  43. else:
  44. raise self.retry(countdown=0, max_retries=rmax)
  45. self.retry_task = retry_task
  46. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  47. def retry_task_noargs(self, **kwargs):
  48. self.iterations += 1
  49. if self.request.retries >= 3:
  50. return 42
  51. else:
  52. raise self.retry(countdown=0)
  53. self.retry_task_noargs = retry_task_noargs
  54. @self.app.task(bind=True, max_retries=3, iterations=0,
  55. base=MockApplyTask, shared=False)
  56. def retry_task_mockapply(self, arg1, arg2, kwarg=1):
  57. self.iterations += 1
  58. retries = self.request.retries
  59. if retries >= 3:
  60. return arg1
  61. raise self.retry(countdown=0)
  62. self.retry_task_mockapply = retry_task_mockapply
  63. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  64. def retry_task_customexc(self, arg1, arg2, kwarg=1, **kwargs):
  65. self.iterations += 1
  66. retries = self.request.retries
  67. if retries >= 3:
  68. return arg1 + kwarg
  69. else:
  70. try:
  71. raise MyCustomException('Elaine Marie Benes')
  72. except MyCustomException as exc:
  73. kwargs.update(kwarg=kwarg)
  74. raise self.retry(countdown=0, exc=exc)
  75. self.retry_task_customexc = retry_task_customexc
  76. class MyCustomException(Exception):
  77. """Random custom exception."""
  78. class test_task_retries(TasksCase):
  79. def test_retry(self):
  80. self.retry_task.max_retries = 3
  81. self.retry_task.iterations = 0
  82. self.retry_task.apply([0xFF, 0xFFFF])
  83. self.assertEqual(self.retry_task.iterations, 4)
  84. self.retry_task.max_retries = 3
  85. self.retry_task.iterations = 0
  86. self.retry_task.apply([0xFF, 0xFFFF], {'max_retries': 10})
  87. self.assertEqual(self.retry_task.iterations, 11)
  88. def test_retry_no_args(self):
  89. self.retry_task_noargs.max_retries = 3
  90. self.retry_task_noargs.iterations = 0
  91. self.retry_task_noargs.apply(propagate=True).get()
  92. self.assertEqual(self.retry_task_noargs.iterations, 4)
  93. def test_retry_kwargs_can_be_empty(self):
  94. self.retry_task_mockapply.push_request()
  95. try:
  96. with self.assertRaises(Retry):
  97. import sys
  98. sys.exc_clear()
  99. self.retry_task_mockapply.retry(args=[4, 4], kwargs=None)
  100. finally:
  101. self.retry_task_mockapply.pop_request()
  102. def test_retry_not_eager(self):
  103. self.retry_task_mockapply.push_request()
  104. try:
  105. self.retry_task_mockapply.request.called_directly = False
  106. exc = Exception('baz')
  107. try:
  108. self.retry_task_mockapply.retry(
  109. args=[4, 4], kwargs={'task_retries': 0},
  110. exc=exc, throw=False,
  111. )
  112. self.assertTrue(self.retry_task_mockapply.applied)
  113. finally:
  114. self.retry_task_mockapply.applied = 0
  115. try:
  116. with self.assertRaises(Retry):
  117. self.retry_task_mockapply.retry(
  118. args=[4, 4], kwargs={'task_retries': 0},
  119. exc=exc, throw=True)
  120. self.assertTrue(self.retry_task_mockapply.applied)
  121. finally:
  122. self.retry_task_mockapply.applied = 0
  123. finally:
  124. self.retry_task_mockapply.pop_request()
  125. def test_retry_with_kwargs(self):
  126. self.retry_task_customexc.max_retries = 3
  127. self.retry_task_customexc.iterations = 0
  128. self.retry_task_customexc.apply([0xFF, 0xFFFF], {'kwarg': 0xF})
  129. self.assertEqual(self.retry_task_customexc.iterations, 4)
  130. def test_retry_with_custom_exception(self):
  131. self.retry_task_customexc.max_retries = 2
  132. self.retry_task_customexc.iterations = 0
  133. result = self.retry_task_customexc.apply(
  134. [0xFF, 0xFFFF], {'kwarg': 0xF},
  135. )
  136. with self.assertRaises(MyCustomException):
  137. result.get()
  138. self.assertEqual(self.retry_task_customexc.iterations, 3)
  139. def test_max_retries_exceeded(self):
  140. self.retry_task.max_retries = 2
  141. self.retry_task.iterations = 0
  142. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  143. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  144. result.get()
  145. self.assertEqual(self.retry_task.iterations, 3)
  146. self.retry_task.max_retries = 1
  147. self.retry_task.iterations = 0
  148. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  149. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  150. result.get()
  151. self.assertEqual(self.retry_task.iterations, 2)
  152. class test_canvas_utils(TasksCase):
  153. def test_si(self):
  154. self.assertTrue(self.retry_task.si())
  155. self.assertTrue(self.retry_task.si().immutable)
  156. def test_chunks(self):
  157. self.assertTrue(self.retry_task.chunks(range(100), 10))
  158. def test_map(self):
  159. self.assertTrue(self.retry_task.map(range(100)))
  160. def test_starmap(self):
  161. self.assertTrue(self.retry_task.starmap(range(100)))
  162. def test_on_success(self):
  163. self.retry_task.on_success(1, 1, (), {})
  164. class test_tasks(TasksCase):
  165. def now(self):
  166. return self.app.now()
  167. @depends_on_current_app
  168. def test_unpickle_task(self):
  169. import pickle
  170. @self.app.task(shared=True)
  171. def xxx():
  172. pass
  173. self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
  174. def test_AsyncResult(self):
  175. task_id = uuid()
  176. result = self.retry_task.AsyncResult(task_id)
  177. self.assertEqual(result.backend, self.retry_task.backend)
  178. self.assertEqual(result.id, task_id)
  179. def assertNextTaskDataEqual(self, consumer, presult, task_name,
  180. test_eta=False, test_expires=False, **kwargs):
  181. next_task = consumer.queues[0].get(accept=['pickle'])
  182. task_data = next_task.decode()
  183. self.assertEqual(task_data['id'], presult.id)
  184. self.assertEqual(task_data['task'], task_name)
  185. task_kwargs = task_data.get('kwargs', {})
  186. if test_eta:
  187. self.assertIsInstance(task_data.get('eta'), string_t)
  188. to_datetime = parse_iso8601(task_data.get('eta'))
  189. self.assertIsInstance(to_datetime, datetime)
  190. if test_expires:
  191. self.assertIsInstance(task_data.get('expires'), string_t)
  192. to_datetime = parse_iso8601(task_data.get('expires'))
  193. self.assertIsInstance(to_datetime, datetime)
  194. for arg_name, arg_value in items(kwargs):
  195. self.assertEqual(task_kwargs.get(arg_name), arg_value)
  196. def test_incomplete_task_cls(self):
  197. class IncompleteTask(Task):
  198. app = self.app
  199. name = 'c.unittest.t.itask'
  200. with self.assertRaises(NotImplementedError):
  201. IncompleteTask().run()
  202. def test_task_kwargs_must_be_dictionary(self):
  203. with self.assertRaises(ValueError):
  204. self.increment_counter.apply_async([], 'str')
  205. def test_task_args_must_be_list(self):
  206. with self.assertRaises(ValueError):
  207. self.increment_counter.apply_async('str', {})
  208. def test_regular_task(self):
  209. self.assertIsInstance(self.mytask, Task)
  210. self.assertTrue(self.mytask.run())
  211. self.assertTrue(
  212. callable(self.mytask), 'Task class is callable()',
  213. )
  214. self.assertTrue(self.mytask(), 'Task class runs run() when called')
  215. with self.app.connection_or_acquire() as conn:
  216. consumer = self.app.amqp.TaskConsumer(conn)
  217. with self.assertRaises(NotImplementedError):
  218. consumer.receive('foo', 'foo')
  219. consumer.purge()
  220. self.assertIsNone(consumer.queues[0].get())
  221. self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
  222. # Without arguments.
  223. presult = self.mytask.delay()
  224. self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
  225. # With arguments.
  226. presult2 = self.mytask.apply_async(
  227. kwargs=dict(name='George Costanza'),
  228. )
  229. self.assertNextTaskDataEqual(
  230. consumer, presult2, self.mytask.name, name='George Costanza',
  231. )
  232. # send_task
  233. sresult = self.app.send_task(self.mytask.name,
  234. kwargs=dict(name='Elaine M. Benes'))
  235. self.assertNextTaskDataEqual(
  236. consumer, sresult, self.mytask.name, name='Elaine M. Benes',
  237. )
  238. # With eta.
  239. presult2 = self.mytask.apply_async(
  240. kwargs=dict(name='George Costanza'),
  241. eta=self.now() + timedelta(days=1),
  242. expires=self.now() + timedelta(days=2),
  243. )
  244. self.assertNextTaskDataEqual(
  245. consumer, presult2, self.mytask.name,
  246. name='George Costanza', test_eta=True, test_expires=True,
  247. )
  248. # With countdown.
  249. presult2 = self.mytask.apply_async(
  250. kwargs=dict(name='George Costanza'), countdown=10, expires=12,
  251. )
  252. self.assertNextTaskDataEqual(
  253. consumer, presult2, self.mytask.name,
  254. name='George Costanza', test_eta=True, test_expires=True,
  255. )
  256. # Discarding all tasks.
  257. consumer.purge()
  258. self.mytask.apply_async()
  259. self.assertEqual(consumer.purge(), 1)
  260. self.assertIsNone(consumer.queues[0].get())
  261. self.assertFalse(presult.successful())
  262. self.mytask.backend.mark_as_done(presult.id, result=None)
  263. self.assertTrue(presult.successful())
  264. def test_repr_v2_compat(self):
  265. self.mytask.__v2_compat__ = True
  266. self.assertIn('v2 compatible', repr(self.mytask))
  267. def test_apply_with_self(self):
  268. @self.app.task(__self__=42, shared=False)
  269. def tawself(self):
  270. return self
  271. self.assertEqual(tawself.apply().get(), 42)
  272. self.assertEqual(tawself(), 42)
  273. def test_context_get(self):
  274. self.mytask.push_request()
  275. try:
  276. request = self.mytask.request
  277. request.foo = 32
  278. self.assertEqual(request.get('foo'), 32)
  279. self.assertEqual(request.get('bar', 36), 36)
  280. request.clear()
  281. finally:
  282. self.mytask.pop_request()
  283. def test_task_class_repr(self):
  284. self.assertIn('class Task of', repr(self.mytask.app.Task))
  285. self.mytask.app.Task._app = None
  286. self.assertIn('unbound', repr(self.mytask.app.Task, ))
  287. def test_bind_no_magic_kwargs(self):
  288. self.mytask.accept_magic_kwargs = None
  289. self.mytask.bind(self.mytask.app)
  290. def test_annotate(self):
  291. with patch('celery.app.task.resolve_all_annotations') as anno:
  292. anno.return_value = [{'FOO': 'BAR'}]
  293. @self.app.task(shared=False)
  294. def task():
  295. pass
  296. task.annotate()
  297. self.assertEqual(task.FOO, 'BAR')
  298. def test_after_return(self):
  299. self.mytask.push_request()
  300. try:
  301. self.mytask.request.chord = self.mytask.s()
  302. self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
  303. self.mytask.request.clear()
  304. finally:
  305. self.mytask.pop_request()
  306. def test_send_task_sent_event(self):
  307. with self.app.connection() as conn:
  308. self.app.conf.CELERY_SEND_TASK_SENT_EVENT = True
  309. self.assertTrue(self.app.amqp.TaskProducer(conn).send_sent_event)
  310. def test_update_state(self):
  311. @self.app.task(shared=False)
  312. def yyy():
  313. pass
  314. yyy.push_request()
  315. try:
  316. tid = uuid()
  317. yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
  318. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
  319. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  320. yyy.request.id = tid
  321. yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
  322. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
  323. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  324. finally:
  325. yyy.pop_request()
  326. def test_repr(self):
  327. @self.app.task(shared=False)
  328. def task_test_repr():
  329. pass
  330. self.assertIn('task_test_repr', repr(task_test_repr))
  331. def test_has___name__(self):
  332. @self.app.task(shared=False)
  333. def yyy2():
  334. pass
  335. self.assertTrue(yyy2.__name__)
  336. class test_apply_task(TasksCase):
  337. def test_apply_throw(self):
  338. with self.assertRaises(KeyError):
  339. self.raising.apply(throw=True)
  340. def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
  341. self.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
  342. with self.assertRaises(KeyError):
  343. self.raising.apply()
  344. def test_apply(self):
  345. self.increment_counter.count = 0
  346. e = self.increment_counter.apply()
  347. self.assertIsInstance(e, EagerResult)
  348. self.assertEqual(e.get(), 1)
  349. e = self.increment_counter.apply(args=[1])
  350. self.assertEqual(e.get(), 2)
  351. e = self.increment_counter.apply(kwargs={'increment_by': 4})
  352. self.assertEqual(e.get(), 6)
  353. self.assertTrue(e.successful())
  354. self.assertTrue(e.ready())
  355. self.assertTrue(repr(e).startswith('<EagerResult:'))
  356. f = self.raising.apply()
  357. self.assertTrue(f.ready())
  358. self.assertFalse(f.successful())
  359. self.assertTrue(f.traceback)
  360. with self.assertRaises(KeyError):
  361. f.get()