Browse Source

Replaced State.freeze/.thaw with a simple mutex. Fixes race condition.

Ask Solem 14 years ago
parent
commit
2bdc13385c
2 changed files with 32 additions and 63 deletions
  1. 2 2
      celery/events/snapshot.py
  2. 30 61
      celery/events/state.py

+ 2 - 2
celery/events/snapshot.py

@@ -12,6 +12,7 @@ from celery.utils.timeutils import rate
 class Polaroid(object):
     shutter_signal = Signal(providing_args=("state", ))
     cleanup_signal = Signal()
+    clear_after = False
 
     _tref = None
 
@@ -51,10 +52,9 @@ class Polaroid(object):
             self.debug("Shutter: %s" % (self.state, ))
             self.shutter_signal.send(self.state)
             self.on_shutter(self.state)
-            self.state.clear()
 
     def capture(self):
-        return self.state.freeze_while(self.shutter)
+        self.state.freeze_while(self.shutter, clear_after=self.clear_after)
 
     def cancel(self):
         if self._tref:

+ 30 - 61
celery/events/state.py

@@ -2,7 +2,7 @@ import time
 import heapq
 
 from collections import deque
-from threading import RLock
+from threading import Lock
 
 from carrot.utils import partition
 
@@ -147,9 +147,6 @@ class State(object):
     """Records clusters state."""
     event_count = 0
     task_count = 0
-    _buffering = False
-    buffer = deque()
-    frozen = False
 
     def __init__(self, callback=None,
             max_workers_in_memory=5000, max_tasks_in_memory=10000):
@@ -158,55 +155,26 @@ class State(object):
         self.event_callback = callback
         self.group_handlers = {"worker": self.worker_event,
                                "task": self.task_event}
-        self._resource = RLock()
-
-    def freeze(self, buffer=True):
-        """Stop recording the event stream.
-
-        :keyword buffer: If true, any events received while frozen
-           will be buffered, you can use ``thaw(replay=True)`` to apply
-           this buffer. :meth:`thaw` will clear the buffer and resume
-           recording the stream.
-
-        """
-        self._buffering = buffer
-        self.frozen = True
-
-    def _replay(self):
-        while self.buffer:
-            try:
-                event = self.buffer.popleft()
-            except IndexError:
-                pass
-            self._dispatch_event(event)
-
-    def thaw(self, replay=True):
-        """Resume recording of the event stream.
-
-        :keyword replay: Will replay buffered events received while
-          the stream was frozen.
-
-        This will always clear the buffer, deleting any events collected
-        while the stream was frozen.
-
-        """
-        self._buffering = False
-        try:
-            if replay:
-                self._replay()
-            else:
-                self.buffer.clear()
-        finally:
-            self.frozen = False
+        self._mutex = Lock()
 
     def freeze_while(self, fun, *args, **kwargs):
-        self.freeze()
+        clear_after = kwargs.pop("clear_after", False)
+        self._mutex.acquire()
         try:
             return fun(*args, **kwargs)
         finally:
-            self.thaw(replay=True)
+            if clear_after:
+                self._clear()
+            self._mutex.release()
 
     def clear_tasks(self, ready=True):
+        self._mutex.acquire()
+        try:
+            return self._clear_tasks(ready)
+        finally:
+            self._mutex.release()
+
+    def _clear_tasks(self, ready=True):
         if ready:
             self.tasks = dict((uuid, task)
                                 for uuid, task in self.tasks.items()
@@ -214,14 +182,18 @@ class State(object):
         else:
             self.tasks.clear()
 
+    def _clear(self, ready=True):
+        self.workers.clear()
+        self._clear_tasks(ready)
+        self.event_count = 0
+        self.task_count = 0
+
     def clear(self, ready=True):
+        self._mutex.acquire()
         try:
-            self.workers.clear()
-            self.clear_tasks(ready)
-            self.event_count = 0
-            self.task_count = 0
+            return self._clear(ready)
         finally:
-            pass
+            self._mutex.release()
 
     def get_or_create_worker(self, hostname, **kwargs):
         """Get or create worker by hostname."""
@@ -263,6 +235,13 @@ class State(object):
             handler(**fields)
         task.worker = worker
 
+    def event(self, event):
+        self._mutex.acquire()
+        try:
+            return self._dispatch_event(event)
+        finally:
+            self._mutex.release()
+
     def _dispatch_event(self, event):
         self.event_count += 1
         event = kwdict(event)
@@ -271,16 +250,6 @@ class State(object):
         if self.event_callback:
             self.event_callback(self, event)
 
-    def event(self, event):
-        """Process event."""
-        try:
-            if not self.frozen:
-                self._dispatch_event(event)
-            elif self._buffering:
-                self.buffer.append(event)
-        finally:
-            pass
-
     def tasks_by_timestamp(self, limit=None):
         """Get tasks by timestamp.