123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- from __future__ import absolute_import
- from datetime import datetime, timedelta
- from kombu import Queue
- from celery import Task
- from celery import group
- from celery.app.task import _reprtask
- from celery.exceptions import Ignore, Retry
- from celery.five import items, range, string_t
- from celery.result import EagerResult
- from celery.utils import uuid
- from celery.utils.timeutils import parse_iso8601
- from celery.tests.case import (
- AppCase, ContextMock, Mock, depends_on_current_app, patch,
- )
- def return_True(*args, **kwargs):
- # Task run functions can't be closures/lambdas, as they're pickled.
- return True
- def raise_exception(self, **kwargs):
- raise Exception('%s error' % self.__class__)
- class MockApplyTask(Task):
- abstract = True
- applied = 0
- def run(self, x, y):
- return x * y
- def apply_async(self, *args, **kwargs):
- self.applied += 1
- class TasksCase(AppCase):
- def setup(self):
- self.mytask = self.app.task(shared=False)(return_True)
- @self.app.task(bind=True, count=0, shared=False)
- def increment_counter(self, increment_by=1):
- self.count += increment_by or 1
- return self.count
- self.increment_counter = increment_counter
- @self.app.task(shared=False)
- def raising():
- raise KeyError('foo')
- self.raising = raising
- @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
- def retry_task(self, arg1, arg2, kwarg=1, max_retries=None, care=True):
- self.iterations += 1
- rmax = self.max_retries if max_retries is None else max_retries
- assert repr(self.request)
- retries = self.request.retries
- if care and retries >= rmax:
- return arg1
- else:
- raise self.retry(countdown=0, max_retries=rmax)
- self.retry_task = retry_task
- @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
- def retry_task_noargs(self, **kwargs):
- self.iterations += 1
- if self.request.retries >= 3:
- return 42
- else:
- raise self.retry(countdown=0)
- self.retry_task_noargs = retry_task_noargs
- @self.app.task(bind=True, max_retries=3, iterations=0,
- base=MockApplyTask, shared=False)
- def retry_task_mockapply(self, arg1, arg2, kwarg=1):
- self.iterations += 1
- retries = self.request.retries
- if retries >= 3:
- return arg1
- raise self.retry(countdown=0)
- self.retry_task_mockapply = retry_task_mockapply
- @self.app.task(bind=True, max_retries=3, iterations=0, shared=False)
- def retry_task_customexc(self, arg1, arg2, kwarg=1, **kwargs):
- self.iterations += 1
- retries = self.request.retries
- if retries >= 3:
- return arg1 + kwarg
- else:
- try:
- raise MyCustomException('Elaine Marie Benes')
- except MyCustomException as exc:
- kwargs.update(kwarg=kwarg)
- raise self.retry(countdown=0, exc=exc)
- self.retry_task_customexc = retry_task_customexc
- @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
- shared=False)
- def autoretry_task_no_kwargs(self, a, b):
- self.iterations += 1
- return a/b
- self.autoretry_task_no_kwargs = autoretry_task_no_kwargs
- @self.app.task(bind=True, autoretry_for=(ZeroDivisionError,),
- retry_kwargs={'max_retries': 5}, shared=False)
- def autoretry_task(self, a, b):
- self.iterations += 1
- return a/b
- self.autoretry_task = autoretry_task
- class MyCustomException(Exception):
- """Random custom exception."""
- class test_task_retries(TasksCase):
- def test_retry(self):
- self.retry_task.max_retries = 3
- self.retry_task.iterations = 0
- self.retry_task.apply([0xFF, 0xFFFF])
- self.assertEqual(self.retry_task.iterations, 4)
- self.retry_task.max_retries = 3
- self.retry_task.iterations = 0
- self.retry_task.apply([0xFF, 0xFFFF], {'max_retries': 10})
- self.assertEqual(self.retry_task.iterations, 11)
- def test_retry_no_args(self):
- self.retry_task_noargs.max_retries = 3
- self.retry_task_noargs.iterations = 0
- self.retry_task_noargs.apply(propagate=True).get()
- self.assertEqual(self.retry_task_noargs.iterations, 4)
- def test_signature_from_request__passes_headers(self):
- self.retry_task.push_request()
- self.retry_task.request.headers = {'custom': 10.1}
- sig = self.retry_task.signature_from_request()
- self.assertEqual(sig.options['headers']['custom'], 10.1)
- def test_signature_from_request__delivery_info(self):
- self.retry_task.push_request()
- self.retry_task.request.delivery_info = {
- 'exchange': 'testex',
- 'routing_key': 'testrk',
- }
- sig = self.retry_task.signature_from_request()
- self.assertEqual(sig.options['exchange'], 'testex')
- self.assertEqual(sig.options['routing_key'], 'testrk')
- def test_retry_kwargs_can_be_empty(self):
- self.retry_task_mockapply.push_request()
- try:
- with self.assertRaises(Retry):
- import sys
- try:
- sys.exc_clear()
- except AttributeError:
- pass
- self.retry_task_mockapply.retry(args=[4, 4], kwargs=None)
- finally:
- self.retry_task_mockapply.pop_request()
- def test_retry_not_eager(self):
- self.retry_task_mockapply.push_request()
- try:
- self.retry_task_mockapply.request.called_directly = False
- exc = Exception('baz')
- try:
- self.retry_task_mockapply.retry(
- args=[4, 4], kwargs={'task_retries': 0},
- exc=exc, throw=False,
- )
- self.assertTrue(self.retry_task_mockapply.applied)
- finally:
- self.retry_task_mockapply.applied = 0
- try:
- with self.assertRaises(Retry):
- self.retry_task_mockapply.retry(
- args=[4, 4], kwargs={'task_retries': 0},
- exc=exc, throw=True)
- self.assertTrue(self.retry_task_mockapply.applied)
- finally:
- self.retry_task_mockapply.applied = 0
- finally:
- self.retry_task_mockapply.pop_request()
- def test_retry_with_kwargs(self):
- self.retry_task_customexc.max_retries = 3
- self.retry_task_customexc.iterations = 0
- self.retry_task_customexc.apply([0xFF, 0xFFFF], {'kwarg': 0xF})
- self.assertEqual(self.retry_task_customexc.iterations, 4)
- def test_retry_with_custom_exception(self):
- self.retry_task_customexc.max_retries = 2
- self.retry_task_customexc.iterations = 0
- result = self.retry_task_customexc.apply(
- [0xFF, 0xFFFF], {'kwarg': 0xF},
- )
- with self.assertRaises(MyCustomException):
- result.get()
- self.assertEqual(self.retry_task_customexc.iterations, 3)
- def test_max_retries_exceeded(self):
- self.retry_task.max_retries = 2
- self.retry_task.iterations = 0
- result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
- with self.assertRaises(self.retry_task.MaxRetriesExceededError):
- result.get()
- self.assertEqual(self.retry_task.iterations, 3)
- self.retry_task.max_retries = 1
- self.retry_task.iterations = 0
- result = self.retry_task.apply([0xFF, 0xFFFF], {'care': False})
- with self.assertRaises(self.retry_task.MaxRetriesExceededError):
- result.get()
- self.assertEqual(self.retry_task.iterations, 2)
- def test_autoretry_no_kwargs(self):
- self.autoretry_task_no_kwargs.max_retries = 3
- self.autoretry_task_no_kwargs.iterations = 0
- self.autoretry_task_no_kwargs.apply((1, 0))
- self.assertEqual(self.autoretry_task_no_kwargs.iterations, 4)
- def test_autoretry(self):
- self.autoretry_task.max_retries = 3
- self.autoretry_task.iterations = 0
- self.autoretry_task.apply((1, 0))
- self.assertEqual(self.autoretry_task.iterations, 6)
- class test_canvas_utils(TasksCase):
- def test_si(self):
- self.assertTrue(self.retry_task.si())
- self.assertTrue(self.retry_task.si().immutable)
- def test_chunks(self):
- self.assertTrue(self.retry_task.chunks(range(100), 10))
- def test_map(self):
- self.assertTrue(self.retry_task.map(range(100)))
- def test_starmap(self):
- self.assertTrue(self.retry_task.starmap(range(100)))
- def test_on_success(self):
- self.retry_task.on_success(1, 1, (), {})
- class test_tasks(TasksCase):
- def now(self):
- return self.app.now()
- @depends_on_current_app
- def test_unpickle_task(self):
- import pickle
- @self.app.task(shared=True)
- def xxx():
- pass
- self.assertIs(pickle.loads(pickle.dumps(xxx)), xxx.app.tasks[xxx.name])
- @patch('celery.app.task.current_app')
- @depends_on_current_app
- def test_bind__no_app(self, current_app):
- class XTask(Task):
- _app = None
- XTask._app = None
- XTask.__bound__ = False
- XTask.bind = Mock(name='bind')
- self.assertIs(XTask.app, current_app)
- XTask.bind.assert_called_with(current_app)
- def test_reprtask__no_fmt(self):
- self.assertTrue(_reprtask(self.mytask))
- def test_AsyncResult(self):
- task_id = uuid()
- result = self.retry_task.AsyncResult(task_id)
- self.assertEqual(result.backend, self.retry_task.backend)
- self.assertEqual(result.id, task_id)
- def assertNextTaskDataEqual(self, consumer, presult, task_name,
- test_eta=False, test_expires=False, **kwargs):
- next_task = consumer.queues[0].get(accept=['pickle', 'json'])
- task_data = next_task.decode()
- self.assertEqual(task_data['id'], presult.id)
- self.assertEqual(task_data['task'], task_name)
- task_kwargs = task_data.get('kwargs', {})
- if test_eta:
- self.assertIsInstance(task_data.get('eta'), string_t)
- to_datetime = parse_iso8601(task_data.get('eta'))
- self.assertIsInstance(to_datetime, datetime)
- if test_expires:
- self.assertIsInstance(task_data.get('expires'), string_t)
- to_datetime = parse_iso8601(task_data.get('expires'))
- self.assertIsInstance(to_datetime, datetime)
- for arg_name, arg_value in items(kwargs):
- self.assertEqual(task_kwargs.get(arg_name), arg_value)
- def test_incomplete_task_cls(self):
- class IncompleteTask(Task):
- app = self.app
- name = 'c.unittest.t.itask'
- with self.assertRaises(NotImplementedError):
- IncompleteTask().run()
- def test_task_kwargs_must_be_dictionary(self):
- with self.assertRaises(TypeError):
- self.increment_counter.apply_async([], 'str')
- def test_task_args_must_be_list(self):
- with self.assertRaises(ValueError):
- self.increment_counter.apply_async('s', {})
- def test_regular_task(self):
- self.assertIsInstance(self.mytask, Task)
- self.assertTrue(self.mytask.run())
- self.assertTrue(
- callable(self.mytask), 'Task class is callable()',
- )
- self.assertTrue(self.mytask(), 'Task class runs run() when called')
- with self.app.connection_or_acquire() as conn:
- consumer = self.app.amqp.TaskConsumer(conn)
- with self.assertRaises(NotImplementedError):
- consumer.receive('foo', 'foo')
- consumer.purge()
- self.assertIsNone(consumer.queues[0].get())
- self.app.amqp.TaskConsumer(conn, queues=[Queue('foo')])
- # Without arguments.
- presult = self.mytask.delay()
- self.assertNextTaskDataEqual(consumer, presult, self.mytask.name)
- # With arguments.
- presult2 = self.mytask.apply_async(
- kwargs=dict(name='George Costanza'),
- )
- self.assertNextTaskDataEqual(
- consumer, presult2, self.mytask.name, name='George Costanza',
- )
- # send_task
- sresult = self.app.send_task(self.mytask.name,
- kwargs=dict(name='Elaine M. Benes'))
- self.assertNextTaskDataEqual(
- consumer, sresult, self.mytask.name, name='Elaine M. Benes',
- )
- # With eta.
- presult2 = self.mytask.apply_async(
- kwargs=dict(name='George Costanza'),
- eta=self.now() + timedelta(days=1),
- expires=self.now() + timedelta(days=2),
- )
- self.assertNextTaskDataEqual(
- consumer, presult2, self.mytask.name,
- name='George Costanza', test_eta=True, test_expires=True,
- )
- # With countdown.
- presult2 = self.mytask.apply_async(
- kwargs=dict(name='George Costanza'), countdown=10, expires=12,
- )
- self.assertNextTaskDataEqual(
- consumer, presult2, self.mytask.name,
- name='George Costanza', test_eta=True, test_expires=True,
- )
- # Discarding all tasks.
- consumer.purge()
- self.mytask.apply_async()
- self.assertEqual(consumer.purge(), 1)
- self.assertIsNone(consumer.queues[0].get())
- self.assertFalse(presult.successful())
- self.mytask.backend.mark_as_done(presult.id, result=None)
- self.assertTrue(presult.successful())
- def test_send_event(self):
- mytask = self.mytask._get_current_object()
- mytask.app.events = Mock(name='events')
- mytask.app.events.attach_mock(ContextMock(), 'default_dispatcher')
- mytask.request.id = 'fb'
- mytask.send_event('task-foo', id=3122)
- mytask.app.events.default_dispatcher().send.assert_called_with(
- 'task-foo', uuid='fb', id=3122,
- )
- def test_replace(self):
- sig1 = Mock(name='sig1')
- with self.assertRaises(Ignore):
- self.mytask.replace(sig1)
- def test_replace__group(self):
- c = group([self.mytask.s()], app=self.app)
- c.freeze = Mock(name='freeze')
- c.delay = Mock(name='delay')
- self.mytask.request.id = 'id'
- self.mytask.request.group = 'group'
- self.mytask.request.root_id = 'root_id',
- with self.assertRaises(Ignore):
- self.mytask.replace(c)
- def test_send_error_email_enabled(self):
- mytask = self.increment_counter._get_current_object()
- mytask.send_error_emails = True
- mytask.disable_error_emails = False
- mytask.ErrorMail = Mock(name='ErrorMail')
- context = Mock(name='context')
- exc = Mock(name='context')
- mytask.send_error_email(context, exc, foo=1)
- mytask.ErrorMail.assert_called_with(mytask, foo=1)
- mytask.ErrorMail().send.assert_called_with(context, exc)
- def test_add_trail__no_trail(self):
- mytask = self.increment_counter._get_current_object()
- mytask.trail = False
- mytask.add_trail('foo')
- def test_repr_v2_compat(self):
- self.mytask.__v2_compat__ = True
- self.assertIn('v2 compatible', repr(self.mytask))
- def test_apply_with_self(self):
- @self.app.task(__self__=42, shared=False)
- def tawself(self):
- return self
- self.assertEqual(tawself.apply().get(), 42)
- self.assertEqual(tawself(), 42)
- def test_context_get(self):
- self.mytask.push_request()
- try:
- request = self.mytask.request
- request.foo = 32
- self.assertEqual(request.get('foo'), 32)
- self.assertEqual(request.get('bar', 36), 36)
- request.clear()
- finally:
- self.mytask.pop_request()
- def test_annotate(self):
- with patch('celery.app.task.resolve_all_annotations') as anno:
- anno.return_value = [{'FOO': 'BAR'}]
- @self.app.task(shared=False)
- def task():
- pass
- task.annotate()
- self.assertEqual(task.FOO, 'BAR')
- def test_after_return(self):
- self.mytask.push_request()
- try:
- self.mytask.request.chord = self.mytask.s()
- self.mytask.after_return('SUCCESS', 1.0, 'foobar', (), {}, None)
- self.mytask.request.clear()
- finally:
- self.mytask.pop_request()
- def test_update_state(self):
- @self.app.task(shared=False)
- def yyy():
- pass
- yyy.push_request()
- try:
- tid = uuid()
- yyy.update_state(tid, 'FROBULATING', {'fooz': 'baaz'})
- self.assertEqual(yyy.AsyncResult(tid).status, 'FROBULATING')
- self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
- yyy.request.id = tid
- yyy.update_state(state='FROBUZATING', meta={'fooz': 'baaz'})
- self.assertEqual(yyy.AsyncResult(tid).status, 'FROBUZATING')
- self.assertDictEqual(yyy.AsyncResult(tid).result, {'fooz': 'baaz'})
- finally:
- yyy.pop_request()
- def test_repr(self):
- @self.app.task(shared=False)
- def task_test_repr():
- pass
- self.assertIn('task_test_repr', repr(task_test_repr))
- def test_has___name__(self):
- @self.app.task(shared=False)
- def yyy2():
- pass
- self.assertTrue(yyy2.__name__)
- class test_apply_task(TasksCase):
- def test_apply_throw(self):
- with self.assertRaises(KeyError):
- self.raising.apply(throw=True)
- def test_apply_with_task_eager_propagates(self):
- self.app.conf.task_eager_propagates = True
- with self.assertRaises(KeyError):
- self.raising.apply()
- def test_apply(self):
- self.increment_counter.count = 0
- e = self.increment_counter.apply()
- self.assertIsInstance(e, EagerResult)
- self.assertEqual(e.get(), 1)
- e = self.increment_counter.apply(args=[1])
- self.assertEqual(e.get(), 2)
- e = self.increment_counter.apply(kwargs={'increment_by': 4})
- self.assertEqual(e.get(), 6)
- self.assertTrue(e.successful())
- self.assertTrue(e.ready())
- self.assertTrue(repr(e).startswith('<EagerResult:'))
- f = self.raising.apply()
- self.assertTrue(f.ready())
- self.assertFalse(f.successful())
- self.assertTrue(f.traceback)
- with self.assertRaises(KeyError):
- f.get()
|