Browse Source

Merge branch 'c31-lockfree-rate-limits'

Ask Solem 12 years ago
parent
commit
d6ad9e122f

+ 0 - 1
celery/app/defaults.py

@@ -170,7 +170,6 @@ NAMESPACES = {
                             alt='--loglevel argument'),
         'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
                            alt='--logfile argument'),
-        'MEDIATOR': Option('celery.worker.mediator:Mediator'),
         'MAX_TASKS_PER_CHILD': Option(type='int'),
         'POOL': Option(DEFAULT_POOL),
         'POOL_PUTLOCKS': Option(True, type='bool'),

+ 0 - 9
celery/concurrency/base.py

@@ -38,15 +38,6 @@ class BasePool(object):
     #: a signal handler.
     signal_safe = True
 
-    #: set to true if pool supports rate limits.
-    #: (this is here for gevent, which currently does not implement
-    #: the necessary timers).
-    rlimit_safe = True
-
-    #: set to true if pool requires the use of a mediator
-    #: thread (e.g. if applying new items can block the current thread).
-    requires_mediator = False
-
     #: set to true if pool uses greenlets.
     is_green = False
 

+ 0 - 1
celery/concurrency/eventlet.py

@@ -117,7 +117,6 @@ class Timer(timer2.Timer):
 class TaskPool(base.BasePool):
     Timer = Timer
 
-    rlimit_safe = False
     signal_safe = False
     is_green = True
 

+ 0 - 1
celery/concurrency/gevent.py

@@ -107,7 +107,6 @@ class TaskPool(BasePool):
     Timer = Timer
 
     signal_safe = False
-    rlimit_safe = False
     is_green = True
 
     def __init__(self, *args, **kwargs):

+ 0 - 1
celery/concurrency/processes.py

@@ -66,7 +66,6 @@ class TaskPool(BasePool):
     """Multiprocessing Pool implementation."""
     Pool = Pool
 
-    requires_mediator = True
     uses_semaphore = True
 
     def on_start(self):

+ 0 - 345
celery/tests/slow/test_buckets.py

@@ -1,345 +0,0 @@
-from __future__ import absolute_import
-
-import sys
-import time
-
-from functools import partial
-from itertools import chain
-
-from mock import Mock, patch
-
-from celery.app.registry import TaskRegistry
-from celery.five import Empty, range
-from celery.task.base import Task
-from celery.utils import timeutils
-from celery.utils import uuid
-from celery.worker import buckets
-
-from celery.tests.utils import Case, skip_if_environ, mock_context
-
-skip_if_disabled = partial(skip_if_environ('SKIP_RLIMITS'))
-
-
-class MockJob(object):
-
-    def __init__(self, id, name, args, kwargs):
-        self.id = id
-        self.name = name
-        self.args = args
-        self.kwargs = kwargs
-
-    def __eq__(self, other):
-        if isinstance(other, self.__class__):
-            return bool(self.id == other.id
-                        and self.name == other.name
-                        and self.args == other.args
-                        and self.kwargs == other.kwargs)
-        else:
-            return self == other
-
-    def __repr__(self):
-        return '<MockJob: task:%s id:%s args:%s kwargs:%s' % (
-            self.name, self.id, self.args, self.kwargs)
-
-
-class test_TokenBucketQueue(Case):
-
-    @skip_if_disabled
-    def empty_queue_yields_QueueEmpty(self):
-        x = buckets.TokenBucketQueue(fill_rate=10)
-        with self.assertRaises(buckets.Empty):
-            x.get()
-
-    @skip_if_disabled
-    def test_bucket__put_get(self):
-        x = buckets.TokenBucketQueue(fill_rate=10)
-        x.put('The quick brown fox')
-        self.assertEqual(x.get(), 'The quick brown fox')
-
-        x.put_nowait('The lazy dog')
-        time.sleep(0.2)
-        self.assertEqual(x.get_nowait(), 'The lazy dog')
-
-    @skip_if_disabled
-    def test_fill_rate(self):
-        x = buckets.TokenBucketQueue(fill_rate=10)
-        # 20 items should take at least one second to complete
-        time_start = time.time()
-        [x.put(str(i)) for i in range(20)]
-        for i in range(20):
-            sys.stderr.write('.')
-            x.wait()
-        self.assertGreater(time.time() - time_start, 1.5)
-
-    @skip_if_disabled
-    def test_can_consume(self):
-        x = buckets.TokenBucketQueue(fill_rate=1)
-        x.put('The quick brown fox')
-        self.assertEqual(x.get(), 'The quick brown fox')
-        time.sleep(0.1)
-        # Not yet ready for another token
-        x.put('The lazy dog')
-        with self.assertRaises(x.RateLimitExceeded):
-            x.get()
-
-    @skip_if_disabled
-    def test_expected_time(self):
-        x = buckets.TokenBucketQueue(fill_rate=1)
-        x.put_nowait('The quick brown fox')
-        self.assertEqual(x.get_nowait(), 'The quick brown fox')
-        self.assertFalse(x.expected_time())
-
-    @skip_if_disabled
-    def test_qsize(self):
-        x = buckets.TokenBucketQueue(fill_rate=1)
-        x.put('The quick brown fox')
-        self.assertEqual(x.qsize(), 1)
-        self.assertEqual(x.get_nowait(), 'The quick brown fox')
-
-
-class test_rate_limit_string(Case):
-
-    @skip_if_disabled
-    def test_conversion(self):
-        self.assertEqual(timeutils.rate(999), 999)
-        self.assertEqual(timeutils.rate(7.5), 7.5)
-        self.assertEqual(timeutils.rate('2.5/s'), 2.5)
-        self.assertEqual(timeutils.rate('1456/s'), 1456)
-        self.assertEqual(timeutils.rate('100/m'),
-                         100 / 60.0)
-        self.assertEqual(timeutils.rate('10/h'),
-                         10 / 60.0 / 60.0)
-
-        for zero in (0, None, '0', '0/m', '0/h', '0/s', '0.0/s'):
-            self.assertEqual(timeutils.rate(zero), 0)
-
-
-class TaskA(Task):
-    rate_limit = 10
-
-
-class TaskB(Task):
-    rate_limit = None
-
-
-class TaskC(Task):
-    rate_limit = '1/s'
-
-
-class TaskD(Task):
-    rate_limit = '1000/m'
-
-
-class test_TaskBucket(Case):
-
-    def setUp(self):
-        self.registry = TaskRegistry()
-        self.task_classes = (TaskA, TaskB, TaskC)
-        for task_cls in self.task_classes:
-            self.registry.register(task_cls)
-
-    @skip_if_disabled
-    def test_get_nowait(self):
-        x = buckets.TaskBucket(task_registry=self.registry)
-        with self.assertRaises(buckets.Empty):
-            x.get_nowait()
-
-    @patch('celery.worker.buckets.sleep')
-    def test_get_block(self, sleep):
-        x = buckets.TaskBucket(task_registry=self.registry)
-        x.not_empty = Mock()
-        get = x._get = Mock()
-        remaining = [0]
-
-        def effect():
-            if get.call_count == 1:
-                raise Empty()
-            rem = remaining[0]
-            remaining[0] = 0
-            return rem, Mock()
-        get.side_effect = effect
-
-        with mock_context(Mock()) as context:
-            x.not_empty = context
-            x.wait = Mock()
-            x.get(block=True)
-
-            get.reset()
-            remaining[0] = 1
-            x.get(block=True)
-
-    def test_get_raises_rate(self):
-        x = buckets.TaskBucket(task_registry=self.registry)
-        x.buckets = {1: Mock()}
-        x.buckets[1].get_nowait.side_effect = buckets.RateLimitExceeded()
-        x.buckets[1].expected_time.return_value = 0
-        x._get()
-
-    @skip_if_disabled
-    def test_refresh(self):
-        reg = {}
-        x = buckets.TaskBucket(task_registry=reg)
-        reg['foo'] = 'something'
-        x.refresh()
-        self.assertIn('foo', x.buckets)
-        self.assertTrue(x.get_bucket_for_type('foo'))
-
-    @skip_if_disabled
-    def test__get_queue_for_type(self):
-        x = buckets.TaskBucket(task_registry={})
-        x.buckets['foo'] = buckets.TokenBucketQueue(fill_rate=1)
-        self.assertIs(x._get_queue_for_type('foo'), x.buckets['foo'].queue)
-        x.buckets['bar'] = buckets.FastQueue()
-        self.assertIs(x._get_queue_for_type('bar'), x.buckets['bar'])
-
-    @skip_if_disabled
-    def test_update_bucket_for_type(self):
-        bucket = buckets.TaskBucket(task_registry=self.registry)
-        b = bucket._get_queue_for_type(TaskC.name)
-        self.assertIs(bucket.update_bucket_for_type(TaskC.name).queue, b)
-        self.assertIs(bucket.buckets[TaskC.name].queue, b)
-
-    @skip_if_disabled
-    def test_auto_add_on_missing_put(self):
-        reg = {}
-        b = buckets.TaskBucket(task_registry=reg)
-        reg['nonexisting.task'] = 'foo'
-
-        b.put(MockJob(uuid(), 'nonexisting.task', (), {}))
-        self.assertIn('nonexisting.task', b.buckets)
-
-    @skip_if_disabled
-    def test_auto_add_on_missing(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-        for task_cls in self.task_classes:
-            self.assertIn(task_cls.name, list(b.buckets.keys()))
-        self.registry.register(TaskD)
-        self.assertTrue(b.get_bucket_for_type(TaskD.name))
-        self.assertIn(TaskD.name, list(b.buckets.keys()))
-        self.registry.unregister(TaskD)
-
-    @skip_if_disabled
-    def test_has_rate_limits(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-        self.assertEqual(b.buckets[TaskA.name]._bucket.fill_rate, 10)
-        self.assertIsInstance(b.buckets[TaskB.name], buckets.Queue)
-        self.assertEqual(b.buckets[TaskC.name]._bucket.fill_rate, 1)
-        self.registry.register(TaskD)
-        b.init_with_registry()
-        try:
-            self.assertEqual(b.buckets[TaskD.name]._bucket.fill_rate,
-                             1000 / 60.0)
-        finally:
-            self.registry.unregister(TaskD)
-
-    @skip_if_disabled
-    def test_on_empty_buckets__get_raises_empty(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-        with self.assertRaises(buckets.Empty):
-            b.get(block=False)
-        self.assertEqual(b.qsize(), 0)
-
-    @skip_if_disabled
-    def test_put__get(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-        job = MockJob(uuid(), TaskA.name, ['theqbf'], {'foo': 'bar'})
-        b.put(job)
-        self.assertEqual(b.get(), job)
-
-    @skip_if_disabled
-    def test_fill_rate(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-
-        cjob = lambda i: MockJob(uuid(), TaskA.name, [i], {})
-        jobs = [cjob(i) for i in range(20)]
-        [b.put(job) for job in jobs]
-
-        self.assertEqual(b.qsize(), 20)
-
-        # 20 items should take at least one second to complete
-        time_start = time.time()
-        for i, job in enumerate(jobs):
-            sys.stderr.write('.')
-            self.assertEqual(b.get(), job)
-        self.assertGreater(time.time() - time_start, 1.5)
-
-    @skip_if_disabled
-    def test__very_busy_queue_doesnt_block_others(self):
-        b = buckets.TaskBucket(task_registry=self.registry)
-
-        cjob = lambda i, t: MockJob(uuid(), t.name, [i], {})
-        ajobs = [cjob(i, TaskA) for i in range(10)]
-        bjobs = [cjob(i, TaskB) for i in range(20)]
-        jobs = list(chain(*zip(bjobs, ajobs)))
-        for job in jobs:
-            b.put(job)
-
-        got_ajobs = 0
-        for job in (b.get() for i in range(20)):
-            if job.name == TaskA.name:
-                got_ajobs += 1
-
-        self.assertGreater(got_ajobs, 2)
-
-    @skip_if_disabled
-    def test_thorough__multiple_types(self):
-        self.registry.register(TaskD)
-        try:
-            b = buckets.TaskBucket(task_registry=self.registry)
-
-            cjob = lambda i, t: MockJob(uuid(), t.name, [i], {})
-
-            ajobs = [cjob(i, TaskA) for i in range(10)]
-            bjobs = [cjob(i, TaskB) for i in range(10)]
-            cjobs = [cjob(i, TaskC) for i in range(10)]
-            djobs = [cjob(i, TaskD) for i in range(10)]
-
-            # Spread the jobs around.
-            jobs = list(chain(*zip(ajobs, bjobs, cjobs, djobs)))
-
-            [b.put(job) for job in jobs]
-            for i, job in enumerate(jobs):
-                sys.stderr.write('.')
-                self.assertTrue(b.get(), job)
-            self.assertEqual(i + 1, len(jobs))
-        finally:
-            self.registry.unregister(TaskD)
-
-    @skip_if_disabled
-    def test_empty(self):
-        x = buckets.TaskBucket(task_registry=self.registry)
-        self.assertTrue(x.empty())
-        x.put(MockJob(uuid(), TaskC.name, [], {}))
-        self.assertFalse(x.empty())
-        x.clear()
-        self.assertTrue(x.empty())
-
-    @skip_if_disabled
-    def test_items(self):
-        x = buckets.TaskBucket(task_registry=self.registry)
-        x.buckets[TaskA.name].put(1)
-        x.buckets[TaskB.name].put(2)
-        x.buckets[TaskC.name].put(3)
-        self.assertEqual(sorted(x.items), [1, 2, 3])
-
-
-class test_FastQueue(Case):
-
-    def test_items(self):
-        x = buckets.FastQueue()
-        x.put(10)
-        x.put(20)
-        self.assertListEqual([10, 20], list(x.items))
-
-    def test_wait(self):
-        x = buckets.FastQueue()
-        x.put(10)
-        self.assertEqual(x.wait(), 10)
-
-    def test_clear(self):
-        x = buckets.FastQueue()
-        x.put(10)
-        x.put(20)
-        self.assertFalse(x.empty())
-        x.clear()
-        self.assertTrue(x.empty())

+ 34 - 12
celery/tests/utilities/test_timeutils.py

@@ -2,16 +2,23 @@ from __future__ import absolute_import
 
 from datetime import datetime, timedelta
 
-from celery.utils import timeutils
-from celery.utils.timeutils import timezone
+from celery.utils.timeutils import (
+    delta_resolution,
+    humanize_seconds,
+    maybe_iso8601,
+    maybe_timedelta,
+    timedelta_seconds,
+    timezone,
+    rate,
+    remaining,
+)
 from celery.tests.utils import Case
 
 
 class test_timeutils(Case):
 
     def test_delta_resolution(self):
-        D = timeutils.delta_resolution
-
+        D = delta_resolution
         dt = datetime(2010, 3, 30, 11, 50, 58, 41065)
         deltamap = ((timedelta(days=2), datetime(2010, 3, 30, 0, 0)),
                     (timedelta(hours=2), datetime(2010, 3, 30, 11, 0)),
@@ -27,11 +34,11 @@ class test_timeutils(Case):
                     (timedelta(hours=4), 4 * 60 * 60),
                     (timedelta(days=3), 3 * 86400))
         for delta, seconds in deltamap:
-            self.assertEqual(timeutils.timedelta_seconds(delta), seconds)
+            self.assertEqual(timedelta_seconds(delta), seconds)
 
     def test_timedelta_seconds_returns_0_on_negative_time(self):
         delta = timedelta(days=-2)
-        self.assertEqual(timeutils.timedelta_seconds(delta), 0)
+        self.assertEqual(timedelta_seconds(delta), 0)
 
     def test_humanize_seconds(self):
         t = ((4 * 60 * 60 * 24, '4.00 days'),
@@ -46,17 +53,17 @@ class test_timeutils(Case):
              (0, 'now'))
 
         for seconds, human in t:
-            self.assertEqual(timeutils.humanize_seconds(seconds), human)
+            self.assertEqual(humanize_seconds(seconds), human)
 
-        self.assertEqual(timeutils.humanize_seconds(4, prefix='about '),
+        self.assertEqual(humanize_seconds(4, prefix='about '),
                          'about 4.00 seconds')
 
     def test_maybe_iso8601_datetime(self):
         now = datetime.now()
-        self.assertIs(timeutils.maybe_iso8601(now), now)
+        self.assertIs(maybe_iso8601(now), now)
 
     def test_maybe_timedelta(self):
-        D = timeutils.maybe_timedelta
+        D = maybe_timedelta
 
         for i in (30, 30.6):
             self.assertEqual(D(i), timedelta(seconds=i))
@@ -64,11 +71,26 @@ class test_timeutils(Case):
         self.assertEqual(D(timedelta(days=2)), timedelta(days=2))
 
     def test_remaining_relative(self):
-        timeutils.remaining(datetime.utcnow(), timedelta(hours=1),
-                            relative=True)
+        remaining(datetime.utcnow(), timedelta(hours=1), relative=True)
 
 
 class test_timezone(Case):
 
     def test_get_timezone_with_pytz(self):
         self.assertTrue(timezone.get_timezone('UTC'))
+
+
+class test_rate_limit_string(Case):
+
+    def test_conversion(self):
+        self.assertEqual(rate(999), 999)
+        self.assertEqual(rate(7.5), 7.5)
+        self.assertEqual(rate('2.5/s'), 2.5)
+        self.assertEqual(rate('1456/s'), 1456)
+        self.assertEqual(rate('100/m'),
+                         100 / 60.0)
+        self.assertEqual(rate('10/h'),
+                         10 / 60.0 / 60.0)
+
+        for zero in (0, None, '0', '0/m', '0/h', '0/s', '0.0/s'):
+            self.assertEqual(rate(zero), 0)

+ 11 - 28
celery/tests/worker/test_control.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 import sys
 import socket
 
+from collections import defaultdict
 from datetime import datetime, timedelta
 
 from kombu import pidbox
@@ -17,7 +18,7 @@ from celery.worker import WorkController as _WC
 from celery.worker import consumer
 from celery.worker import control
 from celery.worker import state
-from celery.worker.buckets import FastQueue
+from celery.five import Queue as FastQueue
 from celery.worker.job import TaskRequest
 from celery.worker.state import revoked
 from celery.worker.control import Panel
@@ -41,7 +42,8 @@ class WorkController(object):
 class Consumer(consumer.Consumer):
 
     def __init__(self):
-        self.ready_queue = FastQueue()
+        self.buffer = FastQueue()
+        self.handle_task = self.buffer.put
         self.timer = Timer()
         self.app = current_app
         self.event_dispatcher = Mock()
@@ -50,6 +52,7 @@ class Consumer(consumer.Consumer):
 
         from celery.concurrency.base import BasePool
         self.pool = BasePool(10)
+        self.task_buckets = defaultdict(lambda: None)
 
 
 class test_ControlPanel(Case):
@@ -256,21 +259,6 @@ class test_ControlPanel(Case):
         finally:
             state.reserved_requests.clear()
 
-    def test_rate_limit_when_disabled(self):
-        app = current_app
-        app.conf.CELERY_DISABLE_RATE_LIMITS = True
-        try:
-            e = self.panel.handle(
-                'rate_limit',
-                arguments={
-                    'task_name': mytask.name,
-                    'rate_limit': '100/m'
-                },
-            )
-            self.assertIn('rate limits disabled', e.get('error'))
-        finally:
-            app.conf.CELERY_DISABLE_RATE_LIMITS = False
-
     def test_rate_limit_invalid_rate_limit_string(self):
         e = self.panel.handle('rate_limit', arguments=dict(
             task_name='tasks.add', rate_limit='x1240301#%!'))
@@ -279,15 +267,10 @@ class test_ControlPanel(Case):
     def test_rate_limit(self):
 
         class Consumer(object):
+            reset = False
 
-            class ReadyQueue(object):
-                fresh = False
-
-                def refresh(self):
-                    self.fresh = True
-
-            def __init__(self):
-                self.ready_queue = self.ReadyQueue()
+            def reset_rate_limits(self):
+                self.reset = True
 
         consumer = Consumer()
         panel = self.create_panel(app=current_app, consumer=consumer)
@@ -298,12 +281,12 @@ class test_ControlPanel(Case):
             panel.handle('rate_limit', arguments=dict(task_name=task.name,
                                                       rate_limit='100/m'))
             self.assertEqual(task.rate_limit, '100/m')
-            self.assertTrue(consumer.ready_queue.fresh)
-            consumer.ready_queue.fresh = False
+            self.assertTrue(consumer.reset)
+            consumer.reset = False
             panel.handle('rate_limit', arguments=dict(task_name=task.name,
                                                       rate_limit=0))
             self.assertEqual(task.rate_limit, 0)
-            self.assertTrue(consumer.ready_queue.fresh)
+            self.assertTrue(consumer.reset)
         finally:
             task.rate_limit = old_rate_limit
 

+ 0 - 112
celery/tests/worker/test_mediator.py

@@ -1,112 +0,0 @@
-from __future__ import absolute_import
-
-import sys
-
-from mock import Mock, patch
-
-from celery.five import Queue
-from celery.worker.mediator import Mediator
-from celery.worker.state import revoked as revoked_tasks
-from celery.tests.utils import Case
-
-
-class MockTask(object):
-    hostname = 'harness.com'
-    id = 1234
-    name = 'mocktask'
-
-    def __init__(self, value, **kwargs):
-        self.value = value
-
-    on_ack = Mock()
-
-    def revoked(self):
-        if self.id in revoked_tasks:
-            self.on_ack()
-            return True
-        return False
-
-
-class test_Mediator(Case):
-
-    def test_mediator_start__stop(self):
-        ready_queue = Queue()
-        m = Mediator(ready_queue, lambda t: t)
-        m.start()
-        self.assertFalse(m._is_shutdown.isSet())
-        self.assertFalse(m._is_stopped.isSet())
-        m.stop()
-        m.join()
-        self.assertTrue(m._is_shutdown.isSet())
-        self.assertTrue(m._is_stopped.isSet())
-
-    def test_mediator_body(self):
-        ready_queue = Queue()
-        got = {}
-
-        def mycallback(value):
-            got['value'] = value.value
-
-        m = Mediator(ready_queue, mycallback)
-        ready_queue.put(MockTask('George Costanza'))
-
-        m.body()
-
-        self.assertEqual(got['value'], 'George Costanza')
-
-        ready_queue.put(MockTask('Jerry Seinfeld'))
-        m._does_debug = False
-        m.body()
-        self.assertEqual(got['value'], 'Jerry Seinfeld')
-
-    @patch('os._exit')
-    def test_mediator_crash(self, _exit):
-        ms = [None]
-
-        class _Mediator(Mediator):
-
-            def body(self):
-                try:
-                    raise KeyError('foo')
-                finally:
-                    ms[0]._is_shutdown.set()
-
-        ready_queue = Queue()
-        ms[0] = m = _Mediator(ready_queue, None)
-        ready_queue.put(MockTask('George Constanza'))
-
-        stderr = Mock()
-        p, sys.stderr = sys.stderr, stderr
-        try:
-            m.run()
-        finally:
-            sys.stderr = p
-        self.assertTrue(_exit.call_count)
-        self.assertTrue(stderr.write.call_count)
-
-    def test_mediator_body_exception(self):
-        ready_queue = Queue()
-
-        def mycallback(value):
-            raise KeyError('foo')
-
-        m = Mediator(ready_queue, mycallback)
-        ready_queue.put(MockTask('Elaine M. Benes'))
-
-        m.body()
-
-    def test_run(self):
-        ready_queue = Queue()
-
-        condition = [None]
-
-        def mycallback(value):
-            condition[0].set()
-
-        m = Mediator(ready_queue, mycallback)
-        condition[0] = m._is_shutdown
-        ready_queue.put(MockTask('Elaine M. Benes'))
-
-        m.run()
-        self.assertTrue(m._is_shutdown.isSet())
-        self.assertTrue(m._is_stopped.isSet())

+ 49 - 107
celery/tests/worker/test_worker.py

@@ -12,7 +12,6 @@ from kombu.common import QoS, PREFETCH_COUNT_MAX, ignore_errors
 from kombu.exceptions import StdChannelError
 from kombu.transport.base import Message
 from mock import Mock, patch
-from nose import SkipTest
 
 from celery import current_app
 from celery.app.defaults import DEFAULTS
@@ -20,13 +19,12 @@ from celery.bootsteps import RUN, CLOSE, TERMINATE, StartStopStep
 from celery.concurrency.base import BasePool
 from celery.datastructures import AttributeDict
 from celery.exceptions import SystemTerminate
-from celery.five import Empty, range
+from celery.five import Empty, range, Queue as FastQueue
 from celery.task import task as task_dec
 from celery.task import periodic_task as periodic_task_dec
 from celery.utils import uuid
 from celery.worker import WorkController
 from celery.worker import components
-from celery.worker.buckets import FastQueue, AsyncTaskBucket
 from celery.worker.job import Request
 from celery.worker import consumer
 from celery.worker.consumer import Consumer as __Consumer
@@ -236,14 +234,14 @@ class test_QoS(Case):
 class test_Consumer(Case):
 
     def setUp(self):
-        self.ready_queue = FastQueue()
+        self.buffer = FastQueue()
         self.timer = Timer()
 
     def tearDown(self):
         self.timer.stop()
 
     def test_info(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.task_consumer = Mock()
         l.qos = QoS(l.task_consumer.qos, 10)
         l.connection = Mock()
@@ -257,12 +255,12 @@ class test_Consumer(Case):
         self.assertTrue(info['broker'])
 
     def test_start_when_closed(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.namespace.state = CLOSE
         l.start()
 
     def test_connection(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
 
         l.namespace.start(l)
         self.assertIsInstance(l.connection, Connection)
@@ -287,7 +285,7 @@ class test_Consumer(Case):
         self.assertIsNone(l.task_consumer)
 
     def test_close_connection(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.namespace.state = RUN
         step = find_step(l, consumer.Connection)
         conn = l.connection = Mock()
@@ -295,7 +293,7 @@ class test_Consumer(Case):
         self.assertTrue(conn.close.called)
         self.assertIsNone(l.connection)
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         eventer = l.event_dispatcher = Mock()
         eventer.enabled = True
         heart = l.heart = MockHeart()
@@ -309,7 +307,7 @@ class test_Consumer(Case):
 
     @patch('celery.worker.consumer.warn')
     def test_receive_message_unknown(self, warn):
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, unknown={'baz': '!!!'})
@@ -320,10 +318,10 @@ class test_Consumer(Case):
         callback(m.decode(), m)
         self.assertTrue(warn.call_count)
 
-    @patch('celery.worker.consumer.to_timestamp')
+    @patch('celery.worker.strategy.to_timestamp')
     def test_receive_message_eta_OverflowError(self, to_timestamp):
         to_timestamp.side_effect = OverflowError()
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            args=('2, 2'),
@@ -340,7 +338,8 @@ class test_Consumer(Case):
 
     @patch('celery.worker.consumer.error')
     def test_receive_message_InvalidTaskError(self, error):
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
+        l.event_dispatcher = Mock()
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            args=(1, 2), kwargs='foobarbaz', id=1)
@@ -353,7 +352,7 @@ class test_Consumer(Case):
 
     @patch('celery.worker.consumer.crit')
     def test_on_decode_error(self, crit):
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
 
         class MockMessage(Mock):
             content_type = 'application/x-msgpack'
@@ -379,14 +378,15 @@ class test_Consumer(Case):
         return l.task_consumer.register_callback.call_args[0][0]
 
     def test_receieve_message(self):
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
+        l.event_dispatcher = Mock()
         m = create_message(Mock(), task=foo_task.name,
                            args=[2, 4, 8], kwargs={})
         l.update_strategies()
         callback = self._get_on_message(l)
         callback(m.decode(), m)
 
-        in_bucket = self.ready_queue.get_nowait()
+        in_bucket = self.buffer.get_nowait()
         self.assertIsInstance(in_bucket, Request)
         self.assertEqual(in_bucket.name, foo_task.name)
         self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
@@ -403,7 +403,7 @@ class test_Consumer(Case):
                     raise KeyError('foo')
                 raise SyntaxError('bar')
 
-        l = MockConsumer(self.ready_queue, timer=self.timer,
+        l = MockConsumer(self.buffer.put, timer=self.timer,
                          send_events=False, pool=BasePool())
         l.channel_errors = (KeyError, )
         with self.assertRaises(KeyError):
@@ -421,7 +421,7 @@ class test_Consumer(Case):
                     raise KeyError('foo')
                 raise SyntaxError('bar')
 
-        l = MockConsumer(self.ready_queue, timer=self.timer,
+        l = MockConsumer(self.buffer.put, timer=self.timer,
                          send_events=False, pool=BasePool())
 
         l.connection_errors = (KeyError, )
@@ -437,7 +437,7 @@ class test_Consumer(Case):
                 self.obj.connection = None
                 raise socket.timeout(10)
 
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.connection = Connection()
         l.task_consumer = Mock()
         l.connection.obj = l
@@ -453,7 +453,7 @@ class test_Consumer(Case):
                 self.obj.connection = None
                 raise socket.error('foo')
 
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
         l.namespace.state = RUN
         c = l.connection = Connection()
         l.connection.obj = l
@@ -474,7 +474,7 @@ class test_Consumer(Case):
             def drain_events(self, **kwargs):
                 self.obj.connection = None
 
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
         l.connection = Connection()
         l.connection.obj = l
         l.task_consumer = Mock()
@@ -492,7 +492,7 @@ class test_Consumer(Case):
         l.task_consumer.qos.assert_called_with(prefetch_count=9)
 
     def test_ignore_errors(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.connection_errors = (AttributeError, KeyError, )
         l.channel_errors = (SyntaxError, )
         ignore_errors(l, Mock(side_effect=AttributeError('foo')))
@@ -503,7 +503,7 @@ class test_Consumer(Case):
 
     def test_apply_eta_task(self):
         from celery.worker import state
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.qos = QoS(None, 10)
 
         task = object()
@@ -511,10 +511,10 @@ class test_Consumer(Case):
         l.apply_eta_task(task)
         self.assertIn(task, state.reserved_requests)
         self.assertEqual(l.qos.value, qos - 1)
-        self.assertIs(self.ready_queue.get_nowait(), task)
+        self.assertIs(self.buffer.get_nowait(), task)
 
     def test_receieve_message_eta_isoformat(self):
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         m = create_message(Mock(), task=foo_task.name,
                            eta=datetime.now().isoformat(),
@@ -541,7 +541,7 @@ class test_Consumer(Case):
         l.timer.stop()
 
     def test_pidbox_callback(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         con.reset = Mock()
@@ -561,8 +561,7 @@ class test_Consumer(Case):
         self.assertTrue(con.reset.called)
 
     def test_revoke(self):
-        ready_queue = FastQueue()
-        l = _MyKombuConsumer(ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         backend = Mock()
         id = uuid()
@@ -573,10 +572,10 @@ class test_Consumer(Case):
 
         callback = self._get_on_message(l)
         callback(t.decode(), t)
-        self.assertTrue(ready_queue.empty())
+        self.assertTrue(self.buffer.empty())
 
     def test_receieve_message_not_registered(self):
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         backend = Mock()
         m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})
@@ -585,13 +584,13 @@ class test_Consumer(Case):
         callback = self._get_on_message(l)
         self.assertFalse(callback(m.decode(), m))
         with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
+            self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
 
     @patch('celery.worker.consumer.warn')
     @patch('celery.worker.consumer.logger')
     def test_receieve_message_ack_raises(self, logger, warn):
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
         backend = Mock()
         m = create_message(backend, args=[2, 4, 8], kwargs={})
 
@@ -603,13 +602,13 @@ class test_Consumer(Case):
         self.assertFalse(callback(m.decode(), m))
         self.assertTrue(warn.call_count)
         with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
+            self.buffer.get_nowait()
         self.assertTrue(self.timer.empty())
         m.reject.assert_called_with()
         self.assertTrue(logger.critical.call_count)
 
     def test_receive_message_eta(self):
-        l = _MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = _MyKombuConsumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         l.event_dispatcher = Mock()
         l.event_dispatcher._outbound_buffer = deque()
@@ -640,10 +639,10 @@ class test_Consumer(Case):
         self.assertEqual(task.name, foo_task.name)
         self.assertEqual(task.execute(), 2 * 4 * 8)
         with self.assertRaises(Empty):
-            self.ready_queue.get_nowait()
+            self.buffer.get_nowait()
 
     def test_reset_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         con = find_step(l, consumer.Control).box
         con.node = Mock()
         chan = con.node.channel = Mock()
@@ -657,7 +656,7 @@ class test_Consumer(Case):
         from celery.worker.pidbox import gPidbox
         pool = Mock()
         pool.is_green = True
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
         con = find_step(l, consumer.Control)
         self.assertIsInstance(con.box, gPidbox)
         con.start(l)
@@ -668,7 +667,7 @@ class test_Consumer(Case):
     def test__green_pidbox_node(self):
         pool = Mock()
         pool.is_green = True
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer, pool=pool)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool)
         l.node = Mock()
         controller = find_step(l, consumer.Control)
 
@@ -730,7 +729,7 @@ class test_Consumer(Case):
     @patch('kombu.connection.Connection._establish_connection')
     @patch('kombu.utils.sleep')
     def test_connect_errback(self, sleep, connect):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         from kombu.transport.memory import Transport
         Transport.connection_errors = (StdChannelError, )
 
@@ -743,7 +742,7 @@ class test_Consumer(Case):
         connect.assert_called_with()
 
     def test_stop_pidbox_node(self):
-        l = MyKombuConsumer(self.ready_queue, timer=self.timer)
+        l = MyKombuConsumer(self.buffer.put, timer=self.timer)
         cont = find_step(l, consumer.Control)
         cont._node_stopped = Event()
         cont._node_shutdown = Event()
@@ -767,7 +766,7 @@ class test_Consumer(Case):
                     raise KeyError('foo')
 
         init_callback = Mock()
-        l = _Consumer(self.ready_queue, timer=self.timer,
+        l = _Consumer(self.buffer.put, timer=self.timer,
                       init_callback=init_callback)
         l.task_consumer = Mock()
         l.broadcast_consumer = Mock()
@@ -789,7 +788,7 @@ class test_Consumer(Case):
         self.assertEqual(l.qos.prev, l.qos.value)
 
         init_callback.reset_mock()
-        l = _Consumer(self.ready_queue, timer=self.timer,
+        l = _Consumer(self.buffer.put, timer=self.timer,
                       send_events=False, init_callback=init_callback)
         l.qos = _QoS()
         l.task_consumer = Mock()
@@ -801,27 +800,11 @@ class test_Consumer(Case):
         self.assertTrue(l.loop.call_count)
 
     def test_reset_connection_with_no_node(self):
-        l = Consumer(self.ready_queue, timer=self.timer)
+        l = Consumer(self.buffer.put, timer=self.timer)
         l.steps.pop()
         self.assertEqual(None, l.pool)
         l.namespace.start(l)
 
-    def test_on_task_revoked(self):
-        l = Consumer(self.ready_queue, timer=self.timer)
-        task = Mock()
-        task.revoked.return_value = True
-        l.on_task(task)
-
-    def test_on_task_no_events(self):
-        l = Consumer(self.ready_queue, timer=self.timer)
-        task = Mock()
-        task.revoked.return_value = False
-        l.event_dispatcher = Mock()
-        l.event_dispatcher.enabled = False
-        task.eta = None
-        l._does_info = False
-        l.on_task(task)
-
 
 class test_WorkController(AppCase):
 
@@ -883,18 +866,12 @@ class test_WorkController(AppCase):
             'celeryd', hostname='awesome.worker.com',
         )
 
-    def test_with_rate_limits_disabled(self):
-        worker = WorkController(concurrency=1, loglevel=0,
-                                disable_rate_limits=True)
-        self.assertTrue(hasattr(worker.ready_queue, 'put'))
-
     def test_attrs(self):
         worker = self.worker
         self.assertIsInstance(worker.timer, Timer)
         self.assertTrue(worker.timer)
         self.assertTrue(worker.pool)
         self.assertTrue(worker.consumer)
-        self.assertTrue(worker.mediator)
         self.assertTrue(worker.steps)
 
     def test_with_embedded_beat(self):
@@ -952,7 +929,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode())
-        worker.process_task(task)
+        worker._process_task(task)
         self.assertEqual(worker.pool.apply_async.call_count, 1)
         worker.pool.stop()
 
@@ -967,7 +944,7 @@ class test_WorkController(AppCase):
         worker.steps = []
         worker.namespace.state = RUN
         with self.assertRaises(KeyboardInterrupt):
-            worker.process_task(task)
+            worker._process_task(task)
         self.assertEqual(worker.namespace.state, TERMINATE)
 
     def test_process_task_raise_SystemTerminate(self):
@@ -981,7 +958,7 @@ class test_WorkController(AppCase):
         worker.steps = []
         worker.namespace.state = RUN
         with self.assertRaises(SystemExit):
-            worker.process_task(task)
+            worker._process_task(task)
         self.assertEqual(worker.namespace.state, TERMINATE)
 
     def test_process_task_raise_regular(self):
@@ -992,7 +969,7 @@ class test_WorkController(AppCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = Request.from_message(m, m.decode())
-        worker.process_task(task)
+        worker._process_task(task)
         worker.pool.stop()
 
     def test_start_catches_base_exceptions(self):
@@ -1023,42 +1000,13 @@ class test_WorkController(AppCase):
         finally:
             state.Persistent = Persistent
 
-    def test_disable_rate_limits_solo(self):
-        worker = self.create_worker(disable_rate_limits=True,
-                                    pool_cls='solo')
-        self.assertIsInstance(worker.ready_queue, FastQueue)
-        self.assertIsNone(worker.mediator)
-        self.assertEqual(worker.ready_queue.put, worker.process_task)
-
-    def test_enable_rate_limits_eventloop(self):
-        try:
-            worker = self.create_worker(disable_rate_limits=False,
-                                        use_eventloop=True,
-                                        pool_cls='processes')
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
-        self.assertIsInstance(worker.ready_queue, AsyncTaskBucket)
-        self.assertFalse(worker.mediator)
-        self.assertNotEqual(worker.ready_queue.put, worker.process_task)
-
-    def test_disable_rate_limits_processes(self):
-        try:
-            worker = self.create_worker(disable_rate_limits=True,
-                                        use_eventloop=False,
-                                        pool_cls='processes')
-        except ImportError:
-            raise SkipTest('multiprocessing not supported')
-        self.assertIsInstance(worker.ready_queue, FastQueue)
-        self.assertFalse(worker.mediator)
-        self.assertEqual(worker.ready_queue.put, worker.process_task)
-
     def test_process_task_sem(self):
         worker = self.worker
         worker._quick_acquire = Mock()
 
         req = Mock()
-        worker.process_task_sem(req)
-        worker._quick_acquire.assert_called_with(worker.process_task, req)
+        worker._process_task_sem(req)
+        worker._quick_acquire.assert_called_with(worker._process_task, req)
 
     def test_signal_consumer_close(self):
         worker = self.worker
@@ -1124,17 +1072,11 @@ class test_WorkController(AppCase):
         for step in worker.steps:
             self.assertTrue(step.terminate.call_count)
 
-    def test_Queues_pool_not_rlimit_safe(self):
-        w = Mock()
-        w.pool_cls.rlimit_safe = False
-        components.Queues(w).create(w)
-        self.assertTrue(w.disable_rate_limits)
-
     def test_Queues_pool_no_sem(self):
         w = Mock()
         w.pool_cls.uses_semaphore = False
         components.Queues(w).create(w)
-        self.assertIs(w.ready_queue.put, w.process_task)
+        self.assertIs(w.process_task, w._process_task)
 
     def test_Hub_crate(self):
         w = Mock()

+ 3 - 5
celery/worker/__init__.py

@@ -59,7 +59,6 @@ class WorkController(configurated):
     send_events = from_config()
     pool_cls = from_config('pool')
     consumer_cls = from_config('consumer')
-    mediator_cls = from_config('mediator')
     timer_cls = from_config('timer')
     timer_precision = from_config('timer_precision')
     autoscaler_cls = from_config('autoscaler')
@@ -93,7 +92,6 @@ class WorkController(configurated):
             'celery.worker.components:Consumer',
             'celery.worker.autoscale:WorkerComponent',
             'celery.worker.autoreload:WorkerComponent',
-            'celery.worker.mediator:WorkerComponent',
 
         ])
 
@@ -206,10 +204,10 @@ class WorkController(configurated):
         except (KeyboardInterrupt, SystemExit):
             self.stop()
 
-    def process_task_sem(self, req):
-        return self._quick_acquire(self.process_task, req)
+    def _process_task_sem(self, req):
+        return self._quick_acquire(self._process_task, req)
 
-    def process_task(self, req):
+    def _process_task(self, req):
         """Process task by sending it to the pool of workers."""
         try:
             req.execute_using_pool(self.pool)

+ 0 - 383
celery/worker/buckets.py

@@ -1,383 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.worker.buckets
-    ~~~~~~~~~~~~~~~~~~~~~
-
-    This module implements the rate limiting of tasks,
-    by having a token bucket queue for each task type.
-    When a task is allowed to be processed it's moved
-    over the ``ready_queue``
-
-    The :mod:`celery.worker.mediator` is then responsible
-    for moving tasks from the ``ready_queue`` to the worker pool.
-
-"""
-from __future__ import absolute_import
-
-import threading
-
-from collections import deque
-from itertools import chain
-from time import time, sleep
-
-from kombu.utils.limits import TokenBucket
-
-from celery.five import Queue, Empty, values, zip_longest
-from celery.utils import timeutils
-
-
-class RateLimitExceeded(Exception):
-    """The token buckets rate limit has been exceeded."""
-
-
-class AsyncTaskBucket(object):
-
-    def __init__(self, task_registry, callback=None, worker=None):
-        self.task_registry = task_registry
-        self.callback = callback
-        self.worker = worker
-        self.buckets = {}
-        self.refresh()
-
-    def cont(self, request, bucket, tokens):
-        if not bucket.can_consume(tokens):
-            hold = bucket.expected_time(tokens)
-            self.worker.timer.apply_after(
-                hold * 1000.0, self.cont, (request, bucket, tokens),
-            )
-        else:
-            self.callback(request)
-
-    def put(self, request):
-        name = request.name
-        try:
-            bucket = self.buckets[name]
-        except KeyError:
-            bucket = self.add_bucket_for_type(name)
-        if not bucket:
-            return self.callback(request)
-        return self.cont(request, bucket, 1)
-
-    def add_task_type(self, name):
-        task_type = self.task_registry[name]
-        limit = getattr(task_type, 'rate_limit', None)
-        limit = timeutils.rate(limit)
-        bucket = self.buckets[name] = (
-            TokenBucket(limit, capacity=1) if limit else None
-        )
-        return bucket
-
-    def clear(self):
-        # called by the worker when the connection is lost,
-        # but this also clears out the timer so we be good.
-        pass
-
-    def refresh(self):
-        for name in self.task_registry:
-            self.add_task_type(name)
-
-
-class TaskBucket(object):
-    """This is a collection of token buckets, each task type having
-    its own token bucket.  If the task type doesn't have a rate limit,
-    it will have a plain :class:`~Queue.Queue` object instead of a
-    :class:`TokenBucketQueue`.
-
-    The :meth:`put` operation forwards the task to its appropriate bucket,
-    while the :meth:`get` operation iterates over the buckets and retrieves
-    the first available item.
-
-    Say we have three types of tasks in the registry: `twitter.update`,
-    `feed.refresh` and `video.compress`, the TaskBucket will consist
-    of the following items::
-
-        {'twitter.update': TokenBucketQueue(fill_rate=300),
-         'feed.refresh': Queue(),
-         'video.compress': TokenBucketQueue(fill_rate=2)}
-
-    The get operation will iterate over these until one of the buckets
-    is able to return an item.  The underlying datastructure is a `dict`,
-    so the order is ignored here.
-
-    :param task_registry: The task registry used to get the task
-                          type class for a given task name.
-
-    """
-
-    def __init__(self, task_registry, callback=None, worker=None):
-        self.task_registry = task_registry
-        self.buckets = {}
-        self.init_with_registry()
-        self.immediate = deque()
-        self.mutex = threading.Lock()
-        self.not_empty = threading.Condition(self.mutex)
-        self.callback = callback
-        self.worker = worker
-
-    def put(self, request):
-        """Put a :class:`~celery.worker.job.Request` into
-        the appropriate bucket."""
-        if request.name not in self.buckets:
-            self.add_bucket_for_type(request.name)
-        self.buckets[request.name].put_nowait(request)
-        with self.mutex:
-            self.not_empty.notify()
-    put_nowait = put
-
-    def _get_immediate(self):
-        try:
-            return self.immediate.popleft()
-        except IndexError:
-            raise Empty()
-
-    def _get(self):
-        # If the first bucket is always returning items, we would never
-        # get to fetch items from the other buckets. So we always iterate over
-        # all the buckets and put any ready items into a queue called
-        # "immediate". This queue is always checked for cached items first.
-        try:
-            return 0, self._get_immediate()
-        except Empty:
-            pass
-
-        remaining_times = []
-        for bucket in values(self.buckets):
-            remaining = bucket.expected_time()
-            if not remaining:
-                try:
-                    # Just put any ready items into the immediate queue.
-                    self.immediate.append(bucket.get_nowait())
-                except Empty:
-                    pass
-                except RateLimitExceeded:
-                    remaining_times.append(bucket.expected_time())
-            else:
-                remaining_times.append(remaining)
-
-        # Try the immediate queue again.
-        try:
-            return 0, self._get_immediate()
-        except Empty:
-            if not remaining_times:
-                # No items in any of the buckets.
-                raise
-
-            # There's items, but have to wait before we can retrieve them,
-            # return the shortest remaining time.
-            return min(remaining_times), None
-
-    def get(self, block=True, timeout=None):
-        """Retrieve the task from the first available bucket.
-
-        Available as in, there is an item in the queue and you can
-        consume tokens from it.
-
-        """
-        tstart = time()
-        get = self._get
-        not_empty = self.not_empty
-
-        with not_empty:
-            while 1:
-                try:
-                    remaining_time, item = get()
-                except Empty:
-                    if not block or (timeout and time() - tstart > timeout):
-                        raise
-                    not_empty.wait(timeout)
-                    continue
-                if remaining_time:
-                    if not block or (timeout and time() - tstart > timeout):
-                        raise Empty()
-                    sleep(min(remaining_time, timeout or 1))
-                else:
-                    return item
-
-    def get_nowait(self):
-        return self.get(block=False)
-
-    def init_with_registry(self):
-        """Initialize with buckets for all the task types in the registry."""
-        for task in self.task_registry:
-            self.add_bucket_for_type(task)
-
-    def refresh(self):
-        """Refresh rate limits for all task types in the registry."""
-        for task in self.task_registry:
-            self.update_bucket_for_type(task)
-
-    def get_bucket_for_type(self, task_name):
-        """Get the bucket for a particular task type."""
-        if task_name not in self.buckets:
-            return self.add_bucket_for_type(task_name)
-        return self.buckets[task_name]
-
-    def _get_queue_for_type(self, task_name):
-        bucket = self.buckets[task_name]
-        if isinstance(bucket, TokenBucketQueue):
-            return bucket.queue
-        return bucket
-
-    def update_bucket_for_type(self, task_name):
-        task_type = self.task_registry[task_name]
-        rate_limit = getattr(task_type, 'rate_limit', None)
-        rate_limit = timeutils.rate(rate_limit)
-        task_queue = FastQueue()
-        if task_name in self.buckets:
-            task_queue = self._get_queue_for_type(task_name)
-        else:
-            task_queue = FastQueue()
-
-        if rate_limit:
-            task_queue = TokenBucketQueue(rate_limit, queue=task_queue)
-
-        self.buckets[task_name] = task_queue
-        return task_queue
-
-    def add_bucket_for_type(self, task_name):
-        """Add a bucket for a task type.
-
-        Will read the tasks rate limit and create a :class:`TokenBucketQueue`
-        if it has one.  If the task doesn't have a rate limit
-        :class:`FastQueue` will be used instead.
-
-        """
-        if task_name not in self.buckets:
-            return self.update_bucket_for_type(task_name)
-
-    def qsize(self):
-        """Get the total size of all the queues."""
-        return sum(bucket.qsize() for bucket in values(self.buckets))
-
-    def empty(self):
-        """Returns :const:`True` if all of the buckets are empty."""
-        return all(bucket.empty() for bucket in values(self.buckets))
-
-    def clear(self):
-        """Delete the data in all of the buckets."""
-        for bucket in values(self.buckets):
-            bucket.clear()
-
-    @property
-    def items(self):
-        """Flattens the data in all of the buckets into a single list."""
-        # for queues with contents [(1, 2), (3, 4), (5, 6), (7, 8)]
-        # zips and flattens to [1, 3, 5, 7, 2, 4, 6, 8]
-        return [x for x in chain.from_iterable(zip_longest(
-            *[bucket.items for bucket in values(self.buckets)])) if x]
-
-
-class FastQueue(Queue):
-    """:class:`Queue.Queue` supporting the interface of
-    :class:`TokenBucketQueue`."""
-
-    def clear(self):
-        return self.queue.clear()
-
-    def expected_time(self, tokens=1):
-        return 0
-
-    def wait(self, block=True):
-        return self.get(block=block)
-
-    @property
-    def items(self):
-        return self.queue
-
-
-class TokenBucketQueue(object):
-    """Queue with rate limited get operations.
-
-    This uses the token bucket algorithm to rate limit the queue on get
-    operations.
-
-    :param fill_rate: The rate in tokens/second that the bucket will
-                      be refilled.
-    :keyword capacity: Maximum number of tokens in the bucket.
-                       Default is 1.
-
-    """
-    RateLimitExceeded = RateLimitExceeded
-
-    def __init__(self, fill_rate, queue=None, capacity=1):
-        self._bucket = TokenBucket(fill_rate, capacity)
-        self.queue = queue
-        if not self.queue:
-            self.queue = Queue()
-
-    def put(self, item, block=True):
-        """Put an item onto the queue."""
-        self.queue.put(item, block=block)
-
-    def put_nowait(self, item):
-        """Put an item into the queue without blocking.
-
-        :raises Queue.Full: If a free slot is not immediately available.
-
-        """
-        return self.put(item, block=False)
-
-    def get(self, block=True):
-        """Remove and return an item from the queue.
-
-        :raises RateLimitExceeded: If a token could not be consumed from the
-                                   token bucket (consuming from the queue
-                                   too fast).
-        :raises Queue.Empty: If an item is not immediately available.
-
-        """
-        get = block and self.queue.get or self.queue.get_nowait
-
-        if not block and not self.items:
-            raise Empty()
-
-        if not self._bucket.can_consume(1):
-            raise RateLimitExceeded()
-
-        return get()
-
-    def get_nowait(self):
-        """Remove and return an item from the queue without blocking.
-
-        :raises RateLimitExceeded: If a token could not be consumed from the
-                                   token bucket (consuming from the queue
-                                   too fast).
-        :raises Queue.Empty: If an item is not immediately available.
-
-        """
-        return self.get(block=False)
-
-    def qsize(self):
-        """Returns the size of the queue."""
-        return self.queue.qsize()
-
-    def empty(self):
-        """Returns :const:`True` if the queue is empty."""
-        return self.queue.empty()
-
-    def clear(self):
-        """Delete all data in the queue."""
-        return self.items.clear()
-
-    def wait(self, block=False):
-        """Wait until a token can be retrieved from the bucket and return
-        the next item."""
-        get = self.get
-        expected_time = self.expected_time
-        while 1:
-            remaining = expected_time()
-            if not remaining:
-                return get(block=block)
-            sleep(remaining)
-
-    def expected_time(self, tokens=1):
-        """Returns the expected time in seconds of when a new token should be
-        available."""
-        if not self.items:
-            return 0
-        return self._bucket.expected_time(tokens)
-
-    @property
-    def items(self):
-        """Underlying data.  Do not modify."""
-        return self.queue.queue

+ 8 - 19
celery/worker/components.py

@@ -21,7 +21,10 @@ from celery.utils.log import worker_logger as logger
 from celery.utils.timer2 import Schedule
 
 from . import hub
-from .buckets import AsyncTaskBucket, TaskBucket, FastQueue
+
+
+class Object(object):  # XXX
+    pass
 
 
 class Hub(bootsteps.StartStopStep):
@@ -44,26 +47,11 @@ class Queues(bootsteps.Step):
     label = 'Queues (intra)'
     requires = (Hub, )
 
-    def __init__(self, w, **kwargs):
-        w.start_mediator = False
-
     def create(self, w):
-        BucketType = TaskBucket
-        w.start_mediator = True
-        if not w.pool_cls.rlimit_safe:
-            w.disable_rate_limits = True
-        process_task = w.process_task
+        w.process_task = w._process_task
         if w.use_eventloop:
-            BucketType = AsyncTaskBucket
             if w.pool_putlocks and w.pool_cls.uses_semaphore:
-                process_task = w.process_task_sem
-        if w.disable_rate_limits:
-            w.ready_queue = FastQueue()
-            w.ready_queue.put = process_task
-        else:
-            w.ready_queue = BucketType(
-                task_registry=w.app.tasks, callback=process_task, worker=w,
-            )
+                w.process_task = w._process_task_sem
 
 
 class Pool(bootsteps.StartStopStep):
@@ -251,7 +239,7 @@ class Consumer(bootsteps.StartStopStep):
     def create(self, w):
         prefetch_count = w.concurrency * w.prefetch_multiplier
         c = w.consumer = self.instantiate(
-            w.consumer_cls, w.ready_queue,
+            w.consumer_cls, w.process_task,
             hostname=w.hostname,
             send_events=w.send_events,
             init_callback=w.ready_callback,
@@ -262,5 +250,6 @@ class Consumer(bootsteps.StartStopStep):
             controller=w,
             hub=w.hub,
             worker_options=w.options,
+            disable_rate_limits=w.disable_rate_limits,
         )
         return c

+ 32 - 54
celery/worker/consumer.py

@@ -25,6 +25,7 @@ from billiard.exceptions import RestartFreqExceeded
 from kombu.common import QoS, ignore_errors
 from kombu.syn import _detect_environment
 from kombu.utils.encoding import safe_repr
+from kombu.utils.limits import TokenBucket
 
 from celery import bootsteps
 from celery.app import app_or_default
@@ -34,8 +35,8 @@ from celery.task.trace import build_tracer
 from celery.utils.functional import noop
 from celery.utils.log import get_logger
 from celery.utils.text import truncate
-from celery.utils.timer2 import default_timer, to_timestamp
-from celery.utils.timeutils import humanize_seconds, timezone
+from celery.utils.timer2 import default_timer
+from celery.utils.timeutils import humanize_seconds, rate
 
 from . import heartbeat, loops, pidbox
 from .state import task_reserved, maybe_shutdown, revoked
@@ -106,9 +107,6 @@ def dump_body(m, body):
 
 class Consumer(object):
 
-    #: Intra-queue for tasks ready to be handled
-    ready_queue = None
-
     #: Optional callback called the first time the worker
     #: is ready to receive tasks.
     init_callback = None
@@ -139,14 +137,13 @@ class Consumer(object):
         def shutdown(self, parent):
             self.restart(parent, 'Shutdown', 'shutdown')
 
-    def __init__(self, ready_queue,
+    def __init__(self, handle_task,
                  init_callback=noop, hostname=None,
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
-                 worker_options=None, **kwargs):
+                 worker_options=None, disable_rate_limits=False, **kwargs):
         self.app = app_or_default(app)
         self.controller = controller
-        self.ready_queue = ready_queue
         self.init_callback = init_callback
         self.hostname = hostname or socket.gethostname()
         self.pool = pool
@@ -158,8 +155,14 @@ class Consumer(object):
         self._restart_state = restart_state(maxR=5, maxT=1)
 
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self._quick_put = self.ready_queue.put
+        self.handle_task = handle_task
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
+        self.disable_rate_limits = disable_rate_limits
+
+        # this contains a tokenbucket for each task type by name, used for
+        # rate limits, or None if rate limits are disabled for that task.
+        self.task_buckets = defaultdict(lambda: None)
+        self.reset_rate_limits()
 
         if hub:
             self.amqheartbeat = amqheartbeat
@@ -186,6 +189,25 @@ class Consumer(object):
         )
         self.namespace.apply(self, **dict(worker_options or {}, **kwargs))
 
+    def bucket_for_task(self, type):
+        limit = rate(getattr(type, 'rate_limit', None))
+        return TokenBucket(limit, capacity=1) if limit else None
+
+    def reset_rate_limits(self):
+        self.task_buckets.update(
+            (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
+        )
+
+    def _limit_task(self, request, bucket, tokens):
+        if not bucket.can_consume(tokens):
+            hold = bucket.expected_time(tokens)
+            self.timer.apply_after(
+                hold * 1000.0, self._limit_task, (request, bucket, tokens),
+            )
+        else:
+            task_reserved(request)
+            self.handle_task(request)
+
     def start(self):
         ns, loop = self.namespace, self.loop
         while ns.state != CLOSE:
@@ -246,7 +268,6 @@ class Consumer(object):
         # Clear internal queues to get rid of old messages.
         # They can't be acked anyway, as a delivery tag is specific
         # to the current channel.
-        self.ready_queue.clear()
         self.timer.clear()
 
     def connect(self):
@@ -304,54 +325,11 @@ class Consumer(object):
         self.app.amqp.queues.select_remove(queue)
         self.task_consumer.cancel_by_queue(queue)
 
-    def on_task(self, task, task_reserved=task_reserved,
-                to_system_tz=timezone.to_system):
-        """Handle received task.
-
-        If the task has an `eta` we enter it into the ETA schedule,
-        otherwise we move it the ready queue for immediate processing.
-
-        """
-        if task.revoked():
-            return
-
-        if self._does_info:
-            info('Got task from broker: %s', task)
-
-        if self.event_dispatcher.enabled:
-            self.event_dispatcher.send(
-                'task-received',
-                uuid=task.id, name=task.name,
-                args=safe_repr(task.args), kwargs=safe_repr(task.kwargs),
-                retries=task.request_dict.get('retries', 0),
-                eta=task.eta and task.eta.isoformat(),
-                expires=task.expires and task.expires.isoformat(),
-            )
-
-        if task.eta:
-            try:
-                if task.utc:
-                    eta = to_timestamp(to_system_tz(task.eta))
-                else:
-                    eta = to_timestamp(task.eta, timezone.local)
-            except OverflowError as exc:
-                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
-                      task.eta, exc, task.info(safe=True), exc_info=True)
-                task.acknowledge()
-            else:
-                self.qos.increment_eventually()
-                self.timer.apply_at(
-                    eta, self.apply_eta_task, (task, ), priority=6,
-                )
-        else:
-            task_reserved(task)
-            self._quick_put(task)
-
     def apply_eta_task(self, task):
         """Method called by the timer to apply a task with an
         ETA/countdown."""
         task_reserved(task)
-        self._quick_put(task)
+        self.handle_task(task)
         self.qos.decrement_eventually()
 
     def _message_report(self, body, message):

+ 1 - 5
celery/worker/control.py

@@ -108,11 +108,7 @@ def rate_limit(panel, task_name, rate_limit, **kwargs):
                      task_name, exc_info=True)
         return {'error': 'unknown task'}
 
-    if not hasattr(panel.consumer.ready_queue, 'refresh'):
-        logger.error('Rate limit attempt, but rate limits disabled.')
-        return {'error': 'rate limits disabled'}
-
-    panel.consumer.ready_queue.refresh()
+    panel.consumer.reset_rate_limits()
 
     if not rate_limit:
         logger.info('Rate limits disabled for tasks of type %s', task_name)

+ 1 - 0
celery/worker/job.py

@@ -49,6 +49,7 @@ _does_debug = False
 
 
 def __optimize__():
+    # this is also called by celery.task.trace.setup_worker_optimizations
     global _does_debug
     global _does_info
     _does_debug = logger.isEnabledFor(logging.DEBUG)

+ 0 - 81
celery/worker/mediator.py

@@ -1,81 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    celery.worker.mediator
-    ~~~~~~~~~~~~~~~~~~~~~~
-
-    The mediator is an internal thread that moves tasks
-    from an internal :class:`Queue` to the worker pool.
-
-    This is only used if rate limits are enabled, as it moves
-    messages from the rate limited queue (which holds tasks
-    that are allowed to be processed) to the pool. Disabling
-    rate limits will also disable this machinery,
-    and can improve performance.
-
-"""
-from __future__ import absolute_import
-
-import logging
-
-from celery.app import app_or_default
-from celery.bootsteps import StartStopStep
-from celery.five import Empty
-from celery.utils.threads import bgThread
-from celery.utils.log import get_logger
-
-from . import components
-
-logger = get_logger(__name__)
-
-
-class WorkerComponent(StartStopStep):
-    label = 'Mediator'
-    conditional = True
-    requires = (components.Pool, components.Queues, )
-
-    def __init__(self, w, **kwargs):
-        w.mediator = None
-
-    def include_if(self, w):
-        return w.start_mediator and not w.use_eventloop
-
-    def create(self, w):
-        m = w.mediator = self.instantiate(w.mediator_cls, w.ready_queue,
-                                          app=w.app, callback=w.process_task)
-        return m
-
-
-class Mediator(bgThread):
-    """Mediator thread."""
-
-    #: The task queue, a :class:`~Queue.Queue` instance.
-    ready_queue = None
-
-    #: Callback called when a task is obtained.
-    callback = None
-
-    def __init__(self, ready_queue, callback, app=None, **kw):
-        self.app = app_or_default(app)
-        self.ready_queue = ready_queue
-        self.callback = callback
-        self._does_debug = logger.isEnabledFor(logging.DEBUG)
-        super(Mediator, self).__init__()
-
-    def body(self):
-        try:
-            task = self.ready_queue.get(timeout=1.0)
-        except Empty:
-            return
-
-        if self._does_debug:
-            logger.debug('Mediator: Running callback for task: %s[%s]',
-                         task.name, task.id)
-
-        try:
-            self.callback(task)
-        except Exception as exc:
-            logger.error('Mediator callback raised exception %r',
-                         exc, exc_info=True,
-                         extra={'data': {'id': task.id,
-                                         'name': task.name,
-                                         'hostname': task.hostname}})

+ 66 - 7
celery/worker/strategy.py

@@ -8,19 +8,78 @@
 """
 from __future__ import absolute_import
 
+import logging
+
+from kombu.utils.encoding import safe_repr
+
+from celery.utils.log import get_logger
+from celery.utils.timer2 import to_timestamp
+from celery.utils.timeutils import timezone
+
+logger = get_logger(__name__)
+
 from .job import Request
+from .state import task_reserved
 
 
-def default(task, app, consumer):
+def default(task, app, consumer,
+            info=logger.info, error=logger.error, task_reserved=task_reserved,
+            to_system_tz=timezone.to_system):
     hostname = consumer.hostname
     eventer = consumer.event_dispatcher
     Req = Request
-    handle = consumer.on_task
     connection_errors = consumer.connection_errors
+    _does_info = logger.isEnabledFor(logging.INFO)
+    events = eventer and eventer.enabled
+    send_event = eventer.send
+    timer_apply_at = consumer.timer.apply_at
+    apply_eta_task = consumer.apply_eta_task
+    rate_limits_enabled = not consumer.disable_rate_limits
+    bucket = consumer.task_buckets[task.name]
+    handle = consumer.handle_task
+    limit_task = consumer._limit_task
+
+    def task_message_handler(message, body, ack, to_timestamp=to_timestamp):
+        req = Req(body, on_ack=ack, app=app, hostname=hostname,
+                  eventer=eventer, task=task,
+                  connection_errors=connection_errors,
+                  delivery_info=message.delivery_info)
+        if req.revoked():
+            return
+
+        if _does_info:
+            info('Got task from broker: %s', req)
+
+        if events:
+            send_event(
+                'task-received',
+                uuid=req.id, name=req.name,
+                args=safe_repr(req.args), kwargs=safe_repr(req.kwargs),
+                retries=req.request_dict.get('retries', 0),
+                eta=req.eta and req.eta.isoformat(),
+                expires=req.expires and req.expires.isoformat(),
+            )
+
+        if req.eta:
+            try:
+                if req.utc:
+                    eta = to_timestamp(to_system_tz(req.eta))
+                else:
+                    eta = to_timestamp(req.eta, timezone.local)
+            except OverflowError as exc:
+                error("Couldn't convert eta %s to timestamp: %r. Task: %r",
+                      req.eta, exc, req.info(safe=True), exc_info=True)
+                req.acknowledge()
+            else:
+                consumer.qos.increment_eventually()
+                timer_apply_at(
+                    eta, apply_eta_task, (req, ), priority=6,
+                )
+        else:
+            if rate_limits_enabled:
+                if bucket:
+                    return limit_task(req, bucket, 1)
+            task_reserved(req)
+            handle(req)
 
-    def task_message_handler(message, body, ack):
-        handle(Req(body, on_ack=ack, app=app, hostname=hostname,
-                   eventer=eventer, task=task,
-                   connection_errors=connection_errors,
-                   delivery_info=message.delivery_info))
     return task_message_handler

+ 0 - 8
docs/configuration.rst

@@ -1602,14 +1602,6 @@ CELERYD_CONSUMER
 Name of the consumer class used by the worker.
 Default is :class:`celery.worker.consumer.Consumer`
 
-.. setting:: CELERYD_MEDIATOR
-
-CELERYD_MEDIATOR
-~~~~~~~~~~~~~~~~
-
-Name of the mediator class used by the worker.
-Default is :class:`celery.worker.controllers.Mediator`.
-
 .. setting:: CELERYD_TIMER
 
 CELERYD_TIMER

+ 11 - 23
docs/internals/worker.rst

@@ -19,16 +19,11 @@ with two data structures: the ready queue and the ETA schedule.
 Data structures
 ===============
 
-ready_queue
------------
+timer
+-----
 
-The ready queue is either an instance of :class:`Queue.Queue`, or
-:class:`celery.buckets.TaskBucket`.  The latter if rate limiting is enabled.
-
-eta_schedule
-------------
-
-The ETA schedule is a heap queue sorted by time.
+The timer uses :mod:`heapq` to schedule internal functions.
+It's very efficient and can handle hundred of thousands of entries.
 
 
 Components
@@ -44,22 +39,15 @@ Receives messages from the broker using `Kombu`_.
 When a message is received it's converted into a
 :class:`celery.worker.job.TaskRequest` object.
 
-Tasks with an ETA are entered into the `eta_schedule`, messages that can
-be immediately processed are moved directly to the `ready_queue`.
+Tasks with an ETA, or rate-limit are entered into the `timer`,
+messages that can be immediately processed are sent to the execution pool.
 
-ScheduleController
-------------------
+Timer
+-----
 
-The schedule controller is running the `eta_schedule`.
-If the scheduled tasks eta has passed it is moved to the `ready_queue`,
-otherwise the thread sleeps until the eta is met (remember that the schedule
-is sorted by time).
-
-Mediator
---------
-The mediator simply moves tasks in the `ready_queue` over to the
-task pool for execution using
-:meth:`celery.worker.job.TaskRequest.execute_using_pool`.
+The timer schedules internal functions, like cleanup and internal monitoring,
+but also it schedules ETA tasks and rate limited tasks.
+If the scheduled tasks eta has passed it is moved to the execution pool.
 
 TaskPool
 --------

+ 0 - 17
docs/userguide/optimizing.rst

@@ -151,20 +151,3 @@ You can enable this behavior by using the following configuration options:
 
     CELERY_ACKS_LATE = True
     CELERYD_PREFETCH_MULTIPLIER = 1
-
-.. optimizing-rate-limits:
-
-Rate Limits
------------
-
-The system responsible for enforcing rate limits introduces some overhead,
-so if you're not using rate limits it may be a good idea to
-disable them completely.  This will disable one thread, and it won't
-spend as many CPU cycles when the queue is inactive.
-
-Set the :setting:`CELERY_DISABLE_RATE_LIMITS` setting to disable
-the rate limit subsystem:
-
-.. code-block:: python
-
-    CELERY_DISABLE_RATE_LIMITS = True

+ 2 - 4
funtests/benchmarks/req.py

@@ -19,9 +19,8 @@ P = TaskPool()
 hostname = socket.gethostname()
 task = {'task': T.name, 'args': (), 'kwargs': {}, 'id': tid, 'flags': 0}
 app = current_app._get_current_object()
-ready_queue = Queue()
 
-def on_put(req):
+def on_task(req):
     req.execute_using_pool(P)
 
 def on_ack(*a): pass
@@ -29,8 +28,7 @@ def on_ack(*a): pass
 
 m = Message(None, {}, {}, task)
 
-ready_queue.put = on_put
-x = Consumer(ready_queue, hostname=hostname, app=app)
+x = Consumer(on_task, hostname=hostname, app=app)
 x.update_strategies()
 name = T.name
 ts = time()

+ 2 - 4
funtests/benchmarks/trace.py

@@ -19,9 +19,8 @@ P = TaskPool()
 hostname = socket.gethostname()
 task = {'task': T.name, 'args': (), 'kwargs': {}, 'id': tid, 'flags': 0}
 app = current_app._get_current_object()
-ready_queue = Queue()
 
-def on_put(req):
+def on_task(req):
     req.execute_using_pool(P)
 
 def on_ack(*a): pass
@@ -29,8 +28,7 @@ def on_ack(*a): pass
 
 m = Message(None, {}, {}, task)
 
-ready_queue.put = on_put
-x = Consumer(ready_queue, hostname=hostname, app=app)
+x = Consumer(on_task, hostname=hostname, app=app)
 x.update_strategies()
 name = T.name
 ts = time()