浏览代码

100% coverage for celery.contrib.migrate

Ask Solem 12 年之前
父节点
当前提交
642d0c4193

+ 1 - 1
celery/contrib/migrate.py

@@ -85,7 +85,7 @@ def migrate_task(producer, body_, message, queues=None):
 def filter_callback(callback, tasks):
 
     def filtered(body, message):
-        if tasks and message.payload['task'] not in tasks:
+        if tasks and body['task'] not in tasks:
             return
 
         return callback(body, message)

+ 0 - 1
celery/events/dumper.py

@@ -88,7 +88,6 @@ def evdump(app=None, out=sys.stdout):
             conn.as_uri(), exc, humanize_seconds(interval, 'in', ' ')
         ))
 
-
     while 1:
         try:
             conn.ensure_connection(_error_handler)

+ 1 - 1
celery/tests/bin/test_celeryevdump.py

@@ -49,7 +49,7 @@ class test_Dumper(Case):
     def test_evdump_error_handler(self):
         app = Mock(name='app')
         with patch('celery.events.dumper.Dumper') as Dumper:
-            dumper = Dumper.return_value = Mock(name='dumper')
+            Dumper.return_value = Mock(name='dumper')
             recv = app.events.Receiver.return_value = Mock()
 
             def se(*_a, **_k):

+ 210 - 7
celery/tests/contrib/test_migrate.py

@@ -1,16 +1,35 @@
 from __future__ import absolute_import, unicode_literals
 
+from contextlib import contextmanager
+from mock import patch
+
 from kombu import Connection, Producer, Queue, Exchange
 from kombu.exceptions import StdChannelError
-from mock import patch
+
+from kombu.transport.virtual import QoS
 
 from celery.contrib.migrate import (
+    StopFiltering,
     State,
     migrate_task,
     migrate_tasks,
+    filter_callback,
+    _maybe_queue,
+    filter_status,
+    move_by_taskmap,
+    move_by_idmap,
+    move_task_by_id,
+    start_filter,
+    task_id_in,
+    task_id_eq,
+    expand_dest,
+    move,
 )
 from celery.utils.encoding import bytes_t, ensure_bytes
-from celery.tests.case import AppCase, Case, Mock
+from celery.tests.case import AppCase, Case, Mock, override_stdouts
+
+# hack to ignore error at shutdown
+QoS.restore_at_shutdown = False
 
 
 def Message(body, exchange='exchange', routing_key='rkey',
@@ -41,6 +60,188 @@ class test_State(Case):
         x.total_apx = 100
         self.assertEqual(x.strtotal, '100')
 
+    def test_repr(self):
+        x = State()
+        self.assertTrue(repr(x))
+        x.filtered = 'foo'
+        self.assertTrue(repr(x))
+
+
+class test_move(AppCase):
+
+    @contextmanager
+    def move_context(self, **kwargs):
+        with patch('celery.contrib.migrate.start_filter') as start:
+            with patch('celery.contrib.migrate.republish') as republish:
+                pred = Mock(name='predicate')
+                move(pred, app=self.app,
+                     connection=self.app.connection(), **kwargs)
+                self.assertTrue(start.called)
+                callback = start.call_args[0][2]
+                yield callback, pred, republish
+
+    def msgpair(self, **kwargs):
+        body = dict({'task': 'add', 'id': 'id'}, **kwargs)
+        return body, Message(body)
+
+    def test_move(self):
+        with self.move_context() as (callback, pred, republish):
+            pred.return_value = None
+            body, message = self.msgpair()
+            callback(body, message)
+            self.assertFalse(message.ack.called)
+            self.assertFalse(republish.called)
+
+            pred.return_value = 'foo'
+            callback(body, message)
+            message.ack.assert_called_with()
+            self.assertTrue(republish.called)
+
+    def test_move_transform(self):
+        trans = Mock(name='transform')
+        trans.return_value = Queue('bar')
+        with self.move_context(transform=trans) as (callback, pred, republish):
+            pred.return_value = 'foo'
+            body, message = self.msgpair()
+            with patch('celery.contrib.migrate.maybe_declare') as maybed:
+                callback(body, message)
+                trans.assert_called_with('foo')
+                self.assertTrue(maybed.called)
+                self.assertTrue(republish.called)
+
+    def test_limit(self):
+        with self.move_context(limit=1) as (callback, pred, republish):
+            pred.return_value = 'foo'
+            body, message = self.msgpair()
+            with self.assertRaises(StopFiltering):
+                callback(body, message)
+            self.assertTrue(republish.called)
+
+    def test_callback(self):
+        cb = Mock()
+        with self.move_context(callback=cb) as (callback, pred, republish):
+            pred.return_value = 'foo'
+            body, message = self.msgpair()
+            callback(body, message)
+            self.assertTrue(republish.called)
+            self.assertTrue(cb.called)
+
+
+class test_start_filter(AppCase):
+
+    def test_start(self):
+        with patch('celery.contrib.migrate.eventloop') as evloop:
+            app = Mock()
+            filter = Mock(name='filter')
+            conn = Connection('memory://')
+            evloop.side_effect = StopFiltering()
+            app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
+            consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
+            consumer.queues = list(app.amqp.queues.values())
+            consumer.channel = conn.default_channel
+            consumer.__enter__ = Mock(name='consumer.__enter__')
+            consumer.__exit__ = Mock(name='consumer.__exit__')
+            consumer.callbacks = []
+
+            def register_callback(x):
+                consumer.callbacks.append(x)
+            consumer.register_callback = register_callback
+
+            start_filter(app, conn, filter,
+                         queues='foo,bar', ack_messages=True)
+            body = {'task': 'add', 'id': 'id'}
+            for callback in consumer.callbacks:
+                callback(body, Message(body))
+            consumer.callbacks[:] = []
+            cb = Mock(name='callback=')
+            start_filter(app, conn, filter, tasks='add,mul', callback=cb)
+            for callback in consumer.callbacks:
+                callback(body, Message(body))
+            self.assertTrue(cb.called)
+
+            on_declare_queue = Mock()
+            start_filter(app, conn, filter, tasks='add,mul', queues='foo',
+                         on_declare_queue=on_declare_queue)
+            self.assertTrue(on_declare_queue.called)
+            start_filter(app, conn, filter, queues=['foo', 'bar'])
+            consumer.callbacks[:] = []
+            state = State()
+            start_filter(app, conn, filter,
+                         tasks='add,mul', callback=cb, state=state, limit=1)
+            stop_filtering_raised = False
+            for callback in consumer.callbacks:
+                try:
+                    callback(body, Message(body))
+                except StopFiltering:
+                    stop_filtering_raised = True
+            self.assertTrue(state.count)
+            self.assertTrue(stop_filtering_raised)
+
+
+class test_filter_callback(Case):
+
+    def test_filter(self):
+        callback = Mock()
+        filter = filter_callback(callback, ['add', 'mul'])
+        t1 = {'task': 'add'}
+        t2 = {'task': 'div'}
+
+        message = Mock()
+        filter(t2, message)
+        self.assertFalse(callback.called)
+        filter(t1, message)
+        callback.assert_called_with(t1, message)
+
+
+class test_utils(Case):
+
+    def test_task_id_in(self):
+        self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
+        self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
+
+    def test_task_id_eq(self):
+        self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
+        self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
+
+    def test_expand_dest(self):
+        self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
+        self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
+
+    def test_maybe_queue(self):
+        app = Mock()
+        app.amqp.queues = {'foo': 313}
+        self.assertEqual(_maybe_queue(app, 'foo'), 313)
+        self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
+
+    def test_filter_status(self):
+        with override_stdouts() as (stdout, stderr):
+            filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
+            self.assertTrue(stdout.getvalue())
+
+    def test_move_by_taskmap(self):
+        with patch('celery.contrib.migrate.move') as move:
+            move_by_taskmap({'add': Queue('foo')})
+            self.assertTrue(move.called)
+            cb = move.call_args[0][0]
+            self.assertTrue(cb({'task': 'add'}, Mock()))
+
+    def test_move_by_idmap(self):
+        with patch('celery.contrib.migrate.move') as move:
+            move_by_idmap({'123f': Queue('foo')})
+            self.assertTrue(move.called)
+            cb = move.call_args[0][0]
+            self.assertTrue(cb({'id': '123f'}, Mock()))
+
+    def test_move_task_by_id(self):
+        with patch('celery.contrib.migrate.move') as move:
+            move_task_by_id('123f', Queue('foo'))
+            self.assertTrue(move.called)
+            cb = move.call_args[0][0]
+            self.assertEqual(
+                cb({'id': '123f'}, Mock()),
+                Queue('foo'),
+            )
+
 
 class test_migrate_task(Case):
 
@@ -77,7 +278,7 @@ class test_migrate_tasks(AppCase):
         self.assertTrue(x.default_channel.queues)
         self.assertFalse(y.default_channel.queues)
 
-        migrate_tasks(x, y, accept=['text/plain'])
+        migrate_tasks(x, y, accept=['text/plain'], app=self.app)
 
         yq = q(y.default_channel)
         self.assertEqual(yq.get().body, ensure_bytes('foo'))
@@ -86,12 +287,13 @@ class test_migrate_tasks(AppCase):
 
         Producer(x).publish('foo', exchange=name, routing_key=name)
         callback = Mock()
-        migrate_tasks(x, y, callback=callback, accept=['text/plain'])
+        migrate_tasks(x, y,
+                      callback=callback, accept=['text/plain'], app=self.app)
         self.assertTrue(callback.called)
         migrate = Mock()
         Producer(x).publish('baz', exchange=name, routing_key=name)
         migrate_tasks(x, y, callback=callback,
-                      migrate=migrate, accept=['text/plain'])
+                      migrate=migrate, accept=['text/plain'], app=self.app)
         self.assertTrue(migrate.called)
 
         with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
@@ -101,11 +303,12 @@ class test_migrate_tasks(AppCase):
                     raise StdChannelError()
                 return 0, 3, 0
             qd.side_effect = effect
-            migrate_tasks(x, y)
+            migrate_tasks(x, y, app=self.app)
 
         x = Connection('memory://')
         x.default_channel.queues = {}
         y.default_channel.queues = {}
         callback = Mock()
-        migrate_tasks(x, y, callback=callback, accept=['text/plain'])
+        migrate_tasks(x, y,
+                      callback=callback, accept=['text/plain'], app=self.app)
         self.assertFalse(callback.called)

+ 1 - 2
celery/tests/events/test_events.py

@@ -216,7 +216,7 @@ class test_EventReceiver(AppCase):
         r.adjust_clock = Mock()
         ts_adjust = Mock()
 
-        e = r.event_from_message(
+        r.event_from_message(
             {'type': 'worker-online', 'clock': 313},
             localize=False,
             adjust_timestamp=ts_adjust,
@@ -224,7 +224,6 @@ class test_EventReceiver(AppCase):
         self.assertFalse(ts_adjust.called)
         r.adjust_clock.assert_called_with(313)
 
-
     def test_itercapture_limit(self):
         connection = self.app.connection()
         channel = connection.channel()

+ 2 - 1
celery/utils/datastructures.py

@@ -487,7 +487,7 @@ class ConfigurationView(AttributeDictMixin):
 
     def __bool__(self):
         return any(self._order)
-    __nonzero__  = __bool__  # Py2
+    __nonzero__ = __bool__  # Py2
 
     def __repr__(self):
         return repr(dict(items(self)))
@@ -645,3 +645,4 @@ class LimitedSet(object):
         return self.__class__, (
             self.maxlen, self.expires, self._data, self._heap,
         )
+MutableSet.register(LimitedSet)