123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597 |
- # -*- coding: utf-8 -*-
- import anyjson
- import logging
- import os
- import sys
- import time
- from datetime import datetime, timedelta
- from kombu.transport.base import Message
- from celery import states
- from celery.app import app_or_default
- from celery.concurrency.base import BasePool
- from celery.datastructures import ExceptionInfo
- from celery.task import task as task_dec
- from celery.exceptions import RetryTaskError, NotRegistered, WorkerLostError
- from celery.log import setup_logger
- from celery.result import AsyncResult
- from celery.task.base import Task
- from celery.utils import gen_unique_id
- from celery.worker.job import WorkerTaskTrace, TaskRequest
- from celery.worker.job import execute_and_trace, AlreadyExecutedError
- from celery.worker.job import InvalidTaskError
- from celery.worker.state import revoked
- from celery.tests.compat import catch_warnings
- from celery.tests.utils import unittest
- from celery.tests.utils import execute_context, StringIO
- scratch = {"ACK": False}
- some_kwargs_scratchpad = {}
- def jail(task_id, task_name, args, kwargs):
- return WorkerTaskTrace(task_name, task_id, args, kwargs)()
- def on_ack():
- scratch["ACK"] = True
- @task_dec(accept_magic_kwargs=True)
- def mytask(i, **kwargs):
- return i ** i
- @task_dec # traverses coverage for decorator without parens
- def mytask_no_kwargs(i):
- return i ** i
- class MyTaskIgnoreResult(Task):
- ignore_result = True
- def run(self, i):
- return i ** i
- @task_dec(accept_magic_kwargs=True)
- def mytask_some_kwargs(i, logfile):
- some_kwargs_scratchpad["logfile"] = logfile
- return i ** i
- @task_dec(accept_magic_kwargs=True)
- def mytask_raising(i, **kwargs):
- raise KeyError(i)
- class test_RetryTaskError(unittest.TestCase):
- def test_retry_task_error(self):
- try:
- raise Exception("foo")
- except Exception, exc:
- ret = RetryTaskError("Retrying task", exc)
- self.assertEqual(ret.exc, exc)
- class test_WorkerTaskTrace(unittest.TestCase):
- def test_execute_jail_success(self):
- ret = jail(gen_unique_id(), mytask.name, [2], {})
- self.assertEqual(ret, 4)
- def test_marked_as_started(self):
- mytask.track_started = True
- class Backend(mytask.backend.__class__):
- _started = []
- def mark_as_started(self, uuid, *args, **kwargs):
- self._started.append(uuid)
- prev, mytask.backend = mytask.backend, Backend()
- try:
- uuid = gen_unique_id()
- jail(uuid, mytask.name, [2], {})
- self.assertIn(uuid, Backend._started)
- mytask.ignore_result = True
- uuid = gen_unique_id()
- jail(uuid, mytask.name, [2], {})
- self.assertNotIn(uuid, Backend._started)
- finally:
- mytask.backend = prev
- mytask.track_started = False
- mytask.ignore_result = False
- def test_execute_jail_failure(self):
- ret = jail(gen_unique_id(), mytask_raising.name,
- [4], {})
- self.assertIsInstance(ret, ExceptionInfo)
- self.assertTupleEqual(ret.exception.args, (4, ))
- def test_execute_ignore_result(self):
- task_id = gen_unique_id()
- ret = jail(id, MyTaskIgnoreResult.name,
- [4], {})
- self.assertEqual(ret, 256)
- self.assertFalse(AsyncResult(task_id).ready())
- class MockEventDispatcher(object):
- def __init__(self):
- self.sent = []
- def send(self, event, **fields):
- self.sent.append(event)
- class test_TaskRequest(unittest.TestCase):
- def test_task_wrapper_repr(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- self.assertTrue(repr(tw))
- def test_sets_store_errors(self):
- mytask.ignore_result = True
- try:
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- self.assertFalse(tw._store_errors)
- mytask.store_errors_even_if_ignored = True
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- self.assertTrue(tw._store_errors)
- finally:
- mytask.ignore_result = False
- mytask.store_errors_even_if_ignored = False
- def test_send_event(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.eventer = MockEventDispatcher()
- tw.send_event("task-frobulated")
- self.assertIn("task-frobulated", tw.eventer.sent)
- def test_on_retry(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.eventer = MockEventDispatcher()
- try:
- raise RetryTaskError("foo", KeyError("moofoobar"))
- except:
- einfo = ExceptionInfo(sys.exc_info())
- tw.on_failure(einfo)
- self.assertIn("task-retried", tw.eventer.sent)
- def test_revoked_expires_expired(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.expires = datetime.now() - timedelta(days=1)
- tw.revoked()
- self.assertIn(tw.task_id, revoked)
- self.assertEqual(mytask.backend.get_status(tw.task_id),
- states.REVOKED)
- def test_revoked_expires_not_expired(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.expires = datetime.now() + timedelta(days=1)
- tw.revoked()
- self.assertNotIn(tw.task_id, revoked)
- self.assertNotEqual(mytask.backend.get_status(tw.task_id),
- states.REVOKED)
- def test_revoked_expires_ignore_result(self):
- mytask.ignore_result = True
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- try:
- tw.expires = datetime.now() - timedelta(days=1)
- tw.revoked()
- self.assertIn(tw.task_id, revoked)
- self.assertNotEqual(mytask.backend.get_status(tw.task_id),
- states.REVOKED)
- finally:
- mytask.ignore_result = False
- def test_send_email(self):
- app = app_or_default()
- old_mail_admins = app.mail_admins
- old_enable_mails = mytask.send_error_emails
- mail_sent = [False]
- def mock_mail_admins(*args, **kwargs):
- mail_sent[0] = True
- app.mail_admins = mock_mail_admins
- mytask.send_error_emails = True
- try:
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- try:
- raise KeyError("moofoobar")
- except:
- einfo = ExceptionInfo(sys.exc_info())
- tw.on_failure(einfo)
- self.assertTrue(mail_sent[0])
- mail_sent[0] = False
- mytask.send_error_emails = False
- tw.on_failure(einfo)
- self.assertFalse(mail_sent[0])
- mail_sent[0] = False
- mytask.send_error_emails = True
- mytask.error_whitelist = [KeyError]
- tw.on_failure(einfo)
- self.assertTrue(mail_sent[0])
- mail_sent[0] = False
- mytask.send_error_emails = True
- mytask.error_whitelist = [SyntaxError]
- tw.on_failure(einfo)
- self.assertFalse(mail_sent[0])
- finally:
- app.mail_admins = old_mail_admins
- mytask.send_error_emails = old_enable_mails
- mytask.error_whitelist = ()
- def test_already_revoked(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw._already_revoked = True
- self.assertTrue(tw.revoked())
- def test_revoked(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- revoked.add(tw.task_id)
- self.assertTrue(tw.revoked())
- self.assertTrue(tw._already_revoked)
- self.assertTrue(tw.acknowledged)
- def test_execute_does_not_execute_revoked(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- revoked.add(tw.task_id)
- tw.execute()
- def test_execute_acks_late(self):
- mytask_raising.acks_late = True
- tw = TaskRequest(mytask_raising.name, gen_unique_id(), [1], {"f": "x"})
- try:
- tw.execute()
- self.assertTrue(tw.acknowledged)
- finally:
- mytask_raising.acks_late = False
- def test_execute_using_pool_does_not_execute_revoked(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- revoked.add(tw.task_id)
- tw.execute_using_pool(None)
- def test_on_accepted_acks_early(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
- self.assertTrue(tw.acknowledged)
- def test_on_accepted_acks_late(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- mytask.acks_late = True
- try:
- tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
- self.assertFalse(tw.acknowledged)
- finally:
- mytask.acks_late = False
- def test_on_success_acks_early(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.time_start = 1
- tw.on_success(42)
- self.assertFalse(tw.acknowledged)
- def test_on_success_acks_late(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.time_start = 1
- mytask.acks_late = True
- try:
- tw.on_success(42)
- self.assertTrue(tw.acknowledged)
- finally:
- mytask.acks_late = False
- def test_on_failure_WorkerLostError(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- try:
- raise WorkerLostError("do re mi")
- except WorkerLostError:
- exc_info = ExceptionInfo(sys.exc_info())
- tw.on_failure(exc_info)
- self.assertEqual(mytask.backend.get_status(tw.task_id),
- states.FAILURE)
- mytask.ignore_result = True
- try:
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.on_failure(exc_info)
- self.assertEqual(mytask.backend.get_status(tw.task_id),
- states.PENDING)
- finally:
- mytask.ignore_result = False
- def test_on_failure_acks_late(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.time_start = 1
- mytask.acks_late = True
- try:
- try:
- raise KeyError("foo")
- except KeyError:
- exc_info = ExceptionInfo(sys.exc_info())
- tw.on_failure(exc_info)
- self.assertTrue(tw.acknowledged)
- finally:
- mytask.acks_late = False
- def test_from_message_invalid_kwargs(self):
- body = dict(task="foo", id=1, args=(), kwargs="foo")
- self.assertRaises(InvalidTaskError,
- TaskRequest.from_message, None, body)
- def test_on_timeout(self):
- class MockLogger(object):
- def __init__(self):
- self.warnings = []
- self.errors = []
- def warning(self, msg, *args, **kwargs):
- self.warnings.append(msg)
- def error(self, msg, *args, **kwargs):
- self.errors.append(msg)
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.logger = MockLogger()
- tw.on_timeout(soft=True)
- self.assertIn("Soft time limit exceeded", tw.logger.warnings[0])
- tw.on_timeout(soft=False)
- self.assertIn("Hard time limit exceeded", tw.logger.errors[0])
- self.assertEqual(mytask.backend.get_status(tw.task_id),
- states.FAILURE)
- mytask.ignore_result = True
- try:
- tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
- tw.logger = MockLogger()
- finally:
- mytask.ignore_result = False
- tw.on_timeout(soft=True)
- self.assertEqual(mytask.backend.get_status(tw.task_id),
- states.PENDING)
- def test_execute_and_trace(self):
- res = execute_and_trace(mytask.name, gen_unique_id(), [4], {})
- self.assertEqual(res, 4 ** 4)
- def test_execute_safe_catches_exception(self):
- old_exec = WorkerTaskTrace.execute
- def _error_exec(self, *args, **kwargs):
- raise KeyError("baz")
- WorkerTaskTrace.execute = _error_exec
- try:
- def with_catch_warnings(log):
- res = execute_and_trace(mytask.name, gen_unique_id(),
- [4], {})
- self.assertIsInstance(res, ExceptionInfo)
- self.assertTrue(log)
- self.assertIn("Exception outside", log[0].message.args[0])
- self.assertIn("KeyError", log[0].message.args[0])
- context = catch_warnings(record=True)
- execute_context(context, with_catch_warnings)
- finally:
- WorkerTaskTrace.execute = old_exec
- def create_exception(self, exc):
- try:
- raise exc
- except exc.__class__:
- return sys.exc_info()
- def test_worker_task_trace_handle_retry(self):
- from celery.exceptions import RetryTaskError
- uuid = gen_unique_id()
- w = WorkerTaskTrace(mytask.name, uuid, [4], {})
- type_, value_, tb_ = self.create_exception(ValueError("foo"))
- type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_),
- exc=value_))
- w._store_errors = False
- w.handle_retry(value_, type_, tb_, "")
- self.assertEqual(mytask.backend.get_status(uuid), states.PENDING)
- w._store_errors = True
- w.handle_retry(value_, type_, tb_, "")
- self.assertEqual(mytask.backend.get_status(uuid), states.RETRY)
- def test_worker_task_trace_handle_failure(self):
- uuid = gen_unique_id()
- w = WorkerTaskTrace(mytask.name, uuid, [4], {})
- type_, value_, tb_ = self.create_exception(ValueError("foo"))
- w._store_errors = False
- w.handle_failure(value_, type_, tb_, "")
- self.assertEqual(mytask.backend.get_status(uuid), states.PENDING)
- w._store_errors = True
- w.handle_failure(value_, type_, tb_, "")
- self.assertEqual(mytask.backend.get_status(uuid), states.FAILURE)
- def test_executed_bit(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
- self.assertFalse(tw.executed)
- tw._set_executed_bit()
- self.assertTrue(tw.executed)
- self.assertRaises(AlreadyExecutedError, tw._set_executed_bit)
- def test_task_wrapper_mail_attrs(self):
- tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
- x = tw.success_msg % {"name": tw.task_name,
- "id": tw.task_id,
- "return_value": 10,
- "runtime": 0.3641}
- self.assertTrue(x)
- x = tw.error_msg % {"name": tw.task_name,
- "id": tw.task_id,
- "exc": "FOOBARBAZ",
- "traceback": "foobarbaz"}
- self.assertTrue(x)
- x = tw.email_subject % {"name": tw.task_name,
- "id": tw.task_id,
- "exc": "FOOBARBAZ",
- "hostname": "lana"}
- self.assertTrue(x)
- def test_from_message(self):
- body = {"task": mytask.name, "id": gen_unique_id(),
- "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
- m = Message(None, body=anyjson.serialize(body), backend="foo",
- content_type="application/json",
- content_encoding="utf-8")
- tw = TaskRequest.from_message(m, m.decode())
- self.assertIsInstance(tw, TaskRequest)
- self.assertEqual(tw.task_name, body["task"])
- self.assertEqual(tw.task_id, body["id"])
- self.assertEqual(tw.args, body["args"])
- self.assertEqual(tw.kwargs.keys()[0],
- u"æØåveéðƒeæ".encode("utf-8"))
- self.assertNotIsInstance(tw.kwargs.keys()[0], unicode)
- self.assertTrue(tw.logger)
- def test_from_message_nonexistant_task(self):
- body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(),
- "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
- m = Message(None, body=anyjson.serialize(body), backend="foo",
- content_type="application/json",
- content_encoding="utf-8")
- self.assertRaises(NotRegistered, TaskRequest.from_message,
- m, m.decode())
- def test_execute(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
- self.assertEqual(tw.execute(), 256)
- meta = mytask.backend.get_task_meta(tid)
- self.assertEqual(meta["result"], 256)
- self.assertEqual(meta["status"], states.SUCCESS)
- def test_execute_success_no_kwargs(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask_no_kwargs.name, tid, [4], {})
- self.assertEqual(tw.execute(), 256)
- meta = mytask_no_kwargs.backend.get_task_meta(tid)
- self.assertEqual(meta["result"], 256)
- self.assertEqual(meta["status"], states.SUCCESS)
- def test_execute_success_some_kwargs(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask_some_kwargs.name, tid, [4], {})
- self.assertEqual(tw.execute(logfile="foobaz.log"), 256)
- meta = mytask_some_kwargs.backend.get_task_meta(tid)
- self.assertEqual(some_kwargs_scratchpad.get("logfile"), "foobaz.log")
- self.assertEqual(meta["result"], 256)
- self.assertEqual(meta["status"], states.SUCCESS)
- def test_execute_ack(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask.name, tid, [4], {"f": "x"},
- on_ack=on_ack)
- self.assertEqual(tw.execute(), 256)
- meta = mytask.backend.get_task_meta(tid)
- self.assertTrue(scratch["ACK"])
- self.assertEqual(meta["result"], 256)
- self.assertEqual(meta["status"], states.SUCCESS)
- def test_execute_fail(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask_raising.name, tid, [4], {"f": "x"})
- self.assertIsInstance(tw.execute(), ExceptionInfo)
- meta = mytask_raising.backend.get_task_meta(tid)
- self.assertEqual(meta["status"], states.FAILURE)
- self.assertIsInstance(meta["result"], KeyError)
- def test_execute_using_pool(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
- class MockPool(BasePool):
- target = None
- args = None
- kwargs = None
- def __init__(self, *args, **kwargs):
- pass
- def apply_async(self, target, args=None, kwargs=None,
- *margs, **mkwargs):
- self.target = target
- self.args = args
- self.kwargs = kwargs
- p = MockPool()
- tw.execute_using_pool(p)
- self.assertTrue(p.target)
- self.assertEqual(p.args[0], mytask.name)
- self.assertEqual(p.args[1], tid)
- self.assertEqual(p.args[2], [4])
- self.assertIn("f", p.args[3])
- self.assertIn([4], p.args)
- def test_default_kwargs(self):
- tid = gen_unique_id()
- tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
- self.assertDictEqual(
- tw.extend_with_default_kwargs(10, "some_logfile"), {
- "f": "x",
- "logfile": "some_logfile",
- "loglevel": 10,
- "task_id": tw.task_id,
- "task_retries": 0,
- "task_is_eager": False,
- "delivery_info": {},
- "task_name": tw.task_name})
- def _test_on_failure(self, exception):
- app = app_or_default()
- tid = gen_unique_id()
- tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
- try:
- raise exception
- except Exception:
- exc_info = ExceptionInfo(sys.exc_info())
- logfh = StringIO()
- tw.logger.handlers = []
- tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO,
- root=False)
- app.conf.CELERY_SEND_TASK_ERROR_EMAILS = True
- tw.on_failure(exc_info)
- logvalue = logfh.getvalue()
- self.assertIn(mytask.name, logvalue)
- self.assertIn(tid, logvalue)
- self.assertIn("ERROR", logvalue)
- app.conf.CELERY_SEND_TASK_ERROR_EMAILS = False
- def test_on_failure(self):
- self._test_on_failure(Exception("Inside unit tests"))
- def test_on_failure_unicode_exception(self):
- self._test_on_failure(Exception(u"Бобры атакуют"))
- def test_on_failure_utf8_exception(self):
- self._test_on_failure(Exception(
- u"Бобры атакуют".encode('utf8')))
|