Browse Source

Coverage up to 85%

Ask Solem 16 years ago
parent
commit
4ca63042c9
2 changed files with 129 additions and 9 deletions
  1. 121 0
      celery/tests/test_worker.py
  2. 8 9
      celery/worker/__init__.py

+ 121 - 0
celery/tests/test_worker.py

@@ -16,6 +16,21 @@ def foo_task(x, y, z, **kwargs):
 registry.tasks.register(foo_task, name="c.u.foo")
 
 
+class MockLogger(object):
+
+    def critical(self, *args, **kwargs):
+        pass
+
+    def info(self, *args, **kwargs):
+        pass
+
+    def error(self, *args, **kwargs):
+        pass
+    
+    def debug(self, *args, **kwargs):
+        pass
+
+
 class MockBackend(object):
     _acked = False
 
@@ -23,6 +38,39 @@ class MockBackend(object):
         self._acked = True
 
 
+class MockPool(object):
+    
+    def __init__(self, *args, **kwargs):
+        self.raise_regular = kwargs.get("raise_regular", False)
+        self.raise_base = kwargs.get("raise_base", False)
+
+    def apply_async(self, *args, **kwargs):
+        if self.raise_regular:
+            raise KeyError("some exception")
+        if self.raise_base:
+            raise KeyboardInterrupt("Ctrl+c")
+
+    def start(self):
+        pass
+
+    def stop(self):
+        pass
+        return True
+
+
+class MockController(object):
+
+    def __init__(self, w, *args, **kwargs):
+        self._w = w
+        self._stopped = False
+
+    def start(self):
+        self._w["started"] = True
+        self._stopped = False
+
+    def stop(self):
+        self._stopped = True
+
 def create_message(backend, **data):
     data["id"] = gen_unique_id()
     return BaseMessage(backend, body=pickle.dumps(dict(**data)),
@@ -73,6 +121,16 @@ class TestAMQPListener(unittest.TestCase):
         self.assertEquals(in_bucket.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.hold_queue.get_nowait)
     
+    def test_receieve_message_not_registered(self):
+        l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
+        backend = MockBackend()
+        m = create_message(backend, task="x.X.31x", args=[2, 4, 8], kwargs={})
+
+        self.assertFalse(l.receive_message(m.decode(), m))
+        self.assertRaises(Empty, self.bucket_queue.get_nowait)
+        self.assertRaises(Empty, self.hold_queue.get_nowait)
+
+    
     def test_receieve_message_eta(self):
         l = AMQPListener(self.bucket_queue, self.hold_queue, self.logger)
         backend = MockBackend()
@@ -89,3 +147,66 @@ class TestAMQPListener(unittest.TestCase):
         self.assertEquals(task.task_name, "c.u.foo")
         self.assertEquals(task.execute(), 2 * 4 * 8)
         self.assertRaises(Empty, self.bucket_queue.get_nowait)
+
+
+class TestWorkController(unittest.TestCase):
+
+    def setUp(self):
+        self.worker = WorkController(concurrency=1, loglevel=0,
+                                     is_detached=False)
+        self.worker.logger = MockLogger()
+
+    def test_attrs(self):
+        worker = self.worker
+        self.assertTrue(isinstance(worker.bucket_queue, Queue))
+        self.assertTrue(isinstance(worker.hold_queue, Queue))
+        self.assertTrue(worker.periodic_work_controller)
+        self.assertTrue(worker.pool)
+        self.assertTrue(worker.amqp_listener)
+        self.assertTrue(worker.mediator)
+        self.assertTrue(worker.components)
+
+    def test_safe_process_task(self):
+        worker = self.worker
+        worker.pool = MockPool()
+        backend = MockBackend()
+        m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
+                           kwargs={})
+        task = TaskWrapper.from_message(m, m.decode())
+        worker.safe_process_task(task)
+        worker.pool.stop()
+    
+    def test_safe_process_task_raise_base(self):
+        worker = self.worker
+        worker.pool = MockPool(raise_base=True)
+        backend = MockBackend()
+        m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
+                           kwargs={})
+        task = TaskWrapper.from_message(m, m.decode())
+        worker.safe_process_task(task)
+        worker.pool.stop()
+
+    def test_safe_process_task_raise_regular(self):
+        worker = self.worker
+        worker.pool = MockPool(raise_regular=True)
+        backend = MockBackend()
+        m = create_message(backend, task="c.u.foo", args=[4, 8, 10],
+                           kwargs={})
+        task = TaskWrapper.from_message(m, m.decode())
+        worker.safe_process_task(task)
+        worker.pool.stop()
+
+    def test_start_stop(self):
+        worker = self.worker
+        w1 = {"started": False}
+        w2 = {"started": False}
+        w3 = {"started": False}
+        w4 = {"started": False}
+        worker.components = [MockController(w1), MockController(w2),
+                             MockController(w3), MockController(w4)]
+
+        worker.start()
+        for w in (w1, w2, w3, w4):
+            self.assertTrue(w["started"])
+        for component in worker.components:
+            self.assertTrue(component._stopped)

+ 8 - 9
celery/worker/__init__.py

@@ -67,8 +67,13 @@ class AMQPListener(object):
         otherwise we move it the bucket queue for immediate processing.
 
         """
-        task = TaskWrapper.from_message(message, message_data,
-                                        logger=self.logger)
+        try:
+            task = TaskWrapper.from_message(message, message_data,
+                                            logger=self.logger)
+        except NotRegistered, exc:
+            self.logger.info("Unknown task ignored: %s" % (exc))
+            return
+                
         eta = message_data.get("eta")
         if eta:
             self.hold_queue.put((task, eta))
@@ -208,14 +213,8 @@ class WorkController(object):
         try:
             try:
                 self.process_task(task)
-            except ValueError:
-                # execute_next_task didn't return a r/name/id tuple,
-                # probably because it got an exception.
-                pass
-            except NotRegistered, exc:
-                self.logger.info("Unknown task ignored: %s" % (exc))
             except Exception, exc:
-                self.logger.critical("Message queue raised %s: %s\n%s" % (
+                self.logger.critical("Internal error %s: %s\n%s" % (
                                 exc.__class__, exc, traceback.format_exc()))
         except (SystemExit, KeyboardInterrupt):
             self.stop()