# -*- coding: utf-8 -*- import sys import logging import unittest import simplejson from StringIO import StringIO from django.core import cache from carrot.backends.base import BaseMessage from celery import states from celery.log import setup_logger from celery.task.base import Task from celery.utils import gen_unique_id from celery.models import TaskMeta from celery.result import AsyncResult from celery.worker.job import WorkerTaskTrace, TaskWrapper from celery.worker.pool import TaskPool from celery.exceptions import RetryTaskError, NotRegistered from celery.decorators import task as task_dec from celery.datastructures import ExceptionInfo from testunits.utils import execute_context from testunits.compat import catch_warnings scratch = {"ACK": False} some_kwargs_scratchpad = {} def jail(task_id, task_name, args, kwargs): return WorkerTaskTrace(task_name, task_id, args, kwargs)() def on_ack(): scratch["ACK"] = True @task_dec() def mytask(i, **kwargs): return i ** i @task_dec() def mytask_no_kwargs(i): return i ** i class MyTaskIgnoreResult(Task): ignore_result = True def run(self, i): return i ** i @task_dec() def mytask_some_kwargs(i, logfile): some_kwargs_scratchpad["logfile"] = logfile return i ** i @task_dec() def mytask_raising(i, **kwargs): raise KeyError(i) @task_dec() def get_db_connection(i, **kwargs): from django.db import connection return id(connection) get_db_connection.ignore_result = True class TestRetryTaskError(unittest.TestCase): def test_retry_task_error(self): try: raise Exception("foo") except Exception, exc: ret = RetryTaskError("Retrying task", exc) self.assertEquals(ret.exc, exc) class TestJail(unittest.TestCase): def test_execute_jail_success(self): ret = jail(gen_unique_id(), mytask.name, [2], {}) self.assertEquals(ret, 4) def test_execute_jail_failure(self): ret = jail(gen_unique_id(), mytask_raising.name, [4], {}) self.assertTrue(isinstance(ret, ExceptionInfo)) self.assertEquals(ret.exception.args, (4, )) def test_execute_ignore_result(self): task_id = gen_unique_id() ret = jail(id, MyTaskIgnoreResult.name, [4], {}) self.assertTrue(ret, 8) self.assertFalse(AsyncResult(task_id).ready()) def test_django_db_connection_is_closed(self): from django.db import connection connection._was_closed = False old_connection_close = connection.close def monkeypatched_connection_close(*args, **kwargs): connection._was_closed = True return old_connection_close(*args, **kwargs) connection.close = monkeypatched_connection_close try: jail(gen_unique_id(), get_db_connection.name, [2], {}) self.assertTrue(connection._was_closed) finally: connection.close = old_connection_close def test_django_cache_connection_is_closed(self): old_cache_close = getattr(cache.cache, "close", None) old_backend = cache.settings.CACHE_BACKEND cache.settings.CACHE_BACKEND = "libmemcached" cache._was_closed = False old_cache_parse_backend = getattr(cache, "parse_backend_uri", None) if old_cache_parse_backend: # checks to make sure attr exists delattr(cache, 'parse_backend_uri') def monkeypatched_cache_close(*args, **kwargs): cache._was_closed = True cache.cache.close = monkeypatched_cache_close jail(gen_unique_id(), mytask.name, [4], {}) self.assertTrue(cache._was_closed) cache.cache.close = old_cache_close cache.settings.CACHE_BACKEND = old_backend if old_cache_parse_backend: cache.parse_backend_uri = old_cache_parse_backend def test_django_cache_connection_is_closed_django_1_1(self): old_cache_close = getattr(cache.cache, "close", None) old_backend = cache.settings.CACHE_BACKEND cache.settings.CACHE_BACKEND = "libmemcached" cache._was_closed = False old_cache_parse_backend = getattr(cache, "parse_backend_uri", None) cache.parse_backend_uri = lambda uri: ["libmemcached", "1", "2"] def monkeypatched_cache_close(*args, **kwargs): cache._was_closed = True cache.cache.close = monkeypatched_cache_close jail(gen_unique_id(), mytask.name, [4], {}) self.assertTrue(cache._was_closed) cache.cache.close = old_cache_close cache.settings.CACHE_BACKEND = old_backend if old_cache_parse_backend: cache.parse_backend_uri = old_cache_parse_backend else: del(cache.parse_backend_uri) class MockEventDispatcher(object): def __init__(self): self.sent = [] def send(self, event): self.sent.append(event) class TestTaskWrapper(unittest.TestCase): def test_task_wrapper_repr(self): tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"}) self.assertTrue(repr(tw)) def test_send_event(self): tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"}) tw.eventer = MockEventDispatcher() tw.send_event("task-frobulated") self.assertTrue("task-frobulated" in tw.eventer.sent) def test_send_email(self): from celery import conf from celery.worker import job old_mail_admins = job.mail_admins old_enable_mails = conf.CELERY_SEND_TASK_ERROR_EMAILS mail_sent = [False] def mock_mail_admins(*args, **kwargs): mail_sent[0] = True job.mail_admins = mock_mail_admins conf.CELERY_SEND_TASK_ERROR_EMAILS = True try: tw = TaskWrapper(mytask.name, gen_unique_id(), [1], {"f": "x"}) try: raise KeyError("foo") except KeyError, exc: einfo = ExceptionInfo(sys.exc_info()) tw.on_failure(einfo) self.assertTrue(mail_sent[0]) mail_sent[0] = False conf.CELERY_SEND_TASK_ERROR_EMAILS = False tw.on_failure(einfo) self.assertFalse(mail_sent[0]) finally: job.mail_admins = old_mail_admins conf.CELERY_SEND_TASK_ERROR_EMAILS = old_enable_mails def test_execute_and_trace(self): from celery.worker.job import execute_and_trace res = execute_and_trace(mytask.name, gen_unique_id(), [4], {}) self.assertEquals(res, 4 ** 4) def test_execute_safe_catches_exception(self): from celery.worker.job import execute_and_trace, WorkerTaskTrace old_exec = WorkerTaskTrace.execute def _error_exec(self, *args, **kwargs): raise KeyError("baz") WorkerTaskTrace.execute = _error_exec try: def with_catch_warnings(log): res = execute_and_trace(mytask.name, gen_unique_id(), [4], {}) self.assertTrue(isinstance(res, ExceptionInfo)) self.assertTrue(log) self.assertTrue("Exception outside" in log[0].message.args[0]) self.assertTrue("KeyError" in log[0].message.args[0]) context = catch_warnings(record=True) execute_context(context, with_catch_warnings) finally: WorkerTaskTrace.execute = old_exec def create_exception(self, exc): try: raise exc except exc.__class__, thrown: return sys.exc_info() def test_worker_task_trace_handle_retry(self): from celery.exceptions import RetryTaskError uuid = gen_unique_id() w = WorkerTaskTrace(mytask.name, uuid, [4], {}) type_, value_, tb_ = self.create_exception(ValueError("foo")) type_, value_, tb_ = self.create_exception(RetryTaskError(str(value_), exc=value_)) w._store_errors = False w.handle_retry(value_, type_, tb_, "") self.assertEquals(mytask.backend.get_status(uuid), states.PENDING) w._store_errors = True w.handle_retry(value_, type_, tb_, "") self.assertEquals(mytask.backend.get_status(uuid), states.RETRY) def test_worker_task_trace_handle_failure(self): from celery.worker.job import WorkerTaskTrace uuid = gen_unique_id() w = WorkerTaskTrace(mytask.name, uuid, [4], {}) type_, value_, tb_ = self.create_exception(ValueError("foo")) w._store_errors = False w.handle_failure(value_, type_, tb_, "") self.assertEquals(mytask.backend.get_status(uuid), states.PENDING) w._store_errors = True w.handle_failure(value_, type_, tb_, "") self.assertEquals(mytask.backend.get_status(uuid), states.FAILURE) def test_executed_bit(self): from celery.worker.job import AlreadyExecutedError tw = TaskWrapper(mytask.name, gen_unique_id(), [], {}) self.assertFalse(tw.executed) tw._set_executed_bit() self.assertTrue(tw.executed) self.assertRaises(AlreadyExecutedError, tw._set_executed_bit) def test_task_wrapper_mail_attrs(self): tw = TaskWrapper(mytask.name, gen_unique_id(), [], {}) x = tw.success_msg % {"name": tw.task_name, "id": tw.task_id, "return_value": 10} self.assertTrue(x) x = tw.fail_msg % {"name": tw.task_name, "id": tw.task_id, "exc": "FOOBARBAZ", "traceback": "foobarbaz"} self.assertTrue(x) x = tw.fail_email_subject % {"name": tw.task_name, "id": tw.task_id, "exc": "FOOBARBAZ", "hostname": "lana"} self.assertTrue(x) def test_from_message(self): body = {"task": mytask.name, "id": gen_unique_id(), "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}} m = BaseMessage(body=simplejson.dumps(body), backend="foo", content_type="application/json", content_encoding="utf-8") tw = TaskWrapper.from_message(m, m.decode()) self.assertTrue(isinstance(tw, TaskWrapper)) self.assertEquals(tw.task_name, body["task"]) self.assertEquals(tw.task_id, body["id"]) self.assertEquals(tw.args, body["args"]) self.assertEquals(tw.kwargs.keys()[0], u"æØåveéðƒeæ".encode("utf-8")) self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode)) self.assertTrue(tw.logger) def test_from_message_nonexistant_task(self): body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(), "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}} m = BaseMessage(body=simplejson.dumps(body), backend="foo", content_type="application/json", content_encoding="utf-8") self.assertRaises(NotRegistered, TaskWrapper.from_message, m, m.decode()) def test_execute(self): tid = gen_unique_id() tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"}) self.assertEquals(tw.execute(), 256) meta = TaskMeta.objects.get(task_id=tid) self.assertEquals(meta.result, 256) self.assertEquals(meta.status, states.SUCCESS) def test_execute_success_no_kwargs(self): tid = gen_unique_id() tw = TaskWrapper(mytask_no_kwargs.name, tid, [4], {}) self.assertEquals(tw.execute(), 256) meta = TaskMeta.objects.get(task_id=tid) self.assertEquals(meta.result, 256) self.assertEquals(meta.status, states.SUCCESS) def test_execute_success_some_kwargs(self): tid = gen_unique_id() tw = TaskWrapper(mytask_some_kwargs.name, tid, [4], {}) self.assertEquals(tw.execute(logfile="foobaz.log"), 256) meta = TaskMeta.objects.get(task_id=tid) self.assertEquals(some_kwargs_scratchpad.get("logfile"), "foobaz.log") self.assertEquals(meta.result, 256) self.assertEquals(meta.status, states.SUCCESS) def test_execute_ack(self): tid = gen_unique_id() tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"}, on_ack=on_ack) self.assertEquals(tw.execute(), 256) meta = TaskMeta.objects.get(task_id=tid) self.assertTrue(scratch["ACK"]) self.assertEquals(meta.result, 256) self.assertEquals(meta.status, states.SUCCESS) def test_execute_fail(self): tid = gen_unique_id() tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"}) self.assertTrue(isinstance(tw.execute(), ExceptionInfo)) meta = TaskMeta.objects.get(task_id=tid) self.assertEquals(meta.status, states.FAILURE) self.assertTrue(isinstance(meta.result, KeyError)) def test_execute_using_pool(self): tid = gen_unique_id() tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"}) p = TaskPool(2) p.start() asyncres = tw.execute_using_pool(p) self.assertTrue(asyncres.get(), 256) p.stop() def test_default_kwargs(self): tid = gen_unique_id() tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"}) self.assertEquals(tw.extend_with_default_kwargs(10, "some_logfile"), { "f": "x", "logfile": "some_logfile", "loglevel": 10, "task_id": tw.task_id, "task_retries": 0, "task_is_eager": False, "delivery_info": {}, "task_name": tw.task_name}) def test_on_failure(self): tid = gen_unique_id() tw = TaskWrapper(mytask.name, tid, [4], {"f": "x"}) try: raise Exception("Inside unit tests") except Exception: exc_info = ExceptionInfo(sys.exc_info()) logfh = StringIO() tw.logger.handlers = [] tw.logger = setup_logger(logfile=logfh, loglevel=logging.INFO) from celery import conf conf.CELERY_SEND_TASK_ERROR_EMAILS = True tw.on_failure(exc_info) logvalue = logfh.getvalue() self.assertTrue(mytask.name in logvalue) self.assertTrue(tid in logvalue) self.assertTrue("ERROR" in logvalue) conf.CELERY_SEND_TASK_ERROR_EMAILS = False