Browse Source

Unittests for celery.job: TaskWrapper + jail

Ask Solem 15 years ago
parent
commit
51d5d5a22b
4 changed files with 155 additions and 12 deletions
  1. 145 0
      celery/tests/test_worker_job.py
  2. 3 2
      celery/worker/__init__.py
  3. 5 10
      celery/worker/job.py
  4. 2 0
      testproj/settings.py

+ 145 - 0
celery/tests/test_worker_job.py

@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+import unittest
+from celery.worker.job import jail
+from celery.worker.job import TaskWrapper
+from celery.datastructures import ExceptionInfo
+from celery.models import TaskMeta
+from celery.registry import tasks, NotRegistered
+from celery.pool import TaskPool
+from uuid import uuid4
+from carrot.backends.base import BaseMessage
+import simplejson
+
+
+def mytask(i, **kwargs):
+    return i ** i
+tasks.register(mytask, name="cu.mytask")
+
+
+def mytask_raising(i, **kwargs):
+    raise KeyError(i)
+tasks.register(mytask_raising, name="cu.mytask-raising")
+
+
+def get_db_connection(i, **kwargs):
+    from django.db import connection
+    return id(connection)
+get_db_connection.ignore_result = True
+
+
+class TestJail(unittest.TestCase):
+
+    def test_execute_jail_success(self):
+        ret = jail(str(uuid4()), str(uuid4()), mytask, [2], {})
+        self.assertEquals(ret, 4)
+    
+    def test_execute_jail_failure(self):
+        ret = jail(str(uuid4()), str(uuid4()), mytask_raising, [4], {})
+        self.assertTrue(isinstance(ret, ExceptionInfo))
+        self.assertEquals(ret.exception.args, (4, ))
+
+    def test_django_db_connection_is_closed(self):
+        from django.db import connection
+        connection._was_closed = False
+        old_connection_close = connection.close
+
+        def monkeypatched_connection_close(*args, **kwargs):
+            connection._was_closed = True
+            return old_connection_close(*args, **kwargs)
+
+        connection.close = monkeypatched_connection_close
+
+        ret = jail(str(uuid4()), str(uuid4()), get_db_connection, [2], {})
+        self.assertTrue(connection._was_closed)
+
+        connection.close = old_connection_close
+
+
+class TestTaskWrapper(unittest.TestCase):
+
+    def test_task_wrapper_attrs(self):
+        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [1], {"f":"x"})
+        for attr in ("task_name", "task_id", "args", "kwargs", "logger"):
+            self.assertTrue(getattr(tw, attr, None))
+
+    def test_task_wrapper_repr(self):
+        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [1], {"f":"x"})
+        self.assertTrue(repr(tw))
+
+    def test_task_wrapper_mail_attrs(self):
+        tw = TaskWrapper(str(uuid4()), str(uuid4()), mytask, [], {})
+        x = tw.success_msg % {"name": tw.task_name,
+                              "id": tw.task_id,
+                              "return_value": 10}
+        self.assertTrue(x)
+        x = tw.fail_msg % {"name": tw.task_name,
+                           "id": tw.task_id,
+                           "exc": "FOOBARBAZ",
+                           "traceback": "foobarbaz"}
+        self.assertTrue(x)
+        x = tw.fail_email_subject % {"name": tw.task_name,
+                                     "id": tw.task_id,
+                                     "exc": "FOOBARBAZ",
+                                     "hostname": "lana"}
+        self.assertTrue(x)
+
+    def test_from_message(self):
+        body = {"task": "cu.mytask", "id": str(uuid4()),
+                "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
+        m = BaseMessage(body=simplejson.dumps(body), backend="foo",
+                        content_type="application/json",
+                        content_encoding="utf-8")
+        tw = TaskWrapper.from_message(m, m.decode())
+        self.assertTrue(isinstance(tw, TaskWrapper))
+        self.assertEquals(tw.task_name, body["task"])
+        self.assertEquals(tw.task_id, body["id"])
+        self.assertEquals(tw.args, body["args"])
+        self.assertEquals(tw.kwargs.keys()[0], u"æØåveéðƒeæ".encode("utf-8"))
+        self.assertFalse(isinstance(tw.kwargs.keys()[0], unicode))
+        self.assertEquals(id(mytask), id(tw.task_func))
+        self.assertTrue(tw.logger)
+
+    def test_from_message_nonexistant_task(self):
+        body = {"task": "cu.mytask.doesnotexist", "id": str(uuid4()),
+                "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
+        m = BaseMessage(body=simplejson.dumps(body), backend="foo",
+                        content_type="application/json",
+                        content_encoding="utf-8")
+        self.assertRaises(NotRegistered, TaskWrapper.from_message,
+                          m, m.decode())
+
+    def test_execute(self):
+        tid = str(uuid4())
+        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
+        self.assertEquals(tw.execute(), 256)
+        meta = TaskMeta.objects.get(task_id=tid)
+        self.assertEquals(meta.result, 256)
+        self.assertEquals(meta.status, "DONE")
+
+    def test_execute_fail(self):
+        tid = str(uuid4())
+        tw = TaskWrapper("cu.mytask-raising", tid, mytask_raising, [4],
+                         {"f":"x"})
+        self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
+        meta = TaskMeta.objects.get(task_id=tid)
+        self.assertEquals(meta.status, "FAILURE")
+        self.assertTrue(isinstance(meta.result, KeyError))
+
+    def test_execute_using_pool(self):
+        tid = str(uuid4())
+        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
+        p = TaskPool(2)
+        p.start()
+        asyncres = tw.execute_using_pool(p)
+        self.assertTrue(asyncres.get(), 256)
+        p.stop()
+
+    def test_default_kwargs(self):
+        tid = str(uuid4())
+        tw = TaskWrapper("cu.mytask", tid, mytask, [4], {"f":"x"})
+        self.assertEquals(tw.extend_with_default_kwargs(10, "some_logfile"), {
+            "f": "x",
+            "logfile": "some_logfile",
+            "loglevel": 10,
+            "task_id": tw.task_id,
+            "task_name": tw.task_name})

+ 3 - 2
celery/worker/__init__.py

@@ -7,7 +7,8 @@ Documentation for this module is in ``docs/reference/celery.worker.rst``.
 """
 from carrot.connection import DjangoAMQPConnection
 from celery.worker.controllers import Mediator, PeriodicWorkController
-from celery.worker.job import TaskWrapper, UnknownTaskError
+from celery.worker.job import TaskWrapper
+from celery.registry import NotRegistered
 from celery.messaging import TaskConsumer
 from celery.conf import DAEMON_CONCURRENCY, DAEMON_LOG_FILE
 from celery.log import setup_logger
@@ -209,7 +210,7 @@ class WorkController(object):
                 # execute_next_task didn't return a r/name/id tuple,
                 # probably because it got an exception.
                 pass
-            except UnknownTaskError, exc:
+            except NotRegistered, exc:
                 self.logger.info("Unknown task ignored: %s" % (exc))
             except Exception, exc:
                 self.logger.critical("Message queue raised %s: %s\n%s" % (

+ 5 - 10
celery/worker/job.py

@@ -4,7 +4,7 @@ Jobs Executable by the Worker Server.
 
 """
 from celery.conf import SEND_CELERY_TASK_ERROR_EMAILS
-from celery.registry import tasks
+from celery.registry import tasks, NotRegistered
 from celery.datastructures import ExceptionInfo
 from celery.backends import default_backend
 from django.core.mail import mail_admins
@@ -33,11 +33,6 @@ celeryd at %%(hostname)s.
 """ % {"EMAIL_SIGNATURE_SEP": EMAIL_SIGNATURE_SEP}
 
 
-class UnknownTaskError(Exception):
-    """Got an unknown task in the queue. The message is requeued and
-    ignored."""
-
-
 def jail(task_id, task_name, func, args, kwargs):
     """Wraps the task in a jail, which catches all exceptions, and
     saves the status and result of the task execution to the task
@@ -172,7 +167,7 @@ class TaskWrapper(object):
                 self.args, self.kwargs)
 
     @classmethod
-    def from_message(cls, message, message_data, logger):
+    def from_message(cls, message, message_data, logger=None):
         """Create a :class:`TaskWrapper` from a task message sent by
         :class:`celery.messaging.TaskPublisher`.
 
@@ -192,7 +187,7 @@ class TaskWrapper(object):
                     for key, value in kwargs.items()])
 
         if task_name not in tasks:
-            raise UnknownTaskError(task_name)
+            raise NotRegistered(task_name)
         task_func = tasks[task_name]
         return cls(task_name, task_id, task_func, args, kwargs,
                     on_acknowledge=message.ack, logger=logger)
@@ -222,8 +217,8 @@ class TaskWrapper(object):
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         if self.on_acknowledge:
             self.on_acknowledge()
-        return jail(self.task_id, self.task_name, [
-                        self.task_func, self.args, task_func_kwargs])
+        return jail(self.task_id, self.task_name, self.task_func,
+                    self.args, task_func_kwargs)
 
     def on_success(self, ret_value, meta):
         """The handler used if the task was successfully processed (

+ 2 - 0
testproj/settings.py

@@ -54,3 +54,5 @@ except ImportError:
     pass
 else:
     INSTALLED_APPS += ("test_extensions", )
+
+SEND_CELERY_TASK_ERROR_EMAILS = False