Browse Source

Asynpool: Fixes race condition when child proces terminated fast

Ask Solem 11 years ago
parent
commit
f2e2bac7c8
1 changed files with 110 additions and 33 deletions
  1. 110 33
      celery/concurrency/asynpool.py

+ 110 - 33
celery/concurrency/asynpool.py

@@ -15,6 +15,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 import errno
 import errno
+import os
 import random
 import random
 import select
 import select
 import socket
 import socket
@@ -42,9 +43,9 @@ from celery.five import Counter, items, values
 from celery.utils.log import get_logger
 from celery.utils.log import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
-debug = logger.debug
+error, debug = logger.error, logger.debug
 
 
-UNAVAIL = frozenset([errno.EAGAIN, errno.EINTR, errno.EBADF])
+UNAVAIL = frozenset([errno.EAGAIN, errno.EINTR])
 
 
 #: Constant sent by child process when started (ready to accept work)
 #: Constant sent by child process when started (ready to accept work)
 WORKER_UP = 15
 WORKER_UP = 15
@@ -243,6 +244,13 @@ class AsynPool(_pool.Pool):
         # synqueue fileno -> process mapping
         # synqueue fileno -> process mapping
         self._fileno_to_synq = {}
         self._fileno_to_synq = {}
 
 
+        # We keep track of processes that have not yet
+        # sent a WORKER_UP message.  If a process fails to send
+        # this message within proc_up_timeout we terminate it
+        # and hope the next process will recover.
+        self._proc_alive_timeout = 2.0
+        self._waiting_to_start = set()
+
         # denormalized set of all inqueues.
         # denormalized set of all inqueues.
         self._all_inqueues = set()
         self._all_inqueues = set()
 
 
@@ -264,20 +272,24 @@ class AsynPool(_pool.Pool):
         for proc in self._pool:
         for proc in self._pool:
             # create initial mappings, these will be updated
             # create initial mappings, these will be updated
             # as processes are recycled, or found lost elsewhere.
             # as processes are recycled, or found lost elsewhere.
-            self._fileno_to_inq[proc.inqW_fd] = proc
             self._fileno_to_outq[proc.outqR_fd] = proc
             self._fileno_to_outq[proc.outqR_fd] = proc
             self._fileno_to_synq[proc.synqW_fd] = proc
             self._fileno_to_synq[proc.synqW_fd] = proc
         self.on_soft_timeout = self._timeout_handler.on_soft_timeout
         self.on_soft_timeout = self._timeout_handler.on_soft_timeout
         self.on_hard_timeout = self._timeout_handler.on_hard_timeout
         self.on_hard_timeout = self._timeout_handler.on_hard_timeout
 
 
+    def _event_process_exit(self, hub, fd):
+        # This method is called whenever the process sentinel is readable.
+        hub.remove(fd)
+        self.maintain_pool()
+
     def register_with_event_loop(self, hub):
     def register_with_event_loop(self, hub):
         """Registers the async pool with the current event loop."""
         """Registers the async pool with the current event loop."""
         self._create_timelimit_handlers(hub)
         self._create_timelimit_handlers(hub)
         self._create_process_handlers(hub)
         self._create_process_handlers(hub)
         self._create_write_handlers(hub)
         self._create_write_handlers(hub)
 
 
-        # Maintain_pool is called whenever a process exits.
-        [hub.add_reader(fd, self.maintain_pool)
+        # Add handler for when a process exits (calls maintain_pool)
+        [hub.add_reader(fd, self._event_process_exit, hub, fd)
          for fd in self.process_sentinels]
          for fd in self.process_sentinels]
         # Handle_result_event is called whenever one of the
         # Handle_result_event is called whenever one of the
         # result queues are readable.
         # result queues are readable.
@@ -363,12 +375,21 @@ class AsynPool(_pool.Pool):
         fileno_to_outq = self._fileno_to_outq
         fileno_to_outq = self._fileno_to_outq
         fileno_to_synq = self._fileno_to_synq
         fileno_to_synq = self._fileno_to_synq
         busy_workers = self._busy_workers
         busy_workers = self._busy_workers
-        maintain_pool = self.maintain_pool
+        event_process_exit = self._event_process_exit
         handle_result_event = self.handle_result_event
         handle_result_event = self.handle_result_event
         process_flush_queues = self.process_flush_queues
         process_flush_queues = self.process_flush_queues
+        waiting_to_start = self._waiting_to_start
+
+        def verify_process_alive(proc):
+            if proc.exitcode is None and proc in waiting_to_start:
+                assert proc.outqR_fd in fileno_to_outq
+                assert fileno_to_outq[proc.outqR_fd] is proc
+                assert proc.outqR_fd in hub.readers
+                error('Timed out waiting for UP message from %r', proc)
+                os.kill(proc.pid, 9)
 
 
         def on_process_up(proc):
         def on_process_up(proc):
-            """Called when a WORKER_UP message is received from process."""
+            """Called when a process has started."""
             # If we got the same fd as a previous process then we will also
             # If we got the same fd as a previous process then we will also
             # receive jobs in the old buffer, so we need to reset the
             # receive jobs in the old buffer, so we need to reset the
             # job._write_to and job._scheduled_for attributes used to recover
             # job._write_to and job._scheduled_for attributes used to recover
@@ -381,22 +402,53 @@ class AsynPool(_pool.Pool):
                     job._scheduled_for = proc
                     job._scheduled_for = proc
             fileno_to_outq[proc.outqR_fd] = proc
             fileno_to_outq[proc.outqR_fd] = proc
             # maintain_pool is called whenever a process exits.
             # maintain_pool is called whenever a process exits.
-            add_reader(proc.sentinel, maintain_pool)
+            add_reader(
+                proc.sentinel, event_process_exit, hub, proc.sentinel,
+            )
             # handle_result_event is called when the processes outqueue is
             # handle_result_event is called when the processes outqueue is
             # readable.
             # readable.
             add_reader(proc.outqR_fd, handle_result_event, proc.outqR_fd)
             add_reader(proc.outqR_fd, handle_result_event, proc.outqR_fd)
+
+            waiting_to_start.add(proc)
+            hub.call_later(
+                self._proc_alive_timeout, verify_process_alive, proc,
+            )
+
         self.on_process_up = on_process_up
         self.on_process_up = on_process_up
 
 
+        def _remove_from_index(obj, proc, index, callback=None):
+            # this remove the file descriptors for a process from
+            # the indices.  we have to make sure we don't overwrite
+            # another processes fds, as the fds may be reused.
+            try:
+                fd = obj.fileno()
+            except (IOError, OSError):
+                return
+
+            try:
+                if index[fd] is proc:
+                    # fd has not been reused so we can remove it from index.
+                    index.pop(fd, None)
+            except KeyError:
+                pass
+            else:
+                hub_remove(fd)
+                if callback is not None:
+                    callback(fd)
+            return fd
+
         def on_process_down(proc):
         def on_process_down(proc):
             """Called when a worker process exits."""
             """Called when a worker process exits."""
+            assert not proc.dead
             process_flush_queues(proc)
             process_flush_queues(proc)
-            fileno_to_outq.pop(proc.outqR_fd, None)
-            fileno_to_inq.pop(proc.inqW_fd, None)
-            fileno_to_synq.pop(proc.synqW_fd, None)
-            all_inqueues.discard(proc.inqW_fd)
-            busy_workers.discard(proc.inqW_fd)
+            _remove_from_index(proc.outq._reader, proc, fileno_to_outq)
+            if proc.synq:
+                _remove_from_index(proc.synq._writer, proc, fileno_to_synq)
+            inq = _remove_from_index(proc.inq._writer, proc, fileno_to_inq,
+                                     callback=all_inqueues.discard)
+            if inq:
+                busy_workers.discard(inq)
             hub_remove(proc.sentinel)
             hub_remove(proc.sentinel)
-            hub_remove(proc.outqR_fd)
         self.on_process_down = on_process_down
         self.on_process_down = on_process_down
 
 
     def _create_write_handlers(self, hub,
     def _create_write_handlers(self, hub,
@@ -437,16 +489,26 @@ class AsynPool(_pool.Pool):
             # argument.  Using this means we minimize the risk of having
             # argument.  Using this means we minimize the risk of having
             # the same fd receive every task if the pipe read buffer is not
             # the same fd receive every task if the pipe read buffer is not
             # full.
             # full.
+
             if outbound:
             if outbound:
                 [hub_add(fd, None, WRITE | ERR, consolidate=True)
                 [hub_add(fd, None, WRITE | ERR, consolidate=True)
                  for fd in diff(active_writes)]
                  for fd in diff(active_writes)]
+            else:
+                [hub_remove(fd) for fd in diff(active_writes)]
         self.on_poll_start = on_poll_start
         self.on_poll_start = on_poll_start
 
 
-        def on_inqueue_close(fd):
+        def on_inqueue_close(fd, proc):
             # Makes sure the fd is removed from tracking when
             # Makes sure the fd is removed from tracking when
             # the connection is closed, this is essential as fds may be reused.
             # the connection is closed, this is essential as fds may be reused.
-            active_writes.discard(fd)
-            all_inqueues.discard(fd)
+            busy_workers.discard(fd)
+            try:
+                if fileno_to_inq[fd] is proc:
+                    fileno_to_inq.pop(fd, None)
+                    active_writes.discard(fd)
+                    all_inqueues.discard(fd)
+                    hub_remove(fd)
+            except KeyError:
+                pass
         self.on_inqueue_close = on_inqueue_close
         self.on_inqueue_close = on_inqueue_close
 
 
         def schedule_writes(ready_fds, shuffle=random.shuffle):
         def schedule_writes(ready_fds, shuffle=random.shuffle):
@@ -463,6 +525,9 @@ class AsynPool(_pool.Pool):
                 if is_fair_strategy and ready_fd in busy_workers:
                 if is_fair_strategy and ready_fd in busy_workers:
                     # worker is already busy with another task
                     # worker is already busy with another task
                     continue
                     continue
+                if ready_fd not in all_inqueues:
+                    hub_remove(ready_fd)
+                    continue
                 try:
                 try:
                     job = pop_message()
                     job = pop_message()
                 except IndexError:
                 except IndexError:
@@ -496,9 +561,13 @@ class AsynPool(_pool.Pool):
                         # Try to write immediately, in case there's an error.
                         # Try to write immediately, in case there's an error.
                         try:
                         try:
                             next(cor)
                             next(cor)
-                            add_writer(ready_fd, cor)
                         except StopIteration:
                         except StopIteration:
                             pass
                             pass
+                        except OSError as exc:
+                            if get_errno(exc) != errno.EBADF:
+                                raise
+                        else:
+                            add_writer(ready_fd, cor)
         hub.consolidate_callback = schedule_writes
         hub.consolidate_callback = schedule_writes
 
 
         def send_job(tup):
         def send_job(tup):
@@ -707,6 +776,9 @@ class AsynPool(_pool.Pool):
         except StopIteration:
         except StopIteration:
             # process already exited :(  this will be handled elsewhere.
             # process already exited :(  this will be handled elsewhere.
             return
             return
+        assert proc.inqW_fd not in self._fileno_to_inq
+        assert proc.inqW_fd not in self._all_inqueues
+        self._waiting_to_start.discard(proc)
         self._fileno_to_inq[proc.inqW_fd] = proc
         self._fileno_to_inq[proc.inqW_fd] = proc
         self._fileno_to_synq[proc.synqW_fd] = proc
         self._fileno_to_synq[proc.synqW_fd] = proc
         self._all_inqueues.add(proc.inqW_fd)
         self._all_inqueues.add(proc.inqW_fd)
@@ -714,10 +786,10 @@ class AsynPool(_pool.Pool):
     def on_job_process_down(self, job, pid_gone):
     def on_job_process_down(self, job, pid_gone):
         """Handler called for each job when the process it was assigned to
         """Handler called for each job when the process it was assigned to
         exits."""
         exits."""
-        if job._write_to:
+        if job._write_to and job._write_to.exitcode:
             # job was partially written
             # job was partially written
             self.on_partial_read(job, job._write_to)
             self.on_partial_read(job, job._write_to)
-        elif job._scheduled_for:
+        elif job._scheduled_for and job._scheduled_for.exitcode:
             # job was only scheduled to be written to this process,
             # job was only scheduled to be written to this process,
             # but no data was sent so put it back on the outbound_buffer.
             # but no data was sent so put it back on the outbound_buffer.
             self._put_back(job)
             self._put_back(job)
@@ -833,27 +905,32 @@ class AsynPool(_pool.Pool):
             if writer:
             if writer:
                 self._active_writers.discard(writer)
                 self._active_writers.discard(writer)
 
 
-            # Replace queues to avoid reuse
-            before = len(self._queues)
-            try:
-                queues = self._find_worker_queues(proc)
-                if self.destroy_queues(queues):
-                    self._queues[self.create_process_queues()] = None
-            except ValueError:
-                # Not in queue map, make sure sockets are closed.
-                self.destroy_queues((proc.inq, proc.outq, proc.synq))
-            assert len(self._queues) == before
-
-    def destroy_queues(self, queues):
+            if not proc.dead:
+                proc.dead = True
+                # Replace queues to avoid reuse
+                before = len(self._queues)
+                try:
+                    queues = self._find_worker_queues(proc)
+                    if self.destroy_queues(queues, proc):
+                        self._queues[self.create_process_queues()] = None
+                except ValueError:
+                    pass
+                    # Not in queue map, make sure sockets are closed.
+                    #self.destroy_queues((proc.inq, proc.outq, proc.synq))
+                assert len(self._queues) == before
+
+    def destroy_queues(self, queues, proc):
         """Destroy queues that can no longer be used, so that they
         """Destroy queues that can no longer be used, so that they
         be replaced by new sockets."""
         be replaced by new sockets."""
+        assert proc.exitcode is not None
+        self._waiting_to_start.discard(proc)
         removed = 1
         removed = 1
         try:
         try:
             self._queues.pop(queues)
             self._queues.pop(queues)
         except KeyError:
         except KeyError:
             removed = 0
             removed = 0
         try:
         try:
-            self.on_inqueue_close(queues[0]._writer.fileno())
+            self.on_inqueue_close(queues[0]._writer.fileno(), proc)
         except IOError:
         except IOError:
             pass
             pass
         for queue in queues:
         for queue in queues: