Parcourir la source

Tests passing again

Ask Solem il y a 14 ans
Parent
commit
0f1c3b6019

+ 1 - 1
celery/concurrency/processes/__init__.py

@@ -54,7 +54,7 @@ class TaskPool(BasePool):
         return self._pool.shrink(n)
 
     def _get_info(self):
-        return {"max-concurrency": self.processes,
+        return {"max-concurrency": self.limit,
                 "processes": [p.pid for p in self._pool._pool],
                 "max-tasks-per-child": self._pool._maxtasksperchild,
                 "put-guarded-by-semaphore": self.putlocks,

+ 2 - 2
celery/tests/test_bin/test_celeryd.py

@@ -221,10 +221,10 @@ class test_signal_handlers(unittest.TestCase):
         terminated = False
         logger = get_logger()
 
-        def stop(self):
+        def stop(self, in_sighandler=False):
             self.stopped = True
 
-        def terminate(self):
+        def terminate(self, in_sighandler=False):
             self.terminated = True
 
     def psig(self, fun, *args, **kwargs):

+ 3 - 3
celery/tests/test_concurrency_processes.py

@@ -4,6 +4,7 @@ import unittest2 as unittest
 from itertools import cycle
 
 from celery.concurrency import processes as mp
+from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.utils import noop
 
@@ -165,13 +166,12 @@ class test_TaskPool(unittest.TestCase):
 
     def test_info(self):
         pool = TaskPool(10)
-        procs = [Object(pid=i) for i in range(pool.processes)]
+        procs = [Object(pid=i) for i in range(pool.limit)]
         pool._pool = Object(_pool=procs,
                             _maxtasksperchild=None,
                             timeout=10,
                             soft_timeout=5)
         info = pool.info
-        self.assertEqual(info["max-concurrency"], pool.processes)
-        self.assertEqual(len(info["processes"]), pool.processes)
+        self.assertEqual(info["max-concurrency"], pool.limit)
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertEqual(info["timeouts"], (5, 10))

+ 1 - 1
celery/tests/test_pool.py

@@ -27,7 +27,7 @@ class TestTaskPool(unittest.TestCase):
 
     def test_attrs(self):
         p = TaskPool(2)
-        self.assertEqual(p.processes, 2)
+        self.assertEqual(p.limit, 2)
         self.assertIsInstance(p.logger, logging.Logger)
         self.assertIsNone(p._pool)
 

+ 15 - 9
celery/tests/test_worker.py

@@ -9,6 +9,7 @@ from kombu.connection import BrokerConnection
 from celery.utils.timer2 import Timer
 
 from celery.app import app_or_default
+from celery.concurrency.base import BasePool
 from celery.decorators import task as task_dec
 from celery.decorators import periodic_task as periodic_task_dec
 from celery.serialization import pickle
@@ -117,7 +118,7 @@ class MockBackend(object):
         self._acked = True
 
 
-class MockPool(object):
+class MockPool(BasePool):
     _terminated = False
     _stopped = False
 
@@ -436,13 +437,16 @@ class test_Consumer(unittest.TestCase):
         l.broadcast_consumer = MockConsumer()
         l.qos = _QoS()
         l.connection = BrokerConnection()
+        l.iterations = 0
 
         def raises_KeyError(limit=None):
-            yield True
-            l.iterations = 1
-            raise KeyError("foo")
+            l.iterations += 1
+            if l.qos.prev != l.qos.next:
+                l.qos.update()
+            if l.iterations >= 2:
+                raise KeyError("foo")
 
-        l._mainloop = raises_KeyError
+        l.consume_messages = raises_KeyError
         self.assertRaises(KeyError, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
@@ -456,11 +460,10 @@ class test_Consumer(unittest.TestCase):
         l.connection = BrokerConnection()
 
         def raises_socket_error(limit=None):
-            yield True
             l.iterations = 1
             raise socket.error("foo")
 
-        l._mainloop = raises_socket_error
+        l.consume_messages = raises_socket_error
         self.assertRaises(socket.error, l.start)
         self.assertTrue(called_back[0])
         self.assertEqual(l.iterations, 1)
@@ -509,8 +512,11 @@ class test_WorkController(unittest.TestCase):
         m = create_message(backend, task=foo_task.name, args=[4, 8, 10],
                            kwargs={})
         task = TaskRequest.from_message(m, m.decode())
-        worker.process_task(task)
-        worker.pool.stop()
+        worker.components = []
+        worker._state = worker.RUN
+        self.assertRaises(KeyboardInterrupt, worker.process_task, task)
+        self.assertEqual(worker._state, worker.TERMINATE)
+
 
     def test_process_task_raise_regular(self):
         worker = self.worker

+ 2 - 1
celery/tests/test_worker_job.py

@@ -10,6 +10,7 @@ from kombu.transport.base import Message
 
 from celery import states
 from celery.app import app_or_default
+from celery.concurrency.base import BasePool
 from celery.datastructures import ExceptionInfo
 from celery.decorators import task as task_dec
 from celery.exceptions import RetryTaskError, NotRegistered
@@ -408,7 +409,7 @@ class test_TaskRequest(unittest.TestCase):
         tid = gen_unique_id()
         tw = TaskRequest(mytask.name, tid, [4], {"f": "x"})
 
-        class MockPool(object):
+        class MockPool(BasePool):
             target = None
             args = None
             kwargs = None

+ 7 - 4
celery/worker/__init__.py

@@ -52,6 +52,9 @@ def process_initializer(app, hostname):
 
 class WorkController(object):
     """Unmanaged worker instance."""
+    RUN = RUN
+    CLOSE = CLOSE
+    TERMINATE = TERMINATE
 
     #: The number of simultaneous processes doing work (default:
     #: :setting:`CELERYD_CONCURRENCY`)
@@ -232,7 +235,7 @@ class WorkController(object):
 
     def start(self):
         """Starts the workers main loop."""
-        self._state = RUN
+        self._state = self.RUN
 
         try:
             for i, component in enumerate(self.components):
@@ -279,11 +282,11 @@ class WorkController(object):
     def _shutdown(self, warm=True):
         what = (warm and "stopping" or "terminating").capitalize()
 
-        if self._state != RUN or self._running != len(self.components):
+        if self._state != self.RUN or self._running != len(self.components):
             # Not fully started, can safely exit.
             return
 
-        self._state = CLOSE
+        self._state = self.CLOSE
         signals.worker_shutdown.send(sender=self)
 
         for component in reversed(self.components):
@@ -295,7 +298,7 @@ class WorkController(object):
             stop()
 
         self.consumer.close_connection()
-        self._state = TERMINATE
+        self._state = self.TERMINATE
 
     def on_timer_error(self, exc_info):
         _, exc, _ = exc_info