Browse Source

Use celery.states constants in unittests.

Ask Solem 15 years ago
parent
commit
b8dc2c3b5b

+ 5 - 4
celery/tests/test_backends/test_amqp.py

@@ -1,13 +1,14 @@
 from __future__ import with_statement
 
 import sys
-import unittest
 import errno
+import unittest
 
 from django.core.exceptions import ImproperlyConfigured
 
-from celery.backends.amqp import AMQPBackend
+from celery import states
 from celery.utils import gen_unique_id
+from celery.backends.amqp import AMQPBackend
 from celery.datastructures import ExceptionInfo
 
 
@@ -29,7 +30,7 @@ class TestRedisBackend(unittest.TestCase):
 
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
-        self.assertEquals(tb.get_status(tid), "SUCCESS")
+        self.assertEquals(tb.get_status(tid), states.SUCCESS)
         self.assertEquals(tb.get_result(tid), 42)
         self.assertTrue(tb._cache.get(tid))
         self.assertTrue(tb.get_result(tid), 42)
@@ -55,7 +56,7 @@ class TestRedisBackend(unittest.TestCase):
             einfo = ExceptionInfo(sys.exc_info())
         tb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(tb.is_successful(tid3))
-        self.assertEquals(tb.get_status(tid3), "FAILURE")
+        self.assertEquals(tb.get_status(tid3), states.FAILURE)
         self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
         self.assertEquals(tb.get_traceback(tid3), einfo.traceback)
 

+ 2 - 1
celery/tests/test_backends/test_base.py

@@ -6,6 +6,7 @@ from billiard.serialization import find_nearest_pickleable_exception as fnpe
 from billiard.serialization import UnpickleableExceptionWrapper
 from billiard.serialization import get_pickleable_exception as gpe
 
+from celery import states
 from celery.backends.base import BaseBackend, KeyValueStoreBackend
 
 
@@ -30,7 +31,7 @@ class TestBaseBackendInterface(unittest.TestCase):
 
     def test_store_result(self):
         self.assertRaises(NotImplementedError,
-                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, "SUCCESS")
+                b.store_result, "SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
 
     def test_get_result(self):
         self.assertRaises(NotImplementedError,

+ 4 - 3
celery/tests/test_backends/test_cache.py

@@ -3,6 +3,7 @@ import unittest
 
 from billiard.serialization import pickle
 
+from celery import states
 from celery.utils import gen_unique_id
 from celery.backends.cache import CacheBackend
 from celery.datastructures import ExceptionInfo
@@ -22,12 +23,12 @@ class TestCacheBackend(unittest.TestCase):
         tid = gen_unique_id()
 
         self.assertFalse(cb.is_successful(tid))
-        self.assertEquals(cb.get_status(tid), "PENDING")
+        self.assertEquals(cb.get_status(tid), states.PENDING)
         self.assertEquals(cb.get_result(tid), None)
 
         cb.mark_as_done(tid, 42)
         self.assertTrue(cb.is_successful(tid))
-        self.assertEquals(cb.get_status(tid), "SUCCESS")
+        self.assertEquals(cb.get_status(tid), states.SUCCESS)
         self.assertEquals(cb.get_result(tid), 42)
         self.assertTrue(cb._cache.get(tid))
         self.assertTrue(cb.get_result(tid), 42)
@@ -55,7 +56,7 @@ class TestCacheBackend(unittest.TestCase):
             pass
         cb.mark_as_failure(tid3, exception, traceback=einfo.traceback)
         self.assertFalse(cb.is_successful(tid3))
-        self.assertEquals(cb.get_status(tid3), "FAILURE")
+        self.assertEquals(cb.get_status(tid3), states.FAILURE)
         self.assertTrue(isinstance(cb.get_result(tid3), KeyError))
         self.assertEquals(cb.get_traceback(tid3), einfo.traceback)
 

+ 4 - 3
celery/tests/test_backends/test_database.py

@@ -1,6 +1,7 @@
 import unittest
 from datetime import timedelta
 
+from celery import states
 from celery.task import PeriodicTask
 from celery.utils import gen_unique_id
 from celery.backends.database import DatabaseBackend
@@ -27,12 +28,12 @@ class TestDatabaseBackend(unittest.TestCase):
         tid = gen_unique_id()
 
         self.assertFalse(b.is_successful(tid))
-        self.assertEquals(b.get_status(tid), "PENDING")
+        self.assertEquals(b.get_status(tid), states.PENDING)
         self.assertTrue(b.get_result(tid) is None)
 
         b.mark_as_done(tid, 42)
         self.assertTrue(b.is_successful(tid))
-        self.assertEquals(b.get_status(tid), "SUCCESS")
+        self.assertEquals(b.get_status(tid), states.SUCCESS)
         self.assertEquals(b.get_result(tid), 42)
         self.assertTrue(b._cache.get(tid))
         self.assertTrue(b.get_result(tid), 42)
@@ -52,7 +53,7 @@ class TestDatabaseBackend(unittest.TestCase):
             pass
         b.mark_as_failure(tid3, exception)
         self.assertFalse(b.is_successful(tid3))
-        self.assertEquals(b.get_status(tid3), "FAILURE")
+        self.assertEquals(b.get_status(tid3), states.FAILURE)
         self.assertTrue(isinstance(b.get_result(tid3), KeyError))
 
     def test_taskset_store(self):

+ 5 - 4
celery/tests/test_backends/test_redis.py

@@ -6,9 +6,10 @@ import errno
 
 from django.core.exceptions import ImproperlyConfigured
 
+from celery import states
+from celery.utils import gen_unique_id
 from celery.backends import pyredis
 from celery.backends.pyredis import RedisBackend
-from celery.utils import gen_unique_id
 
 _no_redis_msg = "* Redis %s. Will not execute related tests."
 _no_redis_msg_emitted = False
@@ -64,12 +65,12 @@ class TestRedisBackend(unittest.TestCase):
         tid = gen_unique_id()
 
         self.assertFalse(tb.is_successful(tid))
-        self.assertEquals(tb.get_status(tid), "PENDING")
+        self.assertEquals(tb.get_status(tid), states.PENDING)
         self.assertEquals(tb.get_result(tid), None)
 
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
-        self.assertEquals(tb.get_status(tid), "SUCCESS")
+        self.assertEquals(tb.get_status(tid), states.SUCCESS)
         self.assertEquals(tb.get_result(tid), 42)
         self.assertTrue(tb._cache.get(tid))
         self.assertTrue(tb.get_result(tid), 42)
@@ -99,7 +100,7 @@ class TestRedisBackend(unittest.TestCase):
             pass
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
-        self.assertEquals(tb.get_status(tid3), "FAILURE")
+        self.assertEquals(tb.get_status(tid3), states.FAILURE)
         self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
 
     def test_process_cleanup(self):

+ 9 - 6
celery/tests/test_backends/test_tyrant.py

@@ -1,11 +1,14 @@
 import sys
-import unittest
 import errno
 import socket
+import unittest
+
+from django.core.exceptions import ImproperlyConfigured
+
+from celery import states
+from celery.utils import gen_unique_id
 from celery.backends import tyrant
 from celery.backends.tyrant import TyrantBackend
-from celery.utils import gen_unique_id
-from django.core.exceptions import ImproperlyConfigured
 
 _no_tyrant_msg = "* Tokyo Tyrant %s. Will not execute related tests."
 _no_tyrant_msg_emitted = False
@@ -64,12 +67,12 @@ class TestTyrantBackend(unittest.TestCase):
         tid = gen_unique_id()
 
         self.assertFalse(tb.is_successful(tid))
-        self.assertEquals(tb.get_status(tid), "PENDING")
+        self.assertEquals(tb.get_status(tid), states.PENDING)
         self.assertEquals(tb.get_result(tid), None)
 
         tb.mark_as_done(tid, 42)
         self.assertTrue(tb.is_successful(tid))
-        self.assertEquals(tb.get_status(tid), "SUCCESS")
+        self.assertEquals(tb.get_status(tid), states.SUCCESS)
         self.assertEquals(tb.get_result(tid), 42)
         self.assertTrue(tb._cache.get(tid))
         self.assertTrue(tb.get_result(tid), 42)
@@ -99,7 +102,7 @@ class TestTyrantBackend(unittest.TestCase):
             pass
         tb.mark_as_failure(tid3, exception)
         self.assertFalse(tb.is_successful(tid3))
-        self.assertEquals(tb.get_status(tid3), "FAILURE")
+        self.assertEquals(tb.get_status(tid3), states.FAILURE)
         self.assertTrue(isinstance(tb.get_result(tid3), KeyError))
 
     def test_process_cleanup(self):

+ 6 - 5
celery/tests/test_models.py

@@ -1,6 +1,7 @@
 import unittest
 from datetime import datetime, timedelta
 
+from celery import states
 from celery.utils import gen_unique_id
 from celery.models import TaskMeta, TaskSetMeta
 
@@ -28,13 +29,13 @@ class TestModels(unittest.TestCase):
         self.assertEquals(TaskMeta.objects.get_task(m1.task_id).task_id,
                 m1.task_id)
         self.assertFalse(
-                TaskMeta.objects.get_task(m1.task_id).status == "SUCCESS")
-        TaskMeta.objects.store_result(m1.task_id, True, status="SUCCESS")
-        TaskMeta.objects.store_result(m2.task_id, True, status="SUCCESS")
+                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
+        TaskMeta.objects.store_result(m1.task_id, True, status=states.SUCCESS)
+        TaskMeta.objects.store_result(m2.task_id, True, status=states.SUCCESS)
         self.assertTrue(
-                TaskMeta.objects.get_task(m1.task_id).status == "SUCCESS")
+                TaskMeta.objects.get_task(m1.task_id).status == states.SUCCESS)
         self.assertTrue(
-                TaskMeta.objects.get_task(m2.task_id).status == "SUCCESS")
+                TaskMeta.objects.get_task(m2.task_id).status == states.SUCCESS)
 
         # Have to avoid save() because it applies the auto_now=True.
         TaskMeta.objects.filter(task_id=m1.task_id).update(

+ 11 - 10
celery/tests/test_result.py

@@ -1,5 +1,6 @@
 import unittest
 
+from celery import states
 from celery.utils import gen_unique_id
 from celery.tests.utils import skip_if_quick
 from celery.result import AsyncResult, TaskSetResult
@@ -14,9 +15,9 @@ def mock_task(name, status, result):
 
 def save_result(task):
     traceback = "Some traceback"
-    if task["status"] == "SUCCESS":
+    if task["status"] == states.SUCCESS:
         default_backend.mark_as_done(task["id"], task["result"])
-    elif task["status"] == "RETRY":
+    elif task["status"] == states.RETRY:
         default_backend.mark_as_retry(task["id"], task["result"],
                 traceback=traceback)
     else:
@@ -25,7 +26,7 @@ def save_result(task):
 
 
 def make_mock_taskset(size=10):
-    tasks = [mock_task("ts%d" % i, "SUCCESS", i) for i in xrange(size)]
+    tasks = [mock_task("ts%d" % i, states.SUCCESS, i) for i in xrange(size)]
     [save_result(task) for task in tasks]
     return [AsyncResult(task["id"]) for task in tasks]
 
@@ -33,10 +34,10 @@ def make_mock_taskset(size=10):
 class TestAsyncResult(unittest.TestCase):
 
     def setUp(self):
-        self.task1 = mock_task("task1", "SUCCESS", "the")
-        self.task2 = mock_task("task2", "SUCCESS", "quick")
-        self.task3 = mock_task("task3", "FAILURE", KeyError("brown"))
-        self.task4 = mock_task("task3", "RETRY", KeyError("red"))
+        self.task1 = mock_task("task1", states.SUCCESS, "the")
+        self.task2 = mock_task("task2", states.SUCCESS, "quick")
+        self.task3 = mock_task("task3", states.FAILURE, KeyError("brown"))
+        self.task4 = mock_task("task3", states.RETRY, KeyError("red"))
 
         for task in (self.task1, self.task2, self.task3, self.task4):
             save_result(task)
@@ -113,7 +114,7 @@ class MockAsyncResultFailure(AsyncResult):
 
     @property
     def status(self):
-        return "FAILURE"
+        return states.FAILURE
 
 
 class MockAsyncResultSuccess(AsyncResult):
@@ -124,7 +125,7 @@ class MockAsyncResultSuccess(AsyncResult):
 
     @property
     def status(self):
-        return "SUCCESS"
+        return states.SUCCESS
 
 
 class TestTaskSetResult(unittest.TestCase):
@@ -205,7 +206,7 @@ class TestFailedTaskSetResult(TestTaskSetResult):
     def setUp(self):
         self.size = 11
         subtasks = make_mock_taskset(10)
-        failed = mock_task("ts11", "FAILED", KeyError("Baz"))
+        failed = mock_task("ts11", states.FAILURE, KeyError("Baz"))
         save_result(failed)
         failed_res = AsyncResult(failed["id"])
         self.ts = TaskSetResult(gen_unique_id(), subtasks + [failed_res])

+ 8 - 7
celery/tests/test_views.py

@@ -11,6 +11,7 @@ from anyjson import deserialize as JSON_load
 from billiard.utils.functional import curry
 
 from celery import conf
+from celery import states
 from celery.utils import gen_unique_id, get_full_cls_name
 from celery.backends import default_backend
 from celery.exceptions import RetryTaskError
@@ -93,16 +94,16 @@ class TestTaskStatus(ViewTestCase):
         self.assertJSONEquals(json, dict(task=expect))
 
     def test_task_status_success(self):
-        self.assertStatusForIs("SUCCESS", "The quick brown fox")
+        self.assertStatusForIs(states.SUCCESS, "The quick brown fox")
 
     def test_task_status_failure(self):
         exc, tb = catch_exception(KeyError("foo"))
-        self.assertStatusForIs("FAILURE", exc, tb)
+        self.assertStatusForIs(states.FAILURE, exc, tb)
 
     def test_task_status_retry(self):
         oexc, _ = catch_exception(KeyError("Resource not available"))
         exc, tb = catch_exception(RetryTaskError(str(oexc), oexc))
-        self.assertStatusForIs("RETRY", exc, tb)
+        self.assertStatusForIs(states.RETRY, exc, tb)
 
 
 class TestTaskIsSuccessful(ViewTestCase):
@@ -116,13 +117,13 @@ class TestTaskIsSuccessful(ViewTestCase):
                                               "executed": outcome}})
 
     def test_is_successful_success(self):
-        self.assertStatusForIs("SUCCESS", True)
+        self.assertStatusForIs(states.SUCCESS, True)
 
     def test_is_successful_pending(self):
-        self.assertStatusForIs("PENDING", False)
+        self.assertStatusForIs(states.PENDING, False)
 
     def test_is_successful_failure(self):
-        self.assertStatusForIs("FAILURE", False)
+        self.assertStatusForIs(states.FAILURE, False)
 
     def test_is_successful_retry(self):
-        self.assertStatusForIs("RETRY", False)
+        self.assertStatusForIs(states.RETRY, False)

+ 10 - 9
celery/tests/test_worker_job.py

@@ -10,6 +10,7 @@ from StringIO import StringIO
 from django.core import cache
 from carrot.backends.base import BaseMessage
 
+from celery import states
 from celery.log import setup_logger
 from celery.task.base import Task
 from celery.utils import gen_unique_id
@@ -250,10 +251,10 @@ class TestTaskWrapper(unittest.TestCase):
                                                                   exc=value_))
         w._store_errors = False
         w.handle_retry(value_, type_, tb_, "")
-        self.assertEquals(mytask.backend.get_status(uuid), "PENDING")
+        self.assertEquals(mytask.backend.get_status(uuid), states.PENDING)
         w._store_errors = True
         w.handle_retry(value_, type_, tb_, "")
-        self.assertEquals(mytask.backend.get_status(uuid), "RETRY")
+        self.assertEquals(mytask.backend.get_status(uuid), states.RETRY)
 
     def test_worker_task_trace_handle_failure(self):
         from celery.worker.job import WorkerTaskTrace
@@ -262,10 +263,10 @@ class TestTaskWrapper(unittest.TestCase):
         type_, value_, tb_ = self.create_exception(ValueError("foo"))
         w._store_errors = False
         w.handle_failure(value_, type_, tb_, "")
-        self.assertEquals(mytask.backend.get_status(uuid), "PENDING")
+        self.assertEquals(mytask.backend.get_status(uuid), states.PENDING)
         w._store_errors = True
         w.handle_failure(value_, type_, tb_, "")
-        self.assertEquals(mytask.backend.get_status(uuid), "FAILURE")
+        self.assertEquals(mytask.backend.get_status(uuid), states.FAILURE)
 
     def test_executed_bit(self):
         from celery.worker.job import AlreadyExecutedError
@@ -323,7 +324,7 @@ class TestTaskWrapper(unittest.TestCase):
         self.assertEquals(tw.execute(), 256)
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEquals(meta.result, 256)
-        self.assertEquals(meta.status, "SUCCESS")
+        self.assertEquals(meta.status, states.SUCCESS)
 
     def test_execute_success_no_kwargs(self):
         tid = gen_unique_id()
@@ -331,7 +332,7 @@ class TestTaskWrapper(unittest.TestCase):
         self.assertEquals(tw.execute(), 256)
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEquals(meta.result, 256)
-        self.assertEquals(meta.status, "SUCCESS")
+        self.assertEquals(meta.status, states.SUCCESS)
 
     def test_execute_success_some_kwargs(self):
         tid = gen_unique_id()
@@ -340,7 +341,7 @@ class TestTaskWrapper(unittest.TestCase):
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertEquals(some_kwargs_scratchpad.get("logfile"), "foobaz.log")
         self.assertEquals(meta.result, 256)
-        self.assertEquals(meta.status, "SUCCESS")
+        self.assertEquals(meta.status, states.SUCCESS)
 
     def test_execute_ack(self):
         tid = gen_unique_id()
@@ -350,14 +351,14 @@ class TestTaskWrapper(unittest.TestCase):
         meta = TaskMeta.objects.get(task_id=tid)
         self.assertTrue(scratch["ACK"])
         self.assertEquals(meta.result, 256)
-        self.assertEquals(meta.status, "SUCCESS")
+        self.assertEquals(meta.status, states.SUCCESS)
 
     def test_execute_fail(self):
         tid = gen_unique_id()
         tw = TaskWrapper(mytask_raising.name, tid, [4], {"f": "x"})
         self.assertTrue(isinstance(tw.execute(), ExceptionInfo))
         meta = TaskMeta.objects.get(task_id=tid)
-        self.assertEquals(meta.status, "FAILURE")
+        self.assertEquals(meta.status, states.FAILURE)
         self.assertTrue(isinstance(meta.result, KeyError))
 
     def test_execute_using_pool(self):