Browse Source

Tests passing

Ask Solem 11 years ago
parent
commit
2cce90ec44

+ 56 - 20
celery/tests/app/test_app.py

@@ -7,6 +7,7 @@ import itertools
 from copy import deepcopy
 from pickle import loads, dumps
 
+from amqp import promise
 from kombu import Exchange
 
 from celery import shared_task, current_app
@@ -46,6 +47,14 @@ object_config = ObjectConfig()
 dict_config = dict(FOO=10, BAR=20)
 
 
+class ObjectConfig2(object):
+    LEAVE_FOR_WORK = True
+    MOMENT_TO_STOP = True
+    CALL_ME_BACK = 123456789
+    WANT_ME_TO = False
+    UNDERSTAND_ME = True
+
+
 class Object(object):
 
     def __init__(self, **kwargs):
@@ -156,20 +165,38 @@ class test_App(AppCase):
         self.app._using_v1_reduce = True
         self.assertTrue(loads(dumps(self.app)))
 
-    def test_autodiscover_tasks(self):
+    def test_autodiscover_tasks_force(self):
         self.app.conf.CELERY_FORCE_BILLIARD_LOGGING = True
         with patch('celery.app.base.ensure_process_aware_logger') as ep:
             self.app.loader.autodiscover_tasks = Mock()
-            self.app.autodiscover_tasks(['proj.A', 'proj.B'])
+            self.app.autodiscover_tasks(['proj.A', 'proj.B'], force=True)
             ep.assert_called_with()
             self.app.loader.autodiscover_tasks.assert_called_with(
                 ['proj.A', 'proj.B'], 'tasks',
             )
         with patch('celery.app.base.ensure_process_aware_logger') as ep:
+            self.app.loader.autodiscover_tasks = Mock()
             self.app.conf.CELERY_FORCE_BILLIARD_LOGGING = False
-            self.app.autodiscover_tasks(['proj.A', 'proj.B'])
+            self.app.autodiscover_tasks(
+                lambda: ['proj.A', 'proj.B'],
+                related_name='george',
+                force=True,
+            )
+            self.app.loader.autodiscover_tasks.assert_called_with(
+                ['proj.A', 'proj.B'], 'george',
+            )
             self.assertFalse(ep.called)
 
+    def test_autodiscover_tasks_lazy(self):
+        with patch('celery.signals.import_modules') as import_modules:
+            packages = lambda: [1, 2, 3]
+            self.app.autodiscover_tasks(packages)
+            self.assertTrue(import_modules.connect.called)
+            prom = import_modules.connect.call_args[0][0]
+            self.assertIsInstance(prom, promise)
+            self.assertEqual(prom.fun, self.app._autodiscover_tasks)
+            self.assertEqual(prom.args[0](), [1, 2, 3])
+
     @with_environ('CELERY_BROKER_URL', '')
     def test_with_broker(self):
         with self.Celery(broker='foo://baribaz') as app:
@@ -346,23 +373,26 @@ class test_App(AppCase):
         self.app.config_from_envvar('CELERYTEST_CONFIG_OBJECT')
         self.assertEqual(self.app.conf.THIS_IS_A_KEY, 'this is a value')
 
-    def test_config_from_object(self):
-
-        class Object(object):
-            LEAVE_FOR_WORK = True
-            MOMENT_TO_STOP = True
-            CALL_ME_BACK = 123456789
-            WANT_ME_TO = False
-            UNDERSTAND_ME = True
-
-        self.app.config_from_object(Object())
-
+    def assert_config2(self):
         self.assertTrue(self.app.conf.LEAVE_FOR_WORK)
         self.assertTrue(self.app.conf.MOMENT_TO_STOP)
         self.assertEqual(self.app.conf.CALL_ME_BACK, 123456789)
         self.assertFalse(self.app.conf.WANT_ME_TO)
         self.assertTrue(self.app.conf.UNDERSTAND_ME)
 
+    def test_config_from_object__lazy(self):
+        conf = ObjectConfig2()
+        self.app.config_from_object(conf)
+        self.assertFalse(self.app.loader._conf)
+        self.assertIs(self.app._config_source, conf)
+
+        self.assert_config2()
+
+    def test_config_from_object__force(self):
+        self.app.config_from_object(ObjectConfig2(), force=True)
+        self.assertTrue(self.app.loader._conf)
+        self.assert_config2()
+
     def test_config_from_cmdline(self):
         cmdline = ['.always_eager=no',
                    '.result_backend=/dev/null',
@@ -434,22 +464,28 @@ class test_App(AppCase):
             next(app for app in _state._get_active_apps() if id(app) == appid)
 
     def test_config_from_envvar_more(self, key='CELERY_HARNESS_CFG1'):
-        self.assertFalse(self.app.config_from_envvar('HDSAJIHWIQHEWQU',
-                                                     silent=True))
+        self.assertFalse(
+            self.app.config_from_envvar(
+                'HDSAJIHWIQHEWQU', force=True, silent=True),
+        )
         with self.assertRaises(ImproperlyConfigured):
-            self.app.config_from_envvar('HDSAJIHWIQHEWQU', silent=False)
+            self.app.config_from_envvar(
+                'HDSAJIHWIQHEWQU', force=True, silent=False,
+            )
         os.environ[key] = __name__ + '.object_config'
-        self.assertTrue(self.app.config_from_envvar(key))
+        self.assertTrue(self.app.config_from_envvar(key, force=True))
         self.assertEqual(self.app.conf['FOO'], 1)
         self.assertEqual(self.app.conf['BAR'], 2)
 
         os.environ[key] = 'unknown_asdwqe.asdwqewqe'
         with self.assertRaises(ImportError):
             self.app.config_from_envvar(key, silent=False)
-        self.assertFalse(self.app.config_from_envvar(key, silent=True))
+        self.assertFalse(
+            self.app.config_from_envvar(key, force=True, silent=True),
+        )
 
         os.environ[key] = __name__ + '.dict_config'
-        self.assertTrue(self.app.config_from_envvar(key))
+        self.assertTrue(self.app.config_from_envvar(key, force=True))
         self.assertEqual(self.app.conf['FOO'], 10)
         self.assertEqual(self.app.conf['BAR'], 20)
 

+ 23 - 0
celery/tests/case.py

@@ -138,6 +138,29 @@ class Mock(mock.Mock):
         for attr_name, attr_value in items(attrs):
             setattr(self, attr_name, attr_value)
 
+class _ContextMock(Mock):
+    """Dummy class implementing __enter__ and __exit__
+    as the with statement requires these to be implemented
+    in the class, not just the instance."""
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, *exc_info):
+        pass
+
+
+def ContextMock(*args, **kwargs):
+    obj = _ContextMock(*args, **kwargs)
+    obj.attach_mock(_ContextMock(), '__enter__')
+    obj.attach_mock(_ContextMock(), '__exit__')
+    obj.__enter__.return_value = obj
+    # if __exit__ return a value the exception is ignored,
+    # so it must return None here.
+    obj.__exit__.return_value = None
+    return obj
+
+
 
 def skip_unless_module(module):
 

+ 53 - 31
celery/tests/fixups/test_django.py

@@ -8,6 +8,7 @@ from celery.fixups.django import (
     _maybe_close_fd,
     fixup,
     DjangoFixup,
+    DjangoWorkerFixup,
 )
 
 from celery.tests.case import (
@@ -15,7 +16,19 @@ from celery.tests.case import (
 )
 
 
-class test_DjangoFixup(AppCase):
+class FixupCase(AppCase):
+    Fixup = None
+
+    @contextmanager
+    def fixup_context(self, app):
+        with patch('celery.fixups.django.import_module') as import_module:
+            with patch('celery.fixups.django.symbol_by_name') as symbyname:
+                f = self.Fixup(app)
+                yield f, import_module, symbyname
+
+
+class test_DjangoFixup(FixupCase):
+    Fixup = DjangoFixup
 
     def test_fixup(self):
         with patch('celery.fixups.django.DjangoFixup') as Fixup:
@@ -31,13 +44,6 @@ class test_DjangoFixup(AppCase):
                     fixup(self.app)
                     self.assertTrue(Fixup.called)
 
-    @contextmanager
-    def fixup_context(self, app):
-        with patch('celery.fixups.django.import_module') as import_module:
-            with patch('celery.fixups.django.symbol_by_name') as symbyname:
-                f = DjangoFixup(app)
-                yield f, import_module, symbyname
-
     def test_maybe_close_fd(self):
         with patch('os.close'):
             _maybe_close_fd(Mock())
@@ -52,33 +58,16 @@ class test_DjangoFixup(AppCase):
                     raise ImportError()
                 return Mock()
             sym.side_effect = se
-            self.assertTrue(DjangoFixup(self.app)._now)
-
-            def se2(name):
-                if name == 'django.db:close_old_connections':
-                    raise ImportError()
-                return Mock()
-            sym.side_effect = se2
-            self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
+            self.assertTrue(self.Fixup(self.app)._now)
 
     def test_install(self):
-        self.app.conf = {'CELERY_DB_REUSE_MAX': None}
         self.app.loader = Mock()
         with self.fixup_context(self.app) as (f, _, _):
             with patch_many('os.getcwd', 'sys.path',
                             'celery.fixups.django.signals') as (cw, p, sigs):
                 cw.return_value = '/opt/vandelay'
                 f.install()
-                sigs.beat_embedded_init.connect.assert_called_with(
-                    f.close_database,
-                )
-                sigs.worker_ready.connect.assert_called_with(f.on_worker_ready)
-                sigs.task_prerun.connect.assert_called_with(f.on_task_prerun)
-                sigs.task_postrun.connect.assert_called_with(f.on_task_postrun)
                 sigs.worker_init.connect.assert_called_with(f.on_worker_init)
-                sigs.worker_process_init.connect.assert_called_with(
-                    f.on_worker_process_init,
-                )
                 self.assertEqual(self.app.loader.now, f.now)
                 self.assertEqual(self.app.loader.mail_admins, f.mail_admins)
                 p.append.assert_called_with('/opt/vandelay')
@@ -99,11 +88,44 @@ class test_DjangoFixup(AppCase):
 
     def test_on_worker_init(self):
         with self.fixup_context(self.app) as (f, _, _):
-            f.close_database = Mock()
-            f.close_cache = Mock()
-            f.on_worker_init()
-            f.close_database.assert_called_with()
-            f.close_cache.assert_called_with()
+            with patch('celery.fixups.django.DjangoWorkerFixup') as DWF:
+                f.on_worker_init()
+                DWF.assert_called_with(f.app)
+                DWF.return_value.install.assert_called_with()
+                self.assertIs(
+                    f._worker_fixup, DWF.return_value.install.return_value,
+                )
+
+
+class test_DjangoWorkerFixup(FixupCase):
+    Fixup = DjangoWorkerFixup
+
+    def test_init(self):
+        with self.fixup_context(self.app) as (f, importmod, sym):
+            self.assertTrue(f)
+
+            def se(name):
+                if name == 'django.db:close_old_connections':
+                    raise ImportError()
+                return Mock()
+            sym.side_effect = se
+            self.assertIsNone(self.Fixup(self.app)._close_old_connections)
+
+    def test_install(self):
+        self.app.conf = {'CELERY_DB_REUSE_MAX': None}
+        self.app.loader = Mock()
+        with self.fixup_context(self.app) as (f, _, _):
+            with patch_many('celery.fixups.django.signals') as (sigs, ):
+                f.install()
+                sigs.beat_embedded_init.connect.assert_called_with(
+                    f.close_database,
+                )
+                sigs.worker_ready.connect.assert_called_with(f.on_worker_ready)
+                sigs.task_prerun.connect.assert_called_with(f.on_task_prerun)
+                sigs.task_postrun.connect.assert_called_with(f.on_task_postrun)
+                sigs.worker_process_init.connect.assert_called_with(
+                    f.on_worker_process_init,
+                )
 
     def test_on_worker_process_init(self):
         with self.fixup_context(self.app) as (f, _, _):

+ 18 - 4
celery/tests/worker/test_consumer.py

@@ -18,7 +18,7 @@ from celery.worker.consumer import (
     CLOSE,
 )
 
-from celery.tests.case import AppCase, Mock, SkipTest, call, patch
+from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch
 
 
 class test_Consumer(AppCase):
@@ -36,7 +36,7 @@ class test_Consumer(AppCase):
         )
         consumer.blueprint = Mock()
         consumer._restart_state = Mock()
-        consumer.connection = Mock()
+        consumer.connection = _amqp_connection()
         consumer.connection_errors = (socket.error, OSError, )
         return consumer
 
@@ -144,8 +144,8 @@ class test_Consumer(AppCase):
             c.on_close()
 
     def test_connect_error_handler(self):
-        self.app.connection = Mock()
-        conn = self.app.connection.return_value = Mock()
+        self.app.connection = _amqp_connection()
+        conn = self.app.connection.return_value
         c = self.get_consumer()
         self.assertTrue(c.connect())
         self.assertTrue(conn.ensure_connection.called)
@@ -204,6 +204,7 @@ class test_Mingle(AppCase):
 
     def test_start_no_replies(self):
         c = Mock()
+        c.app.connection = _amqp_connection()
         mingle = Mingle(c)
         I = c.app.control.inspect.return_value = Mock()
         I.hello.return_value = {}
@@ -212,6 +213,7 @@ class test_Mingle(AppCase):
     def test_start(self):
         try:
             c = Mock()
+            c.app.connection = _amqp_connection()
             mingle = Mingle(c)
             self.assertTrue(mingle.enabled)
 
@@ -248,16 +250,24 @@ class test_Mingle(AppCase):
             worker_state.revoked.clear()
 
 
+def _amqp_connection():
+    connection = ContextMock()
+    connection.return_value = ContextMock()
+    connection.return_value.transport.driver_type = 'amqp'
+    return connection
+
 class test_Gossip(AppCase):
 
     def test_init(self):
         c = self.Consumer()
+        c.app.connection = _amqp_connection()
         g = Gossip(c)
         self.assertTrue(g.enabled)
         self.assertIs(c.gossip, g)
 
     def test_election(self):
         c = self.Consumer()
+        c.app.connection = _amqp_connection()
         g = Gossip(c)
         g.start(c)
         g.election('id', 'topic', 'action')
@@ -268,6 +278,7 @@ class test_Gossip(AppCase):
 
     def test_call_task(self):
         c = self.Consumer()
+        c.app.connection = _amqp_connection()
         g = Gossip(c)
         g.start(c)
 
@@ -298,6 +309,7 @@ class test_Gossip(AppCase):
 
     def test_on_elect(self):
         c = self.Consumer()
+        c.app.connection = _amqp_connection()
         g = Gossip(c)
         g.start(c)
 
@@ -314,6 +326,7 @@ class test_Gossip(AppCase):
 
     def Consumer(self, hostname='foo@x.com', pid=4312):
         c = Mock()
+        c.app.connection = _amqp_connection()
         c.hostname = hostname
         c.pid = pid
         return c
@@ -355,6 +368,7 @@ class test_Gossip(AppCase):
 
     def test_on_elect_ack_lose(self):
         c = self.Consumer(hostname='bar@x.com')  # I will lose
+        c.app.connection = _amqp_connection()
         g = Gossip(c)
         handler = g.election_handlers['topic'] = Mock()
         self.setup_election(g, c)

+ 23 - 14
celery/tests/worker/test_worker.py

@@ -38,6 +38,13 @@ def MockStep(step=None):
     return step
 
 
+def mock_event_dispatcher():
+    evd = Mock(name='event_dispatcher')
+    evd.groups = ['worker']
+    evd._outbound_buffer = deque()
+    return evd
+
+
 class PlaceHolder(object):
         pass
 
@@ -182,7 +189,7 @@ class test_Consumer(AppCase):
         self.assertIsNone(l.connection)
 
         l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
-        eventer = l.event_dispatcher = Mock()
+        eventer = l.event_dispatcher = mock_event_dispatcher()
         eventer.enabled = True
         heart = l.heart = MockHeart()
         l.blueprint.state = RUN
@@ -200,7 +207,7 @@ class test_Consumer(AppCase):
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.node = MockNode()
 
         callback = self._get_on_message(l)
@@ -217,7 +224,7 @@ class test_Consumer(AppCase):
                            args=('2, 2'),
                            kwargs={},
                            eta=datetime.now().isoformat())
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.node = MockNode()
         l.update_strategies()
         l.qos = Mock()
@@ -230,12 +237,12 @@ class test_Consumer(AppCase):
     def test_receive_message_InvalidTaskError(self, error):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.steps.pop()
         m = create_message(Mock(), task=self.foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
         l.update_strategies()
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
 
         callback = self._get_on_message(l)
         callback(m.decode(), m)
@@ -258,7 +265,7 @@ class test_Consumer(AppCase):
     def _get_on_message(self, l):
         if l.qos is None:
             l.qos = Mock()
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.task_consumer = Mock()
         l.connection = Mock()
         l.connection.drain_events.side_effect = SystemExit()
@@ -271,7 +278,7 @@ class test_Consumer(AppCase):
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
         l.blueprint.state = RUN
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         m = create_message(Mock(), task=self.foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
@@ -419,7 +426,7 @@ class test_Consumer(AppCase):
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer.qos, 1)
         current_pcount = l.qos.value
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.enabled = False
         l.update_strategies()
         callback = self._get_on_message(l)
@@ -478,7 +485,7 @@ class test_Consumer(AppCase):
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
 
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         callback = self._get_on_message(l)
         self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
@@ -493,7 +500,7 @@ class test_Consumer(AppCase):
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
-        l.event_dispatcher = Mock()
+        l.event_dispatcher = mock_event_dispatcher()
         l.connection_errors = (socket.error, )
         m.reject = Mock()
         m.reject.side_effect = socket.error('foo')
@@ -509,8 +516,7 @@ class test_Consumer(AppCase):
     def test_receive_message_eta(self):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
         l.steps.pop()
-        l.event_dispatcher = Mock()
-        l.event_dispatcher._outbound_buffer = deque()
+        l.event_dispatcher = mock_event_dispatcher()
         backend = Mock()
         m = create_message(
             backend, task=self.foo_task.name,
@@ -525,12 +531,15 @@ class test_Consumer(AppCase):
             l.blueprint.start(l)
             l.app.conf.BROKER_CONNECTION_RETRY = p
             l.blueprint.restart(l)
-            l.event_dispatcher = Mock()
+            l.event_dispatcher = mock_event_dispatcher()
             callback = self._get_on_message(l)
             callback(m.decode(), m)
         finally:
             l.timer.stop()
-            l.timer.join()
+            try:
+                l.timer.join()
+            except RuntimeError:
+                pass
 
         in_hold = l.timer.queue[0]
         self.assertEqual(len(in_hold), 3)