Procházet zdrojové kódy

Changed the Pool again. Looks like we're getting where we want (knocks on wood)

Ask Solem před 16 roky
rodič
revize
ccf06f6ff2
2 změnil soubory, kde provedl 54 přidání a 20 odebrání
  1. 48 13
      celery/datastructures.py
  2. 6 7
      celery/worker.py

+ 48 - 13
celery/datastructures.py

@@ -5,6 +5,7 @@ Custom Datastructures
 """
 """
 import multiprocessing
 import multiprocessing
 import itertools
 import itertools
+import threading
 import time
 import time
 from UserList import UserList
 from UserList import UserList
 
 
@@ -108,7 +109,7 @@ class TaskWorkerPool(object):
                 "return_value": ret_value})
                 "return_value": ret_value})
 
 
 
 
-class TaskProcessQueue(UserList):
+class TaskProcessQueue(object):
     """Queue of running child processes, which starts waiting for the
     """Queue of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.
     processes to finish when the queue limit has been reached.
 
 
@@ -135,15 +136,47 @@ class TaskProcessQueue(UserList):
 
 
     """
     """
 
 
-    def __init__(self, limit, process_timeout=None, logger=None,
-            done_msg=None):
+    def __init__(self, limit, process_timeout=None,
+            logger=None, done_msg=None):
         self.limit = limit
         self.limit = limit
         self.logger = logger
         self.logger = logger
         self.done_msg = done_msg
         self.done_msg = done_msg
         self.process_timeout = process_timeout
         self.process_timeout = process_timeout
-        self.data = []
+        self._processes = {}
+        self._process_counter = itertools.count(1)
+        self._data_lock = threading.Condition(threading.Lock())
+        self.pool = multiprocessing.Pool(limit)
+
+    def apply_async(self, target, args, kwargs, task_name, task_id):
+        #self._data_lock.acquire()
+        try:
+            _pid = self._process_counter.next()
+
+            on_return = lambda ret_val: self.on_return(_pid, ret_val,
+                                                       task_name, task_id)
+
+            result = self.pool.apply_async(target, args, kwargs,
+                                           callback=on_return)
+            self.add(_pid, result, task_name, task_id)
+        finally:
+            pass
+            #self._data_lock.release()
+
+        return result
+
+    def on_return(self, _pid, ret_val, task_name, task_id):
+        #self._data_lock.acquire()
+        try:
+            del(self._processes[_pid])
+        except KeyError:
+            pass
+        else:
+            self.on_ready(ret_val, task_name, task_id)
+        finally:
+            pass
+            #self._data_lock.acquire()
 
 
-    def add(self, result, task_name, task_id):
+    def add(self, _pid, result, task_name, task_id):
         """Add a process to the queue.
         """Add a process to the queue.
 
 
         If the queue is full, it will wait for the first task to finish,
         If the queue is full, it will wait for the first task to finish,
@@ -158,32 +191,34 @@ class TaskProcessQueue(UserList):
         :param task_id: Id of the task executed.
         :param task_id: Id of the task executed.
 
 
         """
         """
+      
+        self._processes[_pid] = [result, task_name, task_id]
 
 
-        self.data.append([result, task_name, task_id])
-
-        if self.data and len(self.data) >= self.limit:
+        if self.full():
             self.wait_for_result()
             self.wait_for_result()
-        else:
-            self.reap()
+
+    def full(self):
+        return len(self._processes.values()) >= self.limit
 
 
 
 
     def wait_for_result(self):
     def wait_for_result(self):
         """Collect results from processes that are ready."""
         """Collect results from processes that are ready."""
+        assert self.full()
         while True:
         while True:
             if self.reap():
             if self.reap():
                 break
                 break
 
 
     def reap(self):
     def reap(self):
         processes_reaped = 0
         processes_reaped = 0
-        for process_no, process_info in enumerate(self.data):
+        for process_no, entry in enumerate(self._processes.items()):
+            _pid, process_info = entry
             result, task_name, task_id = process_info
             result, task_name, task_id = process_info
             try:
             try:
                 ret_value = result.get(timeout=0.1)
                 ret_value = result.get(timeout=0.1)
             except multiprocessing.TimeoutError:
             except multiprocessing.TimeoutError:
                 continue
                 continue
             else:
             else:
-                del(self[process_no])
-                self.on_ready(ret_value, task_name, task_id)
+                self.on_return(_pid, ret_value, task_name, task_id)
                 processes_reaped += 1
                 processes_reaped += 1
         return processes_reaped
         return processes_reaped
 
 

+ 6 - 7
celery/worker.py

@@ -162,8 +162,10 @@ class TaskWrapper(object):
 
 
         """
         """
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return pool.apply_async(jail, [self.task_id, self.task_func,
-                                       self.args, task_func_kwargs])
+        jail_args = [self.task_id, self.task_func,
+                     self.args, task_func_kwargs]
+        return pool.apply_async(jail, jail_args, {},
+                                self.task_name, self.task_id)
 
 
 
 
 class TaskDaemon(object):
 class TaskDaemon(object):
@@ -231,7 +233,8 @@ class TaskDaemon(object):
         self.queue_wakeup_after = queue_wakeup_after or \
         self.queue_wakeup_after = queue_wakeup_after or \
                                     self.queue_wakeup_after
                                     self.queue_wakeup_after
         self.logger = setup_logger(loglevel, logfile)
         self.logger = setup_logger(loglevel, logfile)
-        self.pool = multiprocessing.Pool(self.concurrency)
+        self.pool = TaskProcessQueue(self.concurrency, logger=self.logger,
+                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         self.task_consumer = None
         self.task_consumer = None
         self.reset_connection()
         self.reset_connection()
 
 
@@ -329,8 +332,6 @@ class TaskDaemon(object):
 
 
     def run(self):
     def run(self):
         """Starts the workers main loop."""
         """Starts the workers main loop."""
-        results = TaskProcessQueue(self.concurrency, logger=self.logger,
-                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         log_wait = lambda: self.logger.info("Waiting for queue...")
         log_wait = lambda: self.logger.info("Waiting for queue...")
         ev_msg_waiting = EventTimer(log_wait, self.empty_msg_emit_every)
         ev_msg_waiting = EventTimer(log_wait, self.empty_msg_emit_every)
         events = [
         events = [
@@ -357,5 +358,3 @@ class TaskDaemon(object):
                 self.logger.critical("Message queue raised %s: %s\n%s" % (
                 self.logger.critical("Message queue raised %s: %s\n%s" % (
                              e.__class__, e, traceback.format_exc()))
                              e.__class__, e, traceback.format_exc()))
                 continue
                 continue
-
-            results.add(result, task_name, task_id)