|  | @@ -0,0 +1,145 @@
 | 
	
		
			
				|  |  | +# -*- coding: utf-8 -*-
 | 
	
		
			
				|  |  | +import unittest
 | 
	
		
			
				|  |  | +from celery.worker.job import jail
 | 
	
		
			
				|  |  | +from celery.worker.job import TaskWrapper
 | 
	
		
			
				|  |  | +from celery.datastructures import ExceptionInfo
 | 
	
		
			
				|  |  | +from celery.models import TaskMeta
 | 
	
		
			
				|  |  | +from celery.registry import tasks, NotRegistered
 | 
	
		
			
				|  |  | +from celery.pool import TaskPool
 | 
	
		
			
				|  |  | +from uuid import uuid4
 | 
	
		
			
				|  |  | +from carrot.backends.base import BaseMessage
 | 
	
		
			
				|  |  | +import simplejson
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def mytask(i, **kwargs):
 | 
	
		
			
				|  |  | +    return i ** i
 | 
	
		
			
				|  |  | +tasks.register(mytask, name="cu.mytask")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def mytask_raising(i, **kwargs):
 | 
	
		
			
				|  |  | +    raise KeyError(i)
 | 
	
		
			
				|  |  | +tasks.register(mytask_raising, name="cu.mytask-raising")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def get_db_connection(i, **kwargs):
 | 
	
		
			
				|  |  | +    from django.db import connection
 | 
	
		
			
				|  |  | +    return id(connection)
 | 
	
		
			
				|  |  | +get_db_connection.ignore_result = True
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class TestJail(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_execute_jail_success(self):
 | 
	
		
			
				|  |  | +        ret = jail(str(uuid4()), str(uuid4()), mytask, [2], {})
 | 
	
		
			
				|  |  | +        self.assertEquals(ret, 4)
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    def test_execute_jail_failure(self):
 | 
	
		
			
				|  |  | +        ret = jail(str(uuid4()), str(uuid4()), mytask_raising, [4], {})
 | 
	
		
			
				|  |  | +        self.assertTrue(isinstance(ret, ExceptionInfo))
 | 
	
		
			
				|  |  | +        self.assertEquals(ret.exception.args, (4, ))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_django_db_connection_is_closed(self):
 | 
	
		
			
				|  |  | +        from django.db import connection
 | 
	
		
			
				|  |  | +        connection._was_closed = False
 | 
	
		
			
				|  |  | +        old_connection_close = connection.close
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def monkeypatched_connection_close(*args, **kwargs):
 | 
	
		
			
				|  |  | +            connection._was_closed = True
 | 
	
		
			
				|  |  | +            return old_connection_close(*args, **kwargs)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        connection.close = monkeypatched_connection_close
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ret = jail(str(uuid4()), str(uuid4()), get_db_connection, [2], {})
 | 
	
		
			
				|  |  | +        self.assertTrue(connection._was_closed)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        connection.close = old_connection_close
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +class TestTaskWrapper(unittest.TestCase):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_task_wrapper_attrs(self):
 | 
	
		
			
				|  |  | +        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [1], {"f":"x"})
 | 
	
		
			
				|  |  | +        for attr in ("task_name", "task_id", "args", "kwargs", "logger"):
 | 
	
		
			
				|  |  | +            self.assertTrue(getattr(tw, attr, None))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_task_wrapper_repr(self):
 | 
	
		
			
				|  |  | +        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [1], {"f":"x"})
 | 
	
		
			
				|  |  | +        self.assertTrue(repr(tw))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_task_wrapper_mail_attrs(self):
 | 
	
		
			
				|  |  | +        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [], {})
 | 
	
		
			
				|  |  | +        x = tw.success_msg % {"name": tw.task_name,
 | 
	
		
			
				|  |  | +                              "id": tw.task_id,
 | 
	
		
			
				|  |  | +                              "return_value": 10}
 | 
	
		
			
				|  |  | +        self.assertTrue(x)
 | 
	
		
			
				|  |  | +        x = tw.fail_msg % {"name": tw.task_name,
 | 
	
		
			
				|  |  | +                           "id": tw.task_id,
 | 
	
		
			
				|  |  | +                           "exc": "FOOBARBAZ",
 | 
	
		
			
				|  |  | +                           "traceback": "foobarbaz"}
 | 
	
		
			
				|  |  | +        self.assertTrue(x)
 | 
	
		
			
				|  |  | +        x = tw.fail_email_subject % {"name": tw.task_name,
 | 
	
		
			
				|  |  | +                                     "id": tw.task_id,
 | 
	
		
			
				|  |  | +                                     "exc": "FOOBARBAZ",
 | 
	
		
			
				|  |  | +                                     "hostname": "lana"}
 | 
	
		
			
				|  |  | +        self.assertTrue(x)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_from_message(self):
 | 
	
		
			
				|  |  | +        body = {"task": "cu.mytask", "id": str(uuid4()),
 | 
	
		
			
				|  |  | +                "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
 | 
	
		
			
				|  |  | +        m = BaseMessage(body=simplejson.dumps(body), backend="foo",
 | 
	
		
			
				|  |  | +                        content_type="application/json",
 | 
	
		
			
				|  |  | +                        content_encoding="utf-8")
 | 
	
		
			
				|  |  | +        tw = TaskWrapper.from_message(m, m.decode())
 | 
	
		
			
				|  |  | +        self.assertTrue(isinstance(tw, TaskWrapper))
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.task_name, body["task"])
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.task_id, body["id"])
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.args, body["args"])
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.kwargs.keys()[0], u"æØåveéðƒeæ".encode("utf-8"))
 | 
	
		
			
				|  |  | +        self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
 | 
	
		
			
				|  |  | +        self.assertEquals(id(mytask), id(tw.task_func))
 | 
	
		
			
				|  |  | +        self.assertTrue(tw.logger)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_from_message_nonexistant_task(self):
 | 
	
		
			
				|  |  | +        body = {"task": "cu.mytask.doesnotexist", "id": str(uuid4()),
 | 
	
		
			
				|  |  | +                "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
 | 
	
		
			
				|  |  | +        m = BaseMessage(body=simplejson.dumps(body), backend="foo",
 | 
	
		
			
				|  |  | +                        content_type="application/json",
 | 
	
		
			
				|  |  | +                        content_encoding="utf-8")
 | 
	
		
			
				|  |  | +        self.assertRaises(NotRegistered, TaskWrapper.from_message,
 | 
	
		
			
				|  |  | +                          m, m.decode())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_execute(self):
 | 
	
		
			
				|  |  | +        tid = str(uuid4())
 | 
	
		
			
				|  |  | +        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.execute(), 256)
 | 
	
		
			
				|  |  | +        meta = TaskMeta.objects.get(task_id=tid)
 | 
	
		
			
				|  |  | +        self.assertEquals(meta.result, 256)
 | 
	
		
			
				|  |  | +        self.assertEquals(meta.status, "DONE")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_execute_fail(self):
 | 
	
		
			
				|  |  | +        tid = str(uuid4())
 | 
	
		
			
				|  |  | +        tw = TaskWrapper("cu.mytask-raising", tid, mytask_raising, [4],
 | 
	
		
			
				|  |  | +                         {"f":"x"})
 | 
	
		
			
				|  |  | +        self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
 | 
	
		
			
				|  |  | +        meta = TaskMeta.objects.get(task_id=tid)
 | 
	
		
			
				|  |  | +        self.assertEquals(meta.status, "FAILURE")
 | 
	
		
			
				|  |  | +        self.assertTrue(isinstance(meta.result, KeyError))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_execute_using_pool(self):
 | 
	
		
			
				|  |  | +        tid = str(uuid4())
 | 
	
		
			
				|  |  | +        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
 | 
	
		
			
				|  |  | +        p = TaskPool(2)
 | 
	
		
			
				|  |  | +        p.start()
 | 
	
		
			
				|  |  | +        asyncres = tw.execute_using_pool(p)
 | 
	
		
			
				|  |  | +        self.assertTrue(asyncres.get(), 256)
 | 
	
		
			
				|  |  | +        p.stop()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def test_default_kwargs(self):
 | 
	
		
			
				|  |  | +        tid = str(uuid4())
 | 
	
		
			
				|  |  | +        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
 | 
	
		
			
				|  |  | +        self.assertEquals(tw.extend_with_default_kwargs(10, "some_logfile"), {
 | 
	
		
			
				|  |  | +            "f": "x",
 | 
	
		
			
				|  |  | +            "logfile": "some_logfile",
 | 
	
		
			
				|  |  | +            "loglevel": 10,
 | 
	
		
			
				|  |  | +            "task_id": tw.task_id,
 | 
	
		
			
				|  |  | +            "task_name": tw.task_name})
 |