Explorar o código

Prefork: Fixes Pool.flush

Ask Solem %!s(int64=11) %!d(string=hai) anos
pai
achega
123f002aaf
Modificáronse 2 ficheiros con 51 adicións e 19 borrados
  1. 50 19
      celery/concurrency/asynpool.py
  2. 1 0
      celery/concurrency/prefork.py

+ 50 - 19
celery/concurrency/asynpool.py

@@ -26,7 +26,7 @@ import socket
 import struct
 import time
 
-from collections import deque, namedtuple
+from collections import defaultdict, deque, namedtuple
 from pickle import HIGHEST_PROTOCOL
 from time import sleep
 from weakref import WeakValueDictionary, ref
@@ -70,6 +70,15 @@ def gen_not_started(gen):
     return gen.gi_frame and gen.gi_frame.f_lasti == -1
 
 
+def _get_job_writer(job):
+    try:
+        writer = job._writer
+    except AttributeError:
+        pass
+    else:
+        return writer()  # is a weakref
+
+
 def _select(readers=None, writers=None, err=None, timeout=0):
     """Simple wrapper to :class:`~select.select`.
 
@@ -469,14 +478,15 @@ class AsynPool(_pool.Pool):
         put_message = outbound.append
         all_inqueues = self._all_inqueues
         active_writes = self._active_writes
+        active_writers = self._active_writers
         busy_workers = self._busy_workers
         diff = all_inqueues.difference
         add_reader, add_writer = hub.add_reader, hub.add_writer
         hub_add, hub_remove = hub.add, hub.remove
         mark_write_fd_as_active = active_writes.add
-        mark_write_gen_as_active = self._active_writers.add
+        mark_write_gen_as_active = active_writers.add
         mark_worker_as_busy = busy_workers.add
-        write_generator_done = self._active_writers.discard
+        write_generator_done = active_writers.discard
         get_job = self._cache.__getitem__
 
         def _put_back(job):
@@ -499,10 +509,10 @@ class AsynPool(_pool.Pool):
             # argument.  Using this means we minimize the risk of having
             # the same fd receive every task if the pipe read buffer is not
             # full.
-
             if outbound:
+                inactive = diff(active_writes)
                 [hub_add(fd, None, WRITE | ERR, consolidate=True)
-                 for fd in diff(active_writes)]
+                 for fd in inactive]
             else:
                 [hub_remove(fd) for fd in diff(active_writes)]
         self.on_poll_start = on_poll_start
@@ -710,6 +720,9 @@ class AsynPool(_pool.Pool):
         # the broker anyway.
         if self.outbound_buffer:
             self.outbound_buffer.clear()
+
+        self.maintain_pool()
+
         try:
             # ...but we must continue writing the payloads we already started
             # to keep message boundaries.
@@ -717,6 +730,12 @@ class AsynPool(_pool.Pool):
             if self._state == RUN:
                 # flush outgoing buffers
                 intervals = fxrange(0.01, 0.1, 0.01, repeatlast=True)
+                owned_by = defaultdict(set)
+                for job in values(self._cache):
+                    writer = _get_job_writer(job)
+                    if writer is not None:
+                        owned_by[writer].add(job)
+
                 while self._active_writers:
                     writers = list(self._active_writers)
                     for gen in writers:
@@ -725,26 +744,38 @@ class AsynPool(_pool.Pool):
                             # has not started writing the job so can
                             # discard the task, but we must also remove
                             # it from the Pool._cache.
-                            job_to_discard = None
-                            for job in values(self._cache):
-                                if job._writer() is gen:  # _writer is saferef
-                                    # removes from Pool._cache
-                                    job_to_discard = job
-                                    break
-                            if job_to_discard:
-                                job_to_discard.discard()
+                            for job in owned_by[gen]:
+                                # removes from Pool._cache
+                                job.discard()
                             self._active_writers.discard(gen)
                         else:
-                            try:
-                                next(gen)
-                            except StopIteration:
-                                self._active_writers.discard(gen)
+                            terminated = set()
+                            for job in owned_by[gen]:
+                                if job._write_to.exitcode is not None:
+                                    terminated.add(job)
+                            if len(terminated) < len(owned_by[gen]):
+                                self._flush_writer(gen)
                     # workers may have exited in the meantime.
                     self.maintain_pool()
                     sleep(next(intervals))  # don't busyloop
         finally:
             self.outbound_buffer.clear()
             self._active_writers.clear()
+            self._active_writes.clear()
+            self._busy_workers.clear()
+
+    def _flush_writer(self, writer):
+        try:
+            while 1:
+                try:
+                    next(writer)
+                except StopIteration:
+                    break
+        except (OSError, IOError) as exc:
+            if get_errno(exc) != errno.EBADF:
+                raise
+        finally:
+            self._active_writers.discard(writer)
 
     def get_process_queues(self):
         """Get queues for a new process.
@@ -911,10 +942,10 @@ class AsynPool(_pool.Pool):
             if not job._accepted:
                 # job was not acked, so find another worker to send it to.
                 self._put_back(job)
-            writer = getattr(job, '_writer')
-            writer = writer and writer() or None
+            writer = _get_job_writer(job)
             if writer:
                 self._active_writers.discard(writer)
+                del(writer)
 
             if not proc.dead:
                 proc.dead = True

+ 1 - 0
celery/concurrency/prefork.py

@@ -128,6 +128,7 @@ class TaskPool(BasePool):
         self.terminate_job = P.terminate_job
         self.grow = P.grow
         self.shrink = P.shrink
+        self.flush = P.flush
         self.restart = P.restart
         self.maybe_handle_result = P._result_handler.handle_event
         self.handle_result_event = P.handle_result_event