Browse Source

Py3k fixes

Ask Solem 13 years ago
parent
commit
d4bb75ef4b

+ 1 - 0
celery/backends/base.py

@@ -356,6 +356,7 @@ class KeyValueStoreBackend(BaseDictBackend):
 
     def _strip_prefix(self, key):
         """Takes bytes, emits string."""
+        key = ensure_bytes(key)
         for prefix in self.task_keyprefix, self.taskset_keyprefix:
             if key.startswith(prefix):
                 return bytes_to_str(key[len(prefix):])

+ 1 - 1
celery/concurrency/threads.py

@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from __future__ import absolute_import
 
-from UserDict import UserDict
+from celery.utils.compat import UserDict
 
 from .base import apply_target, BasePool
 

+ 1 - 1
celery/exceptions.py

@@ -36,7 +36,7 @@ class QueueNotFound(KeyError):
     """Task routed to a queue not in CELERY_QUEUES."""
 
 
-class ImproperlyConfigured(Exception):
+class ImproperlyConfigured(ImportError):
     """Celery is somehow improperly configured."""
 
 

+ 19 - 5
celery/tests/app/test_log.py

@@ -6,6 +6,7 @@ import logging
 from tempfile import mktemp
 
 from mock import patch, Mock
+from nose import SkipTest
 
 from celery import current_app
 from celery import signals
@@ -31,6 +32,7 @@ class test_TaskFormatter(Case):
             msg = "hello world"
             levelname = "info"
             exc_text = exc_info = None
+            stack_info = None
 
             def getMessage(self):
                 return self.msg
@@ -59,7 +61,8 @@ class test_ColorFormatter(Case):
         x = ColorFormatter(value)
         fe.return_value = value
         self.assertTrue(x.formatException(value))
-        self.assertTrue(safe_str.called)
+        if sys.version_info[0] == 2:
+            self.assertTrue(safe_str.called)
 
     @patch("celery.utils.log.safe_str")
     def test_format_raises(self, safe_str):
@@ -72,10 +75,19 @@ class test_ColorFormatter(Case):
                 safe_str.side_effect = None
         safe_str.side_effect = on_safe_str
 
-        record = Mock()
-        record.levelname = "ERROR"
-        record.msg = "HELLO"
-        record.exc_text = "error text"
+        class Record(object):
+            levelname = "ERROR"
+            msg = "HELLO"
+            exc_text = "error text"
+            stack_info = None
+
+            def __str__(self):
+                return on_safe_str("")
+
+            def getMessage(self):
+                return self.msg
+
+        record = Record()
         safe_str.return_value = record
 
         x.format(record)
@@ -84,6 +96,8 @@ class test_ColorFormatter(Case):
 
     @patch("celery.utils.log.safe_str")
     def test_format_raises_no_color(self, safe_str):
+        if sys.version_info[0] == 3:
+            raise SkipTest("py3k")
         x = ColorFormatter("HELLO", False)
         record = Mock()
         record.levelname = "ERROR"

+ 31 - 0
celery/tests/backends/test_cassandra.py

@@ -16,6 +16,31 @@ class Object(object):
     pass
 
 
+def install_exceptions(mod):
+    # py3k: cannot catch exceptions not ineheriting from BaseException.
+
+    class NotFoundException(Exception):
+        pass
+
+    class TException(Exception):
+        pass
+
+    class InvalidRequestException(Exception):
+        pass
+
+    class UnavailableException(Exception):
+        pass
+
+    class TimedOutException(Exception):
+        pass
+
+    mod.NotFoundException = NotFoundException
+    mod.TException = TException
+    mod.InvalidRequestException = InvalidRequestException
+    mod.TimedOutException = TimedOutException
+    mod.UnavailableException = UnavailableException
+
+
 class test_CassandraBackend(AppCase):
 
     def test_init_no_pycassa(self):
@@ -39,6 +64,7 @@ class test_CassandraBackend(AppCase):
         with mock_module("pycassa"):
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
+            install_exceptions(mod.pycassa)
             cons = mod.pycassa.ConsistencyLevel = Object()
             cons.LOCAL_QUORUM = "foo"
 
@@ -65,7 +91,9 @@ class test_CassandraBackend(AppCase):
         with mock_module("pycassa"):
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
+            install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
+            install_exceptions(mod.Thrift)
             app = self.get_app()
             x = mod.CassandraBackend(app=app)
             Get_Column = x._get_column_family = Mock()
@@ -120,7 +148,9 @@ class test_CassandraBackend(AppCase):
         with mock_module("pycassa"):
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
+            install_exceptions(mod.pycassa)
             mod.Thrift = Mock()
+            install_exceptions(mod.Thrift)
             app = self.get_app()
             x = mod.CassandraBackend(app=app)
             Get_Column = x._get_column_family = Mock()
@@ -150,6 +180,7 @@ class test_CassandraBackend(AppCase):
         with mock_module("pycassa"):
             from celery.backends import cassandra as mod
             mod.pycassa = Mock()
+            install_exceptions(mod.pycassa)
             app = self.get_app()
             x = mod.CassandraBackend(app=app)
             self.assertTrue(x._get_column_family())

+ 4 - 4
celery/tests/contrib/test_migrate.py

@@ -10,7 +10,7 @@ from celery.contrib.migrate import (
     migrate_task,
     migrate_tasks,
 )
-from celery.utils.encoding import bytes_t
+from celery.utils.encoding import bytes_t, ensure_bytes
 from celery.tests.utils import AppCase, Case, Mock
 
 
@@ -71,9 +71,9 @@ class test_migrate_tasks(AppCase):
         migrate_tasks(x, y)
 
         yq = q(y.default_channel)
-        self.assertEqual(yq.get().body, "foo")
-        self.assertEqual(yq.get().body, "bar")
-        self.assertEqual(yq.get().body, "baz")
+        self.assertEqual(yq.get().body, ensure_bytes("foo"))
+        self.assertEqual(yq.get().body, ensure_bytes("bar"))
+        self.assertEqual(yq.get().body, ensure_bytes("baz"))
 
         Producer(x).publish("foo", exchange=name, routing_key=name)
         callback = Mock()

+ 1 - 1
celery/utils/functional.py

@@ -93,7 +93,7 @@ class LRUCache(UserDict):
 
 
 def is_list(l):
-    return hasattr(l, "__iter__") and not isinstance(l, dict)
+    return hasattr(l, "__iter__") and not isinstance(l, (dict, basestring))
 
 
 def maybe_list(l):

+ 1 - 0
requirements/default-py3k.txt

@@ -1,3 +1,4 @@
+billiard>=2.7.3.7
 python-dateutil>=2.0
 pytz
 kombu>=2.1.8