Browse Source

Add new function to handle etas and limits together (#4251)

* Add new function to handle etas and limits together

* Adding unit test

* Fixing indentation
arpanshah29 7 years ago
parent
commit
ca962fa72f
3 changed files with 42 additions and 13 deletions
  1. 6 0
      celery/worker/consumer/consumer.py
  2. 21 13
      celery/worker/strategy.py
  3. 15 0
      t/unit/worker/test_strategy.py

+ 6 - 0
celery/worker/consumer/consumer.py

@@ -305,6 +305,12 @@ class Consumer(object):
             return bucket.add(request)
         return self._schedule_bucket_request(request, bucket, tokens)
 
+    def _limit_post_eta(self, request, bucket, tokens):
+        self.qos.decrement_eventually()
+        if bucket.contents:
+            return bucket.add(request)
+        return self._schedule_bucket_request(request, bucket, tokens)
+
     def start(self):
         blueprint = self.blueprint
         while blueprint.state not in STOP_CONDITIONS:

+ 21 - 13
celery/worker/strategy.py

@@ -84,6 +84,7 @@ def default(task, app, consumer,
     get_bucket = consumer.task_buckets.__getitem__
     handle = consumer.on_task_request
     limit_task = consumer._limit_task
+    limit_post_eta = consumer._limit_post_eta
     body_can_be_buffer = consumer.pool.body_can_be_buffer
     Request = symbol_by_name(task.Request)
     Req = create_request_cls(Request, task, consumer.pool, hostname, eventer)
@@ -123,6 +124,8 @@ def default(task, app, consumer,
                 expires=req.expires and req.expires.isoformat(),
             )
 
+        bucket = None
+        eta = None
         if req.eta:
             try:
                 if req.utc:
@@ -133,17 +136,22 @@ def default(task, app, consumer,
                 error("Couldn't convert ETA %r to timestamp: %r. Task: %r",
                       req.eta, exc, req.info(safe=True), exc_info=True)
                 req.reject(requeue=False)
-            else:
-                consumer.qos.increment_eventually()
-                call_at(eta, apply_eta_task, (req,), priority=6)
-        else:
-            if rate_limits_enabled:
-                bucket = get_bucket(task.name)
-                if bucket:
-                    return limit_task(req, bucket, 1)
-            task_reserved(req)
-            if callbacks:
-                [callback(req) for callback in callbacks]
-            handle(req)
-
+        if rate_limits_enabled:
+            bucket = get_bucket(task.name)
+
+        if eta and bucket:
+            consumer.qos.increment_eventually()
+            return call_at(eta, limit_post_eta, (req, bucket, 1),
+                           priority=6)
+        if eta:
+            consumer.qos.increment_eventually()
+            call_at(eta, apply_eta_task, (req,), priority=6)
+            return task_message_handler
+        if bucket:
+            return limit_task(req, bucket, 1)
+
+        task_reserved(req)
+        if callbacks:
+            [callback(req) for callback in callbacks]
+        handle(req)
     return task_message_handler

+ 15 - 0
t/unit/worker/test_strategy.py

@@ -98,6 +98,14 @@ class test_default_strategy_proto2:
             assert not self.was_reserved()
             return self.consumer._limit_task.called
 
+        def was_limited_with_eta(self):
+            assert not self.was_reserved()
+            called = self.consumer.timer.call_at.called
+            if called:
+                assert self.consumer.timer.call_at.call_args[0][1] == \
+                    self.consumer._limit_post_eta
+            return called
+
         def was_scheduled(self):
             assert not self.was_reserved()
             assert not self.was_rate_limited()
@@ -186,6 +194,13 @@ class test_default_strategy_proto2:
             C()
             assert C.was_rate_limited()
 
+    def test_when_rate_limited_with_eta(self):
+        task = self.add.s(2, 2).set(countdown=10)
+        with self._context(task, rate_limits=True, limit='1/m') as C:
+            C()
+            assert C.was_limited_with_eta()
+            C.consumer.qos.increment_eventually.assert_called_with()
+
     def test_when_rate_limited__limits_disabled(self):
         task = self.add.s(2, 2)
         with self._context(task, rate_limits=False, limit='1/m') as C: