Browse Source

100% Coverage for celery.events, events.snapshot, and events.state

Ask Solem 14 years ago
parent
commit
544f2ee87c

+ 9 - 7
celery/events/__init__.py

@@ -100,14 +100,14 @@ class EventDispatcher(object):
             except Exception, exc:
                 if not self.buffer_while_offline:
                     raise
-                self._outbound_buffer.append((event, exc))
+                self._outbound_buffer.append((type, fields, exc))
         finally:
             self._lock.release()
 
     def flush(self):
         while self._outbound_buffer:
-            event, _ = self._outbound_buffer.popleft()
-            self.publisher.send(event)
+            type, fields, _ = self._outbound_buffer.popleft()
+            self.send(type, **fields)
 
     def close(self):
         """Close the event dispatcher."""
@@ -129,13 +129,13 @@ class EventReceiver(object):
     handlers = {}
 
     def __init__(self, connection, handlers=None, routing_key="#",
-            app=None):
+            node_id=None, app=None):
         self.app = app_or_default(app)
         self.connection = connection
         if handlers is not None:
             self.handlers = handlers
         self.routing_key = routing_key
-        self.node_id = gen_unique_id()
+        self.node_id = node_id or gen_unique_id()
         self.queue = Queue("%s.%s" % ("celeryev", self.node_id),
                            exchange=event_exchange,
                            routing_key=self.routing_key,
@@ -195,7 +195,7 @@ class EventReceiver(object):
 
     def drain_events(self, limit=None, timeout=None):
         for iteration in count(0):
-            if limit and iteration > limit:
+            if limit and iteration >= limit:
                 break
             try:
                 self.connection.drain_events(timeout=timeout)
@@ -215,10 +215,12 @@ class Events(object):
     def __init__(self, app=None):
         self.app = app
 
-    def Receiver(self, connection, handlers=None, routing_key="#"):
+    def Receiver(self, connection, handlers=None, routing_key="#",
+            node_id=None):
         return EventReceiver(connection,
                              handlers=handlers,
                              routing_key=routing_key,
+                             node_id=node_id,
                              app=self.app)
 
     def Dispatcher(self, connection=None, hostname=None, enabled=True,

+ 14 - 14
celery/events/snapshot.py

@@ -8,27 +8,30 @@ from celery.utils.timeutils import rate
 
 
 class Polaroid(object):
+    timer = timer2
     shutter_signal = Signal(providing_args=("state", ))
     cleanup_signal = Signal()
     clear_after = False
 
     _tref = None
+    _ctref = None
 
     def __init__(self, state, freq=1.0, maxrate=None,
-            cleanup_freq=3600.0, logger=None, app=None):
+            cleanup_freq=3600.0, logger=None, timer=None, app=None):
         self.app = app_or_default(app)
         self.state = state
         self.freq = freq
         self.cleanup_freq = cleanup_freq
+        self.timer = timer or self.timer
         self.logger = logger or \
                 self.app.log.get_default_logger(name="celery.cam")
         self.maxrate = maxrate and TokenBucket(rate(maxrate))
 
     def install(self):
-        self._tref = timer2.apply_interval(self.freq * 1000.0,
-                                           self.capture)
-        self._ctref = timer2.apply_interval(self.cleanup_freq * 1000.0,
-                                            self.cleanup)
+        self._tref = self.timer.apply_interval(self.freq * 1000.0,
+                                               self.capture)
+        self._ctref = self.timer.apply_interval(self.cleanup_freq * 1000.0,
+                                                self.cleanup)
 
     def on_shutter(self, state):
         pass
@@ -37,17 +40,13 @@ class Polaroid(object):
         pass
 
     def cleanup(self):
-        self.debug("Cleanup: Running...")
+        self.logger.debug("Cleanup: Running...")
         self.cleanup_signal.send(None)
         self.on_cleanup()
 
-    def debug(self, msg):
-        if self.logger:
-            self.logger.debug(msg)
-
     def shutter(self):
         if self.maxrate is None or self.maxrate.can_consume():
-            self.debug("Shutter: %s" % (self.state, ))
+            self.logger.debug("Shutter: %s" % (self.state, ))
             self.shutter_signal.send(self.state)
             self.on_shutter(self.state)
 
@@ -56,7 +55,7 @@ class Polaroid(object):
 
     def cancel(self):
         if self._tref:
-            self._tref()
+            self._tref()  # flush all received events.
             self._tref.cancel()
         if self._ctref:
             self._ctref.cancel()
@@ -70,7 +69,7 @@ class Polaroid(object):
 
 
 def evcam(camera, freq=1.0, maxrate=None, loglevel=0,
-        logfile=None, app=None):
+        logfile=None, timer=None, app=None):
     app = app_or_default(app)
     if not isinstance(loglevel, int):
         loglevel = LOG_LEVELS[loglevel.upper()]
@@ -82,7 +81,8 @@ def evcam(camera, freq=1.0, maxrate=None, loglevel=0,
             camera, freq))
     state = app.events.State()
     cam = instantiate(camera, state, app=app,
-                      freq=freq, maxrate=maxrate, logger=logger)
+                      freq=freq, maxrate=maxrate, logger=logger,
+                      timer=timer)
     cam.install()
     conn = app.broker_connection()
     recv = app.events.Receiver(conn, handlers={"*": state.event})

+ 1 - 1
celery/events/state.py

@@ -109,7 +109,7 @@ class Task(Element):
     def merge(self, state, timestamp, fields):
         keep = self.merge_rules.get(state)
         if keep is not None:
-            fields = dict((key, fields[key]) for key in keep)
+            fields = dict((key, fields.get(key)) for key in keep)
             super(Task, self).update(fields)
 
     def on_sent(self, timestamp=None, **fields):

+ 119 - 5
celery/tests/test_events.py

@@ -1,14 +1,19 @@
-from celery.tests.utils import unittest
+import socket
 
 from celery import events
+from celery.app import app_or_default
+from celery.tests.utils import unittest
 
 
 class MockProducer(object):
+    raise_on_publish = False
 
     def __init__(self, *args, **kwargs):
         self.sent = []
 
     def publish(self, msg, *args, **kwargs):
+        if self.raise_on_publish:
+            raise KeyError()
         self.sent.append(msg)
 
     def close(self):
@@ -31,17 +36,71 @@ class TestEvent(unittest.TestCase):
 
 class TestEventDispatcher(unittest.TestCase):
 
+    def setUp(self):
+        self.app = app_or_default()
+
     def test_send(self):
         producer = MockProducer()
-        eventer = events.EventDispatcher(object(), enabled=False)
+        eventer = self.app.events.Dispatcher(object(), enabled=False)
         eventer.publisher = producer
         eventer.enabled = True
         eventer.send("World War II", ended=True)
         self.assertTrue(producer.has_event("World War II"))
+        eventer.enabled = False
+        eventer.send("World War III")
+        self.assertFalse(producer.has_event("World War III"))
+
+        evs = ("Event 1", "Event 2", "Event 3")
+        eventer.enabled = True
+        eventer.publisher.raise_on_publish = True
+        eventer.buffer_while_offline = False
+        self.assertRaises(KeyError, eventer.send, "Event X")
+        eventer.buffer_while_offline = True
+        for ev in evs:
+            eventer.send(ev)
+        eventer.publisher.raise_on_publish = False
+        eventer.flush()
+        for ev in evs:
+            self.assertTrue(producer.has_event(ev))
+
+    def test_enabled_disable(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            dispatcher = self.app.events.Dispatcher(connection,
+                                                    enabled=True)
+            dispatcher2 = self.app.events.Dispatcher(connection,
+                                                     enabled=True,
+                                                      channel=channel)
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher.channel)
+            self.assertEqual(dispatcher.publisher.serializer,
+                            self.app.conf.CELERY_EVENT_SERIALIZER)
+
+
+            created_channel = dispatcher.publisher.channel
+            dispatcher.disable()
+            dispatcher.disable()  # Disable with no active publisher
+            dispatcher2.disable()
+            self.assertFalse(dispatcher.enabled)
+            self.assertIsNone(dispatcher.publisher)
+            self.assertTrue(created_channel.closed)
+            self.assertFalse(dispatcher2.channel.closed,
+                             "does not close manually provided channel")
+
+            dispatcher.enable()
+            self.assertTrue(dispatcher.enabled)
+            self.assertTrue(dispatcher.publisher)
+        finally:
+            channel.close()
+            connection.close()
 
 
 class TestEventReceiver(unittest.TestCase):
 
+    def setUp(self):
+        self.app = app_or_default()
+
     def test_process(self):
 
         message = {"type": "world-war"}
@@ -51,8 +110,10 @@ class TestEventReceiver(unittest.TestCase):
         def my_handler(event):
             got_event[0] = True
 
-        r = events.EventReceiver(object(), handlers={
-                                    "world-war": my_handler})
+        r = events.EventReceiver(object(),
+                                 handlers={"world-war": my_handler},
+                                 node_id="celery.tests",
+                                 )
         r._receive(message, object())
         self.assertTrue(got_event[0])
 
@@ -65,10 +126,63 @@ class TestEventReceiver(unittest.TestCase):
         def my_handler(event):
             got_event[0] = True
 
-        r = events.EventReceiver(object())
+        r = events.EventReceiver(object(), node_id="celery.tests")
         events.EventReceiver.handlers["*"] = my_handler
         try:
             r._receive(message, object())
             self.assertTrue(got_event[0])
         finally:
             events.EventReceiver.handlers = {}
+
+    def test_itercapture(self):
+        connection = self.app.broker_connection()
+        try:
+            r = self.app.events.Receiver(connection, node_id="celery.tests")
+            it = r.itercapture(timeout=0.0001)
+            consumer = it.next()
+            self.assertTrue(consumer.queues)
+            self.assertEqual(consumer.callbacks[0], r._receive)
+
+            self.assertRaises(socket.timeout, it.next)
+
+            self.assertRaises(socket.timeout,
+                              r.capture, timeout=0.00001)
+        finally:
+            connection.close()
+
+    def test_itercapture_limit(self):
+        connection = self.app.broker_connection()
+        channel = connection.channel()
+        try:
+            events_received = [0]
+            def handler(event):
+                events_received[0] += 1
+
+            producer = self.app.events.Dispatcher(connection,
+                                                  enabled=True,
+                                                  channel=channel)
+            r = self.app.events.Receiver(connection,
+                                         handlers={"*": handler},
+                                         node_id="celery.tests")
+            evs = ["ev1", "ev2", "ev3", "ev4", "ev5"]
+            for ev in evs:
+                producer.send(ev)
+            it = r.itercapture(limit=4, wakeup=True)
+            consumer = it.next()
+            list(it)
+            self.assertEqual(events_received[0], 4)
+        finally:
+            channel.close()
+            connection.close()
+
+
+
+class test_misc(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+
+
+    def test_State(self):
+        state = self.app.events.State()
+        self.assertDictEqual(dict(state.workers), {})

+ 122 - 0
celery/tests/test_events_snapshot.py

@@ -0,0 +1,122 @@
+from celery.app import app_or_default
+from celery.events import Events
+from celery.events.snapshot import Polaroid, evcam
+from celery.tests.utils import unittest
+
+
+class TRef(object):
+    active = True
+    called = False
+
+    def __call__(self):
+        self.called = True
+
+    def cancel(self):
+        self.active = False
+
+
+class MockTimer(object):
+    installed = []
+
+    def apply_interval(self, msecs, fun, *args, **kwargs):
+        self.installed.append(fun)
+        return TRef()
+timer = MockTimer()
+
+
+class test_Polaroid(unittest.TestCase):
+
+    def setUp(self):
+        self.app = app_or_default()
+        self.state = self.app.events.State()
+
+    def test_constructor(self):
+        x = Polaroid(self.state, app=self.app)
+        self.assertIs(x.app, self.app)
+        self.assertIs(x.state, self.state)
+        self.assertTrue(x.freq)
+        self.assertTrue(x.cleanup_freq)
+        self.assertTrue(x.logger)
+        self.assertFalse(x.maxrate)
+
+    def test_install_timers(self):
+        x = Polaroid(self.state, app=self.app)
+        x.timer = timer
+        x.__exit__()
+        x.__enter__()
+        self.assertIn(x.capture, MockTimer.installed)
+        self.assertIn(x.cleanup, MockTimer.installed)
+        self.assertTrue(x._tref.active)
+        self.assertTrue(x._ctref.active)
+        x.__exit__()
+        self.assertFalse(x._tref.active)
+        self.assertFalse(x._ctref.active)
+        self.assertTrue(x._tref.called)
+        self.assertFalse(x._ctref.called)
+
+    def test_cleanup(self):
+        x = Polaroid(self.state, app=self.app)
+        cleanup_signal_sent = [False]
+
+        def handler(**kwargs):
+            cleanup_signal_sent[0] = True
+
+        x.cleanup_signal.connect(handler)
+        x.cleanup()
+        self.assertTrue(cleanup_signal_sent[0])
+
+    def test_shutter__capture(self):
+        x = Polaroid(self.state, app=self.app)
+        shutter_signal_sent = [False]
+
+        def handler(**kwargs):
+            shutter_signal_sent[0] = True
+
+        x.shutter_signal.connect(handler)
+        x.shutter()
+        self.assertTrue(shutter_signal_sent[0])
+
+        shutter_signal_sent[0] = False
+        x.capture()
+        self.assertTrue(shutter_signal_sent[0])
+
+    def test_shutter_maxrate(self):
+        x = Polaroid(self.state, app=self.app, maxrate="1/h")
+        shutter_signal_sent = [0]
+
+        def handler(**kwargs):
+            shutter_signal_sent[0] += 1
+
+        x.shutter_signal.connect(handler)
+        for i in range(30):
+            x.shutter()
+            x.shutter()
+            x.shutter()
+        self.assertEqual(shutter_signal_sent[0], 1)
+
+
+
+class test_evcam(unittest.TestCase):
+
+    class MockReceiver(object):
+        raise_keyboard_interrupt = False
+
+        def capture(self, **kwargs):
+            if self.__class__.raise_keyboard_interrupt:
+                raise KeyboardInterrupt()
+
+    class MockEvents(Events):
+
+        def Receiver(self, *args, **kwargs):
+            return test_evcam.MockReceiver()
+
+
+    def setUp(self):
+        self.app = app_or_default()
+        self.app.events = self.MockEvents()
+
+    def test_evcam(self):
+        evcam(Polaroid, timer=timer)
+        evcam(Polaroid, timer=timer, loglevel="CRITICAL")
+        self.MockReceiver.raise_keyboard_interrupt = True
+        self.assertRaises(SystemExit, evcam, Polaroid, timer=timer)

+ 37 - 0
celery/tests/test_events_state.py

@@ -85,6 +85,9 @@ class test_Worker(unittest.TestCase):
         worker.on_heartbeat(timestamp=None)
         self.assertEqual(worker.heartbeats, [])
 
+    def test_repr(self):
+        self.assertTrue(repr(Worker(hostname="foo")))
+
 
 class test_Task(unittest.TestCase):
 
@@ -119,6 +122,26 @@ class test_Task(unittest.TestCase):
         task.on_succeeded(timestamp=time())
         self.assertTrue(task.ready)
 
+    def test_sent(self):
+        task = Task(uuid="abcdefg",
+                    name="tasks.add")
+        task.on_sent(timestamp=time())
+        self.assertEqual(task.state, states.PENDING)
+
+    def test_merge(self):
+        task = Task()
+        task.on_failed(timestamp=time())
+        task.on_started(timestamp=time())
+        task.on_received(timestamp=time(), name="tasks.add", args=(2, 2))
+        self.assertEqual(task.state, states.FAILURE)
+        self.assertEqual(task.name, "tasks.add")
+        self.assertTupleEqual(task.args, (2, 2))
+        task.on_retried(timestamp=time())
+        self.assertEqual(task.state, states.RETRY)
+
+    def test_repr(self):
+        self.assertTrue(repr(Task(uuid="xxx", name="tasks.add")))
+
 
 class test_State(unittest.TestCase):
 
@@ -221,6 +244,20 @@ class test_State(unittest.TestCase):
         s.freeze_while(work, clear_after=True)
         self.assertFalse(s.event_count)
 
+        s2 = State()
+        r = ev_snapshot(s2)
+        r.play()
+        s2.freeze_while(work, clear_after=False)
+        self.assertTrue(s2.event_count)
+
+    def test_clear_tasks(self):
+        s = State()
+        r = ev_snapshot(s)
+        r.play()
+        self.assertTrue(s.tasks)
+        s.clear_tasks(ready=False)
+        self.assertFalse(s.tasks)
+
     def test_clear(self):
         r = ev_snapshot(State())
         r.play()

+ 2 - 0
setup.cfg

@@ -23,6 +23,8 @@ cover3-exclude = celery
                  celery.contrib*
                  celery.concurrency.threads
                  celery.concurrency.processes.pool
+                 celery.concurrency.evg
+                 celery.concurrency.evlet
                  celery.backends.mongodb
                  celery.backends.tyrant
                  celery.backends.pyredis