Browse Source

100% Coverage for celery.worker.job

Ask Solem 14 years ago
parent
commit
46f40a29b7
1 changed files with 122 additions and 6 deletions
  1. 122 6
      celery/tests/test_worker_job.py

+ 122 - 6
celery/tests/test_worker_job.py

@@ -2,8 +2,8 @@
 import logging
 import anyjson
 import sys
-from celery.tests.utils import unittest
-from celery.tests.utils import StringIO
+
+from datetime import datetime, timedelta
 
 from kombu.transport.base import Message
 
@@ -12,7 +12,7 @@ from celery.app import app_or_default
 from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.task import task as task_dec
-from celery.exceptions import RetryTaskError, NotRegistered
+from celery.exceptions import RetryTaskError, NotRegistered, WorkerLostError
 from celery.log import setup_logger
 from celery.result import AsyncResult
 from celery.task.base import Task
@@ -23,7 +23,9 @@ from celery.worker.job import InvalidTaskError
 from celery.worker.state import revoked
 
 from celery.tests.compat import catch_warnings
-from celery.tests.utils import execute_context
+from celery.tests.utils import unittest
+from celery.tests.utils import execute_context, StringIO
+
 
 scratch = {"ACK": False}
 some_kwargs_scratchpad = {}
@@ -84,10 +86,28 @@ class test_WorkerTaskTrace(unittest.TestCase):
 
     def test_marked_as_started(self):
         mytask.track_started = True
+
+        class Backend(mytask.backend.__class__):
+            _started = []
+
+            def mark_as_started(self, uuid, *args, **kwargs):
+                self._started.append(uuid)
+
+        prev, mytask.backend = mytask.backend, Backend()
+
         try:
-            jail(gen_unique_id(), mytask.name, [2], {})
+            uuid = gen_unique_id()
+            jail(uuid, mytask.name, [2], {})
+            self.assertIn(uuid, Backend._started)
+
+            mytask.ignore_result = True
+            uuid = gen_unique_id()
+            jail(uuid, mytask.name, [2], {})
+            self.assertNotIn(uuid, Backend._started)
         finally:
+            mytask.backend = prev
             mytask.track_started = False
+            mytask.ignore_result = False
 
     def test_execute_jail_failure(self):
         ret = jail(gen_unique_id(), mytask_raising.name,
@@ -108,7 +128,7 @@ class MockEventDispatcher(object):
     def __init__(self):
         self.sent = []
 
-    def send(self, event):
+    def send(self, event, **fields):
         self.sent.append(event)
 
 
@@ -118,12 +138,64 @@ class test_TaskRequest(unittest.TestCase):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         self.assertTrue(repr(tw))
 
+    def test_sets_store_errors(self):
+        mytask.ignore_result = True
+        try:
+            tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+            self.assertFalse(tw._store_errors)
+            mytask.store_errors_even_if_ignored = True
+            tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+            self.assertTrue(tw._store_errors)
+        finally:
+            mytask.ignore_result = False
+            mytask.store_errors_even_if_ignored = False
+
     def test_send_event(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.eventer = MockEventDispatcher()
         tw.send_event("task-frobulated")
         self.assertIn("task-frobulated", tw.eventer.sent)
 
+    def test_on_retry(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.eventer = MockEventDispatcher()
+        try:
+            raise RetryTaskError("foo", KeyError("moofoobar"))
+        except:
+            einfo = ExceptionInfo(sys.exc_info())
+        tw.on_failure(einfo)
+        self.assertIn("task-retried", tw.eventer.sent)
+
+    def test_revoked_expires_expired(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.expires = datetime.now() - timedelta(days=1)
+        tw.revoked()
+        self.assertIn(tw.task_id, revoked)
+        self.assertEqual(mytask.backend.get_status(tw.task_id),
+                         states.REVOKED)
+
+    def test_revoked_expires_not_expired(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        tw.expires = datetime.now() + timedelta(days=1)
+        tw.revoked()
+        self.assertNotIn(tw.task_id, revoked)
+        self.assertNotEqual(mytask.backend.get_status(tw.task_id),
+                         states.REVOKED)
+
+    def test_revoked_expires_ignore_result(self):
+        mytask.ignore_result = True
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        try:
+            tw.expires = datetime.now() - timedelta(days=1)
+            tw.revoked()
+            self.assertIn(tw.task_id, revoked)
+            self.assertNotEqual(mytask.backend.get_status(tw.task_id),
+                                states.REVOKED)
+
+        finally:
+            mytask.ignore_result = False
+
+
     def test_send_email(self):
         app = app_or_default()
         old_mail_admins = app.mail_admins
@@ -150,9 +222,22 @@ class test_TaskRequest(unittest.TestCase):
             tw.on_failure(einfo)
             self.assertFalse(mail_sent[0])
 
+            mail_sent[0] = False
+            mytask.send_error_emails = True
+            mytask.error_whitelist = [KeyError]
+            tw.on_failure(einfo)
+            self.assertTrue(mail_sent[0])
+
+            mail_sent[0] = False
+            mytask.send_error_emails = True
+            mytask.error_whitelist = [SyntaxError]
+            tw.on_failure(einfo)
+            self.assertFalse(mail_sent[0])
+
         finally:
             app.mail_admins = old_mail_admins
             mytask.send_error_emails = old_enable_mails
+            mytask.error_whitelist = ()
 
     def test_already_revoked(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
@@ -215,6 +300,25 @@ class test_TaskRequest(unittest.TestCase):
         finally:
             mytask.acks_late = False
 
+    def test_on_failure_WorkerLostError(self):
+        tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+        try:
+            raise WorkerLostError("do re mi")
+        except WorkerLostError:
+            exc_info = ExceptionInfo(sys.exc_info())
+        tw.on_failure(exc_info)
+        self.assertEqual(mytask.backend.get_status(tw.task_id),
+                         states.FAILURE)
+
+        mytask.ignore_result = True
+        try:
+            tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+            tw.on_failure(exc_info)
+            self.assertEqual(mytask.backend.get_status(tw.task_id),
+                             states.PENDING)
+        finally:
+            mytask.ignore_result = False
+
     def test_on_failure_acks_late(self):
         tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
         tw.time_start = 1
@@ -254,6 +358,18 @@ class test_TaskRequest(unittest.TestCase):
         self.assertIn("Soft time limit exceeded", tw.logger.warnings[0])
         tw.on_timeout(soft=False)
         self.assertIn("Hard time limit exceeded", tw.logger.errors[0])
+        self.assertEqual(mytask.backend.get_status(tw.task_id),
+                         states.FAILURE)
+
+        mytask.ignore_result = True
+        try:
+            tw = TaskRequest(mytask.name, gen_unique_id(), [1], {"f": "x"})
+            tw.logger = MockLogger()
+        finally:
+            mytask.ignore_result = False
+            tw.on_timeout(soft=True)
+            self.assertEqual(mytask.backend.get_status(tw.task_id),
+                             states.PENDING)
 
     def test_execute_and_trace(self):
         res = execute_and_trace(mytask.name, gen_unique_id(), [4], {})