Преглед на файлове

Asynpool: Fixes race condition when child proces terminated fast

Ask Solem преди 11 години
родител
ревизия
f2e2bac7c8
променени са 1 файла, в които са добавени 110 реда и са изтрити 33 реда
  1. 110 33
      celery/concurrency/asynpool.py

+ 110 - 33
celery/concurrency/asynpool.py

@@ -15,6 +15,7 @@
 from __future__ import absolute_import
 
 import errno
+import os
 import random
 import select
 import socket
@@ -42,9 +43,9 @@ from celery.five import Counter, items, values
 from celery.utils.log import get_logger
 
 logger = get_logger(__name__)
-debug = logger.debug
+error, debug = logger.error, logger.debug
 
-UNAVAIL = frozenset([errno.EAGAIN, errno.EINTR, errno.EBADF])
+UNAVAIL = frozenset([errno.EAGAIN, errno.EINTR])
 
 #: Constant sent by child process when started (ready to accept work)
 WORKER_UP = 15
@@ -243,6 +244,13 @@ class AsynPool(_pool.Pool):
         # synqueue fileno -> process mapping
         self._fileno_to_synq = {}
 
+        # We keep track of processes that have not yet
+        # sent a WORKER_UP message.  If a process fails to send
+        # this message within proc_up_timeout we terminate it
+        # and hope the next process will recover.
+        self._proc_alive_timeout = 2.0
+        self._waiting_to_start = set()
+
         # denormalized set of all inqueues.
         self._all_inqueues = set()
 
@@ -264,20 +272,24 @@ class AsynPool(_pool.Pool):
         for proc in self._pool:
             # create initial mappings, these will be updated
             # as processes are recycled, or found lost elsewhere.
-            self._fileno_to_inq[proc.inqW_fd] = proc
             self._fileno_to_outq[proc.outqR_fd] = proc
             self._fileno_to_synq[proc.synqW_fd] = proc
         self.on_soft_timeout = self._timeout_handler.on_soft_timeout
         self.on_hard_timeout = self._timeout_handler.on_hard_timeout
 
+    def _event_process_exit(self, hub, fd):
+        # This method is called whenever the process sentinel is readable.
+        hub.remove(fd)
+        self.maintain_pool()
+
     def register_with_event_loop(self, hub):
         """Registers the async pool with the current event loop."""
         self._create_timelimit_handlers(hub)
         self._create_process_handlers(hub)
         self._create_write_handlers(hub)
 
-        # Maintain_pool is called whenever a process exits.
-        [hub.add_reader(fd, self.maintain_pool)
+        # Add handler for when a process exits (calls maintain_pool)
+        [hub.add_reader(fd, self._event_process_exit, hub, fd)
          for fd in self.process_sentinels]
         # Handle_result_event is called whenever one of the
         # result queues are readable.
@@ -363,12 +375,21 @@ class AsynPool(_pool.Pool):
         fileno_to_outq = self._fileno_to_outq
         fileno_to_synq = self._fileno_to_synq
         busy_workers = self._busy_workers
-        maintain_pool = self.maintain_pool
+        event_process_exit = self._event_process_exit
         handle_result_event = self.handle_result_event
         process_flush_queues = self.process_flush_queues
+        waiting_to_start = self._waiting_to_start
+
+        def verify_process_alive(proc):
+            if proc.exitcode is None and proc in waiting_to_start:
+                assert proc.outqR_fd in fileno_to_outq
+                assert fileno_to_outq[proc.outqR_fd] is proc
+                assert proc.outqR_fd in hub.readers
+                error('Timed out waiting for UP message from %r', proc)
+                os.kill(proc.pid, 9)
 
         def on_process_up(proc):
-            """Called when a WORKER_UP message is received from process."""
+            """Called when a process has started."""
             # If we got the same fd as a previous process then we will also
             # receive jobs in the old buffer, so we need to reset the
             # job._write_to and job._scheduled_for attributes used to recover
@@ -381,22 +402,53 @@ class AsynPool(_pool.Pool):
                     job._scheduled_for = proc
             fileno_to_outq[proc.outqR_fd] = proc
             # maintain_pool is called whenever a process exits.
-            add_reader(proc.sentinel, maintain_pool)
+            add_reader(
+                proc.sentinel, event_process_exit, hub, proc.sentinel,
+            )
             # handle_result_event is called when the processes outqueue is
             # readable.
             add_reader(proc.outqR_fd, handle_result_event, proc.outqR_fd)
+
+            waiting_to_start.add(proc)
+            hub.call_later(
+                self._proc_alive_timeout, verify_process_alive, proc,
+            )
+
         self.on_process_up = on_process_up
 
+        def _remove_from_index(obj, proc, index, callback=None):
+            # this remove the file descriptors for a process from
+            # the indices.  we have to make sure we don't overwrite
+            # another processes fds, as the fds may be reused.
+            try:
+                fd = obj.fileno()
+            except (IOError, OSError):
+                return
+
+            try:
+                if index[fd] is proc:
+                    # fd has not been reused so we can remove it from index.
+                    index.pop(fd, None)
+            except KeyError:
+                pass
+            else:
+                hub_remove(fd)
+                if callback is not None:
+                    callback(fd)
+            return fd
+
         def on_process_down(proc):
             """Called when a worker process exits."""
+            assert not proc.dead
             process_flush_queues(proc)
-            fileno_to_outq.pop(proc.outqR_fd, None)
-            fileno_to_inq.pop(proc.inqW_fd, None)
-            fileno_to_synq.pop(proc.synqW_fd, None)
-            all_inqueues.discard(proc.inqW_fd)
-            busy_workers.discard(proc.inqW_fd)
+            _remove_from_index(proc.outq._reader, proc, fileno_to_outq)
+            if proc.synq:
+                _remove_from_index(proc.synq._writer, proc, fileno_to_synq)
+            inq = _remove_from_index(proc.inq._writer, proc, fileno_to_inq,
+                                     callback=all_inqueues.discard)
+            if inq:
+                busy_workers.discard(inq)
             hub_remove(proc.sentinel)
-            hub_remove(proc.outqR_fd)
         self.on_process_down = on_process_down
 
     def _create_write_handlers(self, hub,
@@ -437,16 +489,26 @@ class AsynPool(_pool.Pool):
             # argument.  Using this means we minimize the risk of having
             # the same fd receive every task if the pipe read buffer is not
             # full.
+
             if outbound:
                 [hub_add(fd, None, WRITE | ERR, consolidate=True)
                  for fd in diff(active_writes)]
+            else:
+                [hub_remove(fd) for fd in diff(active_writes)]
         self.on_poll_start = on_poll_start
 
-        def on_inqueue_close(fd):
+        def on_inqueue_close(fd, proc):
             # Makes sure the fd is removed from tracking when
             # the connection is closed, this is essential as fds may be reused.
-            active_writes.discard(fd)
-            all_inqueues.discard(fd)
+            busy_workers.discard(fd)
+            try:
+                if fileno_to_inq[fd] is proc:
+                    fileno_to_inq.pop(fd, None)
+                    active_writes.discard(fd)
+                    all_inqueues.discard(fd)
+                    hub_remove(fd)
+            except KeyError:
+                pass
         self.on_inqueue_close = on_inqueue_close
 
         def schedule_writes(ready_fds, shuffle=random.shuffle):
@@ -463,6 +525,9 @@ class AsynPool(_pool.Pool):
                 if is_fair_strategy and ready_fd in busy_workers:
                     # worker is already busy with another task
                     continue
+                if ready_fd not in all_inqueues:
+                    hub_remove(ready_fd)
+                    continue
                 try:
                     job = pop_message()
                 except IndexError:
@@ -496,9 +561,13 @@ class AsynPool(_pool.Pool):
                         # Try to write immediately, in case there's an error.
                         try:
                             next(cor)
-                            add_writer(ready_fd, cor)
                         except StopIteration:
                             pass
+                        except OSError as exc:
+                            if get_errno(exc) != errno.EBADF:
+                                raise
+                        else:
+                            add_writer(ready_fd, cor)
         hub.consolidate_callback = schedule_writes
 
         def send_job(tup):
@@ -707,6 +776,9 @@ class AsynPool(_pool.Pool):
         except StopIteration:
             # process already exited :(  this will be handled elsewhere.
             return
+        assert proc.inqW_fd not in self._fileno_to_inq
+        assert proc.inqW_fd not in self._all_inqueues
+        self._waiting_to_start.discard(proc)
         self._fileno_to_inq[proc.inqW_fd] = proc
         self._fileno_to_synq[proc.synqW_fd] = proc
         self._all_inqueues.add(proc.inqW_fd)
@@ -714,10 +786,10 @@ class AsynPool(_pool.Pool):
     def on_job_process_down(self, job, pid_gone):
         """Handler called for each job when the process it was assigned to
         exits."""
-        if job._write_to:
+        if job._write_to and job._write_to.exitcode:
             # job was partially written
             self.on_partial_read(job, job._write_to)
-        elif job._scheduled_for:
+        elif job._scheduled_for and job._scheduled_for.exitcode:
             # job was only scheduled to be written to this process,
             # but no data was sent so put it back on the outbound_buffer.
             self._put_back(job)
@@ -833,27 +905,32 @@ class AsynPool(_pool.Pool):
             if writer:
                 self._active_writers.discard(writer)
 
-            # Replace queues to avoid reuse
-            before = len(self._queues)
-            try:
-                queues = self._find_worker_queues(proc)
-                if self.destroy_queues(queues):
-                    self._queues[self.create_process_queues()] = None
-            except ValueError:
-                # Not in queue map, make sure sockets are closed.
-                self.destroy_queues((proc.inq, proc.outq, proc.synq))
-            assert len(self._queues) == before
-
-    def destroy_queues(self, queues):
+            if not proc.dead:
+                proc.dead = True
+                # Replace queues to avoid reuse
+                before = len(self._queues)
+                try:
+                    queues = self._find_worker_queues(proc)
+                    if self.destroy_queues(queues, proc):
+                        self._queues[self.create_process_queues()] = None
+                except ValueError:
+                    pass
+                    # Not in queue map, make sure sockets are closed.
+                    #self.destroy_queues((proc.inq, proc.outq, proc.synq))
+                assert len(self._queues) == before
+
+    def destroy_queues(self, queues, proc):
         """Destroy queues that can no longer be used, so that they
         be replaced by new sockets."""
+        assert proc.exitcode is not None
+        self._waiting_to_start.discard(proc)
         removed = 1
         try:
             self._queues.pop(queues)
         except KeyError:
             removed = 0
         try:
-            self.on_inqueue_close(queues[0]._writer.fileno())
+            self.on_inqueue_close(queues[0]._writer.fileno(), proc)
         except IOError:
             pass
         for queue in queues: