Browse Source

Fixes race condition in celery.events.state, where task info would be removed while iterating.

Closes #501.  Thanks to erikcw.
Ask Solem 13 years ago
parent
commit
dae6359b16
4 changed files with 76 additions and 8 deletions
  1. 3 0
      Changelog
  2. 16 2
      celery/datastructures.py
  3. 13 6
      celery/events/state.py
  4. 44 0
      celery/tests/test_utils/test_datastructures.py

+ 3 - 0
Changelog

@@ -184,6 +184,9 @@ News
 
     Contributed by Chris Chamberlin.
 
+* Fixed race condition in celery.events.state (celerymon/celeryev)
+  where task info would be removed while iterating over it (Issue #501).
+
 * The Cache, Cassandra, MongoDB, Redis and Tyrant backends now respects
   the :setting:`CELERY_RESULT_SERIALIZER` setting (Issue #435).
 

+ 16 - 2
celery/datastructures.py

@@ -341,10 +341,10 @@ class LRUCache(UserDict):
         return self.data.keys()
 
     def values(self):
-        return self.data.values()
+        return list(self.itervalues())
 
     def items(self):
-        return self.data.items()
+        return list(self.iteritems())
 
     def __setitem__(self, key, value):
         # remove least recently used key.
@@ -356,6 +356,20 @@ class LRUCache(UserDict):
     def __iter__(self):
         return self.data.iterkeys()
 
+    def iteritems(self):
+        for k in self.data:
+            try:
+                yield (k, self.data[k])
+            except KeyError:
+                pass
+
+    def itervalues(self):
+        for k in self.data:
+            try:
+                yield self.data[k]
+            except KeyError:
+                pass
+
 
 class TokenBucket(object):
     """Token Bucket Algorithm.

+ 13 - 6
celery/events/state.py

@@ -216,8 +216,9 @@ class State(object):
 
     def _clear_tasks(self, ready=True):
         if ready:
-            self.tasks = dict((uuid, task)
-                                for uuid, task in self.tasks.items()
+            self.tasks.clear()
+            self.tasks.update((uuid, task)
+                                for uuid, task in self.itertasks()
                                     if task.state not in states.READY_STATES)
         else:
             self.tasks.clear()
@@ -286,13 +287,19 @@ class State(object):
         if self.event_callback:
             self.event_callback(self, event)
 
+    def itertasks(self, limit=None):
+        for index, row in enumerate(self.tasks.iteritems()):
+            yield row
+            if limit and index >= limit:
+                break
+
     def tasks_by_timestamp(self, limit=None):
         """Get tasks by timestamp.
 
         Returns a list of `(uuid, task)` tuples.
 
         """
-        return self._sort_tasks_by_time(self.tasks.items()[:limit])
+        return self._sort_tasks_by_time(self.itertasks(limit))
 
     def _sort_tasks_by_time(self, tasks):
         """Sort task items by time."""
@@ -306,7 +313,7 @@ class State(object):
 
         """
         return self._sort_tasks_by_time([(uuid, task)
-                for uuid, task in self.tasks.items()[:limit]
+                for uuid, task in self.itertasks(limit)
                     if task.name == name])
 
     def tasks_by_worker(self, hostname, limit=None):
@@ -316,12 +323,12 @@ class State(object):
 
         """
         return self._sort_tasks_by_time([(uuid, task)
-                for uuid, task in self.tasks.items()[:limit]
+                for uuid, task in self.itertasks(limit)
                     if task.worker.hostname == hostname])
 
     def task_types(self):
         """Returns a list of all seen task types."""
-        return list(sorted(set(task.name for task in self.tasks.values())))
+        return list(sorted(set(task.name for task in self.tasks.itervalues())))
 
     def alive_workers(self):
         """Returns a list of (seemingly) alive workers."""

+ 44 - 0
celery/tests/test_utils/test_datastructures.py

@@ -151,6 +151,50 @@ class test_LRUCache(unittest.TestCase):
         x[7] = 7
         self.assertEqual(x.keys(), [3, 6, 7])
 
+    def assertSafeIter(self, method, interval=0.01, size=10000):
+        from threading import Thread, Event
+        from time import sleep
+        x = LRUCache(size)
+        x.update(zip(xrange(size), xrange(size)))
+
+        class Burglar(Thread):
+
+            def __init__(self, cache):
+                self.cache = cache
+                self._is_shutdown = Event()
+                self._is_stopped = Event()
+                Thread.__init__(self)
+
+            def run(self):
+                while not self._is_shutdown.isSet():
+                    try:
+                        self.cache.data.popitem(last=False)
+                    except KeyError:
+                        break
+                self._is_stopped.set()
+
+            def stop(self):
+                self._is_shutdown.set()
+                self._is_stopped.wait()
+                self.join(1e10)
+
+        burglar = Burglar(x)
+        burglar.start()
+        try:
+            for _ in getattr(x, method)():
+                sleep(0.0001)
+        finally:
+            burglar.stop()
+
+    def test_safe_to_remove_while_iteritems(self):
+        self.assertSafeIter("iteritems")
+
+    def test_safe_to_remove_while_iterkeys(self):
+        self.assertSafeIter("iterkeys")
+
+    def test_safe_to_remove_while_itervalues(self):
+        self.assertSafeIter("itervalues")
+
 
 class test_AttributeDict(unittest.TestCase):