ソースを参照

Tests passing on Py 3.2

Ask Solem 13 年 前
コミット
b05a249a66

+ 14 - 13
celery/backends/base.py

@@ -13,7 +13,7 @@ from .. import states
 from ..datastructures import LRUCache
 from ..exceptions import TimeoutError, TaskRevokedError
 from ..utils import timeutils
-from ..utils.encoding import ensure_bytes, from_utf8
+from ..utils.encoding import bytes_to_str, ensure_bytes, from_utf8
 from ..utils.serialization import (get_pickled_exception,
                                    get_pickleable_exception,
                                    create_exception_cls)
@@ -299,9 +299,9 @@ class BaseDictBackend(BaseBackend):
 
 
 class KeyValueStoreBackend(BaseDictBackend):
-    task_keyprefix = "celery-task-meta-"
-    taskset_keyprefix = "celery-taskset-meta-"
-    chord_keyprefix = "chord-unlock-"
+    task_keyprefix = ensure_bytes("celery-task-meta-")
+    taskset_keyprefix = ensure_bytes("celery-taskset-meta-")
+    chord_keyprefix = ensure_bytes("chord-unlock-")
 
     def get(self, key):
         raise NotImplementedError("Must implement the get method.")
@@ -317,21 +317,22 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def get_key_for_task(self, task_id):
         """Get the cache key for a task by id."""
-        return ensure_bytes(self.task_keyprefix) + ensure_bytes(task_id)
+        return self.task_keyprefix + ensure_bytes(task_id)
 
     def get_key_for_taskset(self, taskset_id):
         """Get the cache key for a taskset by id."""
-        return ensure_bytes(self.taskset_keyprefix) + ensure_bytes(taskset_id)
+        return self.taskset_keyprefix + ensure_bytes(taskset_id)
 
     def get_key_for_chord(self, taskset_id):
         """Get the cache key for the chord waiting on taskset with given id."""
-        return ensure_bytes(self.chord_keyprefix) + ensure_bytes(taskset_id)
+        return self.chord_keyprefix + ensure_bytes(taskset_id)
 
     def _strip_prefix(self, key):
+        """Takes bytes, emits string."""
         for prefix in self.task_keyprefix, self.taskset_keyprefix:
             if key.startswith(prefix):
-                return key[len(prefix):]
-        return key
+                return bytes_to_str(key[len(prefix):])
+        return bytes_to_str(key)
 
     def _mget_to_results(self, values, keys):
         if hasattr(values, "items"):
@@ -341,7 +342,7 @@ class KeyValueStoreBackend(BaseDictBackend):
                                 if v is not None)
         else:
             # client returns list so need to recreate mapping.
-            return dict((keys[i], self.decode(value))
+            return dict((bytes_to_str(keys[i]), self.decode(value))
                             for i, value in enumerate(values)
                                 if value is not None)
 
@@ -355,7 +356,7 @@ class KeyValueStoreBackend(BaseDictBackend):
                 pass
             else:
                 if cached["status"] in states.READY_STATES:
-                    yield task_id, cached
+                    yield bytes_to_str(task_id), cached
                     cached_ids.add(task_id)
 
         ids ^= cached_ids
@@ -365,9 +366,9 @@ class KeyValueStoreBackend(BaseDictBackend):
             r = self._mget_to_results(self.mget([self.get_key_for_task(k)
                                                     for k in keys]), keys)
             self._cache.update(r)
-            ids ^= set(r)
+            ids ^= set(map(bytes_to_str, r))
             for key, value in r.iteritems():
-                yield key, value
+                yield bytes_to_str(key), value
             if timeout and iterations * interval >= timeout:
                 raise TimeoutError("Operation timed out (%s)" % (timeout, ))
             time.sleep(interval)  # don't busy loop.

+ 2 - 1
celery/tests/test_app/test_beat.py

@@ -192,7 +192,8 @@ class test_Scheduler(Case):
         self.assertTrue(scheduler.logger.logged[0])
         level, msg, args, kwargs = scheduler.logger.logged[0]
         self.assertEqual(level, logging.ERROR)
-        self.assertIn("Couldn't apply scheduled task", args[0].args[0])
+        self.assertIn("Couldn't apply scheduled task",
+                      repr(args[0].args[0]))
 
     def test_due_tick_RuntimeError(self):
         scheduler = mSchedulerRuntimeError()

+ 2 - 2
celery/tests/test_backends/test_database.py

@@ -128,7 +128,7 @@ class test_DatabaseBackend(Case):
         except KeyError, exception:
             import traceback
             trace = "\n".join(traceback.format_stack())
-        tb.mark_as_retry(tid, exception, traceback=trace)
+            tb.mark_as_retry(tid, exception, traceback=trace)
         self.assertEqual(tb.get_status(tid), states.RETRY)
         self.assertIsInstance(tb.get_result(tid), KeyError)
         self.assertEqual(tb.get_traceback(tid), trace)
@@ -142,7 +142,7 @@ class test_DatabaseBackend(Case):
         except KeyError, exception:
             import traceback
             trace = "\n".join(traceback.format_stack())
-        tb.mark_as_failure(tid3, exception, traceback=trace)
+            tb.mark_as_failure(tid3, exception, traceback=trace)
         self.assertEqual(tb.get_status(tid3), states.FAILURE)
         self.assertIsInstance(tb.get_result(tid3), KeyError)
         self.assertEqual(tb.get_traceback(tid3), trace)

+ 36 - 27
celery/tests/test_backends/test_mongodb.py

@@ -50,38 +50,47 @@ class TestBackendMongoDb(Case):
         binary.Binary = self._reset["Binary"]
         datetime.datetime = self._reset["datetime"]
 
-    @patch("pymongo.connection.Connection")
-    def test_get_connection_connection_exists(self, mock_Connection):
-        self.backend._connection = sentinel._connection
+    def test_get_connection_connection_exists(self):
 
-        connection = self.backend._get_connection()
+        @patch("pymongo.connection.Connection")
+        def do_test(mock_Connection):
+            self.backend._connection = sentinel._connection
 
-        self.assertEquals(sentinel._connection, connection)
-        self.assertFalse(mock_Connection.called)
+            connection = self.backend._get_connection()
 
-    @patch("pymongo.connection.Connection")
-    def test_get_connection_no_connection_host(self, mock_Connection):
-        self.backend._connection = None
-        self.backend.mongodb_host = MONGODB_HOST
-        self.backend.mongodb_port = MONGODB_PORT
-        mock_Connection.return_value = sentinel.connection
-
-        connection = self.backend._get_connection()
-        mock_Connection.assert_called_once_with(
-            MONGODB_HOST, MONGODB_PORT)
-        self.assertEquals(sentinel.connection, connection)
-
-    @patch("pymongo.connection.Connection")
-    def test_get_connection_no_connection_mongodb_uri(self, mock_Connection):
-        mongodb_uri = "mongodb://%s:%d" % (MONGODB_HOST, MONGODB_PORT)
-        self.backend._connection = None
-        self.backend.mongodb_host = mongodb_uri
+            self.assertEquals(sentinel._connection, connection)
+            self.assertFalse(mock_Connection.called)
+        do_test()
+
+    def test_get_connection_no_connection_host(self):
+
+        @patch("pymongo.connection.Connection")
+        def do_test(mock_Connection):
+            self.backend._connection = None
+            self.backend.mongodb_host = MONGODB_HOST
+            self.backend.mongodb_port = MONGODB_PORT
+            mock_Connection.return_value = sentinel.connection
+
+            connection = self.backend._get_connection()
+            mock_Connection.assert_called_once_with(
+                MONGODB_HOST, MONGODB_PORT)
+            self.assertEquals(sentinel.connection, connection)
+        do_test()
+
+    def test_get_connection_no_connection_mongodb_uri(self):
+
+        @patch("pymongo.connection.Connection")
+        def do_test(mock_Connection):
+            mongodb_uri = "mongodb://%s:%d" % (MONGODB_HOST, MONGODB_PORT)
+            self.backend._connection = None
+            self.backend.mongodb_host = mongodb_uri
 
-        mock_Connection.return_value = sentinel.connection
+            mock_Connection.return_value = sentinel.connection
 
-        connection = self.backend._get_connection()
-        mock_Connection.assert_called_once_with(mongodb_uri)
-        self.assertEquals(sentinel.connection, connection)
+            connection = self.backend._get_connection()
+            mock_Connection.assert_called_once_with(mongodb_uri)
+            self.assertEquals(sentinel.connection, connection)
+        do_test()
 
     @patch("celery.backends.mongodb.MongoBackend._get_connection")
     def test_get_database_no_existing(self, mock_get_connection):

+ 52 - 0
celery/tests/utils.py

@@ -3,8 +3,10 @@ from __future__ import absolute_import
 try:
     import unittest
     unittest.skip
+    from unittest.util import safe_repr, unorderable_list_difference
 except AttributeError:
     import unittest2 as unittest
+    from unittest2.util import safe_repr, unorderable_list_difference  # noqa
 
 import importlib
 import logging
@@ -131,6 +133,56 @@ class Case(unittest.TestCase):
         return _AssertWarnsContext(expected_warning, self,
                                    None, expected_regex)
 
+    def assertDictContainsSubset(self, expected, actual, msg=None):
+        missing, mismatched = [], []
+
+        for key, value in expected.iteritems():
+            if key not in actual:
+                missing.append(key)
+            elif value != actual[key]:
+                mismatched.append("%s, expected: %s, actual: %s" % (
+                    safe_repr(key), safe_repr(value),
+                    safe_repr(actual[key])))
+
+        if not (missing or mismatched):
+            return
+
+        standard_msg = ""
+        if missing:
+            standard_msg = "Missing: %s" % ','.join(map(safe_repr, missing))
+
+        if mismatched:
+            if standard_msg:
+                standard_msg += "; "
+            standard_msg += "Mismatched values: %s" % (
+                ','.join(mismatched))
+
+        self.fail(self._formatMessage(msg, standard_msg))
+
+    def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
+        try:
+            expected = sorted(expected_seq)
+            actual = sorted(actual_seq)
+        except TypeError:
+            # Unsortable items (example: set(), complex(), ...)
+            expected = list(expected_seq)
+            actual = list(actual_seq)
+            missing, unexpected = unorderable_list_difference(
+                expected, actual)
+        else:
+            return self.assertSequenceEqual(expected, actual, msg=msg)
+
+        errors = []
+        if missing:
+            errors.append('Expected, but missing:\n    %s' % (
+                           safe_repr(missing)))
+        if unexpected:
+            errors.append('Unexpected, but present:\n    %s' % (
+                           safe_repr(unexpected)))
+        if errors:
+            standardMsg = '\n'.join(errors)
+            self.fail(self._formatMessage(msg, standardMsg))
+
 
 class AppCase(Case):
 

+ 5 - 0
celery/worker/control.py

@@ -24,6 +24,11 @@ from ..utils.encoding import safe_repr
 from . import state
 from .state import revoked
 
+try:
+    reload
+except NameError:
+    from imp import reload
+
 TASK_INFO_FIELDS = ("exchange", "routing_key", "rate_limit")