Browse Source

100% coverage for celery.events.state

Ask Solem 12 years ago
parent
commit
5f23048a1d

+ 5 - 5
celery/events/state.py

@@ -20,7 +20,7 @@ from __future__ import absolute_import
 
 import threading
 
-from heapq import heappush
+from heapq import heappush, heappop
 from itertools import islice
 from operator import itemgetter
 from time import time
@@ -123,7 +123,7 @@ class Worker(Element):
     def update_heartbeat(self, received, timestamp):
         if not received or not timestamp:
             return
-        drift = received - timestamp
+        drift = abs(received - timestamp)
         if drift > HEARTBEAT_DRIFT_MAX:
             warn(DRIFT_WARNING, self.hostname, drift)
         heartbeats, hbmax = self.heartbeats, self.heartbeat_max
@@ -361,14 +361,14 @@ class State(object):
         worker, _ = self.get_or_create_worker(hostname)
         task, created = self.get_or_create_task(uuid)
         task.worker = worker
+        maxtasks = self.max_tasks_in_memory * 2
 
         taskheap = self._taskheap
         timestamp = fields.get('timestamp') or 0
         clock = 0 if type == 'sent' else fields.get('clock')
         heappush(taskheap, _lamportinfo(clock, timestamp, worker.id, task))
-        curcount = len(self.tasks)
-        if len(taskheap) > self.max_tasks_in_memory * 2:
-            taskheap[:] = taskheap[curcount:]
+        if len(taskheap) > maxtasks:
+            heappop(taskheap)
 
         handler = getattr(task, 'on_' + type, None)
         if type == 'received':

+ 2 - 2
celery/five.py

@@ -107,7 +107,7 @@ else:
     def nextfun(it):                # noqa
         return it.next
 
-    def exec_(code, globs=None, locs=None):
+    def exec_(code, globs=None, locs=None):  # pragma: no cover
         """Execute code in a namespace."""
         if globs is None:
             frame = sys._getframe(1)
@@ -296,7 +296,7 @@ class MagicModule(ModuleType):
             for item in self._all_by_module[module.__name__]:
                 setattr(self, item, getattr(module, item))
             return getattr(module, name)
-        elif name in self._direct:
+        elif name in self._direct:  # pragma: no cover
             module = __import__(self._direct[name], None, None, [name])
             setattr(self, name, module)
             return module

+ 41 - 1
celery/tests/events/test_state.py

@@ -4,7 +4,7 @@ import pickle
 
 from time import time
 from itertools import count
-from mock import Mock
+from mock import Mock, patch
 
 from celery import states
 from celery.events import Event
@@ -13,6 +13,7 @@ from celery.events.state import (
     Worker,
     Task,
     HEARTBEAT_EXPIRE_WINDOW,
+    HEARTBEAT_DRIFT_MAX,
     _lamportinfo
 )
 from celery.utils import uuid
@@ -116,6 +117,20 @@ class test_Worker(Case):
     def test_repr(self):
         self.assertTrue(repr(Worker(hostname='foo')))
 
+    def test_drift_warning(self):
+        worker = Worker(hostname='foo')
+        with patch('celery.events.state.warn') as warn:
+            worker.update_heartbeat(time(), time() + (HEARTBEAT_DRIFT_MAX * 2))
+            self.assertTrue(warn.called)
+            self.assertIn('Substantial drift', warn.call_args[0][0])
+
+    def test_update_heartbeat(self):
+        worker = Worker(hostname='foo')
+        worker.update_heartbeat(time(), time())
+        self.assertEqual(len(worker.heartbeats), 1)
+        worker.update_heartbeat(time() - 10, time())
+        self.assertEqual(len(worker.heartbeats), 1)
+
 
 class test_Task(Case):
 
@@ -129,6 +144,7 @@ class test_Task(Case):
                     eta=1,
                     runtime=0.0001,
                     expires=1,
+                    foo=None,
                     exception=1,
                     received=time() - 10,
                     started=time() - 8,
@@ -143,6 +159,7 @@ class test_Task(Case):
 
         self.assertEqual(sorted(['args', 'kwargs']),
                          sorted(task.info(['args', 'kwargs']).keys()))
+        self.assertFalse(list(task.info('foo')))
 
     def test_ready(self):
         task = Task(uuid='abcdefg',
@@ -352,6 +369,29 @@ class test_State(Case):
                                                 'uuid': 'x',
                                                 'hostname': 'y'})
 
+    def test_limits_maxtasks(self):
+        s = State()
+        s.max_tasks_in_memory = 1
+        s.task_event('task-unknown-event-xxx', {'foo': 'bar',
+                                                'uuid': 'x',
+                                                'hostname': 'y',
+                                                'clock': 3})
+        s.task_event('task-unknown-event-xxx', {'foo': 'bar',
+                                                'uuid': 'y',
+                                                'hostname': 'y',
+                                                'clock': 4})
+
+        s.task_event('task-unknown-event-xxx', {'foo': 'bar',
+                                                'uuid': 'z',
+                                                'hostname': 'y',
+                                                'clock': 5})
+        self.assertEqual(len(s._taskheap), 2)
+        self.assertEqual(s._taskheap[0].clock, 4)
+        self.assertEqual(s._taskheap[1].clock, 5)
+
+        s._taskheap.append(s._taskheap[0])
+        self.assertTrue(list(s.tasks_by_time()))
+
     def test_callback(self):
         scratch = {}
 

+ 7 - 0
celery/tests/utils/test_timer2.py

@@ -103,6 +103,13 @@ class test_Timer(Case):
         t.exit_after(300, priority=10)
         t.apply_after.assert_called_with(300, sys.exit, 10)
 
+    def test_ensure_started_not_started(self):
+        t = timer2.Timer()
+        t.running = True
+        t.start = Mock()
+        t.ensure_started()
+        self.assertFalse(t.start.called)
+
     def test_apply_interval(self):
         t = timer2.Timer()
         try:

+ 6 - 5
celery/tests/worker/test_loops.py

@@ -70,12 +70,13 @@ class X(object):
             raise socket.timeout()
         mock.side_effect = first
 
-    def close_then_error(self, mock):
+    def close_then_error(self, mock, mod=0):
 
         def first(*args, **kwargs):
-            self.close()
-            self.connection.more_to_read = False
-            raise socket.error()
+            if not mod or mock.call_count > mod:
+                self.close()
+                self.connection.more_to_read = False
+                raise socket.error()
         mock.side_effect = first
 
     def close(self, *args, **kwargs):
@@ -210,7 +211,7 @@ class test_asynloop(AppCase):
     def test_poll_readable(self):
         x = X()
         x.hub.readers = {6: Mock()}
-        x.close_then_error(x.connection.drain_nowait)
+        x.close_then_error(x.connection.drain_nowait, mod=4)
         x.hub.poller.poll.return_value = [(6, READ)]
         with self.assertRaises(socket.error):
             asynloop(*x.args)