Ver código fonte

small refactor of sqlalchemy code

ergo 15 anos atrás
pai
commit
0ed216be06
2 arquivos alterados com 22 adições e 25 exclusões
  1. 9 12
      celery/backends/database.py
  2. 13 13
      celery/db/models.py

+ 9 - 12
celery/backends/database.py

@@ -31,13 +31,11 @@ class DatabaseBackend(BaseDictBackend):
         """Store return value and status of an executed task."""
         session = self.ResultSession()
         try:
-            tasks = session.query(Task).filter(Task.task_id == task_id).all()
-            if not tasks:
+            task = session.query(Task).filter(Task.task_id == task_id).first()
+            if not task:
                 task = Task(task_id)
                 session.add(task)
                 session.flush()
-            else:
-                task = tasks[0]
             task.result = result
             task.status = status
             task.traceback = traceback
@@ -62,9 +60,7 @@ class DatabaseBackend(BaseDictBackend):
         """Get task metadata for a task by id."""
         session = self.ResultSession()
         try:
-            task = None
-            for task in session.query(Task).filter(Task.task_id == task_id):
-                break
+            task = session.query(Task).filter(Task.task_id == task_id).first()
             if not task:
                 task = Task(task_id)
                 session.add(task)
@@ -80,7 +76,8 @@ class DatabaseBackend(BaseDictBackend):
         session = self.ResultSession()
         try:
             qs = session.query(TaskSet)
-            for taskset in qs.filter(TaskSet.taskset_id == taskset_id):
+            taskset = qs.filter(TaskSet.taskset_id == taskset_id).first()
+            if taskset:
                 return taskset.to_dict()
         finally:
             session.close()
@@ -90,12 +87,12 @@ class DatabaseBackend(BaseDictBackend):
         session = self.ResultSession()
         expires = self.result_expires
         try:
-            for task in session.query(Task).filter(
+            qs = session.query(Task).filter(
                     Task.date_done < (datetime.now() - expires)):
-                session.delete(task)
-            for taskset in session.query(TaskSet).filter(
+            qs.delete()
+            qs = session.query(TaskSet).filter(
                     TaskSet.date_done < (datetime.now() - expires)):
-                session.delete(taskset)
+            qs = session.delete()
             session.commit()
         finally:
             session.close()

+ 13 - 13
celery/db/models.py

@@ -1,7 +1,6 @@
 from datetime import datetime
 
-from sqlalchemy import Column, Sequence
-from sqlalchemy import Integer, String, Text, DateTime
+import sqlalchemy as sa
 
 from celery import states
 from celery.db.session import ResultModelBase
@@ -14,14 +13,15 @@ class Task(ResultModelBase):
     __tablename__ = "celery_taskmeta"
     __table_args__ = {"sqlite_autoincrement": True}
 
-    id = Column("id", Integer, Sequence("task_id_sequence"), primary_key=True,
-            autoincrement=True)
-    task_id = Column("task_id", String(255))
-    status = Column("status", String(50), default=states.PENDING)
-    result = Column("result", PickleType, nullable=True)
-    date_done = Column("date_done", DateTime, default=datetime.now,
+    id = sa.Column(sa.Integer, sa.Sequence("task_id_sequence"),
+                   primary_key=True,
+                   autoincrement=True)
+    task_id = sa.Column(sa.String(255))
+    status = sa.Column(sa.String(50), default=states.PENDING)
+    result = sa.Column(PickleType, nullable=True)
+    date_done = sa.Column(sa.DateTime, default=datetime.now,
                        onupdate=datetime.now, nullable=True)
-    traceback = Column("traceback", Text, nullable=True)
+    traceback = sa.Column(sa.Text, nullable=True)
 
     def __init__(self, task_id):
         self.task_id = task_id
@@ -48,11 +48,11 @@ class TaskSet(ResultModelBase):
     __tablename__ = "celery_tasksetmeta"
     __table_args__ = {"sqlite_autoincrement": True}
 
-    id = Column("id", Integer, Sequence("taskset_id_sequence"),
+    id = sa.Column(sa.Integer, sa.Sequence("taskset_id_sequence"),
                 autoincrement=True, primary_key=True)
-    taskset_id = Column("taskset_id", String(255))
-    result = Column("result", PickleType, nullable=True)
-    date_done = Column("date_done", DateTime, default=datetime.now,
+    taskset_id = sa.Column(sa.String(255))
+    result = sa.Column(sa.PickleType, nullable=True)
+    date_done = sa.Column(sa.DateTime, default=datetime.now,
                        nullable=True)
 
     def __init__(self, task_id):