Browse Source

100% coverage for celery.fixups.django

Ask Solem 11 years ago
parent
commit
dbe02f65ff
4 changed files with 290 additions and 22 deletions
  1. 20 21
      celery/fixups/django.py
  2. 2 1
      celery/tests/case.py
  3. 0 0
      celery/tests/fixups/__init__.py
  4. 268 0
      celery/tests/fixups/test_django.py

+ 20 - 21
celery/fixups/django.py

@@ -4,12 +4,14 @@ import os
 import sys
 import warnings
 
+from kombu.utils import symbol_by_name
+
 from datetime import datetime
+from importlib import import_module
 
 from celery import signals
 from celery.exceptions import FixupWarning
 
-SETTINGS_MODULE = os.environ.get('DJANGO_SETTINGS_MODULE')
 ERR_NOT_INSTALLED = """\
 Environment variable DJANGO_SETTINGS_MODULE is defined
 but Django is not installed.  Will not apply Django fixups!
@@ -24,32 +26,36 @@ def _maybe_close_fd(fh):
         pass
 
 
-def fixup(app):
+def fixup(app, env='DJANGO_SETTINGS_MODULE'):
+    SETTINGS_MODULE = os.environ.get(env)
     if SETTINGS_MODULE:
         try:
             import django  # noqa
         except ImportError:
             warnings.warn(FixupWarning(ERR_NOT_INSTALLED))
-        return DjangoFixup(app).install()
+        else:
+            return DjangoFixup(app).install()
 
 
 class DjangoFixup(object):
     _db_recycles = 0
 
     def __init__(self, app):
-        from django import db
-        from django.core import cache
-        from django.conf import settings
-        from django.core.mail import mail_admins
+        self.app = app
+        self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
+        self._db = import_module('django.db')
+        self._cache = import_module('django.core.cache')
+        self._settings = symbol_by_name('django.conf:settings')
+        self._mail_admins = symbol_by_name('django.core.mail:mail_admins')
 
         # Current time and date
         try:
-            from django.utils.timezone import now
+            self._now = symbol_by_name('django.utils.timezone:now')
         except ImportError:  # pre django-1.4
-            now = datetime.now  # noqa
+            self._now = datetime.now  # noqa
 
         # Database-related exceptions.
-        from django.db import DatabaseError
+        DatabaseError = symbol_by_name('django.db:DatabaseError')
         try:
             import MySQLdb as mysql
             _my_database_errors = (mysql.DatabaseError,
@@ -80,24 +86,17 @@ class DjangoFixup(object):
             _oracle_database_errors = ()  # noqa
 
         try:
-            from django.db import close_old_connections
-            self._close_old_connections = close_old_connections
+            self._close_old_connections = symbol_by_name(
+                'django.db:close_old_connections',
+            )
         except ImportError:
             self._close_old_connections = None
-
-        self.app = app
-        self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
-        self._cache = cache
-        self._settings = settings
-        self._db = db
-        self._mail_admins = mail_admins
-        self._now = now
         self.database_errors = (
             (DatabaseError, ) +
             _my_database_errors +
             _pg_database_errors +
             _lite_database_errors +
-            _oracle_database_errors,
+            _oracle_database_errors
         )
 
     def install(self):

+ 2 - 1
celery/tests/case.py

@@ -465,7 +465,8 @@ def reset_modules(*modules):
 def patch_modules(*modules):
     prev = {}
     for mod in modules:
-        prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
+        prev[mod] = sys.modules.get(mod)
+        sys.modules[mod] = ModuleType(mod)
     try:
         yield
     finally:

+ 0 - 0
celery/tests/fixups/__init__.py


+ 268 - 0
celery/tests/fixups/test_django.py

@@ -0,0 +1,268 @@
+from __future__ import absolute_import
+
+import os
+
+from contextlib import contextmanager
+from mock import Mock, patch
+
+from celery import Celery
+from celery.fixups.django import (
+    _maybe_close_fd,
+    fixup,
+    DjangoFixup,
+)
+
+from celery.tests.case import AppCase, patch_many, patch_modules, mask_modules
+
+
+class test_DjangoFixup(AppCase):
+
+    def test_fixup(self):
+        with patch('celery.fixups.django.DjangoFixup') as Fixup:
+            with patch.dict(os.environ, DJANGO_SETTINGS_MODULE=''):
+                fixup(self.app)
+                self.assertFalse(Fixup.called)
+            with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'):
+                with mask_modules('django'):
+                    with self.assertWarnsRegex(UserWarning, 'but Django is'):
+                        fixup(self.app)
+                        self.assertFalse(Fixup.called)
+                with patch_modules('django'):
+                    fixup(self.app)
+                    self.assertTrue(Fixup.called)
+
+    @contextmanager
+    def fixup_context(self, app):
+        with patch('celery.fixups.django.import_module') as import_module:
+            with patch('celery.fixups.django.symbol_by_name') as symbyname:
+                f = DjangoFixup(app)
+                yield f, import_module, symbyname
+
+    def test_maybe_close_fd(self):
+        with patch('os.close'):
+            _maybe_close_fd(Mock())
+            _maybe_close_fd(object())
+
+    def test_init(self):
+        with self.fixup_context(self.app) as (f, importmod, sym):
+            self.assertTrue(f)
+
+            def se(name):
+                if name == 'django.utils.timezone:now':
+                    raise ImportError()
+                return Mock()
+            sym.side_effect = se
+            self.assertTrue(DjangoFixup(self.app)._now)
+
+            def se(name):
+                if name == 'django.db:close_old_connections':
+                    raise ImportError()
+                return Mock()
+            sym.side_effect = se
+            self.assertIsNone(DjangoFixup(self.app)._close_old_connections)
+
+    def test_install(self):
+        app = Celery(set_as_current=False)
+        app.conf = {'CELERY_DB_REUSE_MAX': None}
+        app.loader = Mock()
+        with self.fixup_context(app) as (f, _, _):
+            with patch_many('os.getcwd', 'sys.path',
+                            'celery.fixups.django.signals') as (cw, p, sigs):
+                cw.return_value = '/opt/vandelay'
+                f.install()
+                sigs.beat_embedded_init.connect.assert_called_with(
+                    f.close_database,
+                )
+                sigs.worker_ready.connect.assert_called_with(f.on_worker_ready)
+                sigs.task_prerun.connect.assert_called_with(f.on_task_prerun)
+                sigs.task_postrun.connect.assert_called_with(f.on_task_postrun)
+                sigs.worker_init.connect.assert_called_with(f.on_worker_init)
+                sigs.worker_process_init.connect.assert_called_with(
+                    f.on_worker_process_init,
+                )
+                self.assertEqual(app.loader.now, f.now)
+                self.assertEqual(app.loader.mail_admins, f.mail_admins)
+                p.append.assert_called_with('/opt/vandelay')
+
+    def test_now(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            self.assertTrue(f.now(utc=True))
+            self.assertFalse(f._now.called)
+            self.assertTrue(f.now(utc=False))
+            self.assertTrue(f._now.called)
+
+    def test_mail_admins(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f.mail_admins('sub', 'body', True)
+            f._mail_admins.assert_called_with(
+                'sub', 'body', fail_silently=True,
+            )
+
+    def test_on_worker_init(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f.close_database = Mock()
+            f.close_cache = Mock()
+            f.on_worker_init()
+            f.close_database.assert_called_with()
+            f.close_cache.assert_called_with()
+
+    def test_on_worker_process_init(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            with patch('celery.fixups.django._maybe_close_fd') as mcf:
+                    _all = f._db.connections.all = Mock()
+                    conns = _all.return_value = [
+                        Mock(), Mock(),
+                    ]
+                    conns[0].connection = None
+                    with patch.object(f, 'close_cache'):
+                        with patch.object(f, '_close_database'):
+                            f.on_worker_process_init()
+                            mcf.assert_called_with(conns[1].connection)
+                            f.close_cache.assert_called_with()
+                            f._close_database.assert_called_with()
+
+                            mcf.reset_mock()
+                            _all.side_effect = AttributeError()
+                            f.on_worker_process_init()
+                            mcf.assert_called_with(f._db.connection.connection)
+                            f._db.connection = None
+                            f.on_worker_process_init()
+
+    def test_on_task_prerun(self):
+        task = Mock()
+        with self.fixup_context(self.app) as (f, _, _):
+            task.request.is_eager = False
+            with patch.object(f, 'close_database'):
+                f.on_task_prerun(task)
+                f.close_database.assert_called_with()
+
+            task.request.is_eager = True
+            with patch.object(f, 'close_database'):
+                f.on_task_prerun(task)
+                self.assertFalse(f.close_database.called)
+
+    def test_on_task_postrun(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            with patch.object(f, 'close_database'):
+                with patch.object(f, 'close_cache'):
+                    f.on_task_postrun()
+                    f.close_database.assert_called_with()
+                    f.close_cache.assert_called_with()
+
+    def test_close_database(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f._close_old_connections = Mock()
+            f.close_database()
+            f._close_old_connections.assert_called_with()
+            f._close_old_connections = None
+            with patch.object(f, '_close_database') as _close:
+                f.db_reuse_max = None
+                f.close_database()
+                _close.assert_called_with()
+                _close.reset_mock()
+
+                f.db_reuse_max = 10
+                f._db_recycles = 3
+                f.close_database()
+                self.assertFalse(_close.called)
+                self.assertEqual(f._db_recycles, 4)
+                _close.reset_mock()
+
+                f._db_recycles = 20
+                f.close_database()
+                _close.assert_called_with()
+                self.assertEqual(f._db_recycles, 1)
+
+    def test__close_database(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            conns = f._db.connections = [Mock(), Mock(), Mock()]
+            conns[1].close.side_effect = KeyError('already closed')
+            f.database_errors = (KeyError, )
+
+            f._close_database()
+            conns[0].close.assert_called_with()
+            conns[1].close.assert_called_with()
+            conns[2].close.assert_called_with()
+
+            conns[1].close.side_effect = KeyError('omg')
+            with self.assertRaises(KeyError):
+                f._close_database()
+
+            class Object(object):
+                pass
+            o = Object()
+            o.close_connection = Mock()
+            f._db = o
+            f._close_database()
+            o.close_connection.assert_called_with()
+
+    def test_close_cache(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f.close_cache()
+            f._cache.cache.close.assert_called_with()
+            f._cache.cache.close.side_effect = TypeError()
+            f.close_cache()
+
+    def test_on_worker_ready(self):
+        with self.fixup_context(self.app) as (f, _, _):
+            f._settings.DEBUG = False
+            f.on_worker_ready()
+            with self.assertWarnsRegex(UserWarning, r'leads to a memory leak'):
+                f._settings.DEBUG = True
+                f.on_worker_ready()
+
+    def test_mysql_errors(self):
+        with patch_modules('MySQLdb'):
+            import MySQLdb as mod
+            mod.DatabaseError = Mock()
+            mod.InterfaceError = Mock()
+            mod.OperationalError = Mock()
+            with self.fixup_context(self.app) as (f, _, _):
+                self.assertIn(mod.DatabaseError, f.database_errors)
+                self.assertIn(mod.InterfaceError, f.database_errors)
+                self.assertIn(mod.OperationalError, f.database_errors)
+        with mask_modules('MySQLdb'):
+            with self.fixup_context(self.app):
+                pass
+
+    def test_pg_errors(self):
+        with patch_modules('psycopg2'):
+            import psycopg2 as mod
+            mod.DatabaseError = Mock()
+            mod.InterfaceError = Mock()
+            mod.OperationalError = Mock()
+            with self.fixup_context(self.app) as (f, _, _):
+                self.assertIn(mod.DatabaseError, f.database_errors)
+                self.assertIn(mod.InterfaceError, f.database_errors)
+                self.assertIn(mod.OperationalError, f.database_errors)
+        with mask_modules('psycopg2'):
+            with self.fixup_context(self.app):
+                pass
+
+    def test_sqlite_errors(self):
+        with patch_modules('sqlite3'):
+            import sqlite3 as mod
+            mod.DatabaseError = Mock()
+            mod.InterfaceError = Mock()
+            mod.OperationalError = Mock()
+            with self.fixup_context(self.app) as (f, _, _):
+                self.assertIn(mod.DatabaseError, f.database_errors)
+                self.assertIn(mod.InterfaceError, f.database_errors)
+                self.assertIn(mod.OperationalError, f.database_errors)
+        with mask_modules('sqlite3'):
+            with self.fixup_context(self.app):
+                pass
+
+    def test_oracle_errors(self):
+        with patch_modules('cx_Oracle'):
+            import cx_Oracle as mod
+            mod.DatabaseError = Mock()
+            mod.InterfaceError = Mock()
+            mod.OperationalError = Mock()
+            with self.fixup_context(self.app) as (f, _, _):
+                self.assertIn(mod.DatabaseError, f.database_errors)
+                self.assertIn(mod.InterfaceError, f.database_errors)
+                self.assertIn(mod.OperationalError, f.database_errors)
+        with mask_modules('cx_Oracle'):
+            with self.fixup_context(self.app):
+                pass