Browse Source

[synack] Use separate fd for ack reply

Ask Solem 12 years ago
parent
commit
1633ea2752
1 changed files with 67 additions and 65 deletions
  1. 67 65
      celery/concurrency/processes.py

+ 67 - 65
celery/concurrency/processes.py

@@ -17,7 +17,7 @@ import select
 import socket
 import struct
 
-from collections import defaultdict, deque
+from collections import deque
 from pickle import HIGHEST_PROTOCOL
 from time import sleep, time
 
@@ -234,29 +234,32 @@ class AsynPool(_pool.Pool):
 
     def __init__(self, processes=None, *args, **kwargs):
         processes = self.cpu_count() if processes is None else processes
-        self._queuepairs = dict((self.create_process_queuepair(), None)
-                                for _ in range(processes))
+        self._queues = dict((self.create_process_queues(), None)
+                            for _ in range(processes))
         self._fileno_to_inq = {}
         self._fileno_to_outq = {}
+        self._fileno_to_synq = {}
         self._all_inqueues = set()
         super(AsynPool, self).__init__(processes, *args, **kwargs)
 
         for proc in self._pool:
             self._fileno_to_inq[proc.inqW_fd] = proc
             self._fileno_to_outq[proc.outqR_fd] = proc
+            self._fileno_to_synq[proc.synqW_fd] = proc
 
     def _finalize_args(self):
         orig = super(AsynPool, self)._finalize_args()
         return (self._fileno_to_inq, orig)
 
-    def get_process_queuepair(self):
-        return next(pair for pair, owner in items(self._queuepairs)
+    def get_process_queues(self):
+        return next(q for q, owner in items(self._queues)
                     if owner is None)
 
-    def create_process_queuepair(self):
-        inq, outq = _SimpleQueue(), _SimpleQueue()
+    def create_process_queues(self):
+        inq, outq, synq = _SimpleQueue(), _SimpleQueue(), _SimpleQueue()
         inq._writer.setblocking(0)
-        return inq, outq
+        synq._writer.setblocking(0)
+        return inq, outq, synq
 
     def on_worker_alive(self, pid):
         try:
@@ -264,6 +267,7 @@ class AsynPool(_pool.Pool):
         except StopIteration:
             return
         self._fileno_to_inq[proc.inqW_fd] = proc
+        self._fileno_to_synq[proc.synqW_fd] = proc
         self._all_inqueues.add(proc.inqW_fd)
 
     def on_job_process_down(self, job):
@@ -273,9 +277,9 @@ class AsynPool(_pool.Pool):
     def on_job_process_lost(self, job, pid, exitcode):
         self.mark_as_worker_lost(job, exitcode)
 
-    def _process_cleanup_queuepair(self, proc):
+    def _process_cleanup_queues(self, proc):
         try:
-            self._queuepairs[self._find_worker_queuepair(proc)] = None
+            self._queues[self._find_worker_queues(proc)] = None
         except (KeyError, ValueError):
             pass
 
@@ -291,14 +295,15 @@ class AsynPool(_pool.Pool):
             on_worker_alive=self.on_worker_alive,
         )
 
-    def _process_register_queuepair(self, proc, pair):
-        self._queuepairs[pair] = proc
+    def _process_register_queues(self, proc, queues):
+        self._queues[queues] = proc
 
-    def _find_worker_queuepair(self, proc):
-        for pair, owner in items(self._queuepairs):
-            if owner == proc:
-                return pair
-        raise ValueError(proc)
+    def _find_worker_queues(self, proc):
+        try:
+            return next(q for q, owner in items(self._queues)
+                        if owner == proc)
+        except StopIteration:
+            raise ValueError(proc)
 
     def _setup_queues(self):
         self._inqueue = self._outqueue = \
@@ -319,12 +324,13 @@ class AsynPool(_pool.Pool):
             if not job._accepted:
                 self._put_back(job)
 
-            for conn in (proc.inq, proc.outq):
-                for sock in (conn._reader, conn._writer):
-                    if not sock.closed:
-                        os.close(sock.fileno())
-            self._queuepairs[(proc.inq, proc.outq)] = \
-                self._queuepairs[self.create_process_queuepair()] = None
+            for conn in (proc.inq, proc.outq, proc.synq):
+                if conn:
+                    for sock in (conn._reader, conn._writer):
+                        if not sock.closed:
+                            os.close(sock.fileno())
+            self._queues[(proc.inq, proc.outq, proc.synq)] = \
+                self._queues[self.create_process_queues()] = None
 
     @classmethod
     def _set_result_sentinel(cls, _outqueue, _pool):
@@ -429,12 +435,12 @@ class TaskPool(BasePool):
         apply_at = hub.timer.apply_at
         maintain_pool = self.maintain_pool
         on_soft_timeout = self.on_soft_timeout
-        on_hard_timeout = self.on_hard_timeout
+        fileno_to_inq = pool._fileno_to_inq
+        fileno_to_outq = pool._fileno_to_outq
+        fileno_to_synq = pool._fileno_to_synq
         outbound = self.outbound_buffer
         pop_message = outbound.popleft
         put_message = outbound.append
-        fileno_to_inq = pool._fileno_to_inq
-        fileno_to_outq = pool._fileno_to_outq
         all_inqueues = pool._all_inqueues
         active_writes = self._active_writes
         diff = all_inqueues.difference
@@ -492,6 +498,7 @@ class TaskPool(BasePool):
         def on_process_down(proc):
             fileno_to_outq.pop(proc.outqR_fd, None)
             fileno_to_inq.pop(proc.inqW_fd, None)
+            fileno_to_synq.pop(proc.synqW_fd, None)
             all_inqueues.discard(proc.inqW_fd)
             hub_remove(proc.sentinel)
             hub_remove(proc.outqR_fd)
@@ -511,19 +518,43 @@ class TaskPool(BasePool):
             def __hash__(self):
                 return self.i
 
+        def _write_ack(fd, ack, callback=None):
+            header, body, body_size = ack._payload
+            try:
+                try:
+                    proc = fileno_to_synq[fd]
+                except KeyError:
+                    raise StopIteration()
+                send_offset = proc.synq._writer.send_offset
+
+                Hw = Bw = 0
+                while Hw < 4:
+                    try:
+                        Hw += send_offset(header, Hw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        yield
+                while Bw < body_size:
+                    try:
+                        Bw += send_offset(body, Bw)
+                    except Exception as exc:
+                        if get_errno(exc) not in UNAVAIL:
+                            raise
+                        # suspend until more data
+                        yield
+            finally:
+                if callback:
+                    callback()
+                active_writes.discard(fd)
+
         def _write_to(fd, job, callback=None):
-            is_ack = isinstance(job, Ack)
-            if is_ack:
-                print('WRITE ACK!')
-            else:
-                print('WRITE JOB %r' % (job, ))
             header, body, body_size = job._payload
             try:
                 try:
                     proc = fileno_to_inq[fd]
                 except KeyError:
-                    if not isinstance(job, Ack):
-                        put_message(job)
+                    put_message(job)
                     raise StopIteration()
                 send_offset = proc.inq._writer.send_offset
                 # job result keeps track of what process the job is sent to.
@@ -556,9 +587,7 @@ class TaskPool(BasePool):
                 return
             try:
                 job = pop_message()
-                print('GOT JOB: %r' % (job, ))
             except IndexError:
-                print('GOT INDEX ERROR')
                 for inqfd in diff(active_writes):
                     hub_remove(inqfd)
             else:
@@ -574,41 +603,14 @@ class TaskPool(BasePool):
         ACK_SIZE = len(ACK_BODY)
         ACK_HEAD = pack('>I', ACK_SIZE)
 
-        acks = set()
-
-        from types import GeneratorType as generator
-
-        def _send_ack(fd, gen):
-            mark_write_fd_as_active(fd)
-            ex = hub.writers.get(fd)
-            assert not isinstance(ex, generator)
-            hub_add((fd, ), gen, WRITE)
-
-        def send_ack(pid, i):
-            fd = next(p.inqW_fd for p in pool._pool if p.pid == pid)
+        def send_ack(pid, i, fd):
             msg = Ack(i, fd, (ACK_HEAD, ACK_BODY, ACK_SIZE))
-            print('PUT ACK')
-            acks.add(msg)
-            #ex = hub.writers.get(fd)
-            #from types import GeneratorType as generator
-            #assert not isinstance(ex, generator)
-            #acks.append(_write_to(fd, msg))
-            #_send_ack(fd, _write_to(fd, msg))
-            #acks[fd].append(_write_to(fd, msg))
+            mark_write_fd_as_active(fd)
+            hub_add((fd, ), _write_ack(fd, msg), WRITE)
         self._pool.send_ack = send_ack
 
         def on_poll_start(hub):
-            if acks:
-                dirty = set()
-                for ack in acks:
-                    if ack.fd not in active_writes:
-                        active_writes.add(ack.fd)
-                        hub_add(ack.fd, _write_to(ack.fd, ack._payload), WRITE)
-                        dirty.add(ack)
-                for ack in dirty:
-                    acks.discard(ack)
             if outbound:
-                print('SCHEDULE WRITES FOR %r' % (diff(active_writes, )))
                 hub_add(diff(active_writes), schedule_writes, hub.WRITE)
         self.on_poll_start = on_poll_start