Explorar o código

celery.events.state: Support events received out of order. Tries to guess order based on state precedence rules

Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
b3d3169e4a
Modificáronse 2 ficheiros con 46 adicións e 36 borrados
  1. 21 11
      celery/events/state.py
  2. 25 25
      celery/tests/test_events_state.py

+ 21 - 11
celery/events/state.py

@@ -56,8 +56,8 @@ class Task(Element):
                     "result", "eta", "runtime", "expires",
                     "exception")
 
-    _merge_rules = {"RECEIVED": ("name", "args", "kwargs",
-                                 "retries", "eta", "expires")}
+    _merge_rules = {states.RECEIVED: ("name", "args", "kwargs",
+                                      "retries", "eta", "expires")}
 
     _defaults = dict(uuid=None,
                      name=None,
@@ -95,14 +95,25 @@ class Task(Element):
     def ready(self):
         return self.state in states.READY_STATES
 
-    def update(self, d, **extra):
+    def update(self, state, timestamp, fields):
         if self.worker:
-            self.worker.on_heartbeat(timestamp=time.time())
-        return super(Task, self).update(d, **extra)
+            self.worker.on_heartbeat(timestamp=timestamp)
+        if state < self.state:
+            self.merge(state, timestamp, fields)
+        else:
+            self.state = state
+            self.timestamp = timestamp
+            super(Task, self).update(fields)
+
+    def merge(self, state, timestamp, fields):
+        keep = self._merge_rules.get(state)
+        if keep is not None:
+            fields = dict((key, fields[key]) for key in keep)
+            super(Task, self).update(fields)
 
     def on_received(self, timestamp=None, **fields):
         self.received = timestamp
-        self.update("RECEIVED", timestamp, fields)
+        self.update(states.RECEIVED, timestamp, fields)
 
     def on_started(self, timestamp=None, **fields):
         self.started = timestamp
@@ -215,14 +226,13 @@ class State(object):
                     hostname=hostname, **kwargs)
         return worker
 
-    def get_or_create_task(self, uuid, **kwargs):
+    def get_or_create_task(self, uuid):
         """Get or create task by uuid."""
         try:
-            task = self.tasks[uuid]
-            task.update(kwargs)
+            return self.tasks[uuid]
         except KeyError:
-            task = self.tasks[uuid] = Task(uuid=uuid, **kwargs)
-        return task
+            task = self.tasks[uuid] = Task(uuid=uuid)
+            return task
 
     def worker_event(self, type, fields):
         """Process worker event."""

+ 25 - 25
celery/tests/test_events_state.py

@@ -55,13 +55,13 @@ class ev_task_states(replay):
               args="(2, 2)", kwargs="{'foo': 'bar'}",
               retries=0, eta=None, hostname="utest1"),
         Event("task-started", uuid=uuid, hostname="utest1"),
-        Event("task-succeeded", uuid=uuid, result="4",
-              runtime=0.1234, hostname="utest1"),
-        Event("task-failed", uuid=uuid, exception="KeyError('foo')",
-              traceback="line 1 at main", hostname="utest1"),
+        Event("task-revoked", uuid=uuid, hostname="utest1"),
         Event("task-retried", uuid=uuid, exception="KeyError('bar')",
               traceback="line 2 at main", hostname="utest1"),
-        Event("task-revoked", uuid=uuid, hostname="utest1"),
+        Event("task-failed", uuid=uuid, exception="KeyError('foo')",
+              traceback="line 1 at main", hostname="utest1"),
+        Event("task-succeeded", uuid=uuid, result="4",
+              runtime=0.1234, hostname="utest1"),
     ]
 
 
@@ -150,7 +150,7 @@ class test_State(unittest.TestCase):
         r.next()
         self.assertTrue(r.uuid in r.state.tasks)
         task = r.state.tasks[r.uuid]
-        self.assertEqual(task.state, "RECEIVED")
+        self.assertEqual(task.state, states.RECEIVED)
         self.assertTrue(task.received)
         self.assertEqual(task.timestamp, task.received)
         self.assertEqual(task.worker.hostname, "utest1")
@@ -164,23 +164,12 @@ class test_State(unittest.TestCase):
         self.assertEqual(task.timestamp, task.started)
         self.assertEqual(task.worker.hostname, "utest1")
 
-        # SUCCESS
-        r.next()
-        self.assertEqual(task.state, states.SUCCESS)
-        self.assertTrue(task.succeeded)
-        self.assertEqual(task.timestamp, task.succeeded)
-        self.assertEqual(task.worker.hostname, "utest1")
-        self.assertEqual(task.result, "4")
-        self.assertEqual(task.runtime, 0.1234)
-
-        # FAILURE
+        # REVOKED
         r.next()
-        self.assertEqual(task.state, states.FAILURE)
-        self.assertTrue(task.failed)
-        self.assertEqual(task.timestamp, task.failed)
+        self.assertEqual(task.state, states.REVOKED)
+        self.assertTrue(task.revoked)
+        self.assertEqual(task.timestamp, task.revoked)
         self.assertEqual(task.worker.hostname, "utest1")
-        self.assertEqual(task.exception, "KeyError('foo')")
-        self.assertEqual(task.traceback, "line 1 at main")
 
         # RETRY
         r.next()
@@ -191,12 +180,23 @@ class test_State(unittest.TestCase):
         self.assertEqual(task.exception, "KeyError('bar')")
         self.assertEqual(task.traceback, "line 2 at main")
 
-        # REVOKED
+        # FAILURE
         r.next()
-        self.assertEqual(task.state, states.REVOKED)
-        self.assertTrue(task.revoked)
-        self.assertEqual(task.timestamp, task.revoked)
+        self.assertEqual(task.state, states.FAILURE)
+        self.assertTrue(task.failed)
+        self.assertEqual(task.timestamp, task.failed)
+        self.assertEqual(task.worker.hostname, "utest1")
+        self.assertEqual(task.exception, "KeyError('foo')")
+        self.assertEqual(task.traceback, "line 1 at main")
+
+        # SUCCESS
+        r.next()
+        self.assertEqual(task.state, states.SUCCESS)
+        self.assertTrue(task.succeeded)
+        self.assertEqual(task.timestamp, task.succeeded)
         self.assertEqual(task.worker.hostname, "utest1")
+        self.assertEqual(task.result, "4")
+        self.assertEqual(task.runtime, 0.1234)
 
     def test_freeze_thaw__buffering(self):
         s = State()