Преглед на файлове

100% coverage for celery.worker.strategy

Ask Solem преди 12 години
родител
ревизия
8d7f07a7b3
променени са 4 файла, в които са добавени 167 реда и са изтрити 3 реда
  1. 2 2
      celery/app/task.py
  2. 2 1
      celery/datastructures.py
  3. 162 0
      celery/tests/worker/test_strategy.py
  4. 1 0
      celery/worker/strategy.py

+ 2 - 2
celery/app/task.py

@@ -364,8 +364,8 @@ class Task(object):
         """The body of the task executed by workers."""
         raise NotImplementedError('Tasks must define the run method.')
 
-    def start_strategy(self, app, consumer):
-        return instantiate(self.Strategy, self, app, consumer)
+    def start_strategy(self, app, consumer, **kwargs):
+        return instantiate(self.Strategy, self, app, consumer, **kwargs)
 
     def delay(self, *args, **kwargs):
         """Star argument version of :meth:`apply_async`.

+ 2 - 1
celery/datastructures.py

@@ -566,7 +566,7 @@ class LimitedSet(object):
         self._data.clear()
         self._heap[:] = []
 
-    def pop_value(self, value):
+    def discard(self, value):
         """Remove membership by finding value."""
         try:
             itime = self._data[value]
@@ -577,6 +577,7 @@ class LimitedSet(object):
         except ValueError:
             pass
         self._data.pop(value, None)
+    pop_value = discard  # XXX compat
 
     def _expire_item(self):
         """Hunt down and remove an expired item."""

+ 162 - 0
celery/tests/worker/test_strategy.py

@@ -0,0 +1,162 @@
+from __future__ import absolute_import
+
+from collections import defaultdict
+from contextlib import contextmanager
+from datetime import timedelta
+from mock import Mock, patch
+
+from kombu.utils.limits import TokenBucket
+
+from celery import Celery
+from celery.worker import state
+from celery.utils.timeutils import rate
+
+from celery.tests.utils import AppCase
+
+
+def body_from_sig(app, sig, utc=True):
+    sig._freeze()
+    callbacks = sig.options.pop('link', None)
+    errbacks = sig.options.pop('link_error', None)
+    countdown = sig.options.pop('countdown', None)
+    if countdown:
+        sig.options['eta'] = app.now() + timedelta(seconds=countdown)
+    eta = sig.options.pop('eta', None)
+    eta = eta.isoformat() if eta else None
+    return {
+        'task': sig.task,
+        'id': sig.id,
+        'args': sig.args,
+        'kwargs': sig.kwargs,
+        'callbacks': [dict(s) for s in callbacks] if callbacks else None,
+        'errbacks': [dict(s) for s in errbacks] if errbacks else None,
+        'eta': eta,
+        'utc': utc,
+    }
+
+
+class test_default_strategy(AppCase):
+
+    class Context(object):
+
+        def __init__(self, sig, s, reserved, consumer, message, body):
+            self.sig = sig
+            self.s = s
+            self.reserved = reserved
+            self.consumer = consumer
+            self.message = message
+            self.body = body
+
+        def __call__(self, **kwargs):
+            return self.s(self.message, self.body, self.message.ack, **kwargs)
+
+        def was_reserved(self):
+            return self.reserved.called
+
+        def was_rate_limited(self):
+            assert not self.was_reserved()
+            return self.consumer._limit_task.called
+
+        def was_scheduled(self):
+            assert not self.was_reserved()
+            assert not self.was_rate_limited()
+            return self.consumer.timer.apply_at.called
+
+        def event_sent(self):
+            return self.consumer.event_dispatcher.send.call_args
+
+        def get_request(self):
+            if self.was_reserved():
+                return self.reserved.call_args[0][0]
+            if self.was_rate_limited():
+                return self.consumer._limit_task.call_args[0][0]
+            if self.was_scheduled():
+                return self.consumer.timer.apply_at.call_args[0][0]
+            raise ValueError('request not handled')
+
+    def setup(self):
+        self.c = Celery(set_as_current=False)
+
+        @self.c.task()
+        def add(x, y):
+            return x + y
+
+        self.add = add
+
+    @contextmanager
+    def _context(self, sig,
+                 rate_limits=True, events=True, utc=True, limit=None):
+        self.assertTrue(sig.type.Strategy)
+
+        reserved = Mock()
+        consumer = Mock()
+        consumer.task_buckets = defaultdict(lambda: None)
+        if limit:
+            bucket = TokenBucket(rate(limit), capacity=1)
+            consumer.task_buckets[sig.task] = bucket
+        consumer.disable_rate_limits = not rate_limits
+        consumer.event_dispatcher.enabled = events
+        s = sig.type.start_strategy(self.c, consumer, task_reserved=reserved)
+        self.assertTrue(s)
+
+        message = Mock()
+        body = body_from_sig(self.c, sig, utc=utc)
+
+        yield self.Context(sig, s, reserved, consumer, message, body)
+
+    def test_when_logging_disabled(self):
+        with patch('celery.worker.strategy.logger') as logger:
+            logger.isEnabledFor.return_value = False
+            with self._context(self.add.s(2, 2)) as C:
+                C()
+                self.assertFalse(logger.info.called)
+
+    def test_task_strategy(self):
+        with self._context(self.add.s(2, 2)) as C:
+            C()
+            self.assertTrue(C.was_reserved())
+            req = C.get_request()
+            C.consumer.handle_task.assert_called_with(req)
+            self.assertTrue(C.event_sent())
+
+    def test_when_events_disabled(self):
+        with self._context(self.add.s(2, 2), events=False) as C:
+            C()
+            self.assertTrue(C.was_reserved())
+            self.assertFalse(C.event_sent())
+
+    def test_eta_task(self):
+        with self._context(self.add.s(2, 2).set(countdown=10)) as C:
+            C()
+            self.assertTrue(C.was_scheduled())
+            C.consumer.qos.increment_eventually.assert_called_with()
+
+    def test_eta_task_utc_disabled(self):
+        with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
+            C()
+            self.assertTrue(C.was_scheduled())
+            C.consumer.qos.increment_eventually.assert_called_with()
+
+    def test_when_rate_limited(self):
+        task = self.add.s(2, 2)
+        with self._context(task, rate_limits=True, limit='1/m') as C:
+            C()
+            self.assertTrue(C.was_rate_limited())
+
+    def test_when_rate_limited__limits_disabled(self):
+        task = self.add.s(2, 2)
+        with self._context(task, rate_limits=False, limit='1/m') as C:
+            C()
+            self.assertTrue(C.was_reserved())
+
+    def test_when_revoked(self):
+        task = self.add.s(2, 2)
+        task._freeze()
+        state.revoked.add(task.id)
+        try:
+            with self._context(task) as C:
+                C()
+                with self.assertRaises(ValueError):
+                    C.get_request()
+        finally:
+            state.revoked.discard(task.id)

+ 1 - 0
celery/worker/strategy.py

@@ -76,6 +76,7 @@ def default(task, app, consumer,
                     eta, apply_eta_task, (req, ), priority=6,
                 )
         else:
+            print('BUCKET: %r' % (bucket, ))
             if rate_limits_enabled:
                 if bucket:
                     return limit_task(req, bucket, 1)