Ask Solem 14 лет назад
Родитель
Сommit
fe22b23383
5 измененных файлов с 24 добавлено и 4 удалено
  1. 2 2
      celery/db/models.py
  2. 4 0
      celery/result.py
  3. 2 1
      celery/task/base.py
  4. 8 0
      celery/tests/test_task/test_result.py
  5. 8 1
      celery/worker/job.py

+ 2 - 2
celery/db/models.py

@@ -19,7 +19,7 @@ class Task(ResultModelBase):
     id = sa.Column(sa.Integer, sa.Sequence("task_id_sequence"),
     id = sa.Column(sa.Integer, sa.Sequence("task_id_sequence"),
                    primary_key=True,
                    primary_key=True,
                    autoincrement=True)
                    autoincrement=True)
-    task_id = sa.Column(sa.String(255))
+    task_id = sa.Column(sa.String(255), unique=True)
     status = sa.Column(sa.String(50), default=states.PENDING)
     status = sa.Column(sa.String(50), default=states.PENDING)
     result = sa.Column(PickleType, nullable=True)
     result = sa.Column(PickleType, nullable=True)
     date_done = sa.Column(sa.DateTime, default=datetime.now,
     date_done = sa.Column(sa.DateTime, default=datetime.now,
@@ -46,7 +46,7 @@ class TaskSet(ResultModelBase):
 
 
     id = sa.Column(sa.Integer, sa.Sequence("taskset_id_sequence"),
     id = sa.Column(sa.Integer, sa.Sequence("taskset_id_sequence"),
                 autoincrement=True, primary_key=True)
                 autoincrement=True, primary_key=True)
-    taskset_id = sa.Column(sa.String(255))
+    taskset_id = sa.Column(sa.String(255), unique=True)
     result = sa.Column(sa.PickleType, nullable=True)
     result = sa.Column(sa.PickleType, nullable=True)
     date_done = sa.Column(sa.DateTime, default=datetime.now,
     date_done = sa.Column(sa.DateTime, default=datetime.now,
                        nullable=True)
                        nullable=True)

+ 4 - 0
celery/result.py

@@ -446,6 +446,10 @@ class EagerResult(BaseAsyncResult):
         return (self.__class__, (self.task_id, self._result,
         return (self.__class__, (self.task_id, self._result,
                                  self._state, self._traceback))
                                  self._state, self._traceback))
 
 
+    def __copy__(self):
+        cls, args = self.__reduce__()
+        return cls(*args)
+
     def successful(self):
     def successful(self):
         """Returns :const:`True` if the task executed without failure."""
         """Returns :const:`True` if the task executed without failure."""
         return self.state == states.SUCCESS
         return self.state == states.SUCCESS

+ 2 - 1
celery/task/base.py

@@ -34,7 +34,8 @@ _default_context = {"logfile": None,
                     "kwargs": None,
                     "kwargs": None,
                     "retries": 0,
                     "retries": 0,
                     "is_eager": False,
                     "is_eager": False,
-                    "delivery_info": None}
+                    "delivery_info": None,
+                    "taskset": None}
 
 
 
 
 class Context(threading.local):
 class Context(threading.local):

+ 8 - 0
celery/tests/test_task/test_result.py

@@ -244,6 +244,14 @@ class TestTaskSetResult(unittest.TestCase):
         self.assertEqual(it.next(), 42)
         self.assertEqual(it.next(), 42)
         self.assertEqual(it.next(), 42)
         self.assertEqual(it.next(), 42)
 
 
+    def test_iterate_eager(self):
+        ar1 = EagerResult(gen_unique_id(), 42, states.SUCCESS)
+        ar2 = EagerResult(gen_unique_id(), 42, states.SUCCESS)
+        ts = TaskSetResult(gen_unique_id(), [ar1, ar2])
+        it = iter(ts)
+        self.assertEqual(it.next(), 42)
+        self.assertEqual(it.next(), 42)
+
     def test_join_timeout(self):
     def test_join_timeout(self):
         ar = MockAsyncResultSuccess(gen_unique_id())
         ar = MockAsyncResultSuccess(gen_unique_id())
         ar2 = MockAsyncResultSuccess(gen_unique_id())
         ar2 = MockAsyncResultSuccess(gen_unique_id())

+ 8 - 1
celery/worker/job.py

@@ -178,6 +178,9 @@ class TaskRequest(object):
     #: UUID of the task.
     #: UUID of the task.
     task_id = None
     task_id = None
 
 
+    #: UUID of the taskset that this task belongs to.
+    taskset_id = None
+
     #: List of positional arguments to apply to the task.
     #: List of positional arguments to apply to the task.
     args = None
     args = None
 
 
@@ -242,10 +245,12 @@ class TaskRequest(object):
     def __init__(self, task_name, task_id, args, kwargs,
     def __init__(self, task_name, task_id, args, kwargs,
             on_ack=noop, retries=0, delivery_info=None, hostname=None,
             on_ack=noop, retries=0, delivery_info=None, hostname=None,
             email_subject=None, email_body=None, logger=None,
             email_subject=None, email_body=None, logger=None,
-            eventer=None, eta=None, expires=None, app=None, **opts):
+            eventer=None, eta=None, expires=None, app=None,
+            taskset_id=None, **opts):
         self.app = app_or_default(app)
         self.app = app_or_default(app)
         self.task_name = task_name
         self.task_name = task_name
         self.task_id = task_id
         self.task_id = task_id
+        self.taskset_id = taskset_id
         self.retries = retries
         self.retries = retries
         self.args = args
         self.args = args
         self.kwargs = kwargs
         self.kwargs = kwargs
@@ -282,6 +287,7 @@ class TaskRequest(object):
 
 
         return cls(task_name=body["task"],
         return cls(task_name=body["task"],
                    task_id=body["id"],
                    task_id=body["id"],
+                   taskset_id=body.get("taskset", None),
                    args=body["args"],
                    args=body["args"],
                    kwargs=kwdict(kwargs),
                    kwargs=kwdict(kwargs),
                    retries=body.get("retries", 0),
                    retries=body.get("retries", 0),
@@ -295,6 +301,7 @@ class TaskRequest(object):
         return {"logfile": logfile,
         return {"logfile": logfile,
                 "loglevel": loglevel,
                 "loglevel": loglevel,
                 "id": self.task_id,
                 "id": self.task_id,
+                "taskset": self.taskset_id,
                 "retries": self.retries,
                 "retries": self.retries,
                 "is_eager": False,
                 "is_eager": False,
                 "delivery_info": self.delivery_info}
                 "delivery_info": self.delivery_info}