Browse Source

Prefork: Use poll() to avoid limitations of select() (Issue #2373)

Ask Solem 10 years ago
parent
commit
7245458cac
2 changed files with 95 additions and 55 deletions
  1. 32 9
      celery/concurrency/asynpool.py
  2. 63 46
      celery/tests/concurrency/test_prefork.py

+ 32 - 9
celery/concurrency/asynpool.py

@@ -28,6 +28,7 @@ import time
 
 from collections import deque, namedtuple
 from io import BytesIO
+from numbers import Integral
 from pickle import HIGHEST_PROTOCOL
 from time import sleep
 from weakref import WeakValueDictionary, ref
@@ -109,8 +110,11 @@ def _get_job_writer(job):
         return writer()  # is a weakref
 
 
-def _select(readers=None, writers=None, err=None, timeout=0):
-    """Simple wrapper to :class:`~select.select`.
+def _select(readers=None, writers=None, err=None, timeout=0,
+            poll=select.poll, POLLIN=select.POLLIN,
+            POLLOUT=select.POLLOUT, POLLERR=select.POLLERR):
+    """Simple wrapper to :class:`~select.select`, using :`~select.poll`
+    as the implementation.
 
     :param readers: Set of reader fds to test if readable.
     :param writers: Set of writer fds to test if writable.
@@ -131,25 +135,44 @@ def _select(readers=None, writers=None, err=None, timeout=0):
     readers = set() if readers is None else readers
     writers = set() if writers is None else writers
     err = set() if err is None else err
+    poller = poll()
+    register = poller.register
+
+    if readers:
+        [register(fd, POLLIN) for fd in readers]
+    if writers:
+        [register(fd, POLLOUT) for fd in writers]
+    if err:
+        [register(fd, POLLERR) for fd in err]
+
+    R, W = set(), set()
+    timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3)
     try:
-        r, w, e = select.select(readers, writers, err, timeout)
-        if e:
-            r = list(set(r) | set(e))
-        return r, w, 0
+        events = poller.poll(timeout)
+        for fd, event in events:
+            if not isinstance(fd, Integral):
+                fd = fd.fileno()
+            if event & POLLIN:
+                R.add(fd)
+            if event & POLLOUT:
+                W.add(fd)
+            if event & POLLERR:
+                R.add(fd)
+        return R, W, 0
     except (select.error, socket.error) as exc:
         if exc.errno == errno.EINTR:
-            return [], [], 1
+            return set(), set(), 1
         elif exc.errno in SELECT_BAD_FD:
             for fd in readers | writers | err:
                 try:
                     select.select([fd], [], [], 0)
                 except (select.error, socket.error) as exc:
-                    if exc.errno not in SELECT_BAD_FD:
+                    if getattr(exc, 'errno', None) not in SELECT_BAD_FD:
                         raise
                     readers.discard(fd)
                     writers.discard(fd)
                     err.discard(fd)
-            return [], [], 1
+            return set(), set(), 1
         else:
             raise
 

+ 63 - 46
celery/tests/concurrency/test_prefork.py

@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 import errno
+import select
 import socket
 import time
 
@@ -8,7 +9,7 @@ from itertools import cycle
 
 from celery.five import items, range
 from celery.utils.functional import noop
-from celery.tests.case import AppCase, Mock, SkipTest, call, patch
+from celery.tests.case import AppCase, Mock, SkipTest, patch
 try:
     from celery.concurrency import prefork as mp
     from celery.concurrency import asynpool
@@ -147,67 +148,83 @@ class test_AsynPool(PoolCase):
         list(g)
         self.assertFalse(asynpool.gen_not_started(g))
 
-    def test_select(self):
+    @patch('select.select', create=True)
+    def test_select(self, __select):
         ebadf = socket.error()
         ebadf.errno = errno.EBADF
-        with patch('select.select') as select:
-            select.return_value = ([3], [], [])
+        with patch('select.poll', create=True) as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+            poll.poll.return_value = [(3, select.POLLIN)]
             self.assertEqual(
-                asynpool._select({3}),
-                ([3], [], 0),
+                asynpool._select({3}, poll=poller),
+                ({3}, set(), 0),
             )
 
-            select.return_value = ([], [], [3])
+            poll.poll.return_value = [(3, select.POLLERR)]
             self.assertEqual(
-                asynpool._select({3}, None, {3}),
-                ([3], [], 0),
+                asynpool._select({3}, None, {3}, poll=poller),
+                ({3}, set(), 0),
             )
 
             eintr = socket.error()
             eintr.errno = errno.EINTR
-            select.side_effect = eintr
+            poll.poll.side_effect = eintr
 
             readers = {3}
-            self.assertEqual(asynpool._select(readers), ([], [], 1))
+            self.assertEqual(
+                asynpool._select(readers, poll=poller),
+                (set(), set(), 1),
+            )
             self.assertIn(3, readers)
 
-        with patch('select.select') as select:
-            select.side_effect = ebadf
-            readers = {3}
-            self.assertEqual(asynpool._select(readers), ([], [], 1))
-            select.assert_has_calls([call([3], [], [], 0)])
-            self.assertNotIn(3, readers)
-
-        with patch('select.select') as select:
-            select.side_effect = MemoryError()
-            with self.assertRaises(MemoryError):
-                asynpool._select({1})
-
-        with patch('select.select') as select:
-
-            def se(*args):
-                select.side_effect = MemoryError()
-                raise ebadf
-            select.side_effect = se
+        with patch('select.poll') as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+            poll.poll.side_effect = ebadf
+            with patch('select.select') as selcheck:
+                selcheck.side_effect = ebadf
+                readers = {3}
+                self.assertEqual(
+                    asynpool._select(readers, poll=poller),
+                    (set(), set(), 1),
+                )
+                self.assertNotIn(3, readers)
+
+        with patch('select.poll') as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+            poll.poll.side_effect = MemoryError()
             with self.assertRaises(MemoryError):
-                asynpool._select({3})
-
-        with patch('select.select') as select:
-
-            def se2(*args):
-                select.side_effect = socket.error()
-                select.side_effect.errno = 1321
-                raise ebadf
-            select.side_effect = se2
-            with self.assertRaises(socket.error):
-                asynpool._select({3})
-
-        with patch('select.select') as select:
-
-            select.side_effect = socket.error()
-            select.side_effect.errno = 34134
+                asynpool._select({1}, poll=poller)
+
+        with patch('select.poll') as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+            with patch('select.select') as selcheck:
+
+                def se(*args):
+                    selcheck.side_effect = MemoryError()
+                    raise ebadf
+                poll.poll.side_effect = se
+                with self.assertRaises(MemoryError):
+                    asynpool._select({3}, poll=poller)
+
+        with patch('select.poll') as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+            with patch('select.select') as selcheck:
+
+                def se2(*args):
+                    selcheck.side_effect = socket.error()
+                    selcheck.side_effect.errno = 1321
+                    raise ebadf
+                poll.poll.side_effect = se2
+                with self.assertRaises(socket.error):
+                    asynpool._select({3}, poll=poller)
+
+        with patch('select.poll') as poller:
+            poll = poller.return_value = Mock(name='poll.poll')
+
+            poll.poll.side_effect = socket.error()
+            poll.poll.side_effect.errno = 34134
             with self.assertRaises(socket.error):
-                asynpool._select({3})
+                asynpool._select({3}, poll=poller)
 
     def test_promise(self):
         fun = Mock()