Browse Source

Now requires Django 1.8+

Ask Solem 8 years ago
parent
commit
98369614a7
2 changed files with 19 additions and 219 deletions
  1. 18 104
      celery/fixups/django.py
  2. 1 115
      celery/tests/fixups/test_django.py

+ 18 - 104
celery/fixups/django.py

@@ -88,20 +88,12 @@ class DjangoFixup(object):
         return datetime.utcnow() if utc else self._now()
 
     def autodiscover_tasks(self):
-        try:
-            from django.apps import apps
-        except ImportError:
-            return self._settings.INSTALLED_APPS
-        else:
-            return [config.name for config in apps.get_app_configs()]
+        from django.apps import apps
+        return [config.name for config in apps.get_app_configs()]
 
     @cached_property
     def _now(self):
-        try:
-            return symbol_by_name('django.utils.timezone:now')
-        except (AttributeError, ImportError):  # pre django-1.4
-            return datetime.now
-
+        return symbol_by_name('django.utils.timezone:now')
 
 class DjangoWorkerFixup(object):
     _db_recycles = 0
@@ -120,84 +112,21 @@ class DjangoWorkerFixup(object):
         except (ImportError, AttributeError):
             self._interface_errors = ()
 
-        # Database-related exceptions.
-        DatabaseError = symbol_by_name('django.db:DatabaseError')
-        try:
-            import MySQLdb as mysql
-            _my_database_errors = (mysql.DatabaseError,
-                                   mysql.InterfaceError,
-                                   mysql.OperationalError)
-        except ImportError:
-            _my_database_errors = ()      # noqa
-        try:
-            import psycopg2 as pg
-            _pg_database_errors = (pg.DatabaseError,
-                                   pg.InterfaceError,
-                                   pg.OperationalError)
-        except ImportError:
-            _pg_database_errors = ()      # noqa
-        try:
-            import sqlite3
-            _lite_database_errors = (sqlite3.DatabaseError,
-                                     sqlite3.InterfaceError,
-                                     sqlite3.OperationalError)
-        except ImportError:
-            _lite_database_errors = ()    # noqa
-        try:
-            import cx_Oracle as oracle
-            _oracle_database_errors = (oracle.DatabaseError,
-                                       oracle.InterfaceError,
-                                       oracle.OperationalError)
-        except ImportError:
-            _oracle_database_errors = ()  # noqa
-
-        try:
-            self._close_old_connections = symbol_by_name(
-                'django.db:close_old_connections',
-            )
-        except (ImportError, AttributeError):
-            self._close_old_connections = None
-        self.database_errors = (
-            (DatabaseError,) +
-            _my_database_errors +
-            _pg_database_errors +
-            _lite_database_errors +
-            _oracle_database_errors
-        )
+        self.DatabaseError = symbol_by_name('django.db:DatabaseError')
 
     def django_setup(self):
         import django
-        try:
-            django_setup = django.setup
-        except AttributeError:  # pragma: no cover
-            pass
-        else:
-            django_setup()
+        django.setup()
 
     def validate_models(self):
         self.django_setup()
-        try:
-            from django.core.management.validation import get_validation_errors
-        except ImportError:
-            self._validate_models_django17()
-        else:
-            s = StringIO()
-            num_errors = get_validation_errors(s, None)
-            if num_errors:
-                raise RuntimeError(
-                    'One or more Django models did not validate:\n{0}'.format(
-                        s.getvalue()))
-
-    def _validate_models_django17(self):
-        from django.core.management import base
-        print(base)
-        cmd = base.BaseCommand()
-        try:
-            cmd.stdout = base.OutputWrapper(sys.stdout)
-            cmd.stderr = base.OutputWrapper(sys.stderr)
-        except ImportError:  # before django 1.5
-            cmd.stdout, cmd.stderr = sys.stdout, sys.stderr
-        cmd.check()
+        from django.core.management.validation import get_validation_errors
+        s = StringIO()
+        num_errors = get_validation_errors(s, None)
+        if num_errors:
+            raise RuntimeError(
+                'One or more Django models did not validate:\n{0}'.format(
+                    s.getvalue()))
 
     def install(self):
         signals.beat_embedded_init.connect(self.close_database)
@@ -223,13 +152,9 @@ class DjangoWorkerFixup(object):
         # the inherited DB conn to also get broken in the parent
         # process so we need to remove it without triggering any
         # network IO that close() might cause.
-        try:
-            for c in self._db.connections.all():
-                if c and c.connection:
-                    self._maybe_close_db_fd(c.connection)
-        except AttributeError:
-            if self._db.connection and self._db.connection.connection:
-                self._maybe_close_db_fd(self._db.connection.connection)
+        for c in self._db.connections.all():
+            if c and c.connection:
+                self._maybe_close_db_fd(c.connection)
 
         # use the _ version to avoid DB_REUSE preventing the conn.close() call
         self._close_database()
@@ -254,8 +179,6 @@ class DjangoWorkerFixup(object):
             self.close_cache()
 
     def close_database(self, **kwargs):
-        if self._close_old_connections:
-            return self._close_old_connections()  # Django 1.6
         if not self.db_reuse_max:
             return self._close_database()
         if self._db_recycles >= self.db_reuse_max * 2:
@@ -264,21 +187,12 @@ class DjangoWorkerFixup(object):
         self._db_recycles += 1
 
     def _close_database(self):
-        try:
-            funs = [conn.close for conn in self._db.connections.all()]
-        except AttributeError:
-            if hasattr(self._db, 'close_old_connections'):  # django 1.6
-                funs = [self._db.close_old_connections]
-            else:
-                # pre multidb, pending deprication in django 1.6
-                funs = [self._db.close_connection]
-
-        for close in funs:
+        for conn in self._db.connections.all():
             try:
-                close()
+                conn.close()
             except self.interface_errors:
                 pass
-            except self.database_errors as exc:
+            except self.DatabaseError as exc:
                 str_exc = str(exc)
                 if 'closed' not in str_exc and 'not connected' not in str_exc:
                     raise

+ 1 - 115
celery/tests/fixups/test_django.py

@@ -53,12 +53,6 @@ class test_DjangoFixup(FixupCase):
         f.on_import_modules()
         f.worker_fixup.validate_models.assert_called_with()
 
-    def test_autodiscover_tasks_pre17(self):
-        self.mask_modules('django.apps')
-        f = DjangoFixup(self.app)
-        f._settings = Mock(name='_settings')
-        self.assertIs(f.autodiscover_tasks(), f._settings.INSTALLED_APPS)
-
     def test_autodiscover_tasks(self):
         self.mock_modules('django.apps')
         from django.apps import apps
@@ -93,13 +87,6 @@ class test_DjangoFixup(FixupCase):
         with self.fixup_context(self.app) as (f, importmod, sym):
             self.assertTrue(f)
 
-            def se(name):
-                if name == 'django.utils.timezone:now':
-                    raise ImportError()
-                return Mock()
-            sym.side_effect = se
-            self.assertTrue(self.Fixup(self.app)._now)
-
     def test_install(self):
         self.app.loader = Mock()
         self.cw = self.patch('os.getcwd')
@@ -135,13 +122,6 @@ class test_DjangoWorkerFixup(FixupCase):
         with self.fixup_context(self.app) as (f, importmod, sym):
             self.assertTrue(f)
 
-            def se(name):
-                if name == 'django.db:close_old_connections':
-                    raise ImportError()
-                return Mock()
-            sym.side_effect = se
-            self.assertIsNone(self.Fixup(self.app)._close_old_connections)
-
     def test_install(self):
         self.app.conf = {'CELERY_DB_REUSE_MAX': None}
         self.app.loader = Mock()
@@ -173,13 +153,6 @@ class test_DjangoWorkerFixup(FixupCase):
                             f.close_cache.assert_called_with()
                             f._close_database.assert_called_with()
 
-                            mcf.reset_mock()
-                            _all.side_effect = AttributeError()
-                            f.on_worker_process_init()
-                            mcf.assert_called_with(f._db.connection.connection)
-                            f._db.connection = None
-                            f.on_worker_process_init()
-
                             f.validate_models = Mock(name='validate_models')
                             self.mock_environ('FORKED_BY_MULTIPROCESSING', '1')
                             f.on_worker_process_init()
@@ -218,10 +191,6 @@ class test_DjangoWorkerFixup(FixupCase):
 
     def test_close_database(self):
         with self.fixup_context(self.app) as (f, _, _):
-            f._close_old_connections = Mock()
-            f.close_database()
-            f._close_old_connections.assert_called_with()
-            f._close_old_connections = None
             with patch.object(f, '_close_database') as _close:
                 f.db_reuse_max = None
                 f.close_database()
@@ -240,18 +209,11 @@ class test_DjangoWorkerFixup(FixupCase):
                 _close.assert_called_with()
                 self.assertEqual(f._db_recycles, 1)
 
-    def test_close_database__django16(self):
-        with self.fixup_context(self.app) as (f, _, _):
-            f._db.connections = Mock(name='db.connections')
-            f._db.connections.all.side_effect = AttributeError()
-            f._close_database()
-            f._db.close_old_connections.assert_called_with()
-
     def test__close_database(self):
         with self.fixup_context(self.app) as (f, _, _):
             conns = [Mock(), Mock(), Mock()]
             conns[1].close.side_effect = KeyError('already closed')
-            f.database_errors = (KeyError,)
+            f.DatabaseError = KeyError
             f.interface_errors = ()
 
             f._db.connections = Mock()  # ConnectionHandler
@@ -266,11 +228,6 @@ class test_DjangoWorkerFixup(FixupCase):
             with self.assertRaises(KeyError):
                 f._close_database()
 
-            o = Bunch(close_connection=Mock())
-            f._db = o
-            f._close_database()
-            o.close_connection.assert_called_with()
-
     def test_close_cache(self):
         with self.fixup_context(self.app) as (f, _, _):
             f.close_cache()
@@ -300,21 +257,6 @@ class test_DjangoWorkerFixup(FixupCase):
         with self.assertRaises(RuntimeError):
             f.validate_models()
 
-        self.mask_modules('django.core.management.validation')
-        f._validate_models_django17 = Mock('validate17')
-        f.validate_models()
-        f._validate_models_django17.assert_called_with()
-
-    def test_validate_models_django17(self):
-        self.patch('celery.fixups.django.symbol_by_name')
-        self.patch('celery.fixups.django.import_module')
-        self.mock_modules('django.core.management.base')
-        from django.core.management import base
-        f = self.Fixup(self.app)
-        f._validate_models_django17()
-        base.BaseCommand.assert_called_with()
-        base.BaseCommand().check.assert_called_with()
-
     def test_django_setup(self):
         self.patch('celery.fixups.django.symbol_by_name')
         self.patch('celery.fixups.django.import_module')
@@ -322,59 +264,3 @@ class test_DjangoWorkerFixup(FixupCase):
         f = self.Fixup(self.app)
         f.django_setup()
         django.setup.assert_called_with()
-
-    def test_mysql_errors(self):
-        with mock.module_exists('MySQLdb'):
-            import MySQLdb as mod
-            mod.DatabaseError = Mock()
-            mod.InterfaceError = Mock()
-            mod.OperationalError = Mock()
-            with self.fixup_context(self.app) as (f, _, _):
-                self.assertIn(mod.DatabaseError, f.database_errors)
-                self.assertIn(mod.InterfaceError, f.database_errors)
-                self.assertIn(mod.OperationalError, f.database_errors)
-        with mock.mask_modules('MySQLdb'):
-            with self.fixup_context(self.app):
-                pass
-
-    def test_pg_errors(self):
-        with mock.module_exists('psycopg2'):
-            import psycopg2 as mod
-            mod.DatabaseError = Mock()
-            mod.InterfaceError = Mock()
-            mod.OperationalError = Mock()
-            with self.fixup_context(self.app) as (f, _, _):
-                self.assertIn(mod.DatabaseError, f.database_errors)
-                self.assertIn(mod.InterfaceError, f.database_errors)
-                self.assertIn(mod.OperationalError, f.database_errors)
-        with mock.mask_modules('psycopg2'):
-            with self.fixup_context(self.app):
-                pass
-
-    def test_sqlite_errors(self):
-        with mock.module_exists('sqlite3'):
-            import sqlite3 as mod
-            mod.DatabaseError = Mock()
-            mod.InterfaceError = Mock()
-            mod.OperationalError = Mock()
-            with self.fixup_context(self.app) as (f, _, _):
-                self.assertIn(mod.DatabaseError, f.database_errors)
-                self.assertIn(mod.InterfaceError, f.database_errors)
-                self.assertIn(mod.OperationalError, f.database_errors)
-        with mock.mask_modules('sqlite3'):
-            with self.fixup_context(self.app):
-                pass
-
-    def test_oracle_errors(self):
-        with mock.module_exists('cx_Oracle'):
-            import cx_Oracle as mod
-            mod.DatabaseError = Mock()
-            mod.InterfaceError = Mock()
-            mod.OperationalError = Mock()
-            with self.fixup_context(self.app) as (f, _, _):
-                self.assertIn(mod.DatabaseError, f.database_errors)
-                self.assertIn(mod.InterfaceError, f.database_errors)
-                self.assertIn(mod.OperationalError, f.database_errors)
-        with mock.mask_modules('cx_Oracle'):
-            with self.fixup_context(self.app):
-                pass