123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- from __future__ import absolute_import, unicode_literals
- import os
- import sys
- import warnings
- from kombu.utils import cached_property, symbol_by_name
- from datetime import datetime
- from importlib import import_module
- from celery import signals
- from celery.app import default_app
- from celery.exceptions import FixupWarning
- if sys.version_info[0] < 3 and not hasattr(sys, 'pypy_version_info'):
- from StringIO import StringIO
- else: # pragma: no cover
- from io import StringIO
- __all__ = ['DjangoFixup', 'fixup']
- ERR_NOT_INSTALLED = """\
- Environment variable DJANGO_SETTINGS_MODULE is defined
- but Django is not installed. Will not apply Django fix-ups!
- """
- def _maybe_close_fd(fh):
- try:
- os.close(fh.fileno())
- except (AttributeError, OSError, TypeError):
- # TypeError added for celery#962
- pass
- def fixup(app, env='DJANGO_SETTINGS_MODULE'):
- SETTINGS_MODULE = os.environ.get(env)
- if SETTINGS_MODULE and 'django' not in app.loader_cls.lower():
- try:
- import django # noqa
- except ImportError:
- warnings.warn(FixupWarning(ERR_NOT_INSTALLED))
- else:
- return DjangoFixup(app).install()
- class DjangoFixup(object):
- def __init__(self, app):
- self.app = app
- if default_app is None:
- self.app.set_default()
- self._worker_fixup = None
- def install(self):
- # Need to add project directory to path
- sys.path.append(os.getcwd())
- self._settings = symbol_by_name('django.conf:settings')
- self.app.loader.now = self.now
- self.app.loader.mail_admins = self.mail_admins
- signals.import_modules.connect(self.on_import_modules)
- signals.worker_init.connect(self.on_worker_init)
- return self
- @property
- def worker_fixup(self):
- if self._worker_fixup is None:
- self._worker_fixup = DjangoWorkerFixup(self.app)
- return self._worker_fixup
- @worker_fixup.setter
- def worker_fixup(self, value):
- self._worker_fixup = value
- def on_import_modules(self, **kwargs):
- # call django.setup() before task modules are imported
- self.worker_fixup.validate_models()
- def on_worker_init(self, **kwargs):
- self.worker_fixup.install()
- def now(self, utc=False):
- return datetime.utcnow() if utc else self._now()
- def mail_admins(self, subject, body, fail_silently=False, **kwargs):
- return self._mail_admins(subject, body, fail_silently=fail_silently)
- def autodiscover_tasks(self):
- try:
- from django.apps import apps
- except ImportError:
- return self._settings.INSTALLED_APPS
- else:
- return [config.name for config in apps.get_app_configs()]
- @cached_property
- def _mail_admins(self):
- return symbol_by_name('django.core.mail:mail_admins')
- @cached_property
- def _now(self):
- try:
- return symbol_by_name('django.utils.timezone:now')
- except (AttributeError, ImportError): # pre django-1.4
- return datetime.now
- class DjangoWorkerFixup(object):
- _db_recycles = 0
- def __init__(self, app):
- self.app = app
- self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
- self._db = import_module('django.db')
- self._cache = import_module('django.core.cache')
- self._settings = symbol_by_name('django.conf:settings')
- try:
- self.interface_errors = (
- symbol_by_name('django.db.utils.InterfaceError'),
- )
- except (ImportError, AttributeError):
- self._interface_errors = ()
- # Database-related exceptions.
- DatabaseError = symbol_by_name('django.db:DatabaseError')
- try:
- import MySQLdb as mysql
- _my_database_errors = (mysql.DatabaseError,
- mysql.InterfaceError,
- mysql.OperationalError)
- except ImportError:
- _my_database_errors = () # noqa
- try:
- import psycopg2 as pg
- _pg_database_errors = (pg.DatabaseError,
- pg.InterfaceError,
- pg.OperationalError)
- except ImportError:
- _pg_database_errors = () # noqa
- try:
- import sqlite3
- _lite_database_errors = (sqlite3.DatabaseError,
- sqlite3.InterfaceError,
- sqlite3.OperationalError)
- except ImportError:
- _lite_database_errors = () # noqa
- try:
- import cx_Oracle as oracle
- _oracle_database_errors = (oracle.DatabaseError,
- oracle.InterfaceError,
- oracle.OperationalError)
- except ImportError:
- _oracle_database_errors = () # noqa
- try:
- self._close_old_connections = symbol_by_name(
- 'django.db:close_old_connections',
- )
- except (ImportError, AttributeError):
- self._close_old_connections = None
- self.database_errors = (
- (DatabaseError,) +
- _my_database_errors +
- _pg_database_errors +
- _lite_database_errors +
- _oracle_database_errors
- )
- def django_setup(self):
- import django
- try:
- django_setup = django.setup
- except AttributeError: # pragma: no cover
- pass
- else:
- django_setup()
- def validate_models(self):
- self.django_setup()
- try:
- from django.core.management.validation import get_validation_errors
- except ImportError:
- self._validate_models_django17()
- else:
- s = StringIO()
- num_errors = get_validation_errors(s, None)
- if num_errors:
- raise RuntimeError(
- 'One or more Django models did not validate:\n{0}'.format(
- s.getvalue()))
- def _validate_models_django17(self):
- from django.core.management import base
- print(base)
- cmd = base.BaseCommand()
- try:
- cmd.stdout = base.OutputWrapper(sys.stdout)
- cmd.stderr = base.OutputWrapper(sys.stderr)
- except ImportError: # before django 1.5
- cmd.stdout, cmd.stderr = sys.stdout, sys.stderr
- cmd.check()
- def install(self):
- signals.beat_embedded_init.connect(self.close_database)
- signals.worker_ready.connect(self.on_worker_ready)
- signals.task_prerun.connect(self.on_task_prerun)
- signals.task_postrun.connect(self.on_task_postrun)
- signals.worker_process_init.connect(self.on_worker_process_init)
- self.close_database()
- self.close_cache()
- return self
- def on_worker_process_init(self, **kwargs):
- # Child process must validate models again if on Windows,
- # or if they were started using execv.
- if os.environ.get('FORKED_BY_MULTIPROCESSING'):
- self.validate_models()
- # close connections:
- # the parent process may have established these,
- # so need to close them.
- # calling db.close() on some DB connections will cause
- # the inherited DB conn to also get broken in the parent
- # process so we need to remove it without triggering any
- # network IO that close() might cause.
- try:
- for c in self._db.connections.all():
- if c and c.connection:
- self._maybe_close_db_fd(c.connection)
- except AttributeError:
- if self._db.connection and self._db.connection.connection:
- self._maybe_close_db_fd(self._db.connection.connection)
- # use the _ version to avoid DB_REUSE preventing the conn.close() call
- self._close_database()
- self.close_cache()
- def _maybe_close_db_fd(self, fd):
- try:
- _maybe_close_fd(fd)
- except self.interface_errors:
- pass
- def on_task_prerun(self, sender, **kwargs):
- """Called before every task."""
- if not getattr(sender.request, 'is_eager', False):
- self.close_database()
- def on_task_postrun(self, sender, **kwargs):
- # See http://groups.google.com/group/django-users/
- # browse_thread/thread/78200863d0c07c6d/
- if not getattr(sender.request, 'is_eager', False):
- self.close_database()
- self.close_cache()
- def close_database(self, **kwargs):
- if self._close_old_connections:
- return self._close_old_connections() # Django 1.6
- if not self.db_reuse_max:
- return self._close_database()
- if self._db_recycles >= self.db_reuse_max * 2:
- self._db_recycles = 0
- self._close_database()
- self._db_recycles += 1
- def _close_database(self):
- try:
- funs = [conn.close for conn in self._db.connections.all()]
- except AttributeError:
- if hasattr(self._db, 'close_old_connections'): # django 1.6
- funs = [self._db.close_old_connections]
- else:
- # pre multidb, pending deprication in django 1.6
- funs = [self._db.close_connection]
- for close in funs:
- try:
- close()
- except self.interface_errors:
- pass
- except self.database_errors as exc:
- str_exc = str(exc)
- if 'closed' not in str_exc and 'not connected' not in str_exc:
- raise
- def close_cache(self):
- try:
- self._cache.cache.close()
- except (TypeError, AttributeError):
- pass
- def on_worker_ready(self, **kwargs):
- if self._settings.DEBUG:
- warnings.warn('Using settings.DEBUG leads to a memory leak, never '
- 'use this setting in production environments!')
|