Explorar o código

AMQP Backend: Fixed possible DOS attack introduced by previous commit.

There is now a limit to the number of state messages we receive,
so it doesn't enter a busy loop. (by default it fast forwards to a maximum of
100 messages).
Ask Solem %!s(int64=14) %!d(string=hai) anos
pai
achega
be4c7060ea
Modificáronse 2 ficheiros con 12 adicións e 3 borrados
  1. 11 2
      celery/backends/amqp.py
  2. 1 1
      celery/tests/test_backends/test_amqp.py

+ 11 - 2
celery/backends/amqp.py

@@ -4,6 +4,7 @@ import socket
 import time
 
 from datetime import timedelta
+from itertools import count
 
 from kombu.entity import Exchange, Queue
 from kombu.messaging import Consumer, Producer
@@ -14,6 +15,10 @@ from celery.exceptions import TimeoutError
 from celery.utils import timeutils
 
 
+class BacklogLimitExceeded(Exception):
+    """Too much state history to fast-forward."""
+
+
 def repair_uuid(s):
     # Historically the dashes in UUIDS are removed from AMQ entity names,
     # but there is no known reason to.  Hopefully we'll be able to fix
@@ -28,6 +33,8 @@ class AMQPBackend(BaseDictBackend):
     Consumer = Consumer
     Producer = Producer
 
+    BacklogLimitExceeded = BacklogLimitExceeded
+
     _pool = None
     _pool_owner_pid = None
 
@@ -139,17 +146,19 @@ class AMQPBackend(BaseDictBackend):
         else:
             return self.wait_for(task_id, timeout, cache)
 
-    def poll(self, task_id):
+    def poll(self, task_id, backlog_limit=100):
         conn = self.pool.acquire(block=True)
         channel = conn.channel()
         try:
             binding = self._create_binding(task_id)(channel)
             binding.declare()
             latest, acc = None, None
-            while 1:  # fetch the last state
+            for i in count():  # fast-forward
                 latest, acc = acc, binding.get(no_ack=True)
                 if not acc:
                     break
+                if i > backlog_limit:
+                    raise self.BacklogLimitExceeded(task_id)
             if latest:
                 payload = self._cache[task_id] = latest.payload
                 return payload

+ 1 - 1
celery/tests/test_backends/test_amqp.py

@@ -135,7 +135,7 @@ class test_AMQPBackend(unittest.TestCase):
             def declare(self):
                 pass
 
-            def get(self):
+            def get(self, no_ack=False):
                 if self.get_returns[0]:
                     class Object(object):
                         payload = {"status": "STARTED",