Преглед на файлове

Can now specify SQLAlchemy connection string via CELERY_RESULT_DBURI

Ask Solem преди 15 години
родител
ревизия
0096b92828
променени са 5 файла, в които са добавени 48 реда и са изтрити 26 реда
  1. 7 22
      celery/backends/database.py
  2. 1 0
      celery/conf.py
  3. 3 3
      celery/db/models.py
  4. 29 0
      celery/db/session.py
  5. 8 1
      tests/celeryconfig.py

+ 7 - 22
celery/backends/database.py

@@ -1,28 +1,13 @@
 import urllib
 from datetime import datetime
 
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
 
 from celery import conf
-from celery.db.models import ModelBase, Task, TaskSet
+from celery.db.models import Task, TaskSet
+from celery.db.session import ResultSession
 from celery.backends.base import BaseDictBackend
 
-server = '<sql server host>'
-database = '<your db>'
-userid = '<your user>'
-password = '<your password>'
-port = 1433
-raw_cs = "DRIVER={FreeTDS};SERVER=%s;PORT=%d;DATABASE=%s;UID=%s;PWD=%s;CHARSET=UTF8;TDS_VERSION=8.0;TEXTSIZE=10000" % (server, port, database, userid, password)
-#connection_string = "mssql:///?odbc_connect=%s" % urllib.quote_plus(raw_cs)
-#connection_string = 'sqlite:////mnt/winctmp/celery.db'
-connection_string = 'sqlite:///celery.db'
-engine = create_engine(connection_string)
-Session = sessionmaker(bind=engine)
 
-import os
-if os.environ.get("CELERYINIT"):
-    ModelBase.metadata.create_all(engine)
 
 
 class DatabaseBackend(BaseDictBackend):
@@ -30,7 +15,7 @@ class DatabaseBackend(BaseDictBackend):
 
     def _store_result(self, task_id, result, status, traceback=None):
         """Store return value and status of an executed task."""
-        session = Session()
+        session = ResultSession()
         try:
             tasks = session.query(Task).filter(Task.task_id == task_id).all()
             if not tasks:
@@ -50,7 +35,7 @@ class DatabaseBackend(BaseDictBackend):
     def _save_taskset(self, taskset_id, result):
         """Store the result of an executed taskset."""
         taskset = TaskSet(taskset_id, result)
-        session = Session()
+        session = ResultSession()
         try:
             session.add(taskset)
             session.flush()
@@ -61,7 +46,7 @@ class DatabaseBackend(BaseDictBackend):
 
     def _get_task_meta_for(self, task_id):
         """Get task metadata for a task by id."""
-        session = Session()
+        session = ResultSession()
         try:
             task = None
             for task in session.query(Task).filter(Task.task_id == task_id):
@@ -78,7 +63,7 @@ class DatabaseBackend(BaseDictBackend):
 
     def _restore_taskset(self, taskset_id):
         """Get taskset metadata for a taskset by id."""
-        session = Session()
+        session = ResultSession()
         try:
             qs = session.query(TaskSet)
             for taskset in qs.filter(TaskSet.task_id == task_id):
@@ -89,7 +74,7 @@ class DatabaseBackend(BaseDictBackend):
     def cleanup(self):
         """Delete expired metadata."""
         expires = conf.TASK_RESULT_EXPIRES
-        session = Session()
+        session = ResultSession()
         try:
             for task in session.query(Task).filter(
                     Task.date_done < (datetime.now() - expires)):

+ 1 - 0
celery/conf.py

@@ -94,6 +94,7 @@ def _get(name, default=None, compat=None):
 ALWAYS_EAGER = _get("CELERY_ALWAYS_EAGER")
 RESULT_BACKEND = _get("CELERY_RESULT_BACKEND", compat=["CELERY_BACKEND"])
 CELERY_BACKEND = RESULT_BACKEND # FIXME Remove in 1.4
+RESULT_DBURI = _get("CELERY_RESULT_DBURI")
 CELERY_CACHE_BACKEND = _get("CELERY_CACHE_BACKEND")
 TASK_SERIALIZER = _get("CELERY_TASK_SERIALIZER")
 TASK_RESULT_EXPIRES = _get("CELERY_TASK_RESULT_EXPIRES")

+ 3 - 3
celery/db/models.py

@@ -6,11 +6,11 @@ from sqlalchemy.orm import relation
 from sqlalchemy.ext.declarative import declarative_base
 
 from celery import states
+from celery.db.session import ResultModelBase
 
-ModelBase = declarative_base()
 
 
-class Task(ModelBase):
+class Task(ResultModelBase):
     """Task result/status."""
     __tablename__ = "celery_taskmeta"
     __table_args__ = {"sqlite_autoincrement": True}
@@ -44,7 +44,7 @@ class Task(ModelBase):
         return u"<Task: %s successful: %s>" % (self.task_id, self.status)
 
 
-class TaskSet(ModelBase):
+class TaskSet(ResultModelBase):
     """TaskSet result"""
     __tablename__ = "celery_tasksetmeta"
     __table_args__ = {"sqlite_autoincrement": True}

+ 29 - 0
celery/db/session.py

@@ -0,0 +1,29 @@
+import os
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.ext.declarative import declarative_base
+
+from celery import conf
+
+ResultModelBase = declarative_base()
+
+_SETUP = {"results": False}
+
+
+def create_session(dburi, **kwargs):
+    engine = create_engine(dburi, **kwargs)
+    return engine, sessionmaker(bind=engine)
+
+
+def setup_results(engine):
+    if not _SETUP["results"]:
+        ResultModelBase.metadata.create_all(engine)
+        _SETUP["results"] = True
+
+
+def ResultSession(dburi=conf.RESULT_DBURI, **kwargs):
+    engine, session = create_session(dburi, **kwargs)
+    if os.environ.get("CELERYINIT"):
+        setup_results(engine)
+    return session()

+ 8 - 1
tests/celeryconfig.py

@@ -1,3 +1,5 @@
+import atexit
+
 BROKER_HOST = "localhost"
 BROKER_PORT = 5672
 BROKER_USER = "guest"
@@ -5,5 +7,10 @@ BROKER_PASSWORD = "guest"
 BROKER_VHOST = "/"
 
 CELERY_RESULT_BACKEND = "database"
-CELERY_RESULT_DBURI = "sqlite:///:memory:"
+CELERY_RESULT_DBURI = "sqlite:///test.db"
 CELERY_SEND_TASK_ERROR_EMAILS = False
+
+@atexit.register
+def teardown_testdb():
+    import os
+    os.remove("test.db")