Browse Source

Event loop API is now closer to Tulip

Ask Solem 11 years ago
parent
commit
6895da4984

+ 1 - 1
celery/concurrency/base.py

@@ -70,7 +70,7 @@ class BasePool(object):
     def on_stop(self):
         pass
 
-    def on_poll_init(self, worker, hub):
+    def register_with_event_loop(self, worker, hub):
         pass
 
     def on_poll_start(self, hub):

+ 643 - 649
celery/concurrency/processes.py

@@ -32,7 +32,6 @@ from weakref import WeakValueDictionary, ref
 from amqp.utils import promise
 from billiard import forking_enable
 from billiard import pool as _pool
-from billiard.exceptions import WorkerLostError
 from billiard.pool import (
     RUN, CLOSE, TERMINATE, ACK, NACK, EX_RECYCLE, WorkersJoined, CoroStop,
 )
@@ -297,6 +296,16 @@ class AsynPool(_pool.Pool):
 
         # denormalized set of all inqueues.
         self._all_inqueues = set()
+
+        # Set of fds being written to (busy)
+        self._active_writes = set()
+
+        # Set of active co-routines currently writing jobs.
+        self._active_writers = set()
+
+        # Holds jobs waiting to be written to child processes.
+        self.outbound_buffer = deque()
+
         super(AsynPool, self).__init__(processes, *args, **kwargs)
 
         for proc in self._pool:
@@ -305,730 +314,715 @@ class AsynPool(_pool.Pool):
             self._fileno_to_inq[proc.inqW_fd] = proc
             self._fileno_to_outq[proc.outqR_fd] = proc
             self._fileno_to_synq[proc.synqW_fd] = proc
+        self.on_soft_timeout = self._timeout_handler.on_soft_timeout
+        self.on_hard_timeout = self._timeout_handler.on_hard_timeout
 
-    def get_process_queues(self):
-        """Get queues for a new process.
-
-        Here we will find an unused slot, as there should always
-        be one available when we start a new process.
-        """
-        return next(q for q, owner in items(self._queues)
-                    if owner is None)
-
-    def on_grow(self, n):
-        """Grow the pool by ``n`` proceses."""
-        diff = max(self._processes - len(self._queues), 0)
-        if diff:
-            self._queues.update(
-                dict((self.create_process_queues(), None) for _ in range(diff))
-            )
-
-    def on_shrink(self, n):
-        """Shrink the pool by ``n`` processes."""
-        pass
+    def register_with_event_loop(self, hub):
+        """Registers the async pool with the current event loop."""
+        self._create_timelimit_handlers(hub)
+        self._create_process_handlers(hub)
+        self._create_write_handlers(hub)
 
-    def create_process_queues(self):
-        """Creates new in, out (and optionally syn) queues,
-        returned as a tuple."""
-        inq, outq, synq = _SimpleQueue(), _SimpleQueue(), None
-        inq._writer.setblocking(0)
-        if self.synack:
-            synq = _SimpleQueue()
-            synq._writer.setblocking(0)
-        return inq, outq, synq
+        # Maintain_pool is called whenever a process exits.
+        [hub.add_reader(fd, self.maintain_pool)
+         for fd in self.process_sentinels]
+        # Handle_result_event is called whenever one of the
+        # result queues are readable.
+        [hub.add_reader(fd, self.handle_result_event)
+         for fd in self._fileno_to_outq]
 
-    def on_process_alive(self, pid):
-        """Handler called when the WORKER_UP message is received
-        from a child process, which marks the process as ready
-        to receive work."""
-        try:
-            proc = next(w for w in self._pool if w.pid == pid)
-        except StopIteration:
-            # process already exited :(  this will be handled elsewhere.
-            return
-        self._fileno_to_inq[proc.inqW_fd] = proc
-        self._fileno_to_synq[proc.synqW_fd] = proc
-        self._all_inqueues.add(proc.inqW_fd)
+        # Timers include calling maintain_pool at a regular interval
+        # to be certain processes are restarted.
+        for handler, interval in items(self.timers):
+            hub.call_repeatedly(interval, handler)
 
-    def on_job_process_down(self, job, pid_gone):
-        """Handler called for each job when the process it was assigned to
-        exits."""
-        if job._write_to:
-            # job was partially written
-            self.on_partial_read(job, job._write_to)
-        elif job._scheduled_for:
-            # job was only scheduled to be written to this process,
-            # but no data was sent so put it back on the outbound_buffer.
-            self._put_back(job)
+        hub.on_tick.add(self.on_poll_start)
 
-    def on_job_process_lost(self, job, pid, exitcode):
-        """Handler called for each *started* job when the process it
-        was assigned to exited by mysterious means (error exitcodes and
-        signals)"""
-        self.mark_as_worker_lost(job, exitcode)
+    def _create_timelimit_handlers(self, hub, now=time):
+        """For async pool this sets up the handlers used
+        to implement time limits."""
+        call_later = hub.call_later
+        trefs = self._tref_for_id = WeakValueDictionary()
 
-    def _process_cleanup_queues(self, proc):
-        """Handler called to clean up a processes queues after process
-        exit."""
-        try:
-            self._queues[self._find_worker_queues(proc)] = None
-        except (KeyError, ValueError):
-            pass
+        def on_timeout_set(R, soft, hard):
+            if soft:
+                trefs[R._job] = call_later(
+                    soft * 1000.0,
+                    self._on_soft_timeout, (R._job, soft, hard, hub),
+                )
+            elif hard:
+                trefs[R._job] = call_later(
+                    hard * 1000.0,
+                    self._on_hard_timeout, (R._job, )
+                )
+        self.on_timeout_set = on_timeout_set
 
-    @staticmethod
-    def _stop_task_handler(task_handler):
-        """Called at shutdown to tell processes that we are shutting down."""
-        for proc in task_handler.pool:
-            proc.inq._writer.setblocking(1)
+        def _discard_tref(job):
             try:
-                proc.inq.put(None)
-            except OSError as exc:
-                if get_errno(exc) != errno.EBADF:
-                    raise
+                tref = trefs.pop(job)
+                tref.cancel()
+                del(tref)
+            except (KeyError, AttributeError):
+                pass  # out of scope
+        self._discard_tref = _discard_tref
 
-    def create_result_handler(self):
-        return super(AsynPool, self).create_result_handler(
-            fileno_to_outq=self._fileno_to_outq,
-            on_process_alive=self.on_process_alive,
-        )
+        def on_timeout_cancel(R):
+            _discard_tref(R._job)
+        self.on_timeout_cancel = on_timeout_cancel
 
-    def _process_register_queues(self, proc, queues):
-        """Marks new ownership for ``queues`` so that the fileno indices are
-        updated."""
-        assert queues in self._queues
-        b = len(self._queues)
-        self._queues[queues] = proc
-        assert b == len(self._queues)
+    def _on_soft_timeout(self, job, soft, hard, hub, now=time):
+        # only used by async pool.
+        if hard:
+            self._tref_for_id[job] = hub.call_at(
+                now() + (hard - soft),
+                self._on_hard_timeout, (job, ),
+            )
+        try:
+            result = self._cache[job]
+        except KeyError:
+            pass  # job ready
+        else:
+            self.on_soft_timeout(result)
+        finally:
+            if not hard:
+                # remove tref
+                self._discard_tref(job)
 
-    def _find_worker_queues(self, proc):
-        """Find the queues owned by ``proc``."""
+    def _on_hard_timeout(self, job):
+        # only used by async pool.
         try:
-            return next(q for q, owner in items(self._queues)
-                        if owner == proc)
-        except StopIteration:
-            raise ValueError(proc)
+            result = self._cache[job]
+        except KeyError:
+            pass  # job ready
+        else:
+            self.on_hard_timeout(result)
+        finally:
+            # remove tref
+            self._discard_tref(job)
 
-    def _setup_queues(self):
-        # this is only used by the original pool which uses a shared
-        # queue for all processes.
+    def _create_process_handlers(self, hub, READ=READ, ERR=ERR):
+        """For async pool this will create the handlers called
+        when a process is up/down and etc."""
+        add_reader, hub_remove = hub.add_reader, hub.remove
+        cache = self._cache
+        all_inqueues = self._all_inqueues
+        fileno_to_inq = self._fileno_to_inq
+        fileno_to_outq = self._fileno_to_outq
+        fileno_to_synq = self._fileno_to_synq
+        maintain_pool = self.maintain_pool
+        handle_result_event = self.handle_result_event
+        process_flush_queues = self.process_flush_queues
 
-        # these attributes makes no sense for us, but we will still
-        # have to initialize them.
-        self._inqueue = self._outqueue = \
-            self._quick_put = self._quick_get = self._poll_result = None
+        def on_process_up(proc):
+            """Called when a WORKER_UP message is received from process."""
+            # If we got the same fd as a previous process then we will also
+            # receive jobs in the old buffer, so we need to reset the
+            # job._write_to and job._scheduled_for attributes used to recover
+            # message boundaries when processes exit.
+            infd = proc.inqW_fd
+            for job in values(cache):
+                if job._write_to and job._write_to.inqW_fd == infd:
+                    job._write_to = proc
+                if job._scheduled_for and job._scheduled_for.inqW_fd == infd:
+                    job._scheduled_for = proc
+            fileno_to_outq[proc.outqR_fd] = proc
+            # maintain_pool is called whenever a process exits.
+            add_reader(proc.sentinel, maintain_pool)
+            # handle_result_event is called when the processes outqueue is
+            # readable.
+            add_reader(proc.outqR_fd, handle_result_event)
+        self.on_process_up = on_process_up
 
-    def process_flush_queues(self, proc):
-        """Flushes all queues, including the outbound buffer, so that
-        all tasks that have not been started will be discarded.
+        def on_process_down(proc):
+            """Called when a worker process exits."""
+            process_flush_queues(proc)
+            fileno_to_outq.pop(proc.outqR_fd, None)
+            fileno_to_inq.pop(proc.inqW_fd, None)
+            fileno_to_synq.pop(proc.synqW_fd, None)
+            all_inqueues.discard(proc.inqW_fd)
+            hub_remove(proc.sentinel)
+            hub_remove(proc.outqR_fd)
+        self.on_process_down = on_process_down
 
-        In Celery this is called whenever the transport connection is lost
-        (consumer restart).
+    def _create_write_handlers(self, hub,
+                               pack=struct.pack, dumps=_pickle.dumps,
+                               protocol=HIGHEST_PROTOCOL):
+        """For async pool this creates the handlers used to write data to
+        child processes."""
+        fileno_to_inq = self._fileno_to_inq
+        fileno_to_synq = self._fileno_to_synq
+        outbound = self.outbound_buffer
+        pop_message = outbound.popleft
+        put_message = outbound.append
+        all_inqueues = self._all_inqueues
+        active_writes = self._active_writes
+        diff = all_inqueues.difference
+        add_reader, add_writer = hub.add_reader, hub.add_writer
+        hub_add, hub_remove = hub.add, hub.remove
+        mark_write_fd_as_active = active_writes.add
+        mark_write_gen_as_active = self._active_writers.add
+        write_generator_done = self._active_writers.discard
+        get_job = self._cache.__getitem__
+        # puts back at the end of the queue
+        self._put_back = outbound.appendleft
+        precalc = {ACK: self._create_payload(ACK, (0, )),
+                   NACK: self._create_payload(NACK, (0, ))}
 
-        """
-        resq = proc.outq._reader
-        on_state_change = self._result_handler.on_state_change
-        while not resq.closed and resq.poll(0) and self._state != TERMINATE:
-            try:
-                task = resq.recv()
-            except (IOError, EOFError) as exc:
-                debug('got %r while flushing process %r',
-                      exc, proc, exc_info=1)
-                break
-            else:
-                if task is not None:
-                    on_state_change(task)
-                else:
-                    debug('got sentinel while flushing process %r', proc)
+        def on_poll_start():
+            # called for every event loop iteration, and if there
+            # are messages pending this will schedule writing one message
+            # by registering the 'schedule_writes' function for all currently
+            # inactive inqueues (not already being written to)
 
-    def on_partial_read(self, job, proc):
-        """Called when a job was only partially written to a child process
-        and it exited."""
-        # worker terminated by signal:
-        # we cannot reuse the sockets again, because we don't know if
-        # the process wrote/read anything frmo them, and if so we cannot
-        # restore the message boundaries.
-        if proc.exitcode != EX_RECYCLE:
-            if not job._accepted:
-                # job was not acked, so find another worker to send it to.
-                self._put_back(job)
-            writer = getattr(job, '_writer')
-            writer = writer and writer() or None
-            if writer:
-                self._active_writers.discard(writer)
+            # consolidate means the event loop will merge them
+            # and call the callback once with the list writable fds as
+            # argument.  Using this means we minimize the risk of having
+            # the same fd receive every task if the pipe read buffer is not
+            # full.
+            if outbound:
+                [hub_add(fd, None, WRITE | ERR, consolidate=True)
+                 for fd in diff(active_writes)]
+        self.on_poll_start = on_poll_start
 
-            # Replace queues to avoid reuse
-            before = len(self._queues)
-            try:
-                queues = self._find_worker_queues(proc)
-                if self.destroy_queues(queues):
-                    self._queues[self.create_process_queues()] = None
-            except ValueError:
-                # Not in queue map, make sure sockets are closed.
-                self.destroy_queues((proc.inq, proc.outq, proc.synq))
-            assert len(self._queues) == before
+        def on_inqueue_close(fd):
+            # Makes sure the fd is removed from tracking when
+            # the connection is closed, this is essential as fds may be reused.
+            active_writes.discard(fd)
+            all_inqueues.discard(fd)
+        self.on_inqueue_close = on_inqueue_close
 
-    def destroy_queues(self, queues):
-        """Destroy queues that can no longer be used, so that they
-        be replaced by new sockets."""
-        removed = 1
-        try:
-            self._queues.pop(queues)
-        except KeyError:
-            removed = 0
-        try:
-            self.on_inqueue_close(queues[0]._writer.fileno())
-        except IOError:
-            pass
-        for queue in queues:
-            if queue:
-                for sock in (queue._reader, queue._writer):
-                    if not sock.closed:
+        def schedule_writes(ready_fds, shuffle=random.shuffle):
+            # Schedule write operation to ready file descriptor.
+            # The file descriptor is writeable, but that does not
+            # mean the process is currently reading from the socket.
+            # The socket is buffered so writeable simply means that
+            # the buffer can accept at least 1 byte of data.
+            shuffle(ready_fds)
+            for ready_fd in ready_fds:
+                if ready_fd in active_writes:
+                    # already writing to this fd
+                    continue
+                try:
+                    job = pop_message()
+                except IndexError:
+                    # no more messages, remove all inactive fds from the hub.
+                    # this is important since the fds are always writeable
+                    # as long as there's 1 byte left in the buffer, and so
+                    # this may create a spinloop where the event loop
+                    # always wakes up.
+                    for inqfd in diff(active_writes):
+                        hub_remove(inqfd)
+                    break
+
+                else:
+                    if not job._accepted:  # job not accepted by another worker
                         try:
-                            sock.close()
-                        except (IOError, OSError):
+                            # keep track of what process the write operation
+                            # was scheduled for.
+                            proc = job._scheduled_for = fileno_to_inq[ready_fd]
+                        except KeyError:
+                            # write was scheduled for this fd but the process
+                            # has since exited and the message must be sent to
+                            # another process.
+                            put_message(job)
+                            continue
+                        cor = _write_job(proc, ready_fd, job)
+                        job._writer = ref(cor)
+                        mark_write_gen_as_active(cor)
+                        mark_write_fd_as_active(ready_fd)
+
+                        # Try to write immediately, in case there's an error.
+                        try:
+                            next(cor)
+                            add_writer(ready_fd, cor)
+                        except StopIteration:
                             pass
-        return removed
+        hub.consolidate_callback = schedule_writes
 
-    def _create_payload(self, type_, args,
-                        dumps=_pickle.dumps, pack=struct.pack,
-                        protocol=HIGHEST_PROTOCOL):
-        body = dumps((type_, args), protocol=protocol)
-        size = len(body)
-        header = pack('>I', size)
-        return header, body, size
+        def send_job(tup):
+            # Schedule writing job request for when one of the process
+            # inqueues are writable.
+            body = dumps(tup, protocol=protocol)
+            body_size = len(body)
+            header = pack('>I', body_size)
+            # index 1,0 is the job ID.
+            job = get_job(tup[1][0])
+            job._payload = header, body, body_size
+            put_message(job)
+        self._quick_put = send_job
 
-    @classmethod
-    def _set_result_sentinel(cls, _outqueue, _pool):
-        # unused
-        pass
+        write_stats = self.write_stats = Counter()
 
-    def _help_stuff_finish_args(self):
-        # Pool._help_stuff_finished is a classmethod so we have to use this
-        # trick to modify the arguments passed to it.
-        return (self._pool, )
+        def on_not_recovering(proc):
+            # XXX Theoretically a possibility, but not seen in practice yet.
+            raise Exception(
+                'Process writable but cannot write. Contact support!')
 
-    @classmethod
-    def _help_stuff_finish(cls, pool):
-        debug(
-            'removing tasks from inqueue until task handler finished',
-        )
-        fileno_to_proc = {}
-        inqR = set()
-        for w in pool:
+        def _write_job(proc, fd, job):
+            # writes job to the worker process.
+            # Operation must complete if more than one byte of data
+            # was written.  If the broker connection is lost
+            # and no data was written the operation shall be cancelled.
+            header, body, body_size = job._payload
+            errors = 0
             try:
-                fd = w.inq._reader.fileno()
-                inqR.add(fd)
-                fileno_to_proc[fd] = w
-            except IOError:
-                pass
-        while inqR:
-            readable, _, again = _select(inqR, timeout=0.5)
-            if again:
-                continue
-            if not readable:
-                break
-            for fd in readable:
-                fileno_to_proc[fd].inq._reader.recv()
-            sleep(0)
-
-
-class TaskPool(BasePool):
-    """Multiprocessing Pool implementation."""
-    Pool = AsynPool
-    BlockingPool = _pool.Pool
+                # job result keeps track of what process the job is sent to.
+                job._write_to = proc
+                send = proc.send_job_offset
 
-    uses_semaphore = True
-    write_stats = None
+                Hw = Bw = 0
+                # write header
+                while Hw < 4:
+                    try:
+                        Hw += send(header, Hw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        errors += 1
+                        if errors > 100:
+                            on_not_recovering(proc)
+                            raise StopIteration()
+                        yield
+                    errors = 0
 
-    def on_start(self):
-        """Run the task pool.
+                # write body
+                while Bw < body_size:
+                    try:
+                        Bw += send(body, Bw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        errors += 1
+                        if errors > 100:
+                            on_not_recovering(proc)
+                            raise StopIteration()
+                        yield
+                    errors = 0
+            finally:
+                write_stats[proc.index] += 1
+                # message written, so this fd is now available
+                active_writes.discard(fd)
+                write_generator_done(job._writer())  # is a weakref
 
-        Will pre-fork all workers so they're ready to accept tasks.
+        def send_ack(response, pid, job, fd, WRITE=WRITE, ERR=ERR):
+            # Only used when synack is enabled.
+            # Schedule writing ack response for when the fd is writeable.
+            msg = Ack(job, fd, precalc[response])
+            callback = promise(write_generator_done)
+            cor = _write_ack(fd, msg, callback=callback)
+            mark_write_gen_as_active(cor)
+            mark_write_fd_as_active(fd)
+            callback.args = (cor, )
+            add_writer(fd, cor)
+        self.send_ack = send_ack
 
-        """
-        if self.options.get('maxtasksperchild'):
+        def _write_ack(fd, ack, callback=None):
+            # writes ack back to the worker if synack enabled.
+            # this operation *MUST* complete, otherwise
+            # the worker process will hang waiting for the ack.
+            header, body, body_size = ack[2]
             try:
-                from billiard.connection import Connection
-                Connection.send_offset
-            except (ImportError, AttributeError):
-                # billiard C extension not installed
-                warning(MAXTASKS_NO_BILLIARD)
-
-        forking_enable(self.forking_enable)
-        Pool = (self.BlockingPool if self.options.get('threads', True)
-                else self.Pool)
-        P = self._pool = Pool(processes=self.limit,
-                              initializer=process_initializer,
-                              synack=False,
-                              **self.options)
-
-        # Create proxy methods
-        self.on_apply = P.apply_async
-        self.on_soft_timeout = P._timeout_handler.on_soft_timeout
-        self.on_hard_timeout = P._timeout_handler.on_hard_timeout
-        self.maintain_pool = P.maintain_pool
-        self.terminate_job = P.terminate_job
-        self.grow = P.grow
-        self.shrink = P.shrink
-        self.restart = P.restart
-        self.maybe_handle_result = P._result_handler.handle_event
-        self.handle_result_event = P.handle_result_event
-
-        # Holds jobs waiting to be written to child processes.
-        self.outbound_buffer = deque()
-
-        # Set of fds being written to (busy)
-        self._active_writes = set()
-
-        # Set of active co-routines currently writing jobs.
-        self._active_writers = set()
-
-    def did_start_ok(self):
-        return self._pool.did_start_ok()
-
-    def on_stop(self):
-        """Gracefully stop the pool."""
-        if self._pool is not None and self._pool._state in (RUN, CLOSE):
-            self._pool.close()
-            self._pool.join()
-            self._pool = None
-
-    def on_terminate(self):
-        """Force terminate the pool."""
-        if self._pool is not None:
-            self._pool.terminate()
-            self._pool = None
-
-    def on_close(self):
-        if self._pool is not None and self._pool._state == RUN:
-            self._pool.close()
+                try:
+                    proc = fileno_to_synq[fd]
+                except KeyError:
+                    # process died, we can safely discard the ack at this
+                    # point.
+                    raise StopIteration()
+                send = proc.send_syn_offset
 
-    def _get_info(self):
-        return {
-            'max-concurrency': self.limit,
-            'processes': [p.pid for p in self._pool._pool],
-            'max-tasks-per-child': self._pool._maxtasksperchild or 'N/A',
-            'put-guarded-by-semaphore': self.putlocks,
-            'timeouts': (self._pool.soft_timeout or 0,
-                         self._pool.timeout or 0),
-            'writes': self.human_write_stats(),
-        }
+                Hw = Bw = 0
+                # write header
+                while Hw < 4:
+                    try:
+                        Hw += send(header, Hw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        yield
 
-    def human_write_stats(self):
-        if self.write_stats is None:
-            return 'N/A'
-        vals = list(values(self.write_stats))
-        total = sum(vals)
+                # write body
+                while Bw < body_size:
+                    try:
+                        Bw += send(body, Bw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        yield
+            finally:
+                if callback:
+                    callback()
+                # message written, so this fd is now available
+                active_writes.discard(fd)
 
-        def per(v, total):
-            return '{0:.2f}%'.format((float(v) / total) * 100.0 if v else 0)
+    def flush(self):
+        if self._state == TERMINATE:
+            return
+        # cancel all tasks that have not been accepted so that NACK is sent.
+        for job in values(self._cache):
+            if not job._accepted:
+                job._cancel()
 
-        return {
-            'total': total,
-            'avg': per(total / len(self.write_stats) if total else 0, total),
-            'all': ', '.join(per(v, total) for v in vals),
-            'raw': ', '.join(map(str, vals)),
-        }
+        # clear the outgoing buffer as the tasks will be redelivered by
+        # the broker anyway.
+        if self.outbound_buffer:
+            self.outbound_buffer.clear()
+        try:
+            # ...but we must continue writing the payloads we already started
+            # to keep message boundaries.
+            # The messages may be NACK'ed later if synack is enabled.
+            if self._state == RUN:
+                # flush outgoing buffers
+                intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True)
+                while self._active_writers:
+                    writers = list(self._active_writers)
+                    for gen in writers:
+                        if (gen.__name__ == '_write_job' and
+                                gen_not_started(gen)):
+                            # has not started writing the job so can
+                            # discard the task, but we must also remove
+                            # it from the Pool._cache.
+                            job_to_discard = None
+                            for job in values(self._cache):
+                                if job._writer() is gen:  # _writer is saferef
+                                    # removes from Pool._cache
+                                    job_to_discard = job
+                                    break
+                            if job_to_discard:
+                                job_to_discard.discard()
+                            self._active_writers.discard(gen)
+                        else:
+                            try:
+                                next(gen)
+                            except StopIteration:
+                                self._active_writers.discard(gen)
+                    # workers may have exited in the meantime.
+                    self.maintain_pool()
+                    sleep(next(intervals))  # don't busyloop
+        finally:
+            self.outbound_buffer.clear()
+            self._active_writers.clear()
 
-    def on_poll_init(self, w, hub):
-        """Initialize async pool using the eventloop hub."""
-        pool = self._pool
-        pool._active_writers = self._active_writers
+    def get_process_queues(self):
+        """Get queues for a new process.
 
-        self._create_timelimit_handlers(hub)
-        self._create_process_handlers(hub)
-        self._create_write_handlers(hub)
+        Here we will find an unused slot, as there should always
+        be one available when we start a new process.
+        """
+        return next(q for q, owner in items(self._queues)
+                    if owner is None)
 
-        # did_start_ok will verify that pool processes were able to start,
-        # but this will only work the first time we start, as
-        # maxtasksperchild will mess up metrics.
-        if not w.consumer.restart_count and not pool.did_start_ok():
-            raise WorkerLostError('Could not start worker processes')
+    def on_grow(self, n):
+        """Grow the pool by ``n`` proceses."""
+        diff = max(self._processes - len(self._queues), 0)
+        if diff:
+            self._queues.update(
+                dict((self.create_process_queues(), None) for _ in range(diff))
+            )
 
-        # Maintain_pool is called whenever a process exits.
-        hub.add(pool.process_sentinels, self.maintain_pool, READ | ERR)
-        # Handle_result_event is called whenever one of the
-        # result queues are readable.
-        hub.add(pool._fileno_to_outq, self.handle_result_event, READ | ERR)
+    def on_shrink(self, n):
+        """Shrink the pool by ``n`` processes."""
+        pass
 
-        # Timers include calling maintain_pool at a regular interval
-        # to be certain processes are restarted.
-        for handler, interval in items(self.timers):
-            hub.timer.apply_interval(interval * 1000.0, handler)
+    def create_process_queues(self):
+        """Creates new in, out (and optionally syn) queues,
+        returned as a tuple."""
+        inq, outq, synq = _SimpleQueue(), _SimpleQueue(), None
+        inq._writer.setblocking(0)
+        if self.synack:
+            synq = _SimpleQueue()
+            synq._writer.setblocking(0)
+        return inq, outq, synq
 
-    def _create_timelimit_handlers(self, hub, now=time):
-        """For async pool this sets up the handlers used
-        to implement time limits."""
-        apply_after = hub.timer.apply_after
-        trefs = self._tref_for_id = WeakValueDictionary()
+    def on_process_alive(self, pid):
+        """Handler called when the WORKER_UP message is received
+        from a child process, which marks the process as ready
+        to receive work."""
+        try:
+            proc = next(w for w in self._pool if w.pid == pid)
+        except StopIteration:
+            # process already exited :(  this will be handled elsewhere.
+            return
+        self._fileno_to_inq[proc.inqW_fd] = proc
+        self._fileno_to_synq[proc.synqW_fd] = proc
+        self._all_inqueues.add(proc.inqW_fd)
 
-        def on_timeout_set(R, soft, hard):
-            if soft:
-                trefs[R._job] = apply_after(
-                    soft * 1000.0,
-                    self._on_soft_timeout, (R._job, soft, hard, hub),
-                )
-            elif hard:
-                trefs[R._job] = apply_after(
-                    hard * 1000.0,
-                    self._on_hard_timeout, (R._job, )
-                )
-        self._pool.on_timeout_set = on_timeout_set
+    def on_job_process_down(self, job, pid_gone):
+        """Handler called for each job when the process it was assigned to
+        exits."""
+        if job._write_to:
+            # job was partially written
+            self.on_partial_read(job, job._write_to)
+        elif job._scheduled_for:
+            # job was only scheduled to be written to this process,
+            # but no data was sent so put it back on the outbound_buffer.
+            self._put_back(job)
 
-        def _discard_tref(job):
-            try:
-                tref = trefs.pop(job)
-                tref.cancel()
-                del(tref)
-            except (KeyError, AttributeError):
-                pass  # out of scope
-        self._discard_tref = _discard_tref
+    def on_job_process_lost(self, job, pid, exitcode):
+        """Handler called for each *started* job when the process it
+        was assigned to exited by mysterious means (error exitcodes and
+        signals)"""
+        self.mark_as_worker_lost(job, exitcode)
 
-        def on_timeout_cancel(R):
-            _discard_tref(R._job)
-        self._pool.on_timeout_cancel = on_timeout_cancel
+    def human_write_stats(self):
+        if self.write_stats is None:
+            return 'N/A'
+        vals = list(values(self.write_stats))
+        total = sum(vals)
 
-    def _on_soft_timeout(self, job, soft, hard, hub, now=time):
-        # only used by async pool.
-        if hard:
-            self._tref_for_id[job] = hub.timer.apply_at(
-                now() + (hard - soft),
-                self._on_hard_timeout, (job, ),
-            )
-        try:
-            result = self._pool._cache[job]
-        except KeyError:
-            pass  # job ready
-        else:
-            self.on_soft_timeout(result)
-        finally:
-            if not hard:
-                # remove tref
-                self._discard_tref(job)
+        def per(v, total):
+            return '{0:.2f}%'.format((float(v) / total) * 100.0 if v else 0)
 
-    def _on_hard_timeout(self, job):
-        # only used by async pool.
+        return {
+            'total': total,
+            'avg': per(total / len(self.write_stats) if total else 0, total),
+            'all': ', '.join(per(v, total) for v in vals),
+            'raw': ', '.join(map(str, vals)),
+        }
+
+    def _process_cleanup_queues(self, proc):
+        """Handler called to clean up a processes queues after process
+        exit."""
         try:
-            result = self._pool._cache[job]
-        except KeyError:
-            pass  # job ready
-        else:
-            self.on_hard_timeout(result)
-        finally:
-            # remove tref
-            self._discard_tref(job)
+            self._queues[self._find_worker_queues(proc)] = None
+        except (KeyError, ValueError):
+            pass
 
-    def _create_process_handlers(self, hub, READ=READ, ERR=ERR):
-        """For async pool this will create the handlers called
-        when a process is up/down and etc."""
-        pool = self._pool
-        hub_add, hub_remove = hub.add, hub.remove
-        all_inqueues = self._pool._all_inqueues
-        fileno_to_inq = self._pool._fileno_to_inq
-        fileno_to_outq = self._pool._fileno_to_outq
-        fileno_to_synq = self._pool._fileno_to_synq
-        maintain_pool = self._pool.maintain_pool
-        handle_result_event = self._pool.handle_result_event
-        process_flush_queues = self._pool.process_flush_queues
+    @staticmethod
+    def _stop_task_handler(task_handler):
+        """Called at shutdown to tell processes that we are shutting down."""
+        for proc in task_handler.pool:
+            proc.inq._writer.setblocking(1)
+            try:
+                proc.inq.put(None)
+            except OSError as exc:
+                if get_errno(exc) != errno.EBADF:
+                    raise
 
-        def on_process_up(proc):
-            """Called when a WORKER_UP message is received from process."""
-            # If we got the same fd as a previous process then we will also
-            # receive jobs in the old buffer, so we need to reset the
-            # job._write_to and job._scheduled_for attributes used to recover
-            # message boundaries when processes exit.
-            infd = proc.inqW_fd
-            for job in values(pool._cache):
-                if job._write_to and job._write_to.inqW_fd == infd:
-                    job._write_to = proc
-                if job._scheduled_for and job._scheduled_for.inqW_fd == infd:
-                    job._scheduled_for = proc
-            fileno_to_outq[proc.outqR_fd] = proc
-            # maintain_pool is called whenever a process exits.
-            hub_add(proc.sentinel, maintain_pool, READ | ERR)
-            # handle_result_event is called when the processes outqueue is
-            # readable.
-            hub_add(proc.outqR_fd, handle_result_event, READ | ERR)
-        self._pool.on_process_up = on_process_up
+    def create_result_handler(self):
+        return super(AsynPool, self).create_result_handler(
+            fileno_to_outq=self._fileno_to_outq,
+            on_process_alive=self.on_process_alive,
+        )
 
-        def on_process_down(proc):
-            """Called when a worker process exits."""
-            process_flush_queues(proc)
-            fileno_to_outq.pop(proc.outqR_fd, None)
-            fileno_to_inq.pop(proc.inqW_fd, None)
-            fileno_to_synq.pop(proc.synqW_fd, None)
-            all_inqueues.discard(proc.inqW_fd)
-            hub_remove(proc.sentinel)
-            hub_remove(proc.outqR_fd)
-        self._pool.on_process_down = on_process_down
+    def _process_register_queues(self, proc, queues):
+        """Marks new ownership for ``queues`` so that the fileno indices are
+        updated."""
+        assert queues in self._queues
+        b = len(self._queues)
+        self._queues[queues] = proc
+        assert b == len(self._queues)
 
-    def _create_write_handlers(self, hub,
-                               pack=struct.pack, dumps=_pickle.dumps,
-                               protocol=HIGHEST_PROTOCOL):
-        """For async pool this creates the handlers used to write data to
-        child processes."""
-        pool = self._pool
-        fileno_to_inq = pool._fileno_to_inq
-        fileno_to_synq = pool._fileno_to_synq
-        outbound = self.outbound_buffer
-        pop_message = outbound.popleft
-        put_message = outbound.append
-        all_inqueues = pool._all_inqueues
-        active_writes = self._active_writes
-        diff = all_inqueues.difference
-        hub_add, hub_remove = hub.add, hub.remove
-        mark_write_fd_as_active = active_writes.add
-        mark_write_gen_as_active = self._active_writers.add
-        write_generator_done = self._active_writers.discard
-        get_job = pool._cache.__getitem__
-        # puts back at the end of the queue
-        pool._put_back = outbound.appendleft
-        precalc = {ACK: pool._create_payload(ACK, (0, )),
-                   NACK: pool._create_payload(NACK, (0, ))}
+    def _find_worker_queues(self, proc):
+        """Find the queues owned by ``proc``."""
+        try:
+            return next(q for q, owner in items(self._queues)
+                        if owner == proc)
+        except StopIteration:
+            raise ValueError(proc)
 
-        def on_poll_start(hub):
-            # called for every eventloop iteration, and if there
-            # are messages pending this will schedule writing one message
-            # by registering the 'schedule_writes' function for all currently
-            # inactive inqueues (not already being written to)
+    def _setup_queues(self):
+        # this is only used by the original pool which uses a shared
+        # queue for all processes.
 
-            # consolidate means the eventloop will merge them
-            # and call the callback once with the list writable fds as
-            # argument.  Using this means we minimize the risk of having
-            # the same fd receive every task if the pipe read buffer is not
-            # full.
-            if outbound:
-                hub_add(
-                    diff(active_writes), None, WRITE | ERR,
-                    consolidate=True,
-                )
-        self.on_poll_start = on_poll_start
+        # these attributes makes no sense for us, but we will still
+        # have to initialize them.
+        self._inqueue = self._outqueue = \
+            self._quick_put = self._quick_get = self._poll_result = None
 
-        def on_inqueue_close(fd):
-            # Makes sure the fd is removed from tracking when
-            # the connection is closed, this is essential as fds may be reused.
-            active_writes.discard(fd)
-            all_inqueues.discard(fd)
-        self._pool.on_inqueue_close = on_inqueue_close
+    def process_flush_queues(self, proc):
+        """Flushes all queues, including the outbound buffer, so that
+        all tasks that have not been started will be discarded.
 
-        def schedule_writes(ready_fds, shuffle=random.shuffle):
-            # Schedule write operation to ready file descriptor.
-            # The file descriptor is writeable, but that does not
-            # mean the process is currently reading from the socket.
-            # The socket is buffered so writeable simply means that
-            # the buffer can accept at least 1 byte of data.
-            shuffle(ready_fds)
-            for ready_fd in ready_fds:
-                if ready_fd in active_writes:
-                    # already writing to this fd
-                    continue
-                try:
-                    job = pop_message()
-                except IndexError:
-                    # no more messages, remove all inactive fds from the hub.
-                    # this is important since the fds are always writeable
-                    # as long as there's 1 byte left in the buffer, and so
-                    # this may create a spinloop where the eventloop
-                    # always wakes up.
-                    for inqfd in diff(active_writes):
-                        hub_remove(inqfd)
-                    break
+        In Celery this is called whenever the transport connection is lost
+        (consumer restart).
 
+        """
+        resq = proc.outq._reader
+        on_state_change = self._result_handler.on_state_change
+        while not resq.closed and resq.poll(0) and self._state != TERMINATE:
+            try:
+                task = resq.recv()
+            except (IOError, EOFError) as exc:
+                debug('got %r while flushing process %r',
+                      exc, proc, exc_info=1)
+                break
+            else:
+                if task is not None:
+                    on_state_change(task)
                 else:
-                    if not job._accepted:  # job not accepted by another worker
-                        try:
-                            # keep track of what process the write operation
-                            # was scheduled for.
-                            proc = job._scheduled_for = fileno_to_inq[ready_fd]
-                        except KeyError:
-                            # write was scheduled for this fd but the process
-                            # has since exited and the message must be sent to
-                            # another process.
-                            put_message(job)
-                            continue
-                        cor = _write_job(proc, ready_fd, job)
-                        job._writer = ref(cor)
-                        mark_write_gen_as_active(cor)
-                        mark_write_fd_as_active(ready_fd)
+                    debug('got sentinel while flushing process %r', proc)
 
-                        # Try to write immediately, in case there's an error.
+    def on_partial_read(self, job, proc):
+        """Called when a job was only partially written to a child process
+        and it exited."""
+        # worker terminated by signal:
+        # we cannot reuse the sockets again, because we don't know if
+        # the process wrote/read anything frmo them, and if so we cannot
+        # restore the message boundaries.
+        if proc.exitcode != EX_RECYCLE:
+            if not job._accepted:
+                # job was not acked, so find another worker to send it to.
+                self._put_back(job)
+            writer = getattr(job, '_writer')
+            writer = writer and writer() or None
+            if writer:
+                self._active_writers.discard(writer)
+
+            # Replace queues to avoid reuse
+            before = len(self._queues)
+            try:
+                queues = self._find_worker_queues(proc)
+                if self.destroy_queues(queues):
+                    self._queues[self.create_process_queues()] = None
+            except ValueError:
+                # Not in queue map, make sure sockets are closed.
+                self.destroy_queues((proc.inq, proc.outq, proc.synq))
+            assert len(self._queues) == before
+
+    def destroy_queues(self, queues):
+        """Destroy queues that can no longer be used, so that they
+        be replaced by new sockets."""
+        removed = 1
+        try:
+            self._queues.pop(queues)
+        except KeyError:
+            removed = 0
+        try:
+            self.on_inqueue_close(queues[0]._writer.fileno())
+        except IOError:
+            pass
+        for queue in queues:
+            if queue:
+                for sock in (queue._reader, queue._writer):
+                    if not sock.closed:
                         try:
-                            next(cor)
-                            hub_add((ready_fd, ), cor, WRITE | ERR)
-                        except StopIteration:
+                            sock.close()
+                        except (IOError, OSError):
                             pass
-        hub.consolidate_callback = schedule_writes
+        return removed
 
-        def send_job(tup):
-            # Schedule writing job request for when one of the process
-            # inqueues are writable.
-            body = dumps(tup, protocol=protocol)
-            body_size = len(body)
-            header = pack('>I', body_size)
-            # index 1,0 is the job ID.
-            job = get_job(tup[1][0])
-            job._payload = header, body, body_size
-            put_message(job)
-        self._pool._quick_put = send_job
+    def _create_payload(self, type_, args,
+                        dumps=_pickle.dumps, pack=struct.pack,
+                        protocol=HIGHEST_PROTOCOL):
+        body = dumps((type_, args), protocol=protocol)
+        size = len(body)
+        header = pack('>I', size)
+        return header, body, size
 
-        write_stats = self.write_stats = Counter()
+    @classmethod
+    def _set_result_sentinel(cls, _outqueue, _pool):
+        # unused
+        pass
 
-        def on_not_recovering(proc):
-            # XXX Theoretically a possibility, but not seen in practice yet.
-            raise Exception(
-                'Process writable but cannot write. Contact support!')
+    def _help_stuff_finish_args(self):
+        # Pool._help_stuff_finished is a classmethod so we have to use this
+        # trick to modify the arguments passed to it.
+        return (self._pool, )
 
-        def _write_job(proc, fd, job):
-            # writes job to the worker process.
-            # Operation must complete if more than one byte of data
-            # was written.  If the broker connection is lost
-            # and no data was written the operation shall be cancelled.
-            header, body, body_size = job._payload
-            errors = 0
+    @classmethod
+    def _help_stuff_finish(cls, pool):
+        debug(
+            'removing tasks from inqueue until task handler finished',
+        )
+        fileno_to_proc = {}
+        inqR = set()
+        for w in pool:
             try:
-                # job result keeps track of what process the job is sent to.
-                job._write_to = proc
-                send = proc.send_job_offset
+                fd = w.inq._reader.fileno()
+                inqR.add(fd)
+                fileno_to_proc[fd] = w
+            except IOError:
+                pass
+        while inqR:
+            readable, _, again = _select(inqR, timeout=0.5)
+            if again:
+                continue
+            if not readable:
+                break
+            for fd in readable:
+                fileno_to_proc[fd].inq._reader.recv()
+            sleep(0)
 
-                Hw = Bw = 0
-                # write header
-                while Hw < 4:
-                    try:
-                        Hw += send(header, Hw)
-                    except Exception as exc:
-                        if get_errno(exc) not in UNAVAIL:
-                            raise
-                        # suspend until more data
-                        errors += 1
-                        if errors > 100:
-                            on_not_recovering(proc)
-                            raise StopIteration()
-                        yield
-                    errors = 0
+    @property
+    def timers(self):
+        return {self.maintain_pool: 5.0}
 
-                # write body
-                while Bw < body_size:
-                    try:
-                        Bw += send(body, Bw)
-                    except Exception as exc:
-                        if get_errno(exc) not in UNAVAIL:
-                            raise
-                        # suspend until more data
-                        errors += 1
-                        if errors > 100:
-                            on_not_recovering(proc)
-                            raise StopIteration()
-                        yield
-                    errors = 0
-            finally:
-                write_stats[proc.index] += 1
-                # message written, so this fd is now available
-                active_writes.discard(fd)
-                write_generator_done(job._writer())  # is a weakref
 
-        def send_ack(response, pid, job, fd, WRITE=WRITE, ERR=ERR):
-            # Only used when synack is enabled.
-            # Schedule writing ack response for when the fd is writeable.
-            msg = Ack(job, fd, precalc[response])
-            callback = promise(write_generator_done)
-            cor = _write_ack(fd, msg, callback=callback)
-            mark_write_gen_as_active(cor)
-            mark_write_fd_as_active(fd)
-            callback.args = (cor, )
-            hub_add((fd, ), cor, WRITE | ERR)
-        self._pool.send_ack = send_ack
+class TaskPool(BasePool):
+    """Multiprocessing Pool implementation."""
+    Pool = AsynPool
+    BlockingPool = _pool.Pool
 
-        def _write_ack(fd, ack, callback=None):
-            # writes ack back to the worker if synack enabled.
-            # this operation *MUST* complete, otherwise
-            # the worker process will hang waiting for the ack.
-            header, body, body_size = ack[2]
+    uses_semaphore = True
+    write_stats = None
+
+    def on_start(self):
+        """Run the task pool.
+
+        Will pre-fork all workers so they're ready to accept tasks.
+
+        """
+        if self.options.get('maxtasksperchild'):
             try:
-                try:
-                    proc = fileno_to_synq[fd]
-                except KeyError:
-                    # process died, we can safely discard the ack at this
-                    # point.
-                    raise StopIteration()
-                send = proc.send_syn_offset
+                from billiard.connection import Connection
+                Connection.send_offset
+            except (ImportError, AttributeError):
+                # billiard C extension not installed
+                warning(MAXTASKS_NO_BILLIARD)
 
-                Hw = Bw = 0
-                # write header
-                while Hw < 4:
-                    try:
-                        Hw += send(header, Hw)
-                    except Exception as exc:
-                        if get_errno(exc) not in UNAVAIL:
-                            raise
-                        yield
+        forking_enable(self.forking_enable)
+        Pool = (self.BlockingPool if self.options.get('threads', True)
+                else self.Pool)
+        P = self._pool = Pool(processes=self.limit,
+                              initializer=process_initializer,
+                              synack=False,
+                              **self.options)
 
-                # write body
-                while Bw < body_size:
-                    try:
-                        Bw += send(body, Bw)
-                    except Exception as exc:
-                        if get_errno(exc) not in UNAVAIL:
-                            raise
-                        # suspend until more data
-                        yield
-            finally:
-                if callback:
-                    callback()
-                # message written, so this fd is now available
-                active_writes.discard(fd)
+        # Create proxy methods
+        self.on_apply = P.apply_async
+        self.maintain_pool = P.maintain_pool
+        self.terminate_job = P.terminate_job
+        self.grow = P.grow
+        self.shrink = P.shrink
+        self.restart = P.restart
+        self.maybe_handle_result = P._result_handler.handle_event
+        self.handle_result_event = P.handle_result_event
+        self.register_with_event_loop = P.register_with_event_loop
 
-    def flush(self):
-        if self._pool._state == TERMINATE:
-            return
-        # cancel all tasks that have not been accepted so that NACK is sent.
-        for job in values(self._pool._cache):
-            if not job._accepted:
-                job._cancel()
+    def did_start_ok(self):
+        return self._pool.did_start_ok()
 
-        # clear the outgoing buffer as the tasks will be redelivered by
-        # the broker anyway.
-        if self.outbound_buffer:
-            self.outbound_buffer.clear()
-        try:
-            # ...but we must continue writing the payloads we already started
-            # to keep message boundaries.
-            # The messages may be NACK'ed later if synack is enabled.
-            if self._pool._state == RUN:
-                # flush outgoing buffers
-                intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True)
-                while self._active_writers:
-                    writers = list(self._active_writers)
-                    for gen in writers:
-                        if (gen.__name__ == '_write_job' and
-                                gen_not_started(gen)):
-                            # has not started writing the job so can
-                            # discard the task, but we must also remove
-                            # it from the Pool._cache.
-                            job_to_discard = None
-                            for job in values(self._pool._cache):
-                                if job._writer() is gen:  # _writer is saferef
-                                    # removes from Pool._cache
-                                    job_to_discard = job
-                                    break
-                            if job_to_discard:
-                                job_to_discard.discard()
-                            self._active_writers.discard(gen)
-                        else:
-                            try:
-                                next(gen)
-                            except StopIteration:
-                                self._active_writers.discard(gen)
-                    # workers may have exited in the meantime.
-                    self.maintain_pool()
-                    sleep(next(intervals))  # don't busyloop
-        finally:
-            self.outbound_buffer.clear()
-            self._active_writers.clear()
+    def on_stop(self):
+        """Gracefully stop the pool."""
+        if self._pool is not None and self._pool._state in (RUN, CLOSE):
+            self._pool.close()
+            self._pool.join()
+            self._pool = None
+
+    def on_terminate(self):
+        """Force terminate the pool."""
+        if self._pool is not None:
+            self._pool.terminate()
+            self._pool = None
+
+    def on_close(self):
+        if self._pool is not None and self._pool._state == RUN:
+            self._pool.close()
+
+    def _get_info(self):
+        return {
+            'max-concurrency': self.limit,
+            'processes': [p.pid for p in self._pool._pool],
+            'max-tasks-per-child': self._pool._maxtasksperchild or 'N/A',
+            'put-guarded-by-semaphore': self.putlocks,
+            'timeouts': (self._pool.soft_timeout or 0,
+                         self._pool.timeout or 0),
+            'writes': self._pool.human_write_stats(),
+        }
 
     @property
     def num_processes(self):
         return self._pool._processes
-
-    @property
-    def timers(self):
-        return {self.maintain_pool: 5.0}

+ 4 - 2
celery/tests/concurrency/test_concurrency.py

@@ -90,8 +90,10 @@ class test_BasePool(AppCase):
     def test_interface_did_start_ok(self):
         self.assertTrue(BasePool(10).did_start_ok())
 
-    def test_interface_on_poll_init(self):
-        self.assertIsNone(BasePool(10).on_poll_init(Mock(), Mock()))
+    def test_interface_register_with_event_loop(self):
+        self.assertIsNone(
+            BasePool(10).register_with_event_loop(Mock(), Mock()),
+        )
 
     def test_interface_on_poll_start(self):
         self.assertIsNone(BasePool(10).on_poll_start(Mock()))

+ 18 - 16
celery/tests/worker/test_autoreload.py

@@ -42,9 +42,11 @@ class test_WorkerComponent(AppCase):
         x.instantiate = Mock()
         r = x.create(w)
         x.instantiate.assert_called_with(w.autoreloader_cls, w)
+        x.register_with_event_loop(w, w.hub)
         self.assertIsNone(r)
-        w.hub.on_init.append.assert_called_with(w.autoreloader.on_poll_init)
-        w.hub.on_close.append.assert_called_with(w.autoreloader.on_poll_close)
+        w.hub.on_close.add.assert_called_with(
+            w.autoreloader.on_event_loop_close,
+        )
 
 
 class test_file_hash(Case):
@@ -120,22 +122,22 @@ class test_KQueueMontior(Case):
         close.side_effect.errno = errno.EBADF
         x.stop()
 
-    def test_on_poll_init(self):
+    def test_register_with_event_loop(self):
         x = KQueueMonitor(['a', 'b'])
         hub = Mock()
         x.add_events = Mock()
-        x.on_poll_init(hub)
+        x.register_with_event_loop(hub)
         x.add_events.assert_called_with(hub.poller)
         self.assertEqual(
             hub.poller.on_file_change,
             x.handle_event,
         )
 
-    def test_on_poll_close(self):
+    def test_on_event_loop_close(self):
         x = KQueueMonitor(['a', 'b'])
         x.close = Mock()
         hub = Mock()
-        x.on_poll_close(hub)
+        x.on_event_loop_close(hub)
         x.close.assert_called_with(hub.poller)
 
     def test_handle_event(self):
@@ -242,7 +244,7 @@ class test_default_implementation(Case):
 
 class test_Autoreloader(AppCase):
 
-    def test_on_poll_init(self):
+    def test_register_with_event_loop(self):
         x = Autoreloader(Mock(), modules=[__name__])
         hub = Mock()
         x._monitor = None
@@ -252,22 +254,22 @@ class test_Autoreloader(AppCase):
             x._monitor = Mock()
         x.on_init.side_effect = se
 
-        x.on_poll_init(hub)
+        x.register_with_event_loop(hub)
         x.on_init.assert_called_with()
-        x._monitor.on_poll_init.assert_called_with(hub)
+        x._monitor.register_with_event_loop.assert_called_with(hub)
 
-        x._monitor.on_poll_init.reset_mock()
-        x.on_poll_init(hub)
-        x._monitor.on_poll_init.assert_called_with(hub)
+        x._monitor.register_with_event_loop.reset_mock()
+        x.register_with_event_loop(hub)
+        x._monitor.register_with_event_loop.assert_called_with(hub)
 
-    def test_on_poll_close(self):
+    def test_on_event_loop_close(self):
         x = Autoreloader(Mock(), modules=[__name__])
         hub = Mock()
         x._monitor = Mock()
-        x.on_poll_close(hub)
-        x._monitor.on_poll_close.assert_called_with(hub)
+        x.on_event_loop_close(hub)
+        x._monitor.on_event_loop_close.assert_called_with(hub)
         x._monitor = None
-        x.on_poll_close(hub)
+        x.on_event_loop_close(hub)
 
     @patch('celery.worker.autoreload.file_hash')
     def test_start(self, fhash):

+ 15 - 12
celery/tests/worker/test_autoscale.py

@@ -2,6 +2,7 @@ from __future__ import absolute_import
 
 import sys
 
+from collections import defaultdict
 from time import time
 
 from mock import Mock, patch
@@ -42,28 +43,30 @@ class MockPool(BasePool):
 
 class test_WorkerComponent(AppCase):
 
-    def test_on_poll_init(self):
-        parent = Mock()
+    def test_register_with_event_loop(self):
+        parent = Mock(name='parent')
         parent.autoscale = True
+        parent.consumer.on_task_message = set()
         w = autoscale.WorkerComponent(parent)
         self.assertIsNone(parent.autoscaler)
         self.assertTrue(w.enabled)
 
-        hub = Mock()
-        hub.on_task = []
-        scaler = Mock()
-        scaler.keepalive = 10
-        w.on_poll_init(scaler, hub)
-        self.assertIn(scaler.maybe_scale, hub.on_task)
-        hub.timer.apply_interval.assert_called_with(
-            10 * 1000.0, scaler.maybe_scale,
+        hub = Mock(name='hub')
+        w.create(parent)
+        w.register_with_event_loop(parent, hub)
+        self.assertIn(
+            parent.autoscaler.maybe_scale,
+            parent.consumer.on_task_message,
+        )
+        hub.call_repeatedly.assert_called_with(
+            parent.autoscaler.keepalive, parent.autoscaler.maybe_scale,
         )
 
         parent.hub = hub
         hub.on_init = []
         w.instantiate = Mock()
-        w.create_ev(parent)
-        self.assertTrue(hub.on_init)
+        w.register_with_event_loop(parent, Mock(name='loop'))
+        self.assertTrue(parent.consumer.on_task_message)
 
 
 class test_Autoscaler(AppCase):

+ 6 - 6
celery/tests/worker/test_consumer.py

@@ -28,7 +28,7 @@ class test_Consumer(AppCase):
 
     def get_consumer(self, no_hub=False, **kwargs):
         consumer = Consumer(
-            on_task=Mock(),
+            on_task_request=Mock(),
             init_callback=Mock(),
             pool=Mock(),
             app=self.app,
@@ -81,7 +81,7 @@ class test_Consumer(AppCase):
             c._limit_task(request, bucket, 3)
             bucket.can_consume.assert_called_with(3)
             reserved.assert_called_with(request)
-            c.on_task.assert_called_with(request)
+            c.on_task_request.assert_called_with(request)
 
         with patch('celery.worker.consumer.task_reserved') as reserved:
             bucket.can_consume.return_value = False
@@ -128,15 +128,15 @@ class test_Consumer(AppCase):
         c.start()
         c.connection.collect.assert_called_with()
 
-    def test_on_poll_init(self):
+    def test_register_with_event_loop(self):
         c = self.get_consumer()
         c.connection = Mock()
         c.connection.eventmap = {1: 2}
         hub = Mock()
-        c.on_poll_init(hub)
+        c.register_with_event_loop(hub)
 
-        hub.update_readers.assert_called_with({1: 2})
-        c.connection.transport.on_poll_init.assert_called_with(hub.poller)
+        hub.add_reader.assert_called_with(1, 2)
+        c.connection.transport.register_with_event_loop.assert_called_with(hub)
 
     def test_on_close_clears_semaphore_timer_and_reqs(self):
         with patch('celery.worker.consumer.reserved_requests') as reserved:

+ 24 - 34
celery/tests/worker/test_hub.py

@@ -1,11 +1,7 @@
 from __future__ import absolute_import
 
-from kombu.async import (
-    Hub,
-    repr_flag,
-    _rcb,
-    READ, WRITE, ERR
-)
+from kombu.async import Hub, READ, WRITE, ERR
+from kombu.async.hub import repr_flag, _rcb
 from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 
 from mock import Mock, call, patch
@@ -143,21 +139,11 @@ class test_Hub(Case):
     @patch('kombu.async.hub.poll')
     def test_start_stop(self, poll):
         hub = Hub()
-        hub.start()
         poll.assert_called_with()
 
+        poller = hub.poller
         hub.stop()
-        hub.poller.close.assert_called_with()
-
-    def test_init(self):
-        hub = Hub()
-        cb1 = Mock()
-        cb2 = Mock()
-        hub.on_init.extend([cb1, cb2])
-
-        hub.init()
-        cb1.assert_called_with(hub)
-        cb2.assert_called_with(hub)
+        poller.close.assert_called_with()
 
     def test_fire_timers(self):
         hub = Hub()
@@ -207,17 +193,17 @@ class test_Hub(Case):
 
         eback.side_effect = ValueError('foo')
         hub.scheduler = iter([(0, eback)])
-        with patch('celery.worker.hub.logger') as logger:
+        with patch('kombu.async.hub.logger') as logger:
             with self.assertRaises(StopIteration):
                 hub.fire_timers()
             self.assertTrue(logger.error.called)
 
     def test_add_raises_ValueError(self):
         hub = Hub()
-        hub._add = Mock()
-        hub._add.side_effect = ValueError()
-        hub._discard = Mock()
-        hub.add([2], Mock(), READ)
+        hub.poller = Mock(name='hub.poller')
+        hub.poller.register.side_effect = ValueError()
+        hub._discard = Mock(name='hub.discard')
+        hub.add(2, Mock(), READ)
         hub._discard.assert_called_with(2)
 
     def test_repr_active(self):
@@ -254,21 +240,22 @@ class test_Hub(Case):
             hub._callback_for(6, WRITE)
         self.assertEqual(hub._callback_for(6, WRITE, 'foo'), 'foo')
 
-    def test_update_readers(self):
+    def test_add_remove_readers(self):
         hub = Hub()
         P = hub.poller = Mock()
 
         read_A = Mock()
         read_B = Mock()
-        hub.update_readers({10: read_A, File(11): read_B})
+        hub.add_reader(10, read_A)
+        hub.add_reader(File(11), read_B)
 
         P.register.assert_has_calls([
             call(10, hub.READ | hub.ERR),
             call(File(11), hub.READ | hub.ERR),
         ], any_order=True)
 
-        self.assertIs(hub.readers[10], read_A)
-        self.assertIs(hub.readers[11], read_B)
+        self.assertEqual(hub.readers[10], (read_A, ()))
+        self.assertEqual(hub.readers[11], (read_B, ()))
 
         hub.remove(10)
         self.assertNotIn(10, hub.readers)
@@ -291,21 +278,22 @@ class test_Hub(Case):
 
         hub.remove(313)
 
-    def test_update_writers(self):
+    def test_add_writers(self):
         hub = Hub()
         P = hub.poller = Mock()
 
         write_A = Mock()
         write_B = Mock()
-        hub.update_writers({20: write_A, File(21): write_B})
+        hub.add_writer(20, write_A)
+        hub.add_writer(File(21), write_B)
 
         P.register.assert_has_calls([
             call(20, hub.WRITE),
             call(File(21), hub.WRITE),
         ], any_order=True)
 
-        self.assertIs(hub.writers[20], write_A)
-        self.assertIs(hub.writers[21], write_B)
+        self.assertEqual(hub.writers[20], (write_A, ()))
+        self.assertEqual(hub.writers[21], (write_B, ()))
 
         hub.remove(20)
         self.assertNotIn(20, hub.writers)
@@ -321,7 +309,7 @@ class test_Hub(Case):
         hub.init = Mock()
 
         on_close = Mock()
-        hub.on_close.append(on_close)
+        hub.on_close.add(on_close)
 
         hub.init()
         try:
@@ -329,10 +317,12 @@ class test_Hub(Case):
 
             read_A = Mock()
             read_B = Mock()
-            hub.update_readers({10: read_A, File(11): read_B})
+            hub.add_reader(10, read_A)
+            hub.add_reader(File(11), read_B)
             write_A = Mock()
             write_B = Mock()
-            hub.update_writers({20: write_A, File(21): write_B})
+            hub.add_writer(20, write_A)
+            hub.add_writer(File(21), write_B)
             self.assertTrue(hub.readers)
             self.assertTrue(hub.writers)
         finally:

+ 3 - 5
celery/tests/worker/test_loops.py

@@ -2,6 +2,7 @@ from __future__ import absolute_import
 
 import socket
 
+from collections import defaultdict
 from mock import Mock
 
 from celery.exceptions import InvalidTaskError, SystemTerminate
@@ -40,7 +41,6 @@ class X(object):
         #hent = self.Hub.__enter__ = Mock(name='Hub.__enter__')
         #self.Hub.__exit__ = Mock(name='Hub.__exit__')
         #self.hub = hent.return_value = Mock(name='hub_context')
-        self.hub.on_task = on_task or []
         self.hub.readers = {}
         self.hub.writers = {}
         self.hub.consolidate = set()
@@ -111,8 +111,8 @@ class test_asynloop(AppCase):
         asynloop(*x.args)
         x.consumer.consume.assert_called_with()
         x.obj.on_ready.assert_called_with()
-        x.hub.timer.apply_interval.assert_called_with(
-            10 * 1000.0 / 2.0, x.connection.heartbeat_check, (2.0, ),
+        x.hub.call_repeatedly.assert_called_with(
+            10 / 2.0, x.connection.heartbeat_check, (2.0, ),
         )
 
     def task_context(self, sig, **kwargs):
@@ -196,7 +196,6 @@ class test_asynloop(AppCase):
         asynloop(*x.args, sleep=x.closer())
         x.qos.update.assert_called_with()
         x.hub.fire_timers.assert_called_with(propagate=(socket.error, ))
-        x.connection.transport.on_poll_start.assert_called_with()
 
     def test_poll_empty(self):
         x = X(self.app)
@@ -207,7 +206,6 @@ class test_asynloop(AppCase):
         with self.assertRaises(socket.error):
             asynloop(*x.args)
         x.hub.poller.poll.assert_called_with(33.37)
-        x.connection.transport.on_poll_empty.assert_called_with()
 
     def test_poll_readable(self):
         x = X(self.app)

+ 1 - 1
celery/tests/worker/test_strategy.py

@@ -91,7 +91,7 @@ class test_default_strategy(AppCase):
             C()
             self.assertTrue(C.was_reserved())
             req = C.get_request()
-            C.consumer.on_task.assert_called_with(req)
+            C.consumer.on_task_request.assert_called_with(req)
             self.assertTrue(C.event_sent())
 
     def test_when_events_disabled(self):

+ 5 - 8
celery/tests/worker/test_worker.py

@@ -1025,7 +1025,6 @@ class test_WorkController(AppCase):
         x = components.Hub(w)
         hub = x.create(w)
         self.assertTrue(w.timer.max_interval)
-        self.assertIs(w.hub, hub)
 
     def test_Pool_crate_threaded(self):
         w = Mock()
@@ -1040,7 +1039,6 @@ class test_WorkController(AppCase):
         w = Mock()
         w._conninfo.connection_errors = w._conninfo.channel_errors = ()
         w.hub = Mock()
-        w.hub.on_init = []
 
         PoolImp = Mock()
         poolimp = PoolImp.return_value = Mock()
@@ -1063,19 +1061,18 @@ class test_WorkController(AppCase):
         w.consumer.restart_count = -1
         pool = components.Pool(w)
         pool.create(w)
+        pool.register_with_event_loop(w, w.hub)
         self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
-        self.assertTrue(w.hub.on_init)
         P = w.pool
         P.start()
 
         hub = Mock()
-        w.hub.on_init[0](hub)
 
         w = Mock()
         poolimp.on_process_up(w)
-        hub.add.assert_has_calls([
-            call(w.sentinel, P.maintain_pool, READ | ERR),
-            call(w.outqR_fd, P.handle_result_event, READ | ERR),
+        hub.add_reader.assert_has_calls([
+            call(w.sentinel, P.maintain_pool),
+            call(w.outqR_fd, P.handle_result_event),
         ])
 
         poolimp.on_process_down(w)
@@ -1093,4 +1090,4 @@ class test_WorkController(AppCase):
             P._pool.did_start_ok = Mock()
             P._pool.did_start_ok.return_value = False
             w.consumer.restart_count = 0
-            P.on_poll_init(w, hub)
+            P.register_with_event_loop(w, hub)

+ 3 - 0
celery/worker/__init__.py

@@ -199,6 +199,9 @@ class WorkController(object):
         except (KeyboardInterrupt, SystemExit):
             self.stop()
 
+    def register_with_event_loop(self, hub):
+        self.blueprint.send_all(self, 'register_with_event_loop', args=(hub, ))
+
     def _process_task_sem(self, req):
         return self._quick_acquire(self._process_task, req)
 

+ 12 - 17
celery/worker/autoreload.py

@@ -52,19 +52,14 @@ class WorkerComponent(bootsteps.StartStopStep):
         self.enabled = w.autoreload = autoreload
         w.autoreloader = None
 
-    def create_ev(self, w):
-        ar = w.autoreloader = self.instantiate(w.autoreloader_cls, w)
-        w.hub.on_init.append(ar.on_poll_init)
-        w.hub.on_close.append(ar.on_poll_close)
-
-    def create_threaded(self, w):
+    def create(self, w):
         w.autoreloader = self.instantiate(w.autoreloader_cls, w)
-        return w.autoreloader
+        return w.autoreloader if not w.use_eventloop else None
 
-    def create(self, w):
-        if hasattr(select, 'kqueue') and w.use_eventloop:
-            return self.create_ev(w)
-        return self.create_threaded(w)
+    def register_with_event_loop(self, w, hub):
+        if hasattr(select, 'kqueue'):
+            w.autoreloader.register_with_event_loop(hub)
+            hub.on_close.add(w.autoreloader.on_event_loop_close)
 
 
 def file_hash(filename, algorithm='md5'):
@@ -130,11 +125,11 @@ class KQueueMonitor(BaseMonitor):
         self.filemap = dict((f, None) for f in self.files)
         self.fdmap = {}
 
-    def on_poll_init(self, hub):
+    def register_with_event_loop(self, hub):
         self.add_events(hub.poller)
         hub.poller.on_file_change = self.handle_event
 
-    def on_poll_close(self, hub):
+    def on_event_loop_close(self, hub):
         self.close(hub.poller)
 
     def add_events(self, poller):
@@ -244,14 +239,14 @@ class Autoreloader(bgThread):
             shutdown_event=self._is_shutdown, **self.options)
         self._hashes = dict([(f, file_hash(f)) for f in files])
 
-    def on_poll_init(self, hub):
+    def register_with_event_loop(self, hub):
         if self._monitor is None:
             self.on_init()
-        self._monitor.on_poll_init(hub)
+        self._monitor.register_with_event_loop(hub)
 
-    def on_poll_close(self, hub):
+    def on_event_loop_close(self, hub):
         if self._monitor is not None:
-            self._monitor.on_poll_close(hub)
+            self._monitor.on_event_loop_close(hub)
 
     def body(self):
         self.on_init()

+ 8 - 16
celery/worker/autoscale.py

@@ -45,28 +45,20 @@ class WorkerComponent(bootsteps.StartStopStep):
         self.enabled = w.autoscale
         w.autoscaler = None
 
-    def create_threaded(self, w):
+    def create(self, w):
         scaler = w.autoscaler = self.instantiate(
             w.autoscaler_cls,
             w.pool, w.max_concurrency, w.min_concurrency,
+            mutex=DummyLock() if w.use_eventloop else None,
         )
-        return scaler
-
-    def on_poll_init(self, scaler, hub):
-        hub.on_task.append(scaler.maybe_scale)
-        hub.timer.apply_interval(scaler.keepalive * 1000.0, scaler.maybe_scale)
+        print('HELLO')
+        return scaler if not w.use_eventloop else None
 
-    def create_ev(self, w):
-        scaler = w.autoscaler = self.instantiate(
-            w.autoscaler_cls,
-            w.pool, w.max_concurrency, w.min_concurrency,
-            mutex=DummyLock(),
+    def register_with_event_loop(self, w, hub):
+        w.consumer.on_task_message.add(w.autoscaler.maybe_scale)
+        hub.call_repeatedly(
+            w.autoscaler.keepalive, w.autoscaler.maybe_scale,
         )
-        w.hub.on_init.append(partial(self.on_poll_init, scaler))
-
-    def create(self, w):
-        return (self.create_ev if w.use_eventloop
-                else self.create_threaded)(w)
 
 
 class Autoscaler(bgThread):

+ 10 - 5
celery/worker/components.py

@@ -10,8 +10,6 @@ from __future__ import absolute_import
 
 import atexit
 
-from functools import partial
-
 from kombu.async import Hub as _Hub, get_event_loop, set_event_loop
 from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
 
@@ -67,7 +65,13 @@ class Hub(bootsteps.StartStopStep):
         if w.hub is None:
             w.hub = set_event_loop(_Hub(w.timer))
         self._patch_thread_primitives(w)
-        return w.hub
+        return self
+
+    def start(self, w):
+        pass
+
+    def stop(self, w):
+        w.hub.close()
 
     def _patch_thread_primitives(self, w):
         # make clock use dummy lock
@@ -156,13 +160,14 @@ class Pool(bootsteps.StartStopStep):
             forking_enable=forking_enable,
             semaphore=semaphore,
         )
-        if w.hub:
-            w.hub.on_init.append(partial(pool.on_poll_init, w))
         return pool
 
     def info(self, w):
         return {'pool': w.pool.info}
 
+    def register_with_event_loop(self, w, hub):
+        w.pool.register_with_event_loop(hub)
+
 
 class Beat(bootsteps.StartStopStep):
     """Step used to embed a beat process.

+ 10 - 10
celery/worker/consumer.py

@@ -161,7 +161,7 @@ class Consumer(object):
         def shutdown(self, parent):
             self.send_all(parent, 'shutdown')
 
-    def __init__(self, on_task,
+    def __init__(self, on_task_request,
                  init_callback=noop, hostname=None,
                  pool=None, app=None,
                  timer=None, controller=None, hub=None, amqheartbeat=None,
@@ -180,7 +180,8 @@ class Consumer(object):
         self._restart_state = restart_state(maxR=5, maxT=1)
 
         self._does_info = logger.isEnabledFor(logging.INFO)
-        self.on_task = on_task
+        self.on_task_request = on_task_request
+        self.on_task_message = set()
         self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
         self.disable_rate_limits = disable_rate_limits
 
@@ -194,7 +195,6 @@ class Consumer(object):
             if self.amqheartbeat is None:
                 self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
             self.hub = hub
-            self.hub.on_init.append(self.on_poll_init)
         else:
             self.hub = None
             self.amqheartbeat = 0
@@ -231,7 +231,7 @@ class Consumer(object):
             )
         else:
             task_reserved(request)
-            self.on_task(request)
+            self.on_task_request(request)
 
     def start(self):
         blueprint, loop = self.blueprint, self.loop
@@ -258,6 +258,9 @@ class Consumer(object):
                     self.on_close()
                     blueprint.restart(self)
 
+    def register_with_event_loop(self, hub):
+        self.blueprint.send_all(self, 'register_with_event_loop', args=(hub, ))
+
     def shutdown(self):
         self.in_shutdown = True
         self.blueprint.shutdown(self)
@@ -275,10 +278,6 @@ class Consumer(object):
                 self.blueprint, self.hub, self.qos, self.amqheartbeat,
                 self.app.clock, self.amqheartbeat_rate)
 
-    def on_poll_init(self, hub):
-        hub.update_readers(self.connection.eventmap)
-        self.connection.transport.on_poll_init(hub.poller)
-
     def on_decode_error(self, message, exc):
         """Callback called if an error occurs while decoding
         a message received.
@@ -366,7 +365,7 @@ class Consumer(object):
         """Method called by the timer to apply a task with an
         ETA/countdown."""
         task_reserved(task)
-        self.on_task(task)
+        self.on_task_request(task)
         self.qos.decrement_eventually()
 
     def _message_report(self, body, message):
@@ -394,11 +393,12 @@ class Consumer(object):
             task.__trace__ = build_tracer(name, task, loader, self.hostname,
                                           app=self.app)
 
-    def create_task_handler(self, callbacks):
+    def create_task_handler(self):
         strategies = self.strategies
         on_unknown_message = self.on_unknown_message
         on_unknown_task = self.on_unknown_task
         on_invalid_task = self.on_invalid_task
+        callbacks = self.on_task_message
 
         def on_task_received(body, message):
             if callbacks:

+ 17 - 92
celery/worker/loops.py

@@ -10,12 +10,9 @@ from __future__ import absolute_import
 import socket
 
 from time import sleep
-from types import GeneratorType as generator
-
-from kombu.utils.eventio import READ, WRITE, ERR
 
 from celery.bootsteps import CLOSE
-from celery.exceptions import SystemTerminate
+from celery.exceptions import SystemTerminate, WorkerLostError
 from celery.five import Empty
 from celery.utils.log import get_logger
 
@@ -30,39 +27,32 @@ error = logger.error
 def asynloop(obj, connection, consumer, blueprint, hub, qos,
              heartbeat, clock, hbrate=2.0,
              sleep=sleep, min=min, Empty=Empty):
-    """Non-blocking eventloop consuming messages until connection is lost,
+    """Non-blocking event loop consuming messages until connection is lost,
     or shutdown is requested."""
 
-    hub.init()
     update_qos = qos.update
-    update_readers = hub.update_readers
     readers, writers = hub.readers, hub.writers
-    poll = hub.poller.poll
-    fire_timers = hub.fire_timers
-    hub_add = hub.add
-    hub_remove = hub.remove
-    scheduled = hub.timer._queue
     hbtick = connection.heartbeat_check
-    conn_poll_start = connection.transport.on_poll_start
-    conn_poll_empty = connection.transport.on_poll_empty
-    pool_poll_start = obj.pool.on_poll_start
-    drain_nowait = connection.drain_nowait
-    on_task_callbacks = hub.on_task
-    keep_draining = connection.transport.nb_keep_draining
     errors = connection.connection_errors
     hub_add, hub_remove = hub.add, hub.remove
-    consolidate = hub.consolidate
-    consolidate_callback = hub.consolidate_callback
 
-    on_task_received = obj.create_task_handler(on_task_callbacks)
+    on_task_received = obj.create_task_handler()
 
     if heartbeat and connection.supports_heartbeats:
-        hub.timer.apply_interval(
-            heartbeat * 1000.0 / hbrate, hbtick, (hbrate, ))
+        hub.call_repeatedly(heartbeat / hbrate, hbtick, (hbrate, ))
 
     consumer.callbacks = [on_task_received]
     consumer.consume()
     obj.on_ready()
+    obj.controller.register_with_event_loop(hub)
+    obj.register_with_event_loop(hub)
+
+    # did_start_ok will verify that pool processes were able to start,
+    # but this will only work the first time we start, as
+    # maxtasksperchild will mess up metrics.
+    if not obj.restart_count and not obj.pool.did_start_ok():
+        raise WorkerLostError('Could not start worker processes')
+    loop = hub._loop(propagate=errors)
 
     try:
         while blueprint.state != CLOSE and obj.connection:
@@ -72,91 +62,26 @@ def asynloop(obj, connection, consumer, blueprint, hub, qos,
             elif state.should_terminate:
                 raise SystemTerminate()
 
-            # fire any ready timers, this also returns
-            # the number of seconds until we need to fire timers again.
-            poll_timeout = fire_timers(propagate=errors) if scheduled else 1
-
             # We only update QoS when there is no more messages to read.
             # This groups together qos calls, and makes sure that remote
             # control commands will be prioritized over task messages.
             if qos.prev != qos.value:
                 update_qos()
-
-            #print('[[[HUB]]]: %s' % (hub.repr_active(), ))
-
-            update_readers(conn_poll_start())
-            pool_poll_start(hub)
-            if readers or writers:
-                to_consolidate = []
-                connection.more_to_read = True
-                while connection.more_to_read:
-                    try:
-                        events = poll(poll_timeout)
-                        #print('[[[EV]]]: %s' % (hub.repr_events(events), ))
-                    except ValueError:  # Issue 882
-                        return
-
-                    if not events:
-                        conn_poll_empty()
-                    for fileno, event in events or ():
-                        if fileno in consolidate and \
-                                writers.get(fileno) is None:
-                            to_consolidate.append(fileno)
-                            continue
-                        cb = None
-                        try:
-                            if event & READ:
-                                cb = readers[fileno]
-                            elif event & WRITE:
-                                cb = writers[fileno]
-                            elif event & ERR:
-                                cb = (readers.get(fileno) or
-                                      writers.get(fileno))
-                        except (KeyError, Empty):
-                            continue
-                        if cb is None:
-                            continue
-                        try:
-                            if isinstance(cb, generator):
-                                try:
-                                    next(cb)
-                                except StopIteration:
-                                    hub_remove(fileno)
-                                except Exception:
-                                    hub_remove(fileno)
-                                    raise
-                            else:
-                                try:
-                                    cb(fileno, event)
-                                except Empty:
-                                    continue
-                        except socket.error:
-                            if blueprint.state != CLOSE:  # pragma: no cover
-                                raise
-                    if to_consolidate:
-                        consolidate_callback(to_consolidate)
-                    if keep_draining:
-                        drain_nowait()
-                        poll_timeout = 0
-                    else:
-                        connection.more_to_read = False
-            else:
-                # no sockets yet, startup is probably not done.
-                sleep(min(poll_timeout, 0.1))
+            next(loop)
     finally:
         try:
             hub.close()
         except Exception as exc:
             error(
-                'Error cleaning up after eventloop: %r', exc, exc_info=1,
+                'Error cleaning up after event loop: %r', exc, exc_info=1,
             )
 
 
 def synloop(obj, connection, consumer, blueprint, hub, qos,
             heartbeat, clock, hbrate=2.0, **kwargs):
-    """Fallback blocking eventloop for transports that doesn't support AIO."""
+    """Fallback blocking event loop for transports that doesn't support AIO."""
 
-    on_task_received = obj.create_task_handler([])
+    on_task_received = obj.create_task_handler()
     consumer.register_callback(on_task_received)
     consumer.consume()
 

+ 1 - 1
celery/worker/strategy.py

@@ -38,7 +38,7 @@ def default(task, app, consumer,
     apply_eta_task = consumer.apply_eta_task
     rate_limits_enabled = not consumer.disable_rate_limits
     bucket = consumer.task_buckets[task.name]
-    handle = consumer.on_task
+    handle = consumer.on_task_request
     limit_task = consumer._limit_task
 
     def task_message_handler(message, body, ack, to_timestamp=to_timestamp):