test_tasks.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. from __future__ import absolute_import
  2. from datetime import datetime, timedelta
  3. from kombu import Queue
  4. from celery import Task
  5. from celery.exceptions import Retry
  6. from celery.five import items, range, string_t
  7. from celery.result import EagerResult
  8. from celery.utils import uuid
  9. from celery.utils.timeutils import parse_iso8601
  10. from celery.tests.case import AppCase, depends_on_current_app, patch
  11. def return_True(*args, **kwargs):
  12. # Task run functions can't be closures/lambdas, as they're pickled.
  13. return True
  14. def raise_exception(self, **kwargs):
  15. raise Exception('%s error' % self.__class__)
  16. class MockApplyTask(Task):
  17. abstract = True
  18. applied = 0
  19. def run(self, x, y):
  20. return x * y
  21. def apply_async(self, *args, **kwargs):
  22. self.applied += 1
  23. class TasksCase(AppCase):
  24. def setup(self):
  25. self.mytask = self.app.task(shared=False)(return_True)
  26. @self.app.task(bind=True, count=0, shared=False)
  27. def increment_counter(self, increment_by=1):
  28. self.count += increment_by or 1
  29. return self.count
  30. self.increment_counter = increment_counter
  31. @self.app.task(shared=False)
  32. def raising():
  33. raise KeyError('foo')
  34. self.raising = raising
  35. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  36. def retry_task(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
  37. self.iterations += 1
  38. rmax = self.max_retries if max_retries is None else max_retries
  39. assert repr(self.request)
  40. retries = self.request.retries
  41. if care and retries >= rmax:
  42. return arg1
  43. else:
  44. raise self.retry(countdown=0, max_retries=rmax)
  45. self.retry_task = retry_task
  46. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  47. def retry_task_noargs(self, **kwargs):
  48. self.iterations += 1
  49. if self.request.retries >= 3:
  50. return 42
  51. else:
  52. raise self.retry(countdown=0)
  53. self.retry_task_noargs = retry_task_noargs
  54. @self.app.task(bind=True, max_retries=3, iterations=0,
  55. base=MockApplyTask, shared=False)
  56. def retry_task_mockapply(self, arg1, arg2, kwarg=1):
  57. self.iterations += 1
  58. retries = self.request.retries
  59. if retries >= 3:
  60. return arg1
  61. raise self.retry(countdown=0)
  62. self.retry_task_mockapply = retry_task_mockapply
  63. @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
  64. def retry_task_customexc(self, arg1, arg2, kwarg=1, **kwargs):
  65. self.iterations += 1
  66. retries = self.request.retries
  67. if retries >= 3:
  68. return arg1 + kwarg
  69. else:
  70. try:
  71. raise MyCustomException('Elaine Marie Benes')
  72. except MyCustomException as exc:
  73. kwargs.update(kwarg=kwarg)
  74. raise self.retry(countdown=0, exc=exc)
  75. self.retry_task_customexc = retry_task_customexc
  76. class MyCustomException(Exception):
  77. """Random custom exception."""
  78. class test_task_retries(TasksCase):
  79. def test_retry(self):
  80. self.retry_task.max_retries = 3
  81. self.retry_task.iterations = 0
  82. self.retry_task.apply([0xFF, 0xFFFF])
  83. self.assertEqual(self.retry_task.iterations, 4)
  84. self.retry_task.max_retries = 3
  85. self.retry_task.iterations = 0
  86. self.retry_task.apply([0xFF, 0xFFFF], {'max_retries': 10})
  87. self.assertEqual(self.retry_task.iterations, 11)
  88. def test_retry_no_args(self):
  89. self.retry_task_noargs.max_retries = 3
  90. self.retry_task_noargs.iterations = 0
  91. self.retry_task_noargs.apply(propagate=True).get()
  92. self.assertEqual(self.retry_task_noargs.iterations, 4)
  93. def test_retry_kwargs_can_be_empty(self):
  94. self.retry_task_mockapply.push_request()
  95. try:
  96. with self.assertRaises(Retry):
  97. self.retry_task_mockapply.retry(args=[4, 4], kwargs=None)
  98. finally:
  99. self.retry_task_mockapply.pop_request()
  100. def test_retry_not_eager(self):
  101. self.retry_task_mockapply.push_request()
  102. try:
  103. self.retry_task_mockapply.request.called_directly = False
  104. exc = Exception('baz')
  105. try:
  106. self.retry_task_mockapply.retry(
  107. args=[4, 4], kwargs={'task_retries': 0},
  108. exc=exc, throw=False,
  109. )
  110. self.assertTrue(self.retry_task_mockapply.applied)
  111. finally:
  112. self.retry_task_mockapply.applied = 0
  113. try:
  114. with self.assertRaises(Retry):
  115. self.retry_task_mockapply.retry(
  116. args=[4, 4], kwargs={'task_retries': 0},
  117. exc=exc, throw=True)
  118. self.assertTrue(self.retry_task_mockapply.applied)
  119. finally:
  120. self.retry_task_mockapply.applied = 0
  121. finally:
  122. self.retry_task_mockapply.pop_request()
  123. def test_retry_with_kwargs(self):
  124. self.retry_task_customexc.max_retries = 3
  125. self.retry_task_customexc.iterations = 0
  126. self.retry_task_customexc.apply([0xFF, 0xFFFF], {'kwarg': 0xF})
  127. self.assertEqual(self.retry_task_customexc.iterations, 4)
  128. def test_retry_with_custom_exception(self):
  129. self.retry_task_customexc.max_retries = 2
  130. self.retry_task_customexc.iterations = 0
  131. result = self.retry_task_customexc.apply(
  132. [0xFF, 0xFFFF], {'kwarg': 0xF},
  133. )
  134. with self.assertRaises(MyCustomException):
  135. result.get()
  136. self.assertEqual(self.retry_task_customexc.iterations, 3)
  137. def test_max_retries_exceeded(self):
  138. self.retry_task.max_retries = 2
  139. self.retry_task.iterations = 0
  140. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  141. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  142. result.get()
  143. self.assertEqual(self.retry_task.iterations, 3)
  144. self.retry_task.max_retries = 1
  145. self.retry_task.iterations = 0
  146. result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
  147. with self.assertRaises(self.retry_task.MaxRetriesExceededError):
  148. result.get()
  149. self.assertEqual(self.retry_task.iterations, 2)
  150. class test_canvas_utils(TasksCase):
  151. def test_si(self):
  152. self.assertTrue(self.retry_task.si())
  153. self.assertTrue(self.retry_task.si().immutable)
  154. def test_chunks(self):
  155. self.assertTrue(self.retry_task.chunks(range(100), 10))
  156. def test_map(self):
  157. self.assertTrue(self.retry_task.map(range(100)))
  158. def test_starmap(self):
  159. self.assertTrue(self.retry_task.starmap(range(100)))
  160. def test_on_success(self):
  161. self.retry_task.on_success(1, 1, (), {})
  162. class test_tasks(TasksCase):
  163. def now(self):
  164. return self.app.now()
  165. @depends_on_current_app
  166. def test_unpickle_task(self):
  167. import pickle
  168. @self.app.task(shared=True)
  169. def xxx():
  170. pass
  171. self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
  172. def test_AsyncResult(self):
  173. task_id = uuid()
  174. result = self.retry_task.AsyncResult(task_id)
  175. self.assertEqual(result.backend, self.retry_task.backend)
  176. self.assertEqual(result.id, task_id)
  177. def assertNextTaskDataEqual(self, consumer, presult, task_name,
  178. test_eta=False, test_expires=False, **kwargs):
  179. next_task = consumer.queues[0].get(accept=['pickle'])
  180. task_data = next_task.decode()
  181. self.assertEqual(task_data['id'], presult.id)
  182. self.assertEqual(task_data['task'], task_name)
  183. task_kwargs = task_data.get('kwargs', {})
  184. if test_eta:
  185. self.assertIsInstance(task_data.get('eta'), string_t)
  186. to_datetime = parse_iso8601(task_data.get('eta'))
  187. self.assertIsInstance(to_datetime, datetime)
  188. if test_expires:
  189. self.assertIsInstance(task_data.get('expires'), string_t)
  190. to_datetime = parse_iso8601(task_data.get('expires'))
  191. self.assertIsInstance(to_datetime, datetime)
  192. for arg_name, arg_value in items(kwargs):
  193. self.assertEqual(task_kwargs.get(arg_name), arg_value)
  194. def test_incomplete_task_cls(self):
  195. class IncompleteTask(Task):
  196. app = self.app
  197. name = 'c.unittest.t.itask'
  198. with self.assertRaises(NotImplementedError):
  199. IncompleteTask().run()
  200. def test_task_kwargs_must_be_dictionary(self):
  201. with self.assertRaises(ValueError):
  202. self.increment_counter.apply_async([], 'str')
  203. def test_task_args_must_be_list(self):
  204. with self.assertRaises(ValueError):
  205. self.increment_counter.apply_async('str', {})
  206. def test_regular_task(self):
  207. self.assertIsInstance(self.mytask, Task)
  208. self.assertTrue(self.mytask.run())
  209. self.assertTrue(
  210. callable(self.mytask), 'Task class is callable()',
  211. )
  212. self.assertTrue(self.mytask(), 'Task class runs run() when called')
  213. with self.app.connection_or_acquire() as conn:
  214. consumer = self.app.amqp.TaskConsumer(conn)
  215. with self.assertRaises(NotImplementedError):
  216. consumer.receive('foo', 'foo')
  217. consumer.purge()
  218. self.assertIsNone(consumer.queues[0].get())
  219. self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
  220. # Without arguments.
  221. presult = self.mytask.delay()
  222. self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
  223. # With arguments.
  224. presult2 = self.mytask.apply_async(
  225. kwargs=dict(name='George Costanza'),
  226. )
  227. self.assertNextTaskDataEqual(
  228. consumer, presult2, self.mytask.name, name='George Costanza',
  229. )
  230. # send_task
  231. sresult = self.app.send_task(self.mytask.name,
  232. kwargs=dict(name='Elaine M. Benes'))
  233. self.assertNextTaskDataEqual(
  234. consumer, sresult, self.mytask.name, name='Elaine M. Benes',
  235. )
  236. # With eta.
  237. presult2 = self.mytask.apply_async(
  238. kwargs=dict(name='George Costanza'),
  239. eta=self.now() + timedelta(days=1),
  240. expires=self.now() + timedelta(days=2),
  241. )
  242. self.assertNextTaskDataEqual(
  243. consumer, presult2, self.mytask.name,
  244. name='George Costanza', test_eta=True, test_expires=True,
  245. )
  246. # With countdown.
  247. presult2 = self.mytask.apply_async(
  248. kwargs=dict(name='George Costanza'), countdown=10, expires=12,
  249. )
  250. self.assertNextTaskDataEqual(
  251. consumer, presult2, self.mytask.name,
  252. name='George Costanza', test_eta=True, test_expires=True,
  253. )
  254. # Discarding all tasks.
  255. consumer.purge()
  256. self.mytask.apply_async()
  257. self.assertEqual(consumer.purge(), 1)
  258. self.assertIsNone(consumer.queues[0].get())
  259. self.assertFalse(presult.successful())
  260. self.mytask.backend.mark_as_done(presult.id, result=None)
  261. self.assertTrue(presult.successful())
  262. def test_repr_v2_compat(self):
  263. self.mytask.__v2_compat__ = True
  264. self.assertIn('v2 compatible', repr(self.mytask))
  265. def test_apply_with_self(self):
  266. @self.app.task(__self__=42, shared=False)
  267. def tawself(self):
  268. return self
  269. self.assertEqual(tawself.apply().get(), 42)
  270. self.assertEqual(tawself(), 42)
  271. def test_context_get(self):
  272. self.mytask.push_request()
  273. try:
  274. request = self.mytask.request
  275. request.foo = 32
  276. self.assertEqual(request.get('foo'), 32)
  277. self.assertEqual(request.get('bar', 36), 36)
  278. request.clear()
  279. finally:
  280. self.mytask.pop_request()
  281. def test_task_class_repr(self):
  282. self.assertIn('class Task of', repr(self.mytask.app.Task))
  283. self.mytask.app.Task._app = None
  284. self.assertIn('unbound', repr(self.mytask.app.Task, ))
  285. def test_bind_no_magic_kwargs(self):
  286. self.mytask.accept_magic_kwargs = None
  287. self.mytask.bind(self.mytask.app)
  288. def test_annotate(self):
  289. with patch('celery.app.task.resolve_all_annotations') as anno:
  290. anno.return_value = [{'FOO': 'BAR'}]
  291. @self.app.task(shared=False)
  292. def task():
  293. pass
  294. task.annotate()
  295. self.assertEqual(task.FOO, 'BAR')
  296. def test_after_return(self):
  297. self.mytask.push_request()
  298. try:
  299. self.mytask.request.chord = self.mytask.s()
  300. self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
  301. self.mytask.request.clear()
  302. finally:
  303. self.mytask.pop_request()
  304. def test_send_task_sent_event(self):
  305. with self.app.connection() as conn:
  306. self.app.conf.CELERY_SEND_TASK_SENT_EVENT = True
  307. self.assertTrue(self.app.amqp.TaskProducer(conn).send_sent_event)
  308. def test_update_state(self):
  309. @self.app.task(shared=False)
  310. def yyy():
  311. pass
  312. yyy.push_request()
  313. try:
  314. tid = uuid()
  315. yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
  316. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
  317. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  318. yyy.request.id = tid
  319. yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
  320. self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
  321. self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
  322. finally:
  323. yyy.pop_request()
  324. def test_repr(self):
  325. @self.app.task(shared=False)
  326. def task_test_repr():
  327. pass
  328. self.assertIn('task_test_repr', repr(task_test_repr))
  329. def test_has___name__(self):
  330. @self.app.task(shared=False)
  331. def yyy2():
  332. pass
  333. self.assertTrue(yyy2.__name__)
  334. class test_apply_task(TasksCase):
  335. def test_apply_throw(self):
  336. with self.assertRaises(KeyError):
  337. self.raising.apply(throw=True)
  338. def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
  339. self.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
  340. with self.assertRaises(KeyError):
  341. self.raising.apply()
  342. def test_apply(self):
  343. self.increment_counter.count = 0
  344. e = self.increment_counter.apply()
  345. self.assertIsInstance(e, EagerResult)
  346. self.assertEqual(e.get(), 1)
  347. e = self.increment_counter.apply(args=[1])
  348. self.assertEqual(e.get(), 2)
  349. e = self.increment_counter.apply(kwargs={'increment_by': 4})
  350. self.assertEqual(e.get(), 6)
  351. self.assertTrue(e.successful())
  352. self.assertTrue(e.ready())
  353. self.assertTrue(repr(e).startswith('<EagerResult:'))
  354. f = self.raising.apply()
  355. self.assertTrue(f.ready())
  356. self.assertFalse(f.successful())
  357. self.assertTrue(f.traceback)
  358. with self.assertRaises(KeyError):
  359. f.get()