Sfoglia il codice sorgente

remote-control rate limits: Forgot to update the buckets. See Issue #98.

Ask Solem 15 anni fa
parent
commit
3049f338eb
3 ha cambiato i file con 62 aggiunte e 20 eliminazioni
  1. 26 7
      celery/tests/test_worker_control.py
  2. 27 10
      celery/worker/buckets.py
  3. 9 3
      celery/worker/control.py

+ 26 - 7
celery/tests/test_worker_control.py

@@ -12,7 +12,10 @@ hostname = socket.gethostname()
 class TestControlPanel(unittest.TestCase):
 
     def setUp(self):
-        self.panel = control.ControlDispatch(hostname=hostname)
+        self.panel = self.create_panel()
+
+    def create_panel(self, **kwargs):
+        return control.ControlDispatch(hostname=hostname, **kwargs)
 
     def test_shutdown(self):
         self.assertRaises(SystemExit, self.panel.execute, "shutdown")
@@ -21,17 +24,33 @@ class TestControlPanel(unittest.TestCase):
         self.panel.execute("dump_tasks")
 
     def test_rate_limit(self):
+
+        class Listener(object):
+
+            class ReadyQueue(object):
+                fresh = False
+
+                def refresh(self):
+                    self.fresh = True
+
+            def __init__(self):
+                self.ready_queue = self.ReadyQueue()
+
+        listener = Listener()
+        panel = self.create_panel(listener=listener)
+
         task = tasks[PingTask.name]
         old_rate_limit = task.rate_limit
         try:
-            self.panel.execute("rate_limit", kwargs=dict(
-                                                task_name=task.name,
-                                                rate_limit="100/m"))
+            panel.execute("rate_limit", kwargs=dict(task_name=task.name,
+                                                    rate_limit="100/m"))
             self.assertEqual(task.rate_limit, "100/m")
-            self.panel.execute("rate_limit", kwargs=dict(
-                                                task_name=task.name,
-                                                rate_limit=0))
+            self.assertTrue(listener.ready_queue.fresh)
+            listener.ready_queue.fresh = False
+            panel.execute("rate_limit", kwargs=dict(task_name=task.name,
+                                                    rate_limit=0))
             self.assertEqual(task.rate_limit, 0)
+            self.assertTrue(listener.ready_queue.fresh)
         finally:
             task.rate_limit = old_rate_limit
 

+ 27 - 10
celery/worker/buckets.py

@@ -134,26 +134,31 @@ class TaskBucket(object):
         """Initialize with buckets for all the task types in the registry."""
         map(self.add_bucket_for_type, self.task_registry.keys())
 
+    def refresh(self):
+        """Refresh rate limits for all task types in the registry."""
+        map(self.update_bucket_for_type, self.task_registry.keys())
+
     def get_bucket_for_type(self, task_name):
         """Get the bucket for a particular task type."""
         if task_name not in self.buckets:
             return self.add_bucket_for_type(task_name)
         return self.buckets[task_name]
 
-    def add_bucket_for_type(self, task_name):
-        """Add a bucket for a task type.
-
-        Will read the tasks rate limit and create a :class:`TokenBucketQueue`
-        if it has one. If the task doesn't have a rate limit a regular Queue
-        will be used.
+    def _get_queue_for_type(self, task_name):
+        bucket = self.buckets[task_name]
+        if isinstance(bucket, TokenBucketQueue):
+            return bucket.queue
+        return bucket
 
-        """
-        if task_name in self.buckets:
-            return
+    def update_bucket_for_type(self, task_name):
         task_type = self.task_registry[task_name]
-        task_queue = task_type.rate_limit_queue_type()
         rate_limit = getattr(task_type, "rate_limit", None)
         rate_limit = parse_ratelimit_string(rate_limit)
+        if task_name in self.buckets:
+            task_queue = self._get_queue_for_type(task_name)
+        else:
+            task_queue = task_type.rate_limit_queue_type()
+
         if rate_limit:
             task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
         else:
@@ -162,6 +167,18 @@ class TaskBucket(object):
         self.buckets[task_name] = task_queue
         return task_queue
 
+    def add_bucket_for_type(self, task_name):
+        """Add a bucket for a task type.
+
+        Will read the tasks rate limit and create a :class:`TokenBucketQueue`
+        if it has one. If the task doesn't have a rate limit a regular Queue
+        will be used.
+
+        """
+        if task_name not in self.buckets:
+            return self.update_bucket_for_type(task_name)
+
+
     def qsize(self):
         """Get the total size of all the queues."""
         return sum(bucket.qsize() for bucket in self.buckets.values())

+ 9 - 3
celery/worker/control.py

@@ -20,9 +20,10 @@ class Control(object):
 
     """
 
-    def __init__(self, logger, hostname=None):
+    def __init__(self, logger, hostname=None, listener=None):
         self.logger = logger
         self.hostname = hostname or socket.gethostname()
+        self.listener = listener
 
     @expose
     def revoke(self, task_id, **kwargs):
@@ -45,6 +46,8 @@ class Control(object):
         except KeyError:
             return
 
+        self.listener.ready_queue.refresh()
+
         if not rate_limit:
             self.logger.warn("Disabled rate limits for tasks of type %s" % (
                                 task_name))
@@ -82,10 +85,13 @@ class ControlDispatch(object):
 
     panel_cls = Control
 
-    def __init__(self, logger=None, hostname=None):
+    def __init__(self, logger=None, hostname=None, listener=None):
         self.logger = logger or log.get_default_logger()
         self.hostname = hostname
-        self.panel = self.panel_cls(self.logger, hostname=self.hostname)
+        self.listener = listener
+        self.panel = self.panel_cls(self.logger,
+                                    hostname=self.hostname,
+                                    listener=self.listener)
 
     def dispatch_from_message(self, message):
         """Dispatch by using message data received by the broker.