소스 검색

[Optimization] Unrolls Consumer.on_task method call

Ask Solem 12 년 전
부모
커밋
afb75ef924
4개의 변경된 파일74개의 추가작업 그리고 92개의 파일을 삭제
  1. 3 16
      celery/tests/worker/test_control.py
  2. 3 17
      celery/tests/worker/test_worker.py
  3. 3 53
      celery/worker/consumer.py
  4. 65 6
      celery/worker/strategy.py

+ 3 - 16
celery/tests/worker/test_control.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 import sys
 import socket
 
+from collections import defaultdict
 from datetime import datetime, timedelta
 
 from kombu import pidbox
@@ -42,7 +43,7 @@ class Consumer(consumer.Consumer):
 
     def __init__(self):
         self.buffer = FastQueue()
-        self.handle_task = self.buffer.put()
+        self.handle_task = self.buffer.put
         self.timer = Timer()
         self.app = current_app
         self.event_dispatcher = Mock()
@@ -51,6 +52,7 @@ class Consumer(consumer.Consumer):
 
         from celery.concurrency.base import BasePool
         self.pool = BasePool(10)
+        self.task_buckets = defaultdict(lambda: None)
 
 
 class test_ControlPanel(Case):
@@ -257,21 +259,6 @@ class test_ControlPanel(Case):
         finally:
             state.reserved_requests.clear()
 
-    def test_rate_limit_when_disabled(self):
-        app = current_app
-        app.conf.CELERY_DISABLE_RATE_LIMITS = True
-        try:
-            e = self.panel.handle(
-                'rate_limit',
-                arguments={
-                    'task_name': mytask.name,
-                    'rate_limit': '100/m'
-                },
-            )
-            self.assertIn('rate limits disabled', e.get('error'))
-        finally:
-            app.conf.CELERY_DISABLE_RATE_LIMITS = False
-
     def test_rate_limit_invalid_rate_limit_string(self):
         e = self.panel.handle('rate_limit', arguments=dict(
             task_name='tasks.add', rate_limit='x1240301#%!'))

+ 3 - 17
celery/tests/worker/test_worker.py

@@ -319,7 +319,7 @@ class test_Consumer(Case):
         callback(m.decode(), m)
         self.assertTrue(warn.call_count)
 
-    @patch('celery.worker.consumer.to_timestamp')
+    @patch('celery.worker.strategy.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
@@ -340,6 +340,7 @@ class test_Consumer(Case):
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
         l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l.event_dispatcher = Mock()
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
@@ -379,6 +380,7 @@ class test_Consumer(Case):
 
     def test_receieve_message(self):
         l = Consumer(self.buffer.put, timer=self.timer)
+        l.event_dispatcher = Mock()
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
@@ -804,22 +806,6 @@ class test_Consumer(Case):
         self.assertEqual(None, l.pool)
         l.namespace.start(l)
 
-    def test_on_task_revoked(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
-        task = Mock()
-        task.revoked.return_value = True
-        l.on_task(task)
-
-    def test_on_task_no_events(self):
-        l = Consumer(self.buffer.put, timer=self.timer)
-        task = Mock()
-        task.revoked.return_value = False
-        l.event_dispatcher = Mock()
-        l.event_dispatcher.enabled = False
-        task.eta = None
-        l._does_info = False
-        l.on_task(task)
-
 
 class test_WorkController(AppCase):
 

+ 3 - 53
celery/worker/consumer.py

@@ -35,8 +35,8 @@ from celery.task.trace import build_tracer
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.text import truncate
-from celery.utils.timer2 import default_timer, to_timestamp
-from celery.utils.timeutils import humanize_seconds, timezone, rate
+from celery.utils.timer2 import default_timer
+from celery.utils.timeutils import humanize_seconds, rate
 
 from . import heartbeat, loops, pidbox
 from .state import task_reserved, maybe_shutdown, revoked
@@ -197,7 +197,6 @@ class Consumer(object):
         self.task_buckets.update(
             (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
         )
-        print('BUCKETS: %r' % (self.task_buckets, ))
 
     def _limit_task(self, request, bucket, tokens):
         if not bucket.can_consume(tokens):
@@ -206,6 +205,7 @@ class Consumer(object):
                 hold * 1000.0, self._limit_task, (request, bucket, tokens),
             )
         else:
+            task_reserved(request)
             self.handle_task(request)
 
     def start(self):
@@ -325,56 +325,6 @@ class Consumer(object):
         self.app.amqp.queues.select_remove(queue)
         self.task_consumer.cancel_by_queue(queue)
 
-    def on_task(self, task, task_reserved=task_reserved,
-                to_system_tz=timezone.to_system):
-        """Handle received task.
-
-        If the task has an `eta` we enter it into the ETA schedule,
-        otherwise we move it the ready queue for immediate processing.
-
-        """
-        if task.revoked():
-            return
-
-        if self._does_info:
-            info('Got task from broker: %s', task)
-
-        if self.event_dispatcher.enabled:
-            self.event_dispatcher.send(
-                'task-received',
-                uuid=task.id, name=task.name,
-                args=safe_repr(task.args), kwargs=safe_repr(task.kwargs),
-                retries=task.request_dict.get('retries', 0),
-                eta=task.eta and task.eta.isoformat(),
-                expires=task.expires and task.expires.isoformat(),
-            )
-
-        if task.eta:
-            try:
-                if task.utc:
-                    eta = to_timestamp(to_system_tz(task.eta))
-                else:
-                    eta = to_timestamp(task.eta, timezone.local)
-            except OverflowError as exc:
-                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
-                      task.eta, exc, task.info(safe=True), exc_info=True)
-                task.acknowledge()
-            else:
-                self.qos.increment_eventually()
-                self.timer.apply_at(
-                    eta, self.apply_eta_task, (task, ), priority=6,
-                )
-        else:
-            task_reserved(task)
-            if not self.disable_rate_limits:
-                bucket = self.task_buckets[task.name]
-                if bucket:
-                    self._limit_task(task, bucket, 1)
-                else:
-                    self.handle_task(task)
-            else:
-                self.handle_task(task)
-
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         ETA/countdown."""

+ 65 - 6
celery/worker/strategy.py

@@ -8,19 +8,78 @@
 """
 from __future__ import absolute_import
 
+import logging
+
+from kombu.utils.encoding import safe_repr
+
+from celery.utils.log import get_logger
+from celery.utils.timer2 import to_timestamp
+from celery.utils.timeutils import timezone
+
+logger = get_logger(__name__)
+
 from .job import Request
+from .state import task_reserved
 
 
-def default(task, app, consumer):
+def default(task, app, consumer,
+            info=logger.info, error=logger.error, task_reserved=task_reserved,
+            to_timestamp=to_timestamp, to_system_tz=timezone.to_system):
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
     Req = Request
-    handle = consumer.on_task
     connection_errors = consumer.connection_errors
+    _does_info = logger.isEnabledFor(logging.INFO)
+    events = eventer and eventer.enabled
+    send_event = eventer.send
+    timer_apply_at = consumer.timer.apply_at
+    apply_eta_task = consumer.apply_eta_task
+    rate_limits_enabled = not consumer.disable_rate_limits
+    bucket = consumer.task_buckets[task.name]
+    handle = consumer.handle_task
+    limit_task = consumer._limit_task
 
     def task_message_handler(message, body, ack):
-        handle(Req(body, on_ack=ack, app=app, hostname=hostname,
-                   eventer=eventer, task=task,
-                   connection_errors=connection_errors,
-                   delivery_info=message.delivery_info))
+        req = Req(body, on_ack=ack, app=app, hostname=hostname,
+                  eventer=eventer, task=task,
+                  connection_errors=connection_errors,
+                  delivery_info=message.delivery_info)
+        if req.revoked():
+            return
+
+        if _does_info:
+            info('Got task from broker: %s', req)
+
+        if events:
+            send_event(
+                'task-received',
+                uuid=req.id, name=req.name,
+                args=safe_repr(req.args), kwargs=safe_repr(req.kwargs),
+                retries=req.request_dict.get('retries', 0),
+                eta=req.eta and req.eta.isoformat(),
+                expires=req.expires and req.expires.isoformat(),
+            )
+
+        if req.eta:
+            try:
+                if req.utc:
+                    eta = to_timestamp(to_system_tz(req.eta))
+                else:
+                    eta = to_timestamp(req.eta, timezone.local)
+            except OverflowError as exc:
+                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
+                      req.eta, exc, req.info(safe=True), exc_info=True)
+                req.acknowledge()
+            else:
+                consumer.qos.increment_eventually()
+                timer_apply_at(
+                    eta, apply_eta_task, (req, ), priority=6,
+                )
+        else:
+            if rate_limits_enabled:
+                if bucket:
+                    return limit_task(req, bucket, 1)
+            task_reserved(req)
+            handle(req)
+
     return task_message_handler