Ask Solem 12 лет назад
Родитель
Сommit
dea4fc34e5
1 измененных файлов с 72 добавлено и 9 удалено
  1. 72 9
      celery/concurrency/processes.py

+ 72 - 9
celery/concurrency/processes.py

@@ -17,14 +17,14 @@ import select
 import socket
 import struct
 
-from collections import deque
+from collections import defaultdict, deque
 from pickle import HIGHEST_PROTOCOL
 from time import sleep, time
 
 from billiard import forking_enable
 from billiard import pool as _pool
 from billiard.exceptions import WorkerLostError
-from billiard.pool import RUN, CLOSE, TERMINATE, WorkersJoined, CoroStop
+from billiard.pool import RUN, CLOSE, TERMINATE, ACK, WorkersJoined, CoroStop
 from billiard.queues import _SimpleQueue
 from kombu.serialization import pickle as _pickle
 from kombu.utils import fxrange
@@ -99,12 +99,12 @@ def _select(self, readers=None, writers=None, err=None, timeout=0):
     try:
         r, w, e = select.select(readers, writers, err, timeout)
         if e:
-            _seen = set()
-            r = [f for f in r + e if f not in _seen and not _seen.add(f)]
+            seen = set()
+            r = r | set(f for f in r + e if f not in seen and not seen.add(f))
         return r, w, 0
     except (select.error, socket.error) as exc:
         if get_errno(exc) == errno.EINTR:
-            return
+            return [], [], 1
         elif get_errno(exc) in SELECT_BAD_FD:
             for fd in readers | writers | err:
                 try:
@@ -116,6 +116,8 @@ def _select(self, readers=None, writers=None, err=None, timeout=0):
                     writers.discard(fd)
                     err.discard(fd)
             return [], [], 1
+        else:
+            raise
 
 
 class promise(object):
@@ -199,7 +201,6 @@ class ResultHandler(_pool.ResultHandler):
         while cache and outqueues and self._state != TERMINATE:
             if check_timeouts is not None:
                 check_timeouts()
-            _dirty = set()
             for fd in outqueues:
                 try:
                     proc = fileno_to_outq[fd]
@@ -377,6 +378,7 @@ class TaskPool(BasePool):
                 else self.Pool)
         P = self._pool = Pool(processes=self.limit,
                               initializer=process_initializer,
+                              synack=True,
                               **self.options)
         self.on_apply = P.apply_async
         self.on_soft_timeout = P._timeout_handler.on_soft_timeout
@@ -495,13 +497,33 @@ class TaskPool(BasePool):
             hub_remove(proc.outqR_fd)
         self._pool.on_process_down = on_process_down
 
+        class Ack(object):
+            _write_to = None
+
+            def __init__(self, id, fd, payload):
+                self.id = id
+                self.fd = fd
+                self._payload = payload
+
+            def __eq__(self, other):
+                return self.i == other.i
+
+            def __hash__(self):
+                return self.i
+
         def _write_to(fd, job, callback=None):
+            is_ack = isinstance(job, Ack)
+            if is_ack:
+                print('WRITE ACK!')
+            else:
+                print('WRITE JOB %r' % (job, ))
             header, body, body_size = job._payload
             try:
                 try:
                     proc = fileno_to_inq[fd]
                 except KeyError:
-                    put_message(job)
+                    if not isinstance(job, Ack):
+                        put_message(job)
                     raise StopIteration()
                 send_offset = proc.inq._writer.send_offset
                 # job result keeps track of what process the job is sent to.
@@ -530,9 +552,13 @@ class TaskPool(BasePool):
                 active_writes.discard(fd)
 
         def schedule_writes(ready_fd, events):
+            if ready_fd in active_writes:
+                return
             try:
                 job = pop_message()
+                print('GOT JOB: %r' % (job, ))
             except IndexError:
+                print('GOT INDEX ERROR')
                 for inqfd in diff(active_writes):
                     hub_remove(inqfd)
             else:
@@ -544,8 +570,45 @@ class TaskPool(BasePool):
                     callback.args = (cor, )  # tricky as we need to pass ref
                     hub_add((ready_fd, ), cor, WRITE)
 
+        ACK_BODY = dumps((ACK, (0, )), protocol=protocol)
+        ACK_SIZE = len(ACK_BODY)
+        ACK_HEAD = pack('>I', ACK_SIZE)
+
+        acks = set()
+
+        from types import GeneratorType as generator
+
+        def _send_ack(fd, gen):
+            mark_write_fd_as_active(fd)
+            ex = hub.writers.get(fd)
+            assert not isinstance(ex, generator)
+            hub_add((fd, ), gen, WRITE)
+
+        def send_ack(pid, i):
+            fd = next(p.inqW_fd for p in pool._pool if p.pid == pid)
+            msg = Ack(i, fd, (ACK_HEAD, ACK_BODY, ACK_SIZE))
+            print('PUT ACK')
+            acks.add(msg)
+            #ex = hub.writers.get(fd)
+            #from types import GeneratorType as generator
+            #assert not isinstance(ex, generator)
+            #acks.append(_write_to(fd, msg))
+            #_send_ack(fd, _write_to(fd, msg))
+            #acks[fd].append(_write_to(fd, msg))
+        self._pool.send_ack = send_ack
+
         def on_poll_start(hub):
+            if acks:
+                dirty = set()
+                for ack in acks:
+                    if ack.fd not in active_writes:
+                        active_writes.add(ack.fd)
+                        hub_add(ack.fd, _write_to(ack.fd, ack._payload), WRITE)
+                        dirty.add(ack)
+                for ack in dirty:
+                    acks.discard(ack)
             if outbound:
+                print('SCHEDULE WRITES FOR %r' % (diff(active_writes, )))
                 hub_add(diff(active_writes), schedule_writes, hub.WRITE)
         self.on_poll_start = on_poll_start
 
@@ -553,8 +616,8 @@ class TaskPool(BasePool):
             body = dumps(tup, protocol=protocol)
             body_size = len(body)
             header = pack('>I', body_size)
-            # index 0 is the job ID.
-            job = get_job(tup[0])
+            # index 1,0 is the job ID.
+            job = get_job(tup[1][0])
             job._payload = header, buffer(body), body_size
             put_message(job)
         self._pool._quick_put = quick_put