Browse Source

100% coverage for celery.worker.job

Ask Solem 15 years ago
parent
commit
baa329b77c
3 changed files with 127 additions and 18 deletions
  1. 2 2
      celery/task/builtins.py
  2. 124 15
      celery/tests/test_worker_job.py
  3. 1 1
      celery/worker/job.py

+ 2 - 2
celery/task/builtins.py

@@ -19,8 +19,8 @@ class DeleteExpiredTaskMetaTask(PeriodicTask):
     def run(self, **kwargs):
         """:returns: None"""
         logger = self.get_logger(**kwargs)
-        logger.info("Deleting expired task meta objects...")
-        default_backend.cleanup()
+        logger.info("Deleting expired task results...")
+        self.backend.cleanup()
 
 
 class PingTask(Task):

+ 124 - 15
celery/tests/test_worker_job.py

@@ -1,25 +1,30 @@
 # -*- coding: utf-8 -*-
-import sys
 import logging
-import unittest2 as unittest
 import simplejson
+import sys
+import unittest2 as unittest
+
 from StringIO import StringIO
 
 from carrot.backends.base import BaseMessage
 
 from celery import states
+from celery.backends import default_backend
+from celery.datastructures import ExceptionInfo
+from celery.decorators import task as task_dec
+from celery.exceptions import RetryTaskError, NotRegistered
 from celery.log import setup_logger
+from celery.registry import tasks
+from celery.result import AsyncResult
 from celery.task.base import Task
 from celery.utils import gen_unique_id
-from celery.result import AsyncResult
 from celery.worker.job import WorkerTaskTrace, TaskRequest
-from celery.backends import default_backend
-from celery.exceptions import RetryTaskError, NotRegistered
-from celery.decorators import task as task_dec
-from celery.datastructures import ExceptionInfo
+from celery.worker.job import execute_and_trace, AlreadyExecutedError
+from celery.worker.job import InvalidTaskError
+from celery.worker.revoke import revoked
 
-from celery.tests.utils import execute_context
 from celery.tests.compat import catch_warnings
+from celery.tests.utils import execute_context
 
 scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
@@ -61,7 +66,7 @@ def mytask_raising(i, **kwargs):
     raise KeyError(i)
 
 
-class TestRetryTaskError(unittest.TestCase):
+class test_RetryTaskError(unittest.TestCase):
 
     def test_retry_task_error(self):
         try:
@@ -72,12 +77,19 @@ class TestRetryTaskError(unittest.TestCase):
         self.assertEqual(ret.exc, exc)
 
 
-class TestJail(unittest.TestCase):
+class test_WorkerTaskTrace(unittest.TestCase):
 
     def test_execute_jail_success(self):
         ret = jail(gen_unique_id(), mytask.name, [2], {})
         self.assertEqual(ret, 4)
 
+    def test_marked_as_started(self):
+        mytask.track_started = True
+        try:
+            ret = jail(gen_unique_id(), mytask.name, [2], {})
+        finally:
+            mytask.track_started = False
+
     def test_execute_jail_failure(self):
         ret = jail(gen_unique_id(), mytask_raising.name,
                    [4], {})
@@ -101,7 +113,7 @@ class MockEventDispatcher(object):
         self.sent.append(event)
 
 
-class TestTaskRequest(unittest.TestCase):
+class test_TaskRequest(unittest.TestCase):
 
     def test_task_wrapper_repr(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
@@ -144,13 +156,112 @@ class TestTaskRequest(unittest.TestCase):
             job.mail_admins = old_mail_admins
             conf.CELERY_SEND_TASK_ERROR_EMAILS = old_enable_mails
 
+    def test_already_revoked(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw._already_revoked = True
+        self.assertTrue(tw.revoked())
+
+    def test_revoked(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        revoked.add(tw.task_id)
+        self.assertTrue(tw.revoked())
+        self.assertTrue(tw._already_revoked)
+        self.assertTrue(tw.acknowledged)
+
+    def test_execute_does_not_execute_revoked(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        revoked.add(tw.task_id)
+        tw.execute()
+
+    def test_execute_acks_late(self):
+        mytask_raising.acks_late = True
+        tw = TaskRequest(mytask_raising.name, gen_unique_id(), [1], {"f": "x"})
+        try:
+            tw.execute()
+            self.assertTrue(tw.acknowledged)
+        finally:
+            mytask_raising.acks_late = False
+
+    def test_execute_using_pool_does_not_execute_revoked(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        revoked.add(tw.task_id)
+        tw.execute_using_pool(None)
+
+    def test_on_accepted_acks_early(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.on_accepted()
+        self.assertTrue(tw.acknowledged)
+
+    def test_on_accepted_acks_late(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        mytask.acks_late = True
+        try:
+            tw.on_accepted()
+            self.assertFalse(tw.acknowledged)
+        finally:
+            mytask.acks_late = False
+
+    def test_on_success_acks_early(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = 1
+        tw.on_success(42)
+        self.assertFalse(tw.acknowledged)
+
+    def test_on_success_acks_late(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = 1
+        mytask.acks_late = True
+        try:
+            tw.on_success(42)
+            self.assertTrue(tw.acknowledged)
+        finally:
+            mytask.acks_late = False
+
+    def test_on_failure_acks_late(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.time_start = 1
+        mytask.acks_late = True
+        try:
+            try:
+                raise KeyError("foo")
+            except KeyError:
+                exc_info = ExceptionInfo(sys.exc_info())
+            tw.on_failure(exc_info)
+            self.assertTrue(tw.acknowledged)
+        finally:
+            mytask.acks_late = False
+
+    def test_from_message_invalid_kwargs(self):
+        message_data = dict(task="foo", id=1, args=(), kwargs="foo")
+        self.assertRaises(InvalidTaskError, TaskRequest.from_message, None,
+                message_data)
+
+    def test_on_timeout(self):
+
+        class MockLogger(object):
+
+            def __init__(self):
+                self.warnings = []
+                self.errors = []
+
+            def warning(self, msg, *args, **kwargs):
+                self.warnings.append(msg)
+
+            def error(self, msg, *args, **kwargs):
+                self.errors.append(msg)
+
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.logger = MockLogger()
+        tw.on_timeout(soft=True)
+        self.assertIn("Soft time limit exceeded", tw.logger.warnings[0])
+        tw.on_timeout(soft=False)
+        self.assertIn("Hard time limit exceeded", tw.logger.errors[0])
+
     def test_execute_and_trace(self):
-        from celery.worker.job import execute_and_trace
         res = execute_and_trace(mytask.name, gen_unique_id(), [4], {})
         self.assertEqual(res, 4 ** 4)
 
     def test_execute_safe_catches_exception(self):
-        from celery.worker.job import execute_and_trace, WorkerTaskTrace
         old_exec = WorkerTaskTrace.execute
 
         def _error_exec(self, *args, **kwargs):
@@ -192,7 +303,6 @@ class TestTaskRequest(unittest.TestCase):
         self.assertEqual(mytask.backend.get_status(uuid), states.RETRY)
 
     def test_worker_task_trace_handle_failure(self):
-        from celery.worker.job import WorkerTaskTrace
         uuid = gen_unique_id()
         w = WorkerTaskTrace(mytask.name, uuid, [4], {})
         type_, value_, tb_ = self.create_exception(ValueError("foo"))
@@ -204,7 +314,6 @@ class TestTaskRequest(unittest.TestCase):
         self.assertEqual(mytask.backend.get_status(uuid), states.FAILURE)
 
     def test_executed_bit(self):
-        from celery.worker.job import AlreadyExecutedError
         tw = TaskRequest(mytask.name, gen_unique_id(), [], {})
         self.assertFalse(tw.executed)
         tw._set_executed_bit()

+ 1 - 1
celery/worker/job.py

@@ -237,7 +237,7 @@ class TaskRequest(object):
             self.logger.warn("Skipping revoked task: %s[%s]" % (
                 self.task_name, self.task_id))
             self.send_event("task-revoked", uuid=self.task_id)
-            self.on_ack()
+            self.acknowledge()
             self._already_revoked = True
             return True
         return False