Bläddra i källkod

Py3: events.State Worker and Task now hashable. Closes #1416

Ask Solem 11 år sedan
förälder
incheckning
b9d53e05be
3 ändrade filer med 60 tillägg och 9 borttagningar
  1. 27 7
      celery/events/state.py
  2. 30 0
      celery/tests/events/test_state.py
  3. 3 2
      celery/worker/consumer.py

+ 27 - 7
celery/events/state.py

@@ -95,19 +95,38 @@ class _lamportinfo(tuple):
     obj = property(itemgetter(3))
 
 
-class Element(AttributeDict):
-    """Base class for worker state elements."""
+def with_unique_field(attr):
 
+    def _decorate_cls(cls):
 
-class Worker(Element):
+        def __eq__(this, other):
+            if isinstance(other, this.__class__):
+                return getattr(this, attr) == getattr(other, attr)
+            return NotImplemented
+        cls.__eq__ = __eq__
+
+        def __ne__(this, other):
+            return not this.__eq__(other)
+        cls.__ne__ = __ne__
+
+        def __hash__(this):
+            return hash(getattr(this, attr))
+        cls.__hash__ = __hash__
+
+        return cls
+    return _decorate_cls
+
+
+@with_unique_field('hostname')
+class Worker(AttributeDict):
     """Worker State."""
     heartbeat_max = 4
     expire_window = HEARTBEAT_EXPIRE_WINDOW
     pid = None
+    _defaults = {'hostname': None, 'pid': None, 'freq': 60}
 
     def __init__(self, **fields):
-        fields.setdefault('freq', 60)
-        super(Worker, self).__init__(**fields)
+        dict.__init__(self, self._defaults, **fields)
         self.heartbeats = []
 
     def on_online(self, timestamp=None, local_received=None, **kwargs):
@@ -158,7 +177,8 @@ class Worker(Element):
         return '{0.hostname}.{0.pid}'.format(self)
 
 
-class Task(Element):
+@with_unique_field('uuid')
+class Task(AttributeDict):
     """Task State."""
 
     #: How to merge out of order events.
@@ -188,7 +208,7 @@ class Task(Element):
                      clock=0)
 
     def __init__(self, **fields):
-        super(Task, self).__init__(**dict(self._defaults, **fields))
+        dict.__init__(self, self._defaults, **fields)
 
     def update(self, state, timestamp, fields):
         """Update state from new event.

+ 30 - 0
celery/tests/events/test_state.py

@@ -109,6 +109,21 @@ class ev_snapshot(replay):
 
 class test_Worker(Case):
 
+    def test_equality(self):
+        self.assertEqual(Worker(hostname='foo').hostname, 'foo')
+        self.assertEqual(
+            Worker(hostname='foo'), Worker(hostname='foo'),
+        )
+        self.assertNotEqual(
+            Worker(hostname='foo'), Worker(hostname='bar'),
+        )
+        self.assertEqual(
+            hash(Worker(hostname='foo')), hash(Worker(hostname='foo')),
+        )
+        self.assertNotEqual(
+            hash(Worker(hostname='foo')), hash(Worker(hostname='bar')),
+        )
+
     def test_survives_missing_timestamp(self):
         worker = Worker(hostname='foo')
         worker.on_heartbeat(timestamp=None)
@@ -134,6 +149,21 @@ class test_Worker(Case):
 
 class test_Task(Case):
 
+    def test_equality(self):
+        self.assertEqual(Task(uuid='foo').uuid, 'foo')
+        self.assertEqual(
+            Task(uuid='foo'), Task(uuid='foo'),
+        )
+        self.assertNotEqual(
+            Task(uuid='foo'), Task(uuid='bar'),
+        )
+        self.assertEqual(
+            hash(Task(uuid='foo')), hash(Task(uuid='foo')),
+        )
+        self.assertNotEqual(
+            hash(Task(uuid='foo')), hash(Task(uuid='bar')),
+        )
+
     def test_info(self):
         task = Task(uuid='abcdefg',
                     name='tasks.add',

+ 3 - 2
celery/worker/consumer.py

@@ -658,13 +658,14 @@ class Gossip(bootsteps.ConsumerStep):
         )
 
     def periodic(self):
+        workers = self.state.workers
         dirty = set()
-        for worker in values(self.state.workers):
+        for worker in values(workers):
             if not worker.alive:
                 dirty.add(worker)
                 self.on_node_lost(worker)
         for worker in dirty:
-            self.state.workers.pop(worker.hostname, None)
+            workers.pop(worker.hostname, None)
 
     def get_consumers(self, channel):
         self.register_timer()