Kaynağa Gözat

Coverage for parts of celery.concurrency.processes

Ask Solem 12 yıl önce
ebeveyn
işleme
e21680855c

+ 0 - 15
celery/concurrency/base.py

@@ -119,9 +119,6 @@ class BasePool(object):
     def on_close(self):
     def on_close(self):
         pass
         pass
 
 
-    def init_callbacks(self, **kwargs):
-        pass
-
     def apply_async(self, target, args=[], kwargs={}, **options):
     def apply_async(self, target, args=[], kwargs={}, **options):
         """Equivalent of the :func:`apply` built-in function.
         """Equivalent of the :func:`apply` built-in function.
 
 
@@ -152,15 +149,3 @@ class BasePool(object):
     @property
     @property
     def num_processes(self):
     def num_processes(self):
         return self.limit
         return self.limit
-
-    @property
-    def readers(self):
-        return {}
-
-    @property
-    def writers(self):
-        return {}
-
-    @property
-    def timers(self):
-        return {}

+ 4 - 5
celery/concurrency/processes.py

@@ -84,7 +84,7 @@ def process_initializer(app, hostname):
     # run once per process.
     # run once per process.
     app.loader.init_worker()
     app.loader.init_worker()
     app.loader.init_worker_process()
     app.loader.init_worker_process()
-    app.log.setup(int(os.environ.get('CELERY_LOG_LEVEL', 0)),
+    app.log.setup(int(os.environ.get('CELERY_LOG_LEVEL', 0) or 0),
                   os.environ.get('CELERY_LOG_FILE') or None,
                   os.environ.get('CELERY_LOG_FILE') or None,
                   bool(os.environ.get('CELERY_LOG_REDIRECT', False)),
                   bool(os.environ.get('CELERY_LOG_REDIRECT', False)),
                   str(os.environ.get('CELERY_LOG_REDIRECT_LEVEL')))
                   str(os.environ.get('CELERY_LOG_REDIRECT_LEVEL')))
@@ -102,15 +102,14 @@ def process_initializer(app, hostname):
     signals.worker_process_init.send(sender=None)
     signals.worker_process_init.send(sender=None)
 
 
 
 
-def _select(self, readers=None, writers=None, err=None, timeout=0):
+def _select(readers=None, writers=None, err=None, timeout=0):
     readers = set() if readers is None else readers
     readers = set() if readers is None else readers
     writers = set() if writers is None else writers
     writers = set() if writers is None else writers
     err = set() if err is None else err
     err = set() if err is None else err
     try:
     try:
         r, w, e = select.select(readers, writers, err, timeout)
         r, w, e = select.select(readers, writers, err, timeout)
         if e:
         if e:
-            seen = set()
-            r = r | set(f for f in r + e if f not in seen and not seen.add(f))
+            r = list(set(r) | set(e))
         return r, w, 0
         return r, w, 0
     except (select.error, socket.error) as exc:
     except (select.error, socket.error) as exc:
         if get_errno(exc) == errno.EINTR:
         if get_errno(exc) == errno.EINTR:
@@ -442,7 +441,7 @@ class AsynPool(_pool.Pool):
             if not readable:
             if not readable:
                 break
                 break
             for fd in readable:
             for fd in readable:
-                fileno_to_proc[fd]._reader.recv()
+                fileno_to_proc[fd].inq._reader.recv()
             sleep(0)
             sleep(0)
 
 
 
 

+ 29 - 0
celery/tests/concurrency/test_concurrency.py

@@ -3,6 +3,7 @@ from __future__ import absolute_import
 import os
 import os
 
 
 from itertools import count
 from itertools import count
+from mock import Mock
 
 
 from celery.concurrency.base import apply_target, BasePool
 from celery.concurrency.base import apply_target, BasePool
 from celery.tests.utils import Case
 from celery.tests.utils import Case
@@ -85,3 +86,31 @@ class test_BasePool(Case):
     def test_interface_terminate_job(self):
     def test_interface_terminate_job(self):
         with self.assertRaises(NotImplementedError):
         with self.assertRaises(NotImplementedError):
             BasePool(10).terminate_job(101)
             BasePool(10).terminate_job(101)
+
+    def test_interface_did_start_ok(self):
+        self.assertTrue(BasePool(10).did_start_ok())
+
+    def test_interface_on_poll_init(self):
+        self.assertIsNone(BasePool(10).on_poll_init(Mock(), Mock()))
+
+    def test_interface_on_poll_start(self):
+        self.assertIsNone(BasePool(10).on_poll_start(Mock()))
+
+    def test_interface_on_soft_timeout(self):
+        self.assertIsNone(BasePool(10).on_soft_timeout(Mock()))
+
+    def test_interface_on_hard_timeout(self):
+        self.assertIsNone(BasePool(10).on_hard_timeout(Mock()))
+
+    def test_interface_maybe_handle_result(self):
+        self.assertIsNone(BasePool(10).maybe_handle_result(1, 2))
+
+    def test_interface_close(self):
+        p = BasePool(10)
+        p.on_close = Mock()
+        p.close()
+        self.assertEqual(p._state, p.CLOSE)
+        p.on_close.assert_called_with()
+
+    def test_interface_no_close(self):
+        self.assertIsNone(BasePool(10).on_close())

+ 148 - 4
celery/tests/concurrency/test_processes.py

@@ -1,15 +1,17 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+import errno
+import socket
 import time
 import time
 
 
 from itertools import cycle
 from itertools import cycle
 
 
-from mock import Mock
+from mock import Mock, call, patch
 from nose import SkipTest
 from nose import SkipTest
 
 
 from celery.five import items, range
 from celery.five import items, range
 from celery.utils.functional import noop
 from celery.utils.functional import noop
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
 try:
 try:
     from celery.concurrency import processes as mp
     from celery.concurrency import processes as mp
 except ImportError:
 except ImportError:
@@ -117,14 +119,155 @@ class ExeMockTaskPool(mp.TaskPool):
     Pool = BlockingPool = ExeMockPool
     Pool = BlockingPool = ExeMockPool
 
 
 
 
-class test_TaskPool(Case):
+class PoolCase(AppCase):
 
 
-    def setUp(self):
+    def setup(self):
         try:
         try:
             import multiprocessing  # noqa
             import multiprocessing  # noqa
         except ImportError:
         except ImportError:
             raise SkipTest('multiprocessing not supported')
             raise SkipTest('multiprocessing not supported')
 
 
+
+class test_AsynPool(PoolCase):
+
+    def test_gen_not_started(self):
+
+        def gen():
+            yield 1
+            yield 2
+        g = gen()
+        self.assertTrue(mp.gen_not_started(g))
+        next(g)
+        self.assertFalse(mp.gen_not_started(g))
+        list(g)
+        self.assertFalse(mp.gen_not_started(g))
+
+    def test_select(self):
+        ebadf = socket.error()
+        ebadf.errno = errno.EBADF
+        with patch('select.select') as select:
+            select.return_value = ([3], [], [])
+            self.assertEqual(
+                mp._select(set([3])),
+                ([3], [], 0),
+            )
+
+            select.return_value = ([], [], [3])
+            self.assertEqual(
+                mp._select(set([3]), None, set([3])),
+                ([3], [], 0),
+            )
+
+            eintr = socket.error()
+            eintr.errno = errno.EINTR
+            select.side_effect = eintr
+
+            readers = set([3])
+            self.assertEqual(mp._select(readers), ([], [], 1))
+            self.assertIn(3, readers)
+
+        with patch('select.select') as select:
+            select.side_effect = ebadf
+            readers = set([3])
+            self.assertEqual(mp._select(readers), ([], [], 1))
+            select.assert_has_calls([call([3], [], [], 0)])
+            self.assertNotIn(3, readers)
+
+        with patch('select.select') as select:
+            select.side_effect = MemoryError()
+            with self.assertRaises(MemoryError):
+                mp._select(set([1]))
+
+        with patch('select.select') as select:
+
+            def se(*args):
+                select.side_effect = MemoryError()
+                raise ebadf
+            select.side_effect = se
+            with self.assertRaises(MemoryError):
+                mp._select(set([3]))
+
+        with patch('select.select') as select:
+
+            def se(*args):
+                select.side_effect = socket.error()
+                select.side_effect.errno = 1321
+                raise ebadf
+            select.side_effect = se
+            with self.assertRaises(socket.error):
+                mp._select(set([3]))
+
+        with patch('select.select') as select:
+
+            select.side_effect = socket.error()
+            select.side_effect.errno = 34134
+            with self.assertRaises(socket.error):
+                mp._select(set([3]))
+
+    def test_promise(self):
+        fun = Mock()
+        x = mp.promise(fun, 1, foo=1)
+        x()
+        self.assertTrue(x.ready)
+        fun.assert_called_with(1, foo=1)
+
+    def test_Worker(self):
+        w = mp.Worker(Mock(), Mock())
+        w.on_loop_start(1234)
+        w.outq.put.assert_called_with((mp.WORKER_UP, (1234, )))
+
+
+class test_ResultHandler(PoolCase):
+
+    def test_process_result(self):
+        x = mp.ResultHandler(
+            Mock(), Mock(), {}, Mock(),
+            Mock(), Mock(), Mock(), Mock(),
+            fileno_to_outq={},
+            on_process_alive=Mock(),
+        )
+        self.assertTrue(x)
+        x.on_state_change = Mock()
+        proc = x.fileno_to_outq[3] = Mock()
+        reader = proc.outq._reader
+        reader.poll.return_value = False
+        x.handle_event(6)  # KeyError
+        x.handle_event(3)
+        reader.poll.assert_called_with(0)
+        self.assertFalse(x.on_state_change.called)
+
+        reader.poll.reset()
+        reader.poll.return_value = True
+        task = reader.recv.return_value = (1, (2, 3))
+        x.handle_event(3)
+        reader.poll.assert_called_with(0)
+        reader.recv.assert_called_with()
+        x.on_state_change.assert_called_with(task)
+        self.assertTrue(x._it)
+
+        reader.recv.return_value = None
+        x.handle_event(3)
+        self.assertIsNone(x._it)
+
+        x._state = mp.TERMINATE
+        it = x._process_result()
+        next(it)
+        with self.assertRaises(mp.CoroStop):
+            it.send(3)
+        x.handle_event(3)
+        self.assertIsNone(x._it)
+        x._state == mp.RUN
+
+        reader.recv.side_effect = EOFError()
+        it = x._process_result()
+        next(it)
+        with self.assertRaises(mp.CoroStop):
+            it.send(3)
+        reader.recv.side_effect = None
+
+
+class test_TaskPool(PoolCase):
+
     def test_start(self):
     def test_start(self):
         pool = TaskPool(10)
         pool = TaskPool(10)
         pool.start()
         pool.start()
@@ -187,3 +330,4 @@ class test_TaskPool(Case):
         tp.restart()
         tp.restart()
         time.sleep(0.5)
         time.sleep(0.5)
         self.assertEqual(pids, get_pids(tp))
         self.assertEqual(pids, get_pids(tp))
+

+ 9 - 0
celery/tests/worker/test_worker.py

@@ -1,5 +1,6 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
+import os
 import socket
 import socket
 
 
 from collections import deque
 from collections import deque
@@ -869,6 +870,14 @@ class test_WorkController(AppCase):
             'celeryd', hostname='awesome.worker.com',
             'celeryd', hostname='awesome.worker.com',
         )
         )
 
 
+        with patch('celery.task.trace.setup_worker_optimizations') as swo:
+            os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
+            try:
+                process_initializer(app, 'luke.worker.com')
+                swo.assert_called_with(app)
+            finally:
+                os.environ.pop('FORKED_BY_MULTIPROCESSING', None)
+
     def test_attrs(self):
     def test_attrs(self):
         worker = self.worker
         worker = self.worker
         self.assertIsInstance(worker.timer, Timer)
         self.assertIsInstance(worker.timer, Timer)