test_task.py 18 KB


  1. import unittest2 as unittest
  2. from StringIO import StringIO
  3. from datetime import datetime, timedelta
  4. from billiard.utils.functional import wraps
  5. from celery import task
  6. from celery import messaging
  7. from celery.utils import gen_unique_id
  8. from celery.result import EagerResult
  9. from celery.execute import send_task
  10. from celery.backends import default_backend
  11. from celery.decorators import task as task_dec
  12. from celery.exceptions import RetryTaskError
  13. from celery.worker.listener import parse_iso8601
  14. def return_True(*args, **kwargs):
  15. # Task run functions can't be closures/lambdas, as they're pickled.
  16. return True
  17. return_True_task = task_dec()(return_True)
  18. def raise_exception(self, **kwargs):
  19. raise Exception("%s error" % self.__class__)
  20. class MockApplyTask(task.Task):
  21. def run(self, x, y):
  22. return x * y
  23. @classmethod
  24. def apply_async(self, *args, **kwargs):
  25. pass
  26. class IncrementCounterTask(task.Task):
  27. name = "c.unittest.increment_counter_task"
  28. count = 0
  29. def run(self, increment_by=1, **kwargs):
  30. increment_by = increment_by or 1
  31. self.__class__.count += increment_by
  32. return self.__class__.count
  33. class RaisingTask(task.Task):
  34. name = "c.unittest.raising_task"
  35. def run(self, **kwargs):
  36. raise KeyError("foo")
  37. class RetryTask(task.Task):
  38. max_retries = 3
  39. iterations = 0
  40. def run(self, arg1, arg2, kwarg=1, **kwargs):
  41. self.__class__.iterations += 1
  42. retries = kwargs["task_retries"]
  43. if retries >= 3:
  44. return arg1
  45. else:
  46. kwargs.update({"kwarg": kwarg})
  47. return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
  48. class RetryTaskNoArgs(task.Task):
  49. max_retries = 3
  50. iterations = 0
  51. def run(self, **kwargs):
  52. self.__class__.iterations += 1
  53. retries = kwargs["task_retries"]
  54. if retries >= 3:
  55. return 42
  56. else:
  57. return self.retry(kwargs=kwargs, countdown=0)
  58. class RetryTaskMockApply(task.Task):
  59. max_retries = 3
  60. iterations = 0
  61. applied = 0
  62. def run(self, arg1, arg2, kwarg=1, **kwargs):
  63. self.__class__.iterations += 1
  64. retries = kwargs["task_retries"]
  65. if retries >= 3:
  66. return arg1
  67. else:
  68. kwargs.update({"kwarg": kwarg})
  69. return self.retry(args=[arg1, arg2], kwargs=kwargs, countdown=0)
  70. @classmethod
  71. def apply_async(self, *args, **kwargs):
  72. self.applied = 1
  73. class MyCustomException(Exception):
  74. """Random custom exception."""
  75. class RetryTaskCustomExc(task.Task):
  76. max_retries = 3
  77. iterations = 0
  78. def run(self, arg1, arg2, kwarg=1, **kwargs):
  79. self.__class__.iterations += 1
  80. retries = kwargs["task_retries"]
  81. if retries >= 3:
  82. return arg1 + kwarg
  83. else:
  84. try:
  85. raise MyCustomException("Elaine Marie Benes")
  86. except MyCustomException, exc:
  87. kwargs.update({"kwarg": kwarg})
  88. return self.retry(args=[arg1, arg2], kwargs=kwargs,
  89. countdown=0, exc=exc)
  90. class TestTaskRetries(unittest.TestCase):
  91. def test_retry(self):
  92. RetryTask.max_retries = 3
  93. RetryTask.iterations = 0
  94. result = RetryTask.apply([0xFF, 0xFFFF])
  95. self.assertEqual(result.get(), 0xFF)
  96. self.assertEqual(RetryTask.iterations, 4)
  97. def test_retry_no_args(self):
  98. RetryTaskNoArgs.max_retries = 3
  99. RetryTaskNoArgs.iterations = 0
  100. result = RetryTaskNoArgs.apply()
  101. self.assertEqual(result.get(), 42)
  102. self.assertEqual(RetryTaskNoArgs.iterations, 4)
  103. def test_retry_not_eager(self):
  104. exc = Exception("baz")
  105. try:
  106. RetryTaskMockApply.retry(args=[4, 4], kwargs={},
  107. exc=exc, throw=False)
  108. self.assertTrue(RetryTaskMockApply.applied)
  109. finally:
  110. RetryTaskMockApply.applied = 0
  111. try:
  112. self.assertRaises(RetryTaskError, RetryTaskMockApply.retry,
  113. args=[4, 4], kwargs={}, exc=exc, throw=True)
  114. self.assertTrue(RetryTaskMockApply.applied)
  115. finally:
  116. RetryTaskMockApply.applied = 0
  117. def test_retry_with_kwargs(self):
  118. RetryTaskCustomExc.max_retries = 3
  119. RetryTaskCustomExc.iterations = 0
  120. result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
  121. self.assertEqual(result.get(), 0xFF + 0xF)
  122. self.assertEqual(RetryTaskCustomExc.iterations, 4)
  123. def test_retry_with_custom_exception(self):
  124. RetryTaskCustomExc.max_retries = 2
  125. RetryTaskCustomExc.iterations = 0
  126. result = RetryTaskCustomExc.apply([0xFF, 0xFFFF], {"kwarg": 0xF})
  127. self.assertRaises(MyCustomException,
  128. result.get)
  129. self.assertEqual(RetryTaskCustomExc.iterations, 3)
  130. def test_max_retries_exceeded(self):
  131. RetryTask.max_retries = 2
  132. RetryTask.iterations = 0
  133. result = RetryTask.apply([0xFF, 0xFFFF])
  134. self.assertRaises(RetryTask.MaxRetriesExceededError,
  135. result.get)
  136. self.assertEqual(RetryTask.iterations, 3)
  137. RetryTask.max_retries = 1
  138. RetryTask.iterations = 0
  139. result = RetryTask.apply([0xFF, 0xFFFF])
  140. self.assertRaises(RetryTask.MaxRetriesExceededError,
  141. result.get)
  142. self.assertEqual(RetryTask.iterations, 2)
  143. class MockPublisher(object):
  144. def __init__(self, *args, **kwargs):
  145. self.kwargs = kwargs
  146. class TestCeleryTasks(unittest.TestCase):
  147. def createTaskCls(self, cls_name, task_name=None):
  148. attrs = {"__module__": self.__module__}
  149. if task_name:
  150. attrs["name"] = task_name
  151. cls = type(cls_name, (task.Task, ), attrs)
  152. cls.run = return_True
  153. return cls
  154. def test_AsyncResult(self):
  155. task_id = gen_unique_id()
  156. result = RetryTask.AsyncResult(task_id)
  157. self.assertEqual(result.backend, RetryTask.backend)
  158. self.assertEqual(result.task_id, task_id)
  159. def test_ping(self):
  160. from celery import conf
  161. conf.ALWAYS_EAGER = True
  162. self.assertEqual(task.ping(), 'pong')
  163. conf.ALWAYS_EAGER = False
  164. def test_execute_remote(self):
  165. from celery import conf
  166. conf.ALWAYS_EAGER = True
  167. self.assertEqual(task.execute_remote(return_True, ["foo"]).get(),
  168. True)
  169. conf.ALWAYS_EAGER = False
  170. def test_dmap(self):
  171. from celery import conf
  172. import operator
  173. conf.ALWAYS_EAGER = True
  174. res = task.dmap(operator.add, zip(xrange(10), xrange(10)))
  175. self.assertEqual(sum(res), sum(operator.add(x, x)
  176. for x in xrange(10)))
  177. conf.ALWAYS_EAGER = False
  178. def test_dmap_async(self):
  179. from celery import conf
  180. import operator
  181. conf.ALWAYS_EAGER = True
  182. res = task.dmap_async(operator.add, zip(xrange(10), xrange(10)))
  183. self.assertEqual(sum(res.get()), sum(operator.add(x, x)
  184. for x in xrange(10)))
  185. conf.ALWAYS_EAGER = False
  186. def assertNextTaskDataEqual(self, consumer, presult, task_name,
  187. test_eta=False, **kwargs):
  188. next_task = consumer.fetch()
  189. task_data = next_task.decode()
  190. self.assertEqual(task_data["id"], presult.task_id)
  191. self.assertEqual(task_data["task"], task_name)
  192. task_kwargs = task_data.get("kwargs", {})
  193. if test_eta:
  194. self.assertIsInstance(task_data.get("eta"), basestring)
  195. to_datetime = parse_iso8601(task_data.get("eta"))
  196. self.assertIsInstance(to_datetime, datetime)
  197. for arg_name, arg_value in kwargs.items():
  198. self.assertEqual(task_kwargs.get(arg_name), arg_value)
  199. def test_incomplete_task_cls(self):
  200. class IncompleteTask(task.Task):
  201. name = "c.unittest.t.itask"
  202. self.assertRaises(NotImplementedError, IncompleteTask().run)
  203. def test_task_kwargs_must_be_dictionary(self):
  204. self.assertRaises(ValueError, IncrementCounterTask.apply_async,
  205. [], "str")
  206. def test_task_args_must_be_list(self):
  207. self.assertRaises(ValueError, IncrementCounterTask.apply_async,
  208. "str", {})
  209. def test_regular_task(self):
  210. T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  211. self.assertIsInstance(T1(), T1)
  212. self.assertTrue(T1().run())
  213. self.assertTrue(callable(T1()),
  214. "Task class is callable()")
  215. self.assertTrue(T1()(),
  216. "Task class runs run() when called")
  217. # task name generated out of class module + name.
  218. T2 = self.createTaskCls("T2")
  219. self.assertTrue(T2().name.endswith("test_task.T2"))
  220. t1 = T1()
  221. consumer = t1.get_consumer()
  222. self.assertRaises(NotImplementedError, consumer.receive, "foo", "foo")
  223. consumer.discard_all()
  224. self.assertIsNone(consumer.fetch())
  225. # Without arguments.
  226. presult = t1.delay()
  227. self.assertNextTaskDataEqual(consumer, presult, t1.name)
  228. # With arguments.
  229. presult2 = t1.apply_async(kwargs=dict(name="George Constanza"))
  230. self.assertNextTaskDataEqual(consumer, presult2, t1.name,
  231. name="George Constanza")
  232. # send_task
  233. sresult = send_task(t1.name, kwargs=dict(name="Elaine M. Benes"))
  234. self.assertNextTaskDataEqual(consumer, sresult, t1.name,
  235. name="Elaine M. Benes")
  236. # With eta.
  237. presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
  238. eta=datetime.now() + timedelta(days=1))
  239. self.assertNextTaskDataEqual(consumer, presult2, t1.name,
  240. name="George Constanza", test_eta=True)
  241. # With countdown.
  242. presult2 = task.apply_async(t1, kwargs=dict(name="George Constanza"),
  243. countdown=10)
  244. self.assertNextTaskDataEqual(consumer, presult2, t1.name,
  245. name="George Constanza", test_eta=True)
  246. # Discarding all tasks.
  247. consumer.discard_all()
  248. task.apply_async(t1)
  249. self.assertEqual(consumer.discard_all(), 1)
  250. self.assertIsNone(consumer.fetch())
  251. self.assertFalse(presult.successful())
  252. default_backend.mark_as_done(presult.task_id, result=None)
  253. self.assertTrue(presult.successful())
  254. publisher = t1.get_publisher()
  255. self.assertIsInstance(publisher, messaging.TaskPublisher)
  256. def test_get_publisher(self):
  257. from celery.task import base
  258. old_pub = base.TaskPublisher
  259. base.TaskPublisher = MockPublisher
  260. try:
  261. p = IncrementCounterTask.get_publisher(exchange="foo",
  262. connection="bar")
  263. self.assertEqual(p.kwargs["exchange"], "foo")
  264. finally:
  265. base.TaskPublisher = old_pub
  266. def test_get_logger(self):
  267. T1 = self.createTaskCls("T1", "c.unittest.t.t1")
  268. t1 = T1()
  269. logfh = StringIO()
  270. logger = t1.get_logger(logfile=logfh, loglevel=0)
  271. self.assertTrue(logger)
  272. class TestTaskSet(unittest.TestCase):
  273. def test_function_taskset(self):
  274. from celery import conf
  275. conf.ALWAYS_EAGER = True
  276. ts = task.TaskSet(return_True_task.name, [
  277. ([1], {}), [[2], {}], [[3], {}], [[4], {}], [[5], {}]])
  278. res = ts.apply_async()
  279. self.assertListEqual(res.join(), [True, True, True, True, True])
  280. conf.ALWAYS_EAGER = False
  281. def test_counter_taskset(self):
  282. IncrementCounterTask.count = 0
  283. ts = task.TaskSet(IncrementCounterTask, [
  284. ([], {}),
  285. ([], {"increment_by": 2}),
  286. ([], {"increment_by": 3}),
  287. ([], {"increment_by": 4}),
  288. ([], {"increment_by": 5}),
  289. ([], {"increment_by": 6}),
  290. ([], {"increment_by": 7}),
  291. ([], {"increment_by": 8}),
  292. ([], {"increment_by": 9}),
  293. ])
  294. self.assertEqual(ts.task_name, IncrementCounterTask.name)
  295. self.assertEqual(ts.total, 9)
  296. consumer = IncrementCounterTask().get_consumer()
  297. consumer.discard_all()
  298. taskset_res = ts.apply_async()
  299. subtasks = taskset_res.subtasks
  300. taskset_id = taskset_res.taskset_id
  301. for subtask in subtasks:
  302. m = consumer.fetch().payload
  303. self.assertDictContainsSubset({"taskset": taskset_id,
  304. "task": IncrementCounterTask.name,
  305. "id": subtask.task_id}, m)
  306. IncrementCounterTask().run(
  307. increment_by=m.get("kwargs", {}).get("increment_by"))
  308. self.assertEqual(IncrementCounterTask.count, sum(xrange(1, 10)))
  309. class TestTaskApply(unittest.TestCase):
  310. def test_apply(self):
  311. IncrementCounterTask.count = 0
  312. e = IncrementCounterTask.apply()
  313. self.assertIsInstance(e, EagerResult)
  314. self.assertEqual(e.get(), 1)
  315. e = IncrementCounterTask.apply(args=[1])
  316. self.assertEqual(e.get(), 2)
  317. e = IncrementCounterTask.apply(kwargs={"increment_by": 4})
  318. self.assertEqual(e.get(), 6)
  319. self.assertTrue(e.successful())
  320. self.assertTrue(e.ready())
  321. self.assertTrue(repr(e).startswith("<EagerResult:"))
  322. f = RaisingTask.apply()
  323. self.assertTrue(f.ready())
  324. self.assertFalse(f.successful())
  325. self.assertTrue(f.traceback)
  326. self.assertRaises(KeyError, f.get)
  327. class MyPeriodic(task.PeriodicTask):
  328. run_every = timedelta(hours=1)
  329. class TestPeriodicTask(unittest.TestCase):
  330. def test_must_have_run_every(self):
  331. self.assertRaises(NotImplementedError, type, "Foo",
  332. (task.PeriodicTask, ), {"__module__": __name__})
  333. def test_remaining_estimate(self):
  334. self.assertIsInstance(
  335. MyPeriodic().remaining_estimate(datetime.now()),
  336. timedelta)
  337. def test_timedelta_seconds_returns_0_on_negative_time(self):
  338. delta = timedelta(days=-2)
  339. self.assertEqual(MyPeriodic().timedelta_seconds(delta), 0)
  340. def test_timedelta_seconds(self):
  341. deltamap = ((timedelta(seconds=1), 1),
  342. (timedelta(seconds=27), 27),
  343. (timedelta(minutes=3), 3 * 60),
  344. (timedelta(hours=4), 4 * 60 * 60),
  345. (timedelta(days=3), 3 * 86400))
  346. for delta, seconds in deltamap:
  347. self.assertEqual(MyPeriodic().timedelta_seconds(delta), seconds)
  348. def test_delta_resolution(self):
  349. D = MyPeriodic().run_every.delta_resolution
  350. dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
  351. deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
  352. (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
  353. (timedelta(minutes=2), datetime(2010, 3, 30, 11, 50)),
  354. (timedelta(seconds=2), dt))
  355. for delta, shoulda in deltamap:
  356. self.assertEqual(D(dt, delta), shoulda)
  357. def test_is_due_not_due(self):
  358. due, remaining = MyPeriodic().is_due(datetime.now())
  359. self.assertFalse(due)
  360. self.assertGreater(remaining, 60)
  361. def test_is_due(self):
  362. p = MyPeriodic()
  363. due, remaining = p.is_due(datetime.now() - p.run_every.run_every)
  364. self.assertTrue(due)
  365. self.assertEqual(remaining,
  366. p.timedelta_seconds(p.run_every.run_every))
  367. class EveryMinutePeriodic(task.PeriodicTask):
  368. run_every = task.crontab()
  369. class HourlyPeriodic(task.PeriodicTask):
  370. run_every = task.crontab(minute=30)
  371. class DailyPeriodic(task.PeriodicTask):
  372. run_every = task.crontab(hour=7, minute=30)
  373. class WeeklyPeriodic(task.PeriodicTask):
  374. run_every = task.crontab(hour=7, minute=30, day_of_week=4)
  375. def patch_crontab_nowfun(cls, retval):
  376. def create_patcher(fun):
  377. @wraps(fun)
  378. def __inner(*args, **kwargs):
  379. prev_nowfun = cls.run_every.nowfun
  380. cls.run_every.nowfun = lambda: retval
  381. try:
  382. return fun(*args, **kwargs)
  383. finally:
  384. cls.run_every.nowfun = prev_nowfun
  385. return __inner
  386. return create_patcher
  387. class test_crontab(unittest.TestCase):
  388. def test_every_minute_execution_is_due(self):
  389. last_ran = datetime.now() - timedelta(seconds=61)
  390. due, remaining = EveryMinutePeriodic().is_due(last_ran)
  391. self.assertTrue(due)
  392. self.assertEquals(remaining, 1)
  393. def test_every_minute_execution_is_not_due(self):
  394. last_ran = datetime.now() - timedelta(seconds=30)
  395. due, remaining = EveryMinutePeriodic().is_due(last_ran)
  396. self.assertFalse(due)
  397. self.assertEquals(remaining, 1)
  398. @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 30))
  399. def test_every_hour_execution_is_due(self):
  400. due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 6, 30))
  401. self.assertTrue(due)
  402. self.assertEquals(remaining, 1)
  403. @patch_crontab_nowfun(HourlyPeriodic, datetime(2010, 5, 10, 10, 29))
  404. def test_every_hour_execution_is_not_due(self):
  405. due, remaining = HourlyPeriodic().is_due(datetime(2010, 5, 10, 6, 30))
  406. self.assertFalse(due)
  407. self.assertEquals(remaining, 1)
  408. @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 7, 30))
  409. def test_daily_execution_is_due(self):
  410. due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 9, 7, 30))
  411. self.assertTrue(due)
  412. self.assertEquals(remaining, 1)
  413. @patch_crontab_nowfun(DailyPeriodic, datetime(2010, 5, 10, 10, 30))
  414. def test_daily_execution_is_not_due(self):
  415. due, remaining = DailyPeriodic().is_due(datetime(2010, 5, 10, 6, 29))
  416. self.assertFalse(due)
  417. self.assertEquals(remaining, 1)
  418. @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 6, 7, 30))
  419. def test_weekly_execution_is_due(self):
  420. due, remaining = WeeklyPeriodic().is_due(datetime(2010, 4, 30, 7, 30))
  421. self.assertTrue(due)
  422. self.assertEquals(remaining, 1)
  423. @patch_crontab_nowfun(WeeklyPeriodic, datetime(2010, 5, 7, 10, 30))
  424. def test_weekly_execution_is_not_due(self):
  425. due, remaining = WeeklyPeriodic().is_due(datetime(2010, 4, 30, 6, 29))
  426. self.assertFalse(due)
  427. self.assertEquals(remaining, 1)