Browse Source

Tests passing

Ask Solem 11 years ago
parent
commit
6676276426

+ 1 - 1
celery/tests/backends/test_base.py

@@ -79,7 +79,7 @@ class test_exception_pickle(Case):
     def test_oldstyle(self):
         if Oldstyle is None:
             raise SkipTest('py3k does not support old style classes')
-        self.assertIsNone(fnpe(Oldstyle()))
+        self.assertTrue(fnpe(Oldstyle()))
 
     def test_BaseException(self):
         self.assertIsNone(fnpe(Exception()))

+ 1 - 1
celery/tests/utilities/test_timer2.py

@@ -30,7 +30,7 @@ class test_Entry(Case):
         self.assertTrue(tref.cancelled)
 
     def test_repr(self):
-        tref = tiemr2.Entry(lambda x: x (1, ), {})
+        tref = timer2.Entry(lambda x: x (1, ), {})
         self.assertTrue(repr(tref))
 
 

+ 1 - 2
celery/tests/utilities/test_timeutils.py

@@ -248,8 +248,7 @@ class test_ffwd(Case):
 
     def test_radd_with_unknown_gives_NotImplemented(self):
         x = ffwd(year=2012)
-        with self.assertRaises(TypeError):
-            x.__radd__(object())
+        self.assertEqual(x.__radd__(object()), NotImplemented)
 
 
 class test_utcoffset(Case):

+ 2 - 2
celery/tests/worker/test_consumer.py

@@ -28,7 +28,7 @@ class test_Consumer(AppCase):
 
     def get_consumer(self, no_hub=False, **kwargs):
         consumer = Consumer(
-            handle_task=Mock(),
+            on_task=Mock(),
             init_callback=Mock(),
             pool=Mock(),
             app=self.app,
@@ -91,7 +91,7 @@ class test_Consumer(AppCase):
             c._limit_task(request, bucket, 3)
             bucket.can_consume.assert_called_with(3)
             reserved.assert_called_with(request)
-            c.handle_task.assert_called_with(request)
+            c.on_task.assert_called_with(request)
 
         with patch('celery.worker.consumer.task_reserved') as reserved:
             bucket.can_consume.return_value = False

+ 23 - 12
celery/tests/worker/test_loops.py

@@ -7,6 +7,7 @@ from mock import Mock
 from celery.exceptions import InvalidTaskError, SystemTerminate
 from celery.five import Empty
 from celery.worker import state
+from celery.worker.consumer import Consumer
 from celery.worker.loops import asynloop, synloop, CLOSE, READ, WRITE, ERR
 
 from celery.tests.utils import AppCase, body_from_sig
@@ -19,29 +20,22 @@ class X(object):
             self.obj,
             self.connection,
             self.consumer,
-            self.strategies,
             self.blueprint,
             self.hub,
             self.qos,
             self.heartbeat,
-            self.handle_unknown_message,
-            self.handle_unknown_task,
-            self.handle_invalid_task,
             self.clock,
         ) = self.args = [Mock(name='obj'),
                          Mock(name='connection'),
                          Mock(name='consumer'),
-                         {},
                          Mock(name='blueprint'),
                          Mock(name='Hub'),
                          Mock(name='qos'),
                          heartbeat,
-                         Mock(name='handle_unknown_message'),
-                         Mock(name='handle_unknown_task'),
-                         Mock(name='handle_invalid_task'),
                          Mock(name='clock')]
         self.connection.supports_heartbeats = True
         self.consumer.callbacks = []
+        self.obj.strategies = {}
         self.connection.connection_errors = (socket.error, )
         #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
         #self.Hub.__exit__ = Mock(name='Hub.__exit__')
@@ -51,6 +45,23 @@ class X(object):
         self.hub.writers = {}
         self.hub.fire_timers.return_value = 1.7
         self.Hub = self.hub
+        # need this for create_task_handler
+        _consumer = Consumer(Mock(), timer=Mock())
+        self.obj.create_task_handler = _consumer.create_task_handler
+        self.on_unknown_message = self.obj.on_unknown_message = Mock(
+            name='on_unknown_message',
+        )
+        _consumer.on_unknown_message = self.on_unknown_message
+        self.on_unknown_task = self.obj.on_unknown_task = Mock(
+            name='on_unknown_task',
+        )
+        _consumer.on_unknown_task = self.on_unknown_task
+        self.on_invalid_task = self.obj.on_invalid_task = Mock(
+            name='on_invalid_task',
+        )
+        _consumer.on_invalid_task = self.on_invalid_task
+        _consumer.strategies = self.obj.strategies
+
 
     def timeout_then_error(self, mock):
 
@@ -107,7 +118,7 @@ class test_asynloop(AppCase):
         x, on_task = get_task_callback(**kwargs)
         body = body_from_sig(self.app, sig)
         message = Mock()
-        strategy = x.strategies[sig.task] = Mock()
+        strategy = x.obj.strategies[sig.task] = Mock()
         return x, on_task, body, message, strategy
 
     def test_on_task_received(self):
@@ -127,19 +138,19 @@ class test_asynloop(AppCase):
         x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
         body.pop('task')
         on_task(body, msg)
-        x.handle_unknown_message.assert_called_with(body, msg)
+        x.on_unknown_message.assert_called_with(body, msg)
 
     def test_on_task_not_registered(self):
         x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
         exc = strategy.side_effect = KeyError(self.add.name)
         on_task(body, msg)
-        x.handle_unknown_task.assert_called_with(body, msg, exc)
+        x.on_unknown_task.assert_called_with(body, msg, exc)
 
     def test_on_task_InvalidTaskError(self):
         x, on_task, body, msg, strategy = self.task_context(self.add.s(2, 2))
         exc = strategy.side_effect = InvalidTaskError()
         on_task(body, msg)
-        x.handle_invalid_task.assert_called_with(body, msg, exc)
+        x.on_invalid_task.assert_called_with(body, msg, exc)
 
     def test_should_terminate(self):
         x = X()

+ 1 - 1
celery/tests/worker/test_strategy.py

@@ -94,7 +94,7 @@ class test_default_strategy(AppCase):
             C()
             self.assertTrue(C.was_reserved())
             req = C.get_request()
-            C.consumer.handle_task.assert_called_with(req)
+            C.consumer.on_task.assert_called_with(req)
             self.assertTrue(C.event_sent())
 
     def test_when_events_disabled(self):

+ 4 - 18
celery/utils/serialization.py

@@ -8,7 +8,8 @@
 """
 from __future__ import absolute_import
 
-import inspect
+from inspect import getmro
+from itertools import takewhile
 
 try:
     import cPickle as pickle
@@ -21,7 +22,7 @@ from .encoding import safe_repr
 #: List of base classes we probably don't want to reduce to.
 try:
     unwanted_base_classes = (StandardError, Exception, BaseException, object)
-except NameError:
+except NameError:  # pragma: no cover
     unwanted_base_classes = (Exception, BaseException, object)  # py3k
 
 
@@ -58,22 +59,7 @@ find_nearest_pickleable_exception = find_pickleable_exception  # XXX compat
 
 
 def itermro(cls, stop):
-    getmro_ = getattr(cls, 'mro', None)
-
-    # old-style classes doesn't have mro()
-    if not getmro_:  # pragma: no cover
-        # all Py2.4 exceptions has a baseclass.
-        if not getattr(cls, '__bases__', ()):
-            return
-        # Use inspect.getmro() to traverse bases instead.
-        getmro_ = lambda: inspect.getmro(cls)
-
-    for supercls in getmro_():
-        if supercls in stop:
-            # only BaseException and object, from here on down,
-            # we don't care about these.
-            return
-        yield supercls
+    return takewhile(lambda sup: sup not in stop, getmro(cls))
 
 
 def create_exception_cls(name, module, parent=None):

+ 14 - 12
celery/worker/consumer.py

@@ -33,6 +33,7 @@ from kombu.utils.limits import TokenBucket
 from celery import bootsteps
 from celery.app import app_or_default
 from celery.canvas import subtask
+from celery.exceptions import InvalidTaskError
 from celery.five import items, values
 from celery.task.trace import build_tracer
 from celery.utils.functional import noop
@@ -153,7 +154,7 @@ class Consumer(object):
         def shutdown(self, parent):
             self.send_all(parent, 'shutdown')
 
-    def __init__(self, handle_task,
+    def __init__(self, on_task,
                  init_callback=noop, hostname=None,
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
@@ -172,7 +173,7 @@ class Consumer(object):
         self._restart_state = restart_state(maxR=5, maxT=1)
 
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self.handle_task = handle_task
+        self.on_task = on_task
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.disable_rate_limits = disable_rate_limits
 
@@ -223,7 +224,7 @@ class Consumer(object):
             )
         else:
             task_reserved(request)
-            self.handle_task(request)
+            self.on_task(request)
 
     def start(self):
         blueprint, loop = self.blueprint, self.loop
@@ -264,9 +265,7 @@ class Consumer(object):
 
     def loop_args(self):
         return (self, self.connection, self.task_consumer,
-                self.strategies, self.blueprint, self.hub, self.qos,
-                self.amqheartbeat, self.handle_unknown_message,
-                self.handle_unknown_task, self.handle_invalid_task,
+                self.blueprint, self.hub, self.qos, self.amqheartbeat,
                 self.app.clock, self.amqheartbeat_rate)
 
     def on_poll_init(self, hub):
@@ -360,7 +359,7 @@ class Consumer(object):
         """Method called by the timer to apply a task with an
         ETA/countdown."""
         task_reserved(task)
-        self.handle_task(task)
+        self.on_task(task)
         self.qos.decrement_eventually()
 
     def _message_report(self, body, message):
@@ -369,15 +368,15 @@ class Consumer(object):
                                      safe_repr(message.content_encoding),
                                      safe_repr(message.delivery_info))
 
-    def handle_unknown_message(self, body, message):
+    def on_unknown_message(self, body, message):
         warn(UNKNOWN_FORMAT, self._message_report(body, message))
         message.reject_log_error(logger, self.connection_errors)
 
-    def handle_unknown_task(self, body, message, exc):
+    def on_unknown_task(self, body, message, exc):
         error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         message.reject_log_error(logger, self.connection_errors)
 
-    def handle_invalid_task(self, body, message, exc):
+    def on_invalid_task(self, body, message, exc):
         error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
         message.reject_log_error(logger, self.connection_errors)
 
@@ -387,8 +386,11 @@ class Consumer(object):
             self.strategies[name] = task.start_strategy(self.app, self)
             task.__trace__ = build_tracer(name, task, loader, self.hostname)
 
-    def create_task_handler(self, strategies, callbacks,
-            on_unknown_message, on_unknown_task, on_invalid_task):
+    def create_task_handler(self, callbacks):
+        strategies = self.strategies
+        on_unknown_message = self.on_unknown_message
+        on_unknown_task = self.on_unknown_task
+        on_invalid_task = self.on_invalid_task
 
         def on_task_received(body, message):
             if callbacks:

+ 7 - 13
celery/worker/loops.py

@@ -15,7 +15,7 @@ from types import GeneratorType as generator
 from kombu.utils.eventio import READ, WRITE, ERR
 
 from celery.bootsteps import CLOSE
-from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.exceptions import SystemTerminate
 from celery.five import Empty
 from celery.utils.log import get_logger
 
@@ -25,9 +25,8 @@ logger = get_logger(__name__)
 error = logger.error
 
 
-def asynloop(obj, connection, consumer, strategies, blueprint, hub, qos,
-             heartbeat, handle_unknown_message, handle_unknown_task,
-             handle_invalid_task, clock, hbrate=2.0,
+def asynloop(obj, connection, consumer, blueprint, hub, qos,
+             heartbeat, clock, hbrate=2.0,
              sleep=sleep, min=min, Empty=Empty):
     """Non-blocking eventloop consuming messages until connection is lost,
     or shutdown is requested."""
@@ -51,9 +50,7 @@ def asynloop(obj, connection, consumer, strategies, blueprint, hub, qos,
     errors = connection.connection_errors
     hub_add, hub_remove = hub.add, hub.remove
 
-    on_task_received = obj.create_task_handler(
-        strategies, on_task_callbacks, handle_unknown_message,
-        handle_unknown_task, handle_invalid_task)
+    on_task_received = obj.create_task_handler(on_task_callbacks)
 
     if heartbeat and connection.supports_heartbeats:
         hub.timer.apply_interval(
@@ -144,14 +141,11 @@ def asynloop(obj, connection, consumer, strategies, blueprint, hub, qos,
             )
 
 
-def synloop(obj, connection, consumer, strategies, blueprint, hub, qos,
-            heartbeat, handle_unknown_message, handle_unknown_task,
-            handle_invalid_task, clock, hbrate=2.0, **kwargs):
+def synloop(obj, connection, consumer, blueprint, hub, qos,
+            heartbeat, clock, hbrate=2.0, **kwargs):
     """Fallback blocking eventloop for transports that doesn't support AIO."""
 
-    on_task_received = obj.create_task_handler(
-        strategies, [], handle_unknown_message,
-        handle_unknown_task, handle_invalid_task)
+    on_task_received = obj.create_task_handler([])
     consumer.register_callback(on_task_received)
     consumer.consume()
 

+ 1 - 1
celery/worker/strategy.py

@@ -36,7 +36,7 @@ def default(task, app, consumer,
     apply_eta_task = consumer.apply_eta_task
     rate_limits_enabled = not consumer.disable_rate_limits
     bucket = consumer.task_buckets[task.name]
-    handle = consumer.handle_task
+    handle = consumer.on_task
     limit_task = consumer._limit_task
 
     def task_message_handler(message, body, ack, to_timestamp=to_timestamp):