Browse Source

Unittests for celery.events.state

Ask Solem 14 years ago
parent
commit
33a10d2b95
2 changed files with 284 additions and 46 deletions
  1. 108 46
      celery/events/state.py
  2. 176 0
      celery/tests/test_events_state.py

+ 108 - 46
celery/events/state.py

@@ -4,21 +4,27 @@ import heapq
 from carrot.utils import partition
 
 from celery import states
+from celery.datastructures import LocalCache
 
 HEARTBEAT_EXPIRE = 150 # 2 minutes, 30 seconds
 
 
-class Element(object):
+class Element(dict):
     """Base class for types."""
     visited = False
 
     def __init__(self, **fields):
-        self.update(fields)
+        dict.__init__(self, fields)
 
-    def update(self, fields, **extra):
-        for field_name, field_value in dict(fields, **extra).items():
-            setattr(self, field_name, field_value)
+    def __getattr__(self, key):
+        try:
+            return self[key]
+        except KeyError:
+            raise AttributeError("'%s' object has no attribute '%s'" % (
+                    self.__class__.__name__, key))
 
+    def __setattr__(self, key, value):
+        self[key] = value
 
 
 class Worker(Element):
@@ -28,13 +34,13 @@ class Worker(Element):
         super(Worker, self).__init__(**fields)
         self.heartbeats = []
 
-    def online(self, timestamp=None, **kwargs):
+    def on_online(self, timestamp=None, **kwargs):
         self._heartpush(timestamp)
 
-    def offline(self, **kwargs):
+    def on_offline(self, **kwargs):
         self.heartbeats = []
 
-    def heartbeat(self, timestamp=None, **kwargs):
+    def on_heartbeat(self, timestamp=None, **kwargs):
         self._heartpush(timestamp)
 
     def _heartpush(self, timestamp):
@@ -44,7 +50,7 @@ class Worker(Element):
     @property
     def alive(self):
         return (self.heartbeats and
-                time.time() < self.heartbeats[0] + HEARTBEAT_EXPIRE)
+                time.time() < self.heartbeats[-1] + HEARTBEAT_EXPIRE)
 
 
 class Task(Element):
@@ -52,17 +58,25 @@ class Task(Element):
     _info_fields = ("args", "kwargs", "retries",
                     "result", "eta", "runtime",
                     "exception")
-    uuid = None
-    name = None
-    state = states.PENDING
-    received = False
-    started = False
-    args = None
-    kwargs = None
-    eta = None
-    retries = 0
-    worker = None
-    timestamp = None
+
+    _defaults = dict(uuid=None,
+                     name=None,
+                     state=states.PENDING,
+                     received=False,
+                     started=False,
+                     succeeded=False,
+                     failed=False,
+                     retried=False,
+                     revoked=False,
+                     args=None,
+                     kwargs=None,
+                     eta=None,
+                     retries=None,
+                     worker=None,
+                     timestamp=None)
+
+    def __init__(self, **fields):
+        super(Task, self).__init__(**dict(self._defaults, **fields))
 
     def info(self, fields=None, extra=[]):
         if fields is None:
@@ -77,52 +91,58 @@ class Task(Element):
         return self.state in states.READY_STATES
 
     def update(self, d, **extra):
-        d = dict(d, **extra)
         if self.worker:
-            self.worker.online()
-        return super(Task, self).update(d)
+            self.worker.on_heartbeat(timestamp=time.time())
+        return super(Task, self).update(d, **extra)
 
-    def received(self, timestamp=None, **fields):
+    def on_received(self, timestamp=None, **fields):
+        print("ON RECEIVED")
         self.received = timestamp
         self.state = "RECEIVED"
+        print(fields)
         self.update(fields, timestamp=timestamp)
 
-    def started(self, timestamp=None, **fields):
+    def on_started(self, timestamp=None, **fields):
         self.state = states.STARTED
         self.started = timestamp
-        self.update(fields)
+        self.update(fields, timestamp=timestamp)
 
-    def failed(self, timestamp=None, **fields):
+    def on_failed(self, timestamp=None, **fields):
         self.state = states.FAILURE
         self.failed = timestamp
         self.update(fields, timestamp=timestamp)
 
-    def retried(self, timestamp=None, **fields):
+    def on_retried(self, timestamp=None, **fields):
         self.state = states.RETRY
         self.retried = timestamp
         self.update(fields, timestamp=timestamp)
 
-    def succeeded(self, timestamp=None, **fields):
+    def on_succeeded(self, timestamp=None, **fields):
         self.state = states.SUCCESS
-        self.suceeded = timestamp
+        self.succeeded = timestamp
         self.update(fields, timestamp=timestamp)
 
-    def revoked(self, timestamp=None):
+    def on_revoked(self, timestamp=None, **fields):
         self.state = states.REVOKED
+        self.revoked = timestamp
+        self.update(fields, timestamp=timestamp)
 
 
 class State(object):
+    """Represents a snapshot of a clusters state."""
     event_count = 0
     task_count = 0
 
-    def __init__(self, callback=None):
-        self.workers = {}
-        self.tasks = {}
-        self.callback = callback
+    def __init__(self, callback=None,
+            max_workers_in_memory=5000, max_tasks_in_memory=10000):
+        self.workers = LocalCache(max_workers_in_memory)
+        self.tasks = LocalCache(max_tasks_in_memory)
+        self.event_callback = callback
         self.group_handlers = {"worker": self.worker_event,
                                "task": self.task_event}
 
-    def get_worker(self, hostname, **kwargs):
+    def get_or_create_worker(self, hostname, **kwargs):
+        """Get or create worker by hostname."""
         try:
             worker = self.workers[hostname]
             worker.update(kwargs)
@@ -131,7 +151,8 @@ class State(object):
                     hostname=hostname, **kwargs)
         return worker
 
-    def get_task(self, uuid, **kwargs):
+    def get_or_create_task(self, uuid, **kwargs):
+        """Get or create task by uuid."""
         try:
             task = self.tasks[uuid]
             task.update(kwargs)
@@ -140,34 +161,75 @@ class State(object):
         return task
 
     def worker_event(self, type, fields):
+        """Process worker event."""
         hostname = fields.pop("hostname")
-        worker = self.workers[hostname] = Worker(hostname=hostname)
-        handler = getattr(worker, type)
+        worker = self.get_or_create_worker(hostname)
+        handler = getattr(worker, "on_%s" % type)
         if handler:
             handler(**fields)
 
     def task_event(self, type, fields):
+        """Process task event."""
         uuid = fields.pop("uuid")
         hostname = fields.pop("hostname")
-        worker = self.get_worker(hostname)
-        task = self.get_task(uuid, worker=worker)
-        handler = getattr(task, type)
+        worker = self.get_or_create_worker(hostname)
+        task = self.get_or_create_task(uuid)
+        handler = getattr(task, "on_%s" % type)
         if type == "received":
             self.task_count += 1
         if handler:
             handler(**fields)
+        task.worker = worker
 
     def event(self, event):
+        """Process event."""
         event = dict((key.encode("utf-8"), value)
                         for key, value in event.items())
         self.event_count += 1
         group, _, type = partition(event.pop("type"), "-")
         self.group_handlers[group](type, event)
-        if self.callback:
-            self.callback(self, event)
+        if self.event_callback:
+            self.event_callback(self, event)
 
     def tasks_by_timestamp(self):
-        return sorted(self.tasks.items(), key=lambda t: t[1].timestamp,
-                reverse=True)
+        """Get tasks by timestamp.
+
+        Returns a list of ``(uuid, task)`` tuples.
+
+        """
+        return self._sort_tasks_by_time(self.tasks.items())
+
+    def _sort_tasks_by_time(self, tasks):
+        """Sort task items by time."""
+        return sorted(tasks, key=lambda t: t[1].timestamp, reverse=True)
+
+    def tasks_by_type(self, name):
+        """Get all tasks by type.
+
+        Returns a list of ``(uuid, task)`` tuples.
+
+        """
+        return self._sort_tasks_by_time([(uuid, task)
+                for uuid, task in self.tasks.items()
+                    if task.name == name])
+
+    def tasks_by_worker(self, hostname):
+        """Get all tasks by worker.
+
+        Returns a list of ``(uuid, task)`` tuples.
+
+        """
+        return self._sort_tasks_by_time([(uuid, task)
+                for uuid, task in self.tasks.items()
+                    if task.worker.hostname == hostname])
+
+    def task_types(self):
+        """Returns a list of all seen task types."""
+        return list(set(task.name for task in self.tasks.values()))
+
+    def alive_workers(self):
+        """Returns a list of (seemingly) alive workers."""
+        return [w for w in self.workers.values() if w.alive]
+
 
 state = State()

+ 176 - 0
celery/tests/test_events_state.py

@@ -0,0 +1,176 @@
+import time
+import unittest2 as unittest
+
+from itertools import count
+
+from celery import states
+from celery.events import Event
+from celery.events.state import State, HEARTBEAT_EXPIRE
+from celery.utils import gen_unique_id
+
+
+class replay(object):
+
+    def __init__(self, state):
+        self.state = state
+        self.rewind()
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        try:
+            self.state.event(self.events[self.position()])
+        except IndexError:
+            raise StopIteration()
+
+    def rewind(self):
+        self.position = count(0).next
+        return self
+
+    def play(self):
+        for _ in self:
+            pass
+
+
+class ev_worker_online_offline(replay):
+    events = [
+        Event("worker-online", hostname="utest1"),
+        Event("worker-offline", hostname="utest1"),
+    ]
+
+
+class ev_worker_heartbeats(replay):
+    events = [
+        Event("worker-heartbeat", hostname="utest1",
+              timestamp=time.time() - HEARTBEAT_EXPIRE * 2),
+        Event("worker-heartbeat", hostname="utest1"),
+    ]
+
+
+class ev_task_states(replay):
+    uuid = gen_unique_id()
+    events = [
+        Event("task-received", uuid=uuid, name="task1",
+              args="(2, 2)", kwargs="{'foo': 'bar'}",
+              retries=0, eta=None, hostname="utest1"),
+        Event("task-started", uuid=uuid, hostname="utest1"),
+        Event("task-succeeded", uuid=uuid, result="4",
+              runtime=0.1234, hostname="utest1"),
+        Event("task-failed", uuid=uuid, exception="KeyError('foo')",
+              traceback="line 1 at main", hostname="utest1"),
+        Event("task-retried", uuid=uuid, exception="KeyError('bar')",
+              traceback="line 2 at main", hostname="utest1"),
+        Event("task-revoked", uuid=uuid, hostname="utest1"),
+    ]
+
+
+class ev_snapshot(replay):
+    events = [
+        Event("worker-online", hostname="utest1"),
+        Event("worker-online", hostname="utest2"),
+        Event("worker-online", hostname="utest3"),
+    ]
+    for i in range(20):
+        worker = not i % 2 and "utest2" or "utest1"
+        type = not i % 2 and "task2" or "task1"
+        events.append(Event("task-received", name=type,
+                      uuid=gen_unique_id(), hostname=worker))
+
+
+class test_State(unittest.TestCase):
+
+    def test_worker_online_offline(self):
+        r = ev_worker_online_offline(State())
+        r.next()
+        self.assertTrue(r.state.alive_workers())
+        self.assertTrue(r.state.workers["utest1"].alive)
+        r.play()
+        self.assertFalse(r.state.alive_workers())
+        self.assertFalse(r.state.workers["utest1"].alive)
+
+    def test_worker_heartbeat_expire(self):
+        r = ev_worker_heartbeats(State())
+        r.next()
+        self.assertFalse(r.state.alive_workers())
+        self.assertFalse(r.state.workers["utest1"].alive)
+        r.play()
+        self.assertTrue(r.state.alive_workers())
+        self.assertTrue(r.state.workers["utest1"].alive)
+
+    def test_task_states(self):
+        r = ev_task_states(State())
+
+        # RECEIVED
+        r.next()
+        self.assertTrue(r.uuid in r.state.tasks)
+        task = r.state.tasks[r.uuid]
+        self.assertEqual(task.state, "RECEIVED")
+        self.assertTrue(task.received)
+        self.assertEqual(task.timestamp, task.received)
+        self.assertEqual(task.worker.hostname, "utest1")
+
+        # STARTED
+        r.next()
+        self.assertTrue(r.state.workers["utest1"].alive,
+                "any task event adds worker heartbeat")
+        self.assertEqual(task.state, states.STARTED)
+        self.assertTrue(task.started)
+        self.assertEqual(task.timestamp, task.started)
+        self.assertEqual(task.worker.hostname, "utest1")
+
+        # SUCCESS
+        r.next()
+        self.assertEqual(task.state, states.SUCCESS)
+        self.assertTrue(task.succeeded)
+        self.assertEqual(task.timestamp, task.succeeded)
+        self.assertEqual(task.worker.hostname, "utest1")
+        self.assertEqual(task.result, "4")
+        self.assertEqual(task.runtime, 0.1234)
+
+        # FAILURE
+        r.next()
+        self.assertEqual(task.state, states.FAILURE)
+        self.assertTrue(task.failed)
+        self.assertEqual(task.timestamp, task.failed)
+        self.assertEqual(task.worker.hostname, "utest1")
+        self.assertEqual(task.exception, "KeyError('foo')")
+        self.assertEqual(task.traceback, "line 1 at main")
+
+        # RETRY
+        r.next()
+        self.assertEqual(task.state, states.RETRY)
+        self.assertTrue(task.retried)
+        self.assertEqual(task.timestamp, task.retried)
+        self.assertEqual(task.worker.hostname, "utest1")
+        self.assertEqual(task.exception, "KeyError('bar')")
+        self.assertEqual(task.traceback, "line 2 at main")
+
+        # REVOKED
+        r.next()
+        self.assertEqual(task.state, states.REVOKED)
+        self.assertTrue(task.revoked)
+        self.assertEqual(task.timestamp, task.revoked)
+        self.assertEqual(task.worker.hostname, "utest1")
+
+    def test_tasks_by_timestamp(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertEqual(len(r.state.tasks_by_timestamp()), 20)
+
+    def test_tasks_by_type(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertEqual(len(r.state.tasks_by_type("task1")), 10)
+        self.assertEqual(len(r.state.tasks_by_type("task2")), 10)
+
+    def test_alive_workers(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertEqual(len(r.state.alive_workers()), 3)
+
+    def test_tasks_by_worker(self):
+        r = ev_snapshot(State())
+        r.play()
+        self.assertEqual(len(r.state.tasks_by_worker("utest1")), 10)
+        self.assertEqual(len(r.state.tasks_by_worker("utest2")), 10)