Browse Source

Replaced SharedCounter with good old fashioned mutual exclusion

Ask Solem 14 years ago
parent
commit
25a59556f0

+ 2 - 2
celery/contrib/batches.py

@@ -106,7 +106,7 @@ class Batches(Task):
         return self.apply_buffer(requests, ([SimpleRequest.from_request(r)
                                                 for r in requests], ))
 
-    def execute(self, request, pool, loglevel, logfile, consumer):
+    def execute(self, request, pool, loglevel, logfile):
         if not self._pool:         # just take pool from first task.
             self._pool = pool
 
@@ -126,7 +126,7 @@ class Batches(Task):
         if self._buffer.qsize():
             requests = list(consume_queue(self._buffer))
             if requests:
-                self.debug("Buffer complete: %s" % (len(requests, )))
+                self.debug("Buffer complete: %s" % (len(requests), ))
                 self.flush(requests)
         if not requests:
             self.debug("Cancelling timer: Nothing in buffer.")

+ 0 - 61
celery/datastructures.py

@@ -190,67 +190,6 @@ def consume_queue(queue):
             break
 
 
-class SharedCounter(object):
-    """Thread-safe counter.
-
-    Please note that the final value is not synchronized, this means
-    that you should not update the value by using a previous value, the only
-    reliable operations are increment and decrement.
-
-    Example::
-
-        >>> max_clients = SharedCounter(initial_value=10)
-
-        # Thread one
-        >>> max_clients += 1 # OK (safe)
-
-        # Thread two
-        >>> max_clients -= 3 # OK (safe)
-
-        # Main thread
-        >>> if client >= int(max_clients): # Max clients now at 8
-        ...    wait()
-
-        >>> max_client = max_clients + 10 # NOT OK (unsafe)
-
-    """
-
-    def __init__(self, initial_value):
-        self._value = initial_value
-        self._modify_queue = Queue()
-
-    def increment(self, n=1):
-        """Increment value."""
-        self += n
-        return int(self)
-
-    def decrement(self, n=1):
-        """Decrement value."""
-        self -= n
-        return int(self)
-
-    def _update_value(self):
-        self._value += sum(consume_queue(self._modify_queue))
-        return self._value
-
-    def __iadd__(self, y):
-        """`self += y`"""
-        self._modify_queue.put(y * +1)
-        return self
-
-    def __isub__(self, y):
-        """`self -= y`"""
-        self._modify_queue.put(y * -1)
-        return self
-
-    def __int__(self):
-        """`int(self) -> int`"""
-        return self._update_value()
-
-    def __repr__(self):
-        return repr(int(self))
-
-
 class LimitedSet(object):
     """Kind-of Set with limitations.
 

+ 1 - 1
celery/tests/__init__.py

@@ -56,7 +56,7 @@ def find_distribution_modules(name=__name__, file=__file__):
 
 
 def import_all_modules(name=__name__, file=__file__,
-        skip=["celery.decorators"]):
+        skip=["celery.decorators", "celery.contrib.batches"]):
     for module in find_distribution_modules(name, file):
         if module not in skip:
             try:

+ 1 - 34
celery/tests/test_datastructures.py

@@ -3,7 +3,7 @@ from celery.tests.utils import unittest
 from Queue import Queue
 
 from celery.datastructures import ExceptionInfo, LocalCache
-from celery.datastructures import LimitedSet, SharedCounter, consume_queue
+from celery.datastructures import LimitedSet, consume_queue
 from celery.datastructures import AttributeDict, DictAttribute
 from celery.datastructures import ConfigurationView
 
@@ -103,39 +103,6 @@ class test_utilities(unittest.TestCase):
         self.assertRaises(StopIteration, it.next)
 
 
-class test_SharedCounter(unittest.TestCase):
-
-    def test_initial_value(self):
-        self.assertEqual(int(SharedCounter(10)), 10)
-
-    def test_increment(self):
-        c = SharedCounter(10)
-        c.increment()
-        self.assertEqual(int(c), 11)
-        c.increment(2)
-        self.assertEqual(int(c), 13)
-
-    def test_decrement(self):
-        c = SharedCounter(10)
-        c.decrement()
-        self.assertEqual(int(c), 9)
-        c.decrement(2)
-        self.assertEqual(int(c), 7)
-
-    def test_iadd(self):
-        c = SharedCounter(10)
-        c += 10
-        self.assertEqual(int(c), 20)
-
-    def test_isub(self):
-        c = SharedCounter(10)
-        c -= 20
-        self.assertEqual(int(c), -10)
-
-    def test_repr(self):
-        self.assertIn("10", repr(SharedCounter(10)))
-
-
 class test_LimitedSet(unittest.TestCase):
 
     def test_add(self):

+ 68 - 14
celery/tests/test_worker.py

@@ -170,34 +170,88 @@ def create_message(backend, **data):
                    content_type="application/x-python-serialize",
                    content_encoding="binary")
 
-
 class test_QoS(unittest.TestCase):
 
+    class _QoS(QoS):
+        def __init__(self, value):
+            self.value = value
+            QoS.__init__(self, None, value, None)
+
+        def set(self, value):
+            return value
+
+    def test_qos_increment_decrement(self):
+        qos = self._QoS(10)
+        self.assertEqual(qos.increment(), 11)
+        self.assertEqual(qos.increment(3), 14)
+        self.assertEqual(qos.increment(-30), 14)
+        self.assertEqual(qos.decrement(7), 7)
+        self.assertEqual(qos.decrement(), 6)
+        self.assertRaises(AssertionError, qos.decrement, 10)
+
+    def test_qos_disabled_increment_decrement(self):
+        qos = self._QoS(0)
+        self.assertEqual(qos.increment(), 0)
+        self.assertEqual(qos.increment(3), 0)
+        self.assertEqual(qos.increment(-30), 0)
+        self.assertEqual(qos.decrement(7), 0)
+        self.assertEqual(qos.decrement(), 0)
+        self.assertEqual(qos.decrement(10), 0)
+
+    def test_qos_thread_safe(self):
+        qos = self._QoS(10)
+
+        def add():
+            for i in xrange(1000):
+                qos.increment()
+
+        def sub():
+            for i in xrange(1000):
+                qos.decrement_eventually()
+
+        def threaded(funs):
+            from threading import Thread
+            threads = [Thread(target=fun) for fun in funs]
+            for thread in threads:
+                thread.start()
+            for thread in threads:
+                thread.join()
+
+        threaded([add, add])
+        self.assertEqual(qos.value, 2010)
+
+        qos.value = 1000
+        threaded([add, sub]) # n = 2
+        self.assertEqual(qos.value, 1000)
+
+        threaded([sub, add, add, sub]) # n = 4
+        self.assertEqual(qos.value, 1000)
+
     class MockConsumer(object):
         prefetch_count = 0
 
         def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
             self.prefetch_count = prefetch_count
 
-    def test_increment_decrement(self):
+    def test_consumer_increment_decrement(self):
         consumer = self.MockConsumer()
         qos = QoS(consumer, 10, app_or_default().log.get_default_logger())
         qos.update()
-        self.assertEqual(int(qos.value), 10)
+        self.assertEqual(qos.value, 10)
         self.assertEqual(consumer.prefetch_count, 10)
         qos.decrement()
-        self.assertEqual(int(qos.value), 9)
+        self.assertEqual(qos.value, 9)
         self.assertEqual(consumer.prefetch_count, 9)
         qos.decrement_eventually()
-        self.assertEqual(int(qos.value), 8)
+        self.assertEqual(qos.value, 8)
         self.assertEqual(consumer.prefetch_count, 9)
 
         # Does not decrement 0 value
-        qos.value._value = 0
+        qos.value = 0
         qos.decrement()
-        self.assertEqual(int(qos.value), 0)
+        self.assertEqual(qos.value, 0)
         qos.increment()
-        self.assertEqual(int(qos.value), 0)
+        self.assertEqual(qos.value, 0)
 
 
 class test_Consumer(unittest.TestCase):
@@ -435,10 +489,10 @@ class test_Consumer(unittest.TestCase):
         l.qos = QoS(None, 10, l.logger)
 
         task = object()
-        qos = l.qos.next
+        qos = l.qos.value
         l.apply_eta_task(task)
         self.assertIn(task, state.reserved_requests)
-        self.assertEqual(l.qos.next, qos - 1)
+        self.assertEqual(l.qos.value, qos - 1)
         self.assertIs(self.ready_queue.get_nowait(), task)
 
     def test_receieve_message_eta_isoformat(self):
@@ -531,10 +585,10 @@ class test_Consumer(unittest.TestCase):
 
         class _QoS(object):
             prev = 3
-            next = 4
+            value = 4
 
             def update(self):
-                self.prev = self.next
+                self.prev = self.value
 
         class _Consumer(MyKombuConsumer):
             iterations = 0
@@ -559,7 +613,7 @@ class test_Consumer(unittest.TestCase):
 
         def raises_KeyError(limit=None):
             l.iterations += 1
-            if l.qos.prev != l.qos.next:
+            if l.qos.prev != l.qos.value:
                 l.qos.update()
             if l.iterations >= 2:
                 raise KeyError("foo")
@@ -568,7 +622,7 @@ class test_Consumer(unittest.TestCase):
         self.assertRaises(KeyError, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
-        self.assertEqual(l.qos.prev, l.qos.next)
+        self.assertEqual(l.qos.prev, l.qos.value)
 
         l = _Consumer(self.ready_queue, self.eta_schedule, self.logger,
                       send_events=False, init_callback=init_callback)

+ 1 - 2
celery/worker/__init__.py

@@ -262,8 +262,7 @@ class WorkController(object):
         """Process task by sending it to the pool of workers."""
         try:
             request.task.execute(request, self.pool,
-                                 self.loglevel, self.logfile,
-                                 consumer=self.consumer)
+                                 self.loglevel, self.logfile)
         except SystemTerminate:
             self.terminate()
             raise SystemExit()

+ 41 - 18
celery/worker/consumer.py

@@ -72,11 +72,12 @@ from __future__ import generators
 
 import socket
 import sys
+import threading
 import traceback
 import warnings
 
 from celery.app import app_or_default
-from celery.datastructures import AttributeDict, SharedCounter
+from celery.datastructures import AttributeDict
 from celery.exceptions import NotRegistered
 from celery.utils import noop
 from celery.utils.timer2 import to_timestamp
@@ -104,17 +105,34 @@ class QoS(object):
     def __init__(self, consumer, initial_value, logger):
         self.consumer = consumer
         self.logger = logger
-        self.value = SharedCounter(initial_value)
+        self._mutex = threading.RLock()
+        self.value = initial_value
 
     def increment(self, n=1):
         """Increment the current prefetch count value by one."""
-        if int(self.value):
-            return self.set(self.value.increment(n))
+        self._mutex.acquire()
+        try:
+            if self.value:
+                self.value += max(n, 0)
+                self.set(self.value)
+            return self.value
+        finally:
+            self._mutex.release()
+
+    def _sub(self, n=1):
+        assert self.value -n > 1
+        self.value -= n
 
     def decrement(self, n=1):
         """Decrement the current prefetch count value by one."""
-        if int(self.value):
-            return self.set(self.value.decrement(n))
+        self._mutex.acquire()
+        try:
+            if self.value:
+                self._sub(n)
+                self.set(self.value)
+            return self.value
+        finally:
+            self._mutex.release()
 
     def decrement_eventually(self, n=1):
         """Decrement the value, but do not update the qos.
@@ -123,23 +141,28 @@ class QoS(object):
         when necessary.
 
         """
-        if int(self.value):
-            self.value.decrement(n)
+        self._mutex.acquire()
+        try:
+            if self.value:
+                self._sub(n)
+        finally:
+            self._mutex.release()
 
     def set(self, pcount):
         """Set channel prefetch_count setting."""
-        self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
-        self.consumer.qos(prefetch_count=pcount)
-        self.prev = pcount
+        if pcount != self.prev:
+            self.logger.debug("basic.qos: prefetch_count->%s" % pcount)
+            self.consumer.qos(prefetch_count=pcount)
+            self.prev = pcount
         return pcount
 
     def update(self):
         """Update prefetch count with current value."""
-        return self.set(self.next)
-
-    @property
-    def next(self):
-        return int(self.value)
+        self._mutex.acquire()
+        try:
+            return self.set(self.value)
+        finally:
+            self._mutex.release()
 
 
 class Consumer(object):
@@ -253,7 +276,7 @@ class Consumer(object):
         while 1:
             if not self.connection:
                 break
-            if self.qos.prev != self.qos.next:
+            if self.qos.prev != self.qos.value:
                 self.qos.update()
             self.connection.drain_events()
 
@@ -481,4 +504,4 @@ class Consumer(object):
             conninfo = self.connection.info()
             conninfo.pop("password", None)  # don't send password.
         return {"broker": conninfo,
-                "prefetch_count": self.qos.next}
+                "prefetch_count": self.qos.value}