Ask Solem пре 14 година
родитељ
комит
1187b1f328

+ 6 - 10
celery/tests/test_task.py

@@ -147,10 +147,6 @@ class TestTaskRetries(unittest.TestCase):
         self.assertEqual(result.get(), 42)
         self.assertEqual(RetryTaskNoArgs.iterations, 4)
 
-    def test_retry_kwargs_can_not_be_empty(self):
-        self.assertRaises(TypeError, RetryTaskMockApply.retry,
-                            args=[4, 4], kwargs={})
-
     def test_retry_not_eager(self):
         exc = Exception("baz")
         try:
@@ -322,20 +318,20 @@ class TestCeleryTasks(unittest.TestCase):
                 name="Elaine M. Benes")
 
         # With eta.
-        presult2 = task.apply_async(t1, kwargs=dict(name="George Costanza"),
-                                    eta=datetime.now() + timedelta(days=1))
+        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
+                                  eta=datetime.now() + timedelta(days=1))
         self.assertNextTaskDataEqual(consumer, presult2, t1.name,
                 name="George Costanza", test_eta=True)
 
         # With countdown.
-        presult2 = task.apply_async(t1, kwargs=dict(name="George Costanza"),
-                                    countdown=10)
+        presult2 = t1.apply_async(kwargs=dict(name="George Costanza"),
+                                  countdown=10)
         self.assertNextTaskDataEqual(consumer, presult2, t1.name,
                 name="George Costanza", test_eta=True)
 
         # Discarding all tasks.
         consumer.discard_all()
-        task.apply_async(t1)
+        t1.apply_async()
         self.assertEqual(consumer.discard_all(), 1)
         self.assertIsNone(consumer.fetch())
 
@@ -418,7 +414,7 @@ class TestTaskSet(unittest.TestCase):
         taskset_res = ts.apply_async()
         subtasks = taskset_res.subtasks
         taskset_id = taskset_res.taskset_id
-        consumer = IncrementCountertask().get_consumer()
+        consumer = IncrementCounterTask().get_consumer()
         for subtask in subtasks:
             m = consumer.fetch().payload
             self.assertDictContainsSubset({"taskset": taskset_id,

+ 26 - 36
celery/tests/test_task_control.py

@@ -1,50 +1,40 @@
 import unittest2 as unittest
 
+from celery.pidbox import Mailbox
 from celery.task import control
 from celery.task.builtins import PingTask
 from celery.utils import gen_unique_id
 from celery.utils.functional import wraps
 
 
-class MockBroadcastPublisher(object):
+class MockMailbox(Mailbox):
     sent = []
 
-    def __init__(self, *args, **kwargs):
-        pass
-
-    def send(self, command, *args, **kwargs):
+    def publish(self, command, *args, **kwargs):
         self.__class__.sent.append(command)
 
     def close(self):
         pass
 
-
-class MockControlReplyConsumer(object):
-
-    def __init__(self, *args, **kwarg):
+    def collect_reply(self, *args, **kwargs):
         pass
 
-    def collect(self, *args, **kwargs):
-        pass
 
-    def close(self):
-        pass
+def mock_mailbox(connection):
+    return MockMailbox("celeryd", connection)
 
 
 def with_mock_broadcast(fun):
 
     @wraps(fun)
     def _mocked(*args, **kwargs):
-        old_pub = control.BroadcastPublisher
-        old_rep = control.ControlReplyConsumer
-        control.BroadcastPublisher = MockBroadcastPublisher
-        control.ControlReplyConsumer = MockControlReplyConsumer
+        old_box = control.mailbox
+        control.mailbox = mock_mailbox
         try:
             return fun(*args, **kwargs)
         finally:
-            MockBroadcastPublisher.sent = []
-            control.BroadcastPublisher = old_pub
-            control.ControlReplyConsumer = old_rep
+            MockMailbox.sent = []
+            control.mailbox = old_box
     return _mocked
 
 
@@ -65,47 +55,47 @@ class test_inspect(unittest.TestCase):
     @with_mock_broadcast
     def test_active(self):
         self.i.active()
-        self.assertIn("dump_active", MockBroadcastPublisher.sent)
+        self.assertIn("dump_active", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_scheduled(self):
         self.i.scheduled()
-        self.assertIn("dump_schedule", MockBroadcastPublisher.sent)
+        self.assertIn("dump_schedule", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_reserved(self):
         self.i.reserved()
-        self.assertIn("dump_reserved", MockBroadcastPublisher.sent)
+        self.assertIn("dump_reserved", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_stats(self):
         self.i.stats()
-        self.assertIn("stats", MockBroadcastPublisher.sent)
+        self.assertIn("stats", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoked(self):
         self.i.revoked()
-        self.assertIn("dump_revoked", MockBroadcastPublisher.sent)
+        self.assertIn("dump_revoked", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_registered_tasks(self):
         self.i.registered_tasks()
-        self.assertIn("dump_tasks", MockBroadcastPublisher.sent)
+        self.assertIn("dump_tasks", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_enable_events(self):
         self.i.enable_events()
-        self.assertIn("enable_events", MockBroadcastPublisher.sent)
+        self.assertIn("enable_events", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_disable_events(self):
         self.i.disable_events()
-        self.assertIn("disable_events", MockBroadcastPublisher.sent)
+        self.assertIn("disable_events", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_ping(self):
         self.i.ping()
-        self.assertIn("ping", MockBroadcastPublisher.sent)
+        self.assertIn("ping", MockMailbox.sent)
 
 
 class test_Broadcast(unittest.TestCase):
@@ -116,13 +106,13 @@ class test_Broadcast(unittest.TestCase):
     @with_mock_broadcast
     def test_broadcast(self):
         control.broadcast("foobarbaz", arguments=[])
-        self.assertIn("foobarbaz", MockBroadcastPublisher.sent)
+        self.assertIn("foobarbaz", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_broadcast_limit(self):
         control.broadcast("foobarbaz1", arguments=[], limit=None,
                 destination=[1, 2, 3])
-        self.assertIn("foobarbaz1", MockBroadcastPublisher.sent)
+        self.assertIn("foobarbaz1", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_broadcast_validate(self):
@@ -132,23 +122,23 @@ class test_Broadcast(unittest.TestCase):
     @with_mock_broadcast
     def test_rate_limit(self):
         control.rate_limit(PingTask.name, "100/m")
-        self.assertIn("rate_limit", MockBroadcastPublisher.sent)
+        self.assertIn("rate_limit", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke(self):
         control.revoke("foozbaaz")
-        self.assertIn("revoke", MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_ping(self):
         control.ping()
-        self.assertIn("ping", MockBroadcastPublisher.sent)
+        self.assertIn("ping", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke_from_result(self):
         from celery.result import AsyncResult
         AsyncResult("foozbazzbar").revoke()
-        self.assertIn("revoke", MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockMailbox.sent)
 
     @with_mock_broadcast
     def test_revoke_from_resultset(self):
@@ -156,4 +146,4 @@ class test_Broadcast(unittest.TestCase):
         r = TaskSetResult(gen_unique_id(), map(AsyncResult, [gen_unique_id()
                                                         for i in range(10)]))
         r.revoke()
-        self.assertIn("revoke", MockBroadcastPublisher.sent)
+        self.assertIn("revoke", MockMailbox.sent)

+ 0 - 50
celery/tests/test_utils.py

@@ -56,14 +56,6 @@ class test_gen_unique_id(unittest.TestCase):
 
 class test_utils(unittest.TestCase):
 
-    def test_repeatlast(self):
-        items = range(6)
-        it = utils.repeatlast(items)
-        for i in items:
-            self.assertEqual(it.next(), i)
-        for j in items:
-            self.assertEqual(it.next(), i)
-
     def test_get_full_cls_name(self):
         Class = type("Fox", (object, ), {"__module__": "quick.brown"})
         self.assertEqual(utils.get_full_cls_name(Class), "quick.brown.Fox")
@@ -122,48 +114,6 @@ class test_utils(unittest.TestCase):
         self.assertIs(utils.get_cls_by_name(instance), instance)
 
 
-class test_retry_over_time(unittest.TestCase):
-
-    def test_returns_retval_on_success(self):
-
-        def _fun(x, y):
-            return x * y
-
-        ret = utils.retry_over_time(_fun, (socket.error, ), args=[16, 16],
-                                    max_retries=3)
-
-        self.assertEqual(ret, 256)
-
-    @sleepdeprived
-    def test_raises_on_unlisted_exception(self):
-
-        def _fun(x, y):
-            raise KeyError("bar")
-
-        self.assertRaises(KeyError, utils.retry_over_time, _fun,
-                         (socket.error, ), args=[32, 32], max_retries=3)
-
-    @sleepdeprived
-    def test_retries_on_failure(self):
-
-        iterations = [0]
-
-        def _fun(x, y):
-            iterations[0] += 1
-            if iterations[0] == 3:
-                return x * y
-            raise socket.error("foozbaz")
-
-        ret = utils.retry_over_time(_fun, (socket.error, ), args=[32, 32],
-                                    max_retries=None)
-
-        self.assertEqual(iterations[0], 3)
-        self.assertEqual(ret, 1024)
-
-        self.assertRaises(socket.error, utils.retry_over_time,
-                        _fun, (socket.error, ), args=[32, 32], max_retries=1)
-
-
 class test_promise(unittest.TestCase):
 
     def test__str__(self):

+ 29 - 51
celery/tests/test_worker.py

@@ -4,7 +4,7 @@ import unittest2 as unittest
 from datetime import datetime, timedelta
 from Queue import Empty
 
-from kombu.backends.base import BaseMessage
+from kombu.transport.base import Message
 from kombu.connection import BrokerConnection
 from celery.utils.timer2 import Timer
 
@@ -23,11 +23,31 @@ from celery.tests.compat import catch_warnings
 from celery.tests.utils import execute_context
 
 
+class MockConsumer(object):
+
+    class Channel(object):
+
+        def close(self):
+            pass
+
+    def register_callback(self, cb):
+        pass
+
+    def consume(self):
+        pass
+
+    @property
+    def channel(self):
+        return self.Channel()
+
+
 class PlaceHolder(object):
         pass
 
 
 class MyCarrotListener(CarrotListener):
+    broadcast_consumer = MockConsumer()
+    task_consumer = MockConsumer()
 
     def restart_heartbeat(self):
         self.heart = None
@@ -139,9 +159,9 @@ class MockController(object):
 
 def create_message(backend, **data):
     data.setdefault("id", gen_unique_id())
-    return BaseMessage(backend, body=pickle.dumps(dict(**data)),
-                       content_type="application/x-python-serialize",
-                       content_encoding="binary")
+    return Message(backend, body=pickle.dumps(dict(**data)),
+                   content_type="application/x-python-serialize",
+                   content_encoding="binary")
 
 
 class test_QoS(unittest.TestCase):
@@ -177,47 +197,6 @@ class test_CarrotListener(unittest.TestCase):
     def tearDown(self):
         self.eta_schedule.stop()
 
-    def test_mainloop(self):
-        l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
-                           send_events=False)
-
-        class MockConnection(object):
-
-            def drain_events(self):
-                return "draining"
-
-        l.connection = MockConnection()
-        l.connection.connection = MockConnection()
-
-        it = l._mainloop()
-        self.assertTrue(it.next(), "draining")
-        records = {}
-
-        def create_recorder(key):
-            def _recorder(*args, **kwargs):
-                records[key] = True
-            return _recorder
-
-        l.task_consumer = PlaceHolder()
-        l.task_consumer.iterconsume = create_recorder("consume_tasks")
-        l.broadcast_consumer = PlaceHolder()
-        l.broadcast_consumer.register_callback = create_recorder(
-                                                    "broadcast_callback")
-        l.broadcast_consumer.iterconsume = create_recorder(
-                                             "consume_broadcast")
-        l.task_consumer.add_consumer = create_recorder("consumer_add")
-
-        records.clear()
-        self.assertEqual(l._detect_wait_method(), l._mainloop)
-        for record in ("broadcast_callback", "consume_broadcast",
-                "consume_tasks"):
-            self.assertTrue(records.get(record))
-
-        records.clear()
-        l.connection.connection = PlaceHolder()
-        self.assertIs(l._detect_wait_method(), l.task_consumer.iterconsume)
-        self.assertTrue(records.get("consumer_add"))
-
     def test_connection(self):
         l = MyCarrotListener(self.ready_queue, self.eta_schedule, self.logger,
                            send_events=False)
@@ -459,9 +438,6 @@ class test_CarrotListener(unittest.TestCase):
                 if self.iterations >= 1:
                     raise KeyError("foo")
 
-            def _detect_wait_method(self):
-                return self.wait_method
-
         called_back = [False]
 
         def init_callback(listener):
@@ -469,6 +445,7 @@ class test_CarrotListener(unittest.TestCase):
 
         l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
+        l.task_consumer = MockConsumer()
         l.qos = _QoS()
         l.connection = BrokerConnection()
 
@@ -477,7 +454,7 @@ class test_CarrotListener(unittest.TestCase):
             l.iterations = 1
             raise KeyError("foo")
 
-        l.wait_method = raises_KeyError
+        l._mainloop = raises_KeyError
         self.assertRaises(KeyError, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
@@ -486,6 +463,7 @@ class test_CarrotListener(unittest.TestCase):
         l = _Listener(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
+        l.task_consumer = MockConsumer()
         l.connection = BrokerConnection()
 
         def raises_socket_error(limit=None):
@@ -493,8 +471,8 @@ class test_CarrotListener(unittest.TestCase):
             l.iterations = 1
             raise socket.error("foo")
 
-        l.wait_method = raises_socket_error
-        self.assertRaises(KeyError, l.start)
+        l._mainloop = raises_socket_error
+        self.assertRaises(socket.error, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
 

+ 3 - 12
celery/tests/test_worker_control.py

@@ -200,19 +200,10 @@ class test_ControlPanel(unittest.TestCase):
 
         replies = []
 
-        class MockReplyPublisher(object):
-
-            def __init__(self, *args, **kwargs):
-                pass
-
-            def send(self, reply, **kwargs):
-                replies.append(reply)
-
-            def close(self):
-                pass
-
         class _Dispatch(control.ControlDispatch):
-            ReplyPublisher = MockReplyPublisher
+
+            def reply(self, data, exchange, routing_key, **kwargs):
+                replies.append(data)
 
         panel = _Dispatch(hostname, listener=Listener())
 

+ 7 - 7
celery/tests/test_worker_job.py

@@ -6,7 +6,7 @@ import unittest2 as unittest
 
 from StringIO import StringIO
 
-from kombu.backends.base import BaseMessage
+from kombu.transport.base import Message
 
 from celery import states
 from celery.app import app_or_default
@@ -338,9 +338,9 @@ class test_TaskRequest(unittest.TestCase):
     def test_from_message(self):
         body = {"task": mytask.name, "id": gen_unique_id(),
                 "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
-        m = BaseMessage(None, body=simplejson.dumps(body), backend="foo",
-                        content_type="application/json",
-                        content_encoding="utf-8")
+        m = Message(None, body=simplejson.dumps(body), backend="foo",
+                          content_type="application/json",
+                          content_encoding="utf-8")
         tw = TaskRequest.from_message(m, m.decode())
         self.assertIsInstance(tw, TaskRequest)
         self.assertEqual(tw.task_name, body["task"])
@@ -354,9 +354,9 @@ class test_TaskRequest(unittest.TestCase):
     def test_from_message_nonexistant_task(self):
         body = {"task": "cu.mytask.doesnotexist", "id": gen_unique_id(),
                 "args": [2], "kwargs": {u"æØåveéðƒeæ": "bar"}}
-        m = BaseMessage(None, body=simplejson.dumps(body), backend="foo",
-                        content_type="application/json",
-                        content_encoding="utf-8")
+        m = Message(None, body=simplejson.dumps(body), backend="foo",
+                          content_type="application/json",
+                          content_encoding="utf-8")
         self.assertRaises(NotRegistered, TaskRequest.from_message,
                           m, m.decode())