Prechádzať zdrojové kódy

Process pool: Non-blocking write and using multiple AF_UNIX sockets (one per process)

Ask Solem 12 rokov pred
rodič
commit
2251f142f2

+ 3 - 0
celery/concurrency/base.py

@@ -62,6 +62,9 @@ class BasePool(object):
     def did_start_ok(self):
         return True
 
+    def flush(self):
+        pass
+
     def on_stop(self):
         pass
 

+ 373 - 28
celery/concurrency/processes.py

@@ -11,10 +11,25 @@
 """
 from __future__ import absolute_import
 
+import errno
 import os
+import select
+import socket
+import struct
+
+from collections import deque
+from pickle import HIGHEST_PROTOCOL
+from time import sleep, time
 
 from billiard import forking_enable
-from billiard.pool import Pool, RUN, CLOSE
+from billiard import pool as _pool
+from billiard.exceptions import WorkerLostError
+from billiard.pool import RUN, CLOSE, TERMINATE, WorkersJoined
+from billiard.queues import _SimpleQueue
+from kombu.serialization import pickle as _pickle
+from kombu.utils import fxrange
+from kombu.utils.compat import get_errno
+from kombu.utils.eventio import SELECT_BAD_FD
 
 from celery import platforms
 from celery import signals
@@ -22,6 +37,8 @@ from celery._state import set_default_app
 from celery.concurrency.base import BasePool
 from celery.five import items
 from celery.task import trace
+from celery.utils.log import get_logger
+from celery.worker.hub import READ, WRITE, ERR
 
 #: List of signals to reset when a child process starts.
 WORKER_SIGRESET = frozenset(['SIGTERM',
@@ -33,6 +50,16 @@ WORKER_SIGRESET = frozenset(['SIGTERM',
 #: List of signals to ignore when a child process starts.
 WORKER_SIGIGNORE = frozenset(['SIGINT'])
 
+UNAVAIL = frozenset([errno.EAGAIN, errno.EINTR, errno.EBADF])
+
+MAXTASKS_NO_BILLIARD = """\
+    maxtasksperchild enabled but billiard C extension not installed!
+    This may lead to a deadlock, please install the billiard C extension.
+"""
+
+logger = get_logger(__name__)
+warning, debug = logger.warning, logger.debug
+
 
 def process_initializer(app, hostname):
     """Initializes the process so it can be used to process tasks."""
@@ -62,9 +89,174 @@ def process_initializer(app, hostname):
     signals.worker_process_init.send(sender=None)
 
 
+def _select(self, readers=None, writers=None, err=None, timeout=0):
+    readers = set() if readers is None else readers
+    writers = set() if writers is None else writers
+    err = set() if err is None else err
+    try:
+        r, w, e = select.select(readers, writers, err, timeout)
+        if e:
+            _seen = set()
+            r = [f for f in r + e if f not in _seen and not _seen.add(f)]
+        return r, w, 0
+    except (select.error, socket.error) as exc:
+        if get_errno(exc) == errno.EINTR:
+            return
+        elif get_errno(exc) in SELECT_BAD_FD:
+            for fd in readers | writers | err:
+                try:
+                    select.select([fd], [], [], 0)
+                except (select.error, socket.error) as exc:
+                    if get_errno(exc) not in SELECT_BAD_FD:
+                        raise
+                    readers.discard(fd)
+                    writers.discard(fd)
+                    err.discard(fd)
+            return [], [], 1
+
+
+class promise(object):
+
+    def __init__(self, fun, *partial_args, **partial_kwargs):
+        self.fun = fun
+        self.args = partial_args
+        self.kwargs = partial_kwargs
+        self.ready = False
+
+    def __call__(self, *args, **kwargs):
+        try:
+            return self.fun(*tuple(self.args) + tuple(args),
+                            **dict(self.kwargs, **kwargs))
+        finally:
+            self.ready = True
+
+
+class ResultHandler(_pool.ResultHandler):
+
+    def on_stop_not_started(self):
+        cache = self.cache
+        check_timeouts = self.check_timeouts
+        fileno_to_proc = self.fileno_to_proc
+        on_state_change = self.on_state_change
+        join_exited_workers = self.join_exited_workers
+
+        outqueues = set(fileno_to_proc)
+        while cache and outqueues and self._state != TERMINATE:
+            if check_timeouts is not None:
+                check_timeouts()
+            for fd in outqueues:
+                proc = fileno_to_proc[fd]
+                reader = proc.outq._reader
+                try:
+                    if reader.poll(0):
+                        task = reader.recv()
+                    else:
+                        task = None
+                        sleep(0.5)
+                except (IOError, EOFError):
+                    outqueues.discard(fd)
+                    continue
+                else:
+                    if task is None:
+                        debug('result handler ignoring extra sentinel')
+                        continue
+                    on_state_change(task)
+                try:
+                    join_exited_workers(shutdown=True)
+                except WorkersJoined:
+                    debug('result handler: all workers terminated')
+                    return
+
+
+class AsynPool(_pool.Pool):
+    ResultHandler = ResultHandler
+
+    def __init__(self, processes=None, *args, **kwargs):
+        processes = self.cpu_count() if processes is None else processes
+        self._queuepairs = dict((self.create_process_queuepair(), None)
+                                for _ in range(processes))
+        super(AsynPool, self).__init__(processes, *args, **kwargs)
+
+    def get_process_queuepair(self):
+        return next(pair for pair, owner in items(self._queuepairs)
+                    if owner is None)
+
+    def create_process_queuepair(self):
+        return _SimpleQueue(), _SimpleQueue()
+
+    def _process_cleanup_queuepair(self, proc):
+        try:
+            self._queuepairs[self._find_worker_queuepair(proc)] = None
+        except (KeyError, ValueError):
+            pass
+
+    @staticmethod
+    def _stop_task_handler(task_handler):
+        for worker in task_handler.pool:
+            # send sentinels
+            worker.inq.put(None)
+
+    def _process_register_queuepair(self, proc, pair):
+        self._queuepairs[pair] = proc
+
+    def _find_worker_queuepair(self, proc):
+        for pair, owner in items(self._queuepairs):
+            if owner == proc:
+                return pair
+        raise ValueError(proc)
+
+    def _setup_queues(self):
+        self._inqueue = self._outqueue = \
+            self._quick_put = self._quick_get = self._poll_result = None
+
+    def on_partial_read(self, job, proc):
+        resq = proc.outq._reader
+        # empty result queue buffer
+        while resq.poll(0):
+            self.handle_result_event(resq.fileno())
+
+        # job was not acked, so find another worker to send it to.
+        if not job._accepted:
+            self._put_back(job)
+
+        # worker terminated by signal:
+        # we cannot reuse the sockets again, because we don't know if
+        # the process wrote/read anything frmo them, and if so we cannot
+        # restore the message boundaries.
+        if proc.exitcode < 0:
+            for conn in (proc.inq, proc.outq):
+                for sock in (conn._reader, conn._writer):
+                    if not sock.closed:
+                        os.close(sock.fileno())
+            self._queuepairs[(proc.inq, proc.outq)] = \
+                self._queuepairs[self.create_process_queuepair()] = None
+
+    @classmethod
+    def _set_result_sentinel(cls, _outqueue, workers):
+        for worker in workers:
+            worker.outq.put(None)
+
+    @classmethod
+    def _help_stuff_finish(cls, _inqueue, _taskhandler, _size, fileno_to_inq):
+        # task_handler may be blocked trying to put items on inqueue
+        debug(
+            'removing tasks from inqueue until task handler finished',
+        )
+        inqueues = set(fileno_to_inq)
+        while inqueues:
+            readable, _, again = _select(inqueues, timeout=0.5)
+            if again:
+                continue
+            if not readable:
+                break
+            for fd in readable:
+                fileno_to_inq[fd]._reader.recv()
+            sleep(0)
+
+
 class TaskPool(BasePool):
     """Multiprocessing Pool implementation."""
-    Pool = Pool
+    Pool = _pool.Pool
 
     uses_semaphore = True
 
@@ -74,15 +266,33 @@ class TaskPool(BasePool):
         Will pre-fork all workers so they're ready to accept tasks.
 
         """
+        if self.options.get('maxtasksperchild'):
+            try:
+                import _billiard  # noqa
+                _billiard.Connection.send_offset
+            except (ImportError, AttributeError):
+                # billiard C extension not installed
+                warning(MAXTASKS_NO_BILLIARD)
+
         forking_enable(self.forking_enable)
-        P = self._pool = self.Pool(processes=self.limit,
-                                   initializer=process_initializer,
-                                   **self.options)
+        Pool = self.Pool if self.options.get('threads', True) else AsynPool
+        P = self._pool = Pool(processes=self.limit,
+                              initializer=process_initializer,
+                              **self.options)
         self.on_apply = P.apply_async
         self.on_soft_timeout = P._timeout_handler.on_soft_timeout
         self.on_hard_timeout = P._timeout_handler.on_hard_timeout
         self.maintain_pool = P.maintain_pool
+        self.terminate_job = self._pool.terminate_job
+        self.grow = self._pool.grow
+        self.shrink = self._pool.shrink
+        self.restart = self._pool.restart
         self.maybe_handle_result = P._result_handler.handle_event
+        self.outbound_buffer = deque()
+        self.handle_result_event = P.handle_result_event
+        self._all_inqueues = set(p.inqW_fd for p in P._pool)
+        self._active_writes = set()
+        self._active_writers = set()
 
     def did_start_ok(self):
         return self._pool.did_start_ok()
@@ -104,18 +314,6 @@ class TaskPool(BasePool):
         if self._pool is not None and self._pool._state == RUN:
             self._pool.close()
 
-    def terminate_job(self, pid, signal=None):
-        return self._pool.terminate_job(pid, signal)
-
-    def grow(self, n=1):
-        return self._pool.grow(n)
-
-    def shrink(self, n=1):
-        return self._pool.shrink(n)
-
-    def restart(self):
-        self._pool.restart()
-
     def _get_info(self):
         return {'max-concurrency': self.limit,
                 'processes': [p.pid for p in self._pool._pool],
@@ -123,26 +321,173 @@ class TaskPool(BasePool):
                 'put-guarded-by-semaphore': self.putlocks,
                 'timeouts': (self._pool.soft_timeout, self._pool.timeout)}
 
-    def init_callbacks(self, **kwargs):
-        for k, v in items(kwargs):
-            setattr(self._pool, k, v)
+    def on_poll_init(self, w, hub,
+                     now=time, protocol=HIGHEST_PROTOCOL, pack=struct.pack,
+                     dumps=_pickle.dumps):
+        pool = self._pool
+        apply_after = hub.timer.apply_after
+        apply_at = hub.timer.apply_at
+        maintain_pool = self.maintain_pool
+        on_soft_timeout = self.on_soft_timeout
+        on_hard_timeout = self.on_hard_timeout
+        outbound = self.outbound_buffer
+        pop_message = outbound.popleft
+        put_message = outbound.append
+        fileno_to_inq = pool._fileno_to_inq
+        fileno_to_outq = pool._fileno_to_outq
+        all_inqueues = self._all_inqueues
+        active_writes = self._active_writes
+        add_coro = hub.add_coro
+        diff = all_inqueues.difference
+        hub_add, hub_remove = hub.add, hub.remove
+        mark_write_fd_as_active = active_writes.add
+        mark_write_gen_as_active = self._active_writers.add
+        write_generator_gone = self._active_writers.discard
+        get_job = pool._cache.__getitem__
+        pool._put_back = put_message
+
+        # did_start_ok will verify that pool processes were able to start,
+        # but this will only work the first time we start, as
+        # maxtasksperchild will mess up metrics.
+        if not w.consumer.restart_count and not pool.did_start_ok():
+            raise WorkerLostError('Could not start worker processes')
+
+        hub_add(pool.process_sentinels, self.maintain_pool, READ | ERR)
+        hub_add(fileno_to_outq, self.handle_result_event, READ | ERR)
+        for handler, interval in items(self.timers):
+            hub.timer.apply_interval(interval * 1000.0, handler)
+
+        # need to handle pool results before every task
+        # since multiple tasks can be received in a single poll()
+        # XXX do we need this now?!?
+        # hub.on_task.append(pool.maybe_handle_result)
+
+        def on_timeout_set(R, soft, hard):
+
+            def _on_soft_timeout():
+                if hard:
+                    R._tref = apply_at(now() + (hard - soft),
+                                       on_hard_timeout, (R, ))
+                on_soft_timeout(R)
+            if soft:
+                R._tref = apply_after(soft * 1000.0, _on_soft_timeout)
+            elif hard:
+                R._tref = apply_after(hard * 1000.0,
+                                      on_hard_timeout, (R, ))
+        self.on_timeout_set = on_timeout_set
+
+        def on_timeout_cancel(result):
+            try:
+                result._tref.cancel()
+                delattr(result, '_tref')
+            except AttributeError:
+                pass
+        self.on_timeout_cancel = on_timeout_cancel
+
+        def on_process_up(proc):
+            pool._all_inqueues.add(proc.inqW_fd)
+            hub_add(proc.sentinel, maintain_pool, READ | ERR)
+            hub_add(proc.outqR_fd, pool.handle_result_event, READ | ERR)
+        self.on_process_up = on_process_up
+
+        def on_process_down(proc):
+            pool._all_inqueues.discard(proc.inqW_fd)
+            hub_remove(proc.sentinel)
+            hub_remove(proc.outqR_fd)
+        self.on_process_down = on_process_down
+
+        def _write_to(fd, job, callback=None):
+            header, body, body_size = job._payload
+            try:
+                try:
+                    proc = fileno_to_inq[fd]
+                except KeyError:
+                    put_message(job)
+                    raise StopIteration()
+                send_offset = proc.inq._writer.send_offset
+                # job result keeps track of what process the job is sent to.
+                job._write_to = proc
+
+                Hw = Bw = 0
+                while Hw < 4:
+                    try:
+                        Hw += send_offset(header, Hw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        yield
+                while Bw < body_size:
+                    try:
+                        Bw += send_offset(body, Bw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        yield
+            finally:
+                if callback:
+                    callback()
+                active_writes.discard(fd)
+
+        def schedule_writes(ready_fd, events):
+            try:
+                job = pop_message()
+            except IndexError:
+                for inqfd in diff(active_writes):
+                    hub_remove(inqfd)
+            else:
+                callback = promise(write_generator_gone)
+                cor = _write_to(ready_fd, job, callback=callback)
+                mark_write_gen_as_active(cor)
+                mark_write_fd_as_active(ready_fd)
+                callback.args = (cor, )  # tricky as we need to pass ref
+                add_coro((ready_fd, ), cor, WRITE)
+
+        def on_poll_start(hub):
+            if outbound:
+                hub_add(diff(active_writes), schedule_writes, hub.WRITE)
+        self.on_poll_start = on_poll_start
+
+        def quick_put(tup):
+            body = dumps(tup, protocol=protocol)
+            body_size = len(body)
+            header = pack('>I', body_size)
+            # index 0 is the job ID.
+            job = get_job(tup[0])
+            job._payload = header, buffer(body), body_size
+            put_message(job)
+        self._pool._quick_put = quick_put
 
     def handle_timeouts(self):
         if self._pool._timeout_handler:
             self._pool._timeout_handler.handle_event()
 
+    def flush(self):
+        if self.outbound_buffer:
+            self.outbound_buffer.clear()
+        try:
+            # flush outgoing buffers
+            intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True)
+            while self._active_writers:
+                writers = list(self._active_writers)
+                for gen in writers:
+                    if gen.gi_frame.f_lasti != -1:  # generator started?
+                        try:
+                            next(gen)
+                        except StopIteration:
+                            self._active_writers.discard(gen)
+                # workers may have exited in the meantime.
+                self.maintain_pool()
+                sleep(next(intervals))  # don't busyloop
+        finally:
+            self.outbound_buffer.clear()
+            self._active_writers.clear()
+
     @property
     def num_processes(self):
         return self._pool._processes
 
-    @property
-    def readers(self):
-        return self._pool.readers
-
-    @property
-    def writers(self):
-        return self._pool.writers
-
     @property
     def timers(self):
         return {self.maintain_pool: 5.0}

+ 3 - 2
celery/datastructures.py

@@ -528,7 +528,6 @@ class LimitedSet(object):
     :keyword expires: Time in seconds, before a membership expires.
 
     """
-    __slots__ = ('maxlen', 'expires', '_data', '__len__', '_heap')
 
     def __init__(self, maxlen=None, expires=None, data=None, heap=None):
         self.maxlen = maxlen
@@ -537,7 +536,9 @@ class LimitedSet(object):
         self._heap = [] if heap is None else heap
         self.__len__ = self._data.__len__
         self.__contains__ = self._data.__contains__
-        self.__iter__ = self._data.__iter__
+
+    def __iter__(self):
+        return iter(self._data)
 
     def add(self, value):
         """Add a new member."""

+ 4 - 4
celery/utils/functional.py

@@ -106,14 +106,14 @@ class LRUCache(UserDict):
         self.mutex = threading.RLock()
 
 
-def is_list(l):
+def is_list(l, scalars=(dict, string_t)):
     """Returns true if object is list-like, but not a dict or string."""
-    return hasattr(l, '__iter__') and not isinstance(l, (dict, string_t))
+    return hasattr(l, '__iter__') and not isinstance(l, scalars or ())
 
 
-def maybe_list(l):
+def maybe_list(l, scalars=(dict, string_t)):
     """Returns list of one element if ``l`` is a scalar."""
-    return l if l is None or is_list(l) else [l]
+    return l if l is None or is_list(l, scalars) else [l]
 
 
 def memoize(maxsize=None, Cache=LRUCache):

+ 2 - 56
celery/worker/components.py

@@ -9,14 +9,11 @@
 from __future__ import absolute_import
 
 import atexit
-import time
 
 from functools import partial
 
-from billiard.exceptions import WorkerLostError
-
 from celery import bootsteps
-from celery.five import items, string_t
+from celery.five import string_t
 from celery.utils.log import worker_logger as logger
 from celery.utils.timer2 import Schedule
 
@@ -92,57 +89,6 @@ class Pool(bootsteps.StartStopStep):
         if w.pool:
             w.pool.terminate()
 
-    def on_poll_init(self, pool, w, hub):
-        apply_after = hub.timer.apply_after
-        apply_at = hub.timer.apply_at
-        on_soft_timeout = pool.on_soft_timeout
-        on_hard_timeout = pool.on_hard_timeout
-        maintain_pool = pool.maintain_pool
-        add_reader = hub.add_reader
-        remove = hub.remove
-        now = time.time
-
-        # did_start_ok will verify that pool processes were able to start,
-        # but this will only work the first time we start, as
-        # maxtasksperchild will mess up metrics.
-        if not w.consumer.restart_count and not pool.did_start_ok():
-            raise WorkerLostError('Could not start worker processes')
-
-        # need to handle pool results before every task
-        # since multiple tasks can be received in a single poll()
-        hub.on_task.append(pool.maybe_handle_result)
-
-        hub.update_readers(pool.readers)
-        for handler, interval in items(pool.timers):
-            hub.timer.apply_interval(interval * 1000.0, handler)
-
-        def on_timeout_set(R, soft, hard):
-
-            def _on_soft_timeout():
-                if hard:
-                    R._tref = apply_at(now() + (hard - soft),
-                                       on_hard_timeout, (R, ))
-                on_soft_timeout(R)
-            if soft:
-                R._tref = apply_after(soft * 1000.0, _on_soft_timeout)
-            elif hard:
-                R._tref = apply_after(hard * 1000.0,
-                                      on_hard_timeout, (R, ))
-
-        def on_timeout_cancel(result):
-            try:
-                result._tref.cancel()
-                delattr(result, '_tref')
-            except AttributeError:
-                pass
-
-        pool.init_callbacks(
-            on_process_up=lambda w: add_reader(w.sentinel, maintain_pool),
-            on_process_down=lambda w: remove(w.sentinel),
-            on_timeout_set=on_timeout_set,
-            on_timeout_cancel=on_timeout_cancel,
-        )
-
     def create(self, w, semaphore=None, max_restarts=None):
         threaded = not w.use_eventloop
         procs = w.min_concurrency
@@ -168,7 +114,7 @@ class Pool(bootsteps.StartStopStep):
             semaphore=semaphore,
         )
         if w.hub:
-            w.hub.on_init.append(partial(self.on_poll_init, pool, w))
+            w.hub.on_init.append(partial(pool.on_poll_init, w))
         return pool
 
     def info(self, w):

+ 7 - 0
celery/worker/consumer.py

@@ -226,6 +226,10 @@ class Consumer(object):
                     sleep(1)
                 if ns.state != CLOSE and self.connection:
                     warn(CONNECTION_RETRY, exc_info=True)
+                    try:
+                        self.connection.close()
+                    except Exception:
+                        pass
                     ns.restart(self)
 
     def shutdown(self):
@@ -270,8 +274,11 @@ class Consumer(object):
         # Clear internal queues to get rid of old messages.
         # They can't be acked anyway, as a delivery tag is specific
         # to the current channel.
+        if self.controller.semaphore:
+            self.controller.semaphore.clear()
         self.timer.clear()
         reserved_requests.clear()
+        self.pool.flush()
 
     def connect(self):
         """Establish the broker connection.

+ 118 - 18
celery/worker/hub.py

@@ -8,15 +8,42 @@
 """
 from __future__ import absolute_import
 
+from functools import wraps
+
 from kombu.utils import cached_property
 from kombu.utils import eventio
+from kombu.utils.eventio import READ, WRITE, ERR
 
 from celery.five import items, range
+from celery.platforms import fileno
+from celery.utils.functional import maybe_list
 from celery.utils.log import get_logger
 from celery.utils.timer2 import Schedule
 
 logger = get_logger(__name__)
-READ, WRITE, ERR = eventio.READ, eventio.WRITE, eventio.ERR
+
+
+def repr_flag(flag):
+    return '{0}{1}{2}'.format('R' if flag & READ else '',
+                              'W' if flag & WRITE else '',
+                              '!' if flag & ERR else '')
+
+
+def _rcb(obj):
+    if isinstance(obj, str):
+        return obj
+    return obj.__name__
+
+
+def coroutine(gen):
+
+    @wraps(gen)
+    def advances(*args, **kwargs):
+        it = gen(*args, **kwargs)
+        next(it)
+        return it
+
+    return advances
 
 
 class BoundedSemaphore(object):
@@ -32,7 +59,7 @@ class BoundedSemaphore(object):
         >>> x = BoundedSemaphore(2)
 
         >>> def callback(i):
-        ...     print('HELLO {0!r}'.format(i))
+        ...     say('HELLO {0!r}'.format(i))
 
         >>> x.acquire(callback, 1)
         HELLO 1
@@ -133,6 +160,33 @@ class Hub(object):
         self.on_init = []
         self.on_close = []
         self.on_task = []
+        self.coros = {}
+
+        self.trampoline = self._trampoline()
+
+    @coroutine
+    def _trampoline(self):
+        coros = self.coros
+        add = self.add_coro
+        remove_self = self.remove
+        pop = self.coros.pop
+
+        while 1:
+            fd, events = (yield)
+            remove_self(fd)
+            try:
+                gen = coros[fd]
+            except KeyError:
+                pass
+            else:
+                try:
+                    next(gen)
+                    add(fd, gen, WRITE)
+                except StopIteration:
+                    pop(fd, None)
+                except Exception:
+                    pop(fd, None)
+                    raise
 
     def start(self):
         """Called by Hub bootstep at worker startup."""
@@ -162,20 +216,36 @@ class Hub(object):
                     logger.error('Error in timer: %r', exc, exc_info=1)
         return min(max(delay or 0, min_delay), max_delay)
 
-    def add(self, fd, callback, flags):
+    def _add(self, fd, cb, flags):
         self.poller.register(fd, flags)
-        if not isinstance(fd, int):
-            fd = fd.fileno()
-        if flags & READ:
-            self.readers[fd] = callback
-        if flags & WRITE:
-            self.writers[fd] = callback
+        (self.readers if flags & READ else self.writers)[fileno(fd)] = cb
 
-    def add_reader(self, fd, callback):
-        return self.add(fd, callback, READ | ERR)
+    def add(self, fds, callback, flags):
+        for fd in maybe_list(fds, None):
+            try:
+                self._add(fd, callback, flags)
+            except ValueError:
+                self._discard(fd)
 
-    def add_writer(self, fd, callback):
-        return self.add(fd, callback, WRITE)
+    def remove(self, fd):
+        fd = fileno(fd)
+        self._unregister(fd)
+        self._discard(fd)
+
+    def add_coro(self, fds, coro, flags):
+        for fd in (fileno(f) for f in maybe_list(fds, None)):
+            self._add(fd, self.trampoline, flags)
+            self.coros[fd] = coro
+
+    def remove_coro(self, fds):
+        for fd in maybe_list(fds, None):
+            self.coros.pop(fileno(fd), None)
+
+    def add_reader(self, fds, callback):
+        return self.add(fds, callback, READ | ERR)
+
+    def add_writer(self, fds, callback):
+        return self.add(fds, callback, WRITE)
 
     def update_readers(self, readers):
         [self.add_reader(*x) for x in items(readers)]
@@ -189,11 +259,10 @@ class Hub(object):
         except (KeyError, OSError):
             pass
 
-    def remove(self, fd):
-        fileno = fd.fileno() if not isinstance(fd, int) else fd
-        self.readers.pop(fileno, None)
-        self.writers.pop(fileno, None)
-        self._unregister(fd)
+    def _discard(self, fd):
+        fd = fileno(fd)
+        self.readers.pop(fd, None)
+        self.writers.pop(fd, None)
 
     def __enter__(self):
         self.init()
@@ -208,6 +277,37 @@ class Hub(object):
             callback(self)
     __exit__ = close
 
+    def _repr_readers(self):
+        return ['{0}->{1}->{2!r}'.format(_rcb(cb), repr_flag(READ | ERR, fd))
+                for fd, cb in items(self.readers)]
+
+    def _repr_writers(self):
+        return ['{0}->{1}->{2!r}'.format(_rcb(cb), repr_flag(WRITE, fd))
+                for fd, cb in items(self.writers)]
+
+    def repr_active(self):
+        return ', '.join(self._repr_readers() + self._repr_writers())
+
+    def repr_events(self, events):
+        return ', '.join(
+            '{0}->{1}' % (
+                _rcb(self._callback_for(fd, fl, '{0!r}(GONE)'.format(fd))),
+                repr_flag(fl)
+            )
+            for fd, fl in events
+        )
+
+    def _callback_for(self, fd, flag, *default):
+        try:
+            if flag & READ:
+                return self.readers[fileno(fd)]
+            if flag & WRITE:
+                return self.writers[fileno(fd)]
+        except KeyError:
+            if default:
+                return default[0]
+            raise
+
     @cached_property
     def scheduler(self):
         return iter(self.timer)

+ 18 - 13
celery/worker/loops.py

@@ -10,6 +10,7 @@ from __future__ import absolute_import
 import socket
 
 from time import sleep
+from types import GeneratorType as generator
 
 from kombu.utils.eventio import READ, WRITE, ERR
 
@@ -35,12 +36,14 @@ def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
         fire_timers = hub.fire_timers
         scheduled = hub.timer._queue
         hbtick = connection.heartbeat_check
-        on_poll_start = connection.transport.on_poll_start
-        on_poll_empty = connection.transport.on_poll_empty
+        conn_poll_start = connection.transport.on_poll_start
+        conn_poll_empty = connection.transport.on_poll_empty
+        pool_poll_start = obj.pool.on_poll_start
         drain_nowait = connection.drain_nowait
         on_task_callbacks = hub.on_task
         keep_draining = connection.transport.nb_keep_draining
         errors = connection.connection_errors
+        hub_add, hub_remove = hub.add, hub.remove
 
         if heartbeat and connection.supports_heartbeats:
             hub.timer.apply_interval(
@@ -82,7 +85,8 @@ def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
             if qos.prev != qos.value:
                 update_qos()
 
-            update_readers(on_poll_start())
+            update_readers(conn_poll_start())
+            pool_poll_start(hub)
             if readers or writers:
                 connection.more_to_read = True
                 while connection.more_to_read:
@@ -91,19 +95,20 @@ def asynloop(obj, connection, consumer, strategies, ns, hub, qos,
                     except ValueError:  # Issue 882
                         return
                     if not events:
-                        on_poll_empty()
+                        conn_poll_empty()
                     for fileno, event in events or ():
                         try:
                             if event & READ:
-                                readers[fileno](fileno, event)
-                            if event & WRITE:
-                                writers[fileno](fileno, event)
-                            if event & ERR:
-                                for handlermap in readers, writers:
-                                    try:
-                                        handlermap[fileno](fileno, event)
-                                    except KeyError:
-                                        pass
+                                cb = readers[fileno]
+                            elif event & WRITE:
+                                cb = writers[fileno]
+                            elif event & ERR:
+                                cb = (readers.get(fileno) or
+                                      writers.get(fileno))
+                            if isinstance(cb, generator):
+                                cb.send((fileno, event))
+                            else:
+                                cb(fileno, event)
                         except (KeyError, Empty):
                             continue
                         except socket.error: