瀏覽代碼

Changed the Pool again. Looks like we're getting where we want (knocks on wood)

Ask Solem 16 年之前
父節點
當前提交
ccf06f6ff2
共有 2 個文件被更改,包括 54 次插入20 次删除
  1. 48 13
      celery/datastructures.py
  2. 6 7
      celery/worker.py

+ 48 - 13
celery/datastructures.py

@@ -5,6 +5,7 @@ Custom Datastructures
 """
 import multiprocessing
 import itertools
+import threading
 import time
 from UserList import UserList
 
@@ -108,7 +109,7 @@ class TaskWorkerPool(object):
                 "return_value": ret_value})
 
 
-class TaskProcessQueue(UserList):
+class TaskProcessQueue(object):
     """Queue of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.
 
@@ -135,15 +136,47 @@ class TaskProcessQueue(UserList):
 
     """
 
-    def __init__(self, limit, process_timeout=None, logger=None,
-            done_msg=None):
+    def __init__(self, limit, process_timeout=None,
+            logger=None, done_msg=None):
         self.limit = limit
         self.logger = logger
         self.done_msg = done_msg
         self.process_timeout = process_timeout
-        self.data = []
+        self._processes = {}
+        self._process_counter = itertools.count(1)
+        self._data_lock = threading.Condition(threading.Lock())
+        self.pool = multiprocessing.Pool(limit)
+
+    def apply_async(self, target, args, kwargs, task_name, task_id):
+        #self._data_lock.acquire()
+        try:
+            _pid = self._process_counter.next()
+
+            on_return = lambda ret_val: self.on_return(_pid, ret_val,
+                                                       task_name, task_id)
+
+            result = self.pool.apply_async(target, args, kwargs,
+                                           callback=on_return)
+            self.add(_pid, result, task_name, task_id)
+        finally:
+            pass
+            #self._data_lock.release()
+
+        return result
+
+    def on_return(self, _pid, ret_val, task_name, task_id):
+        #self._data_lock.acquire()
+        try:
+            del(self._processes[_pid])
+        except KeyError:
+            pass
+        else:
+            self.on_ready(ret_val, task_name, task_id)
+        finally:
+            pass
+            #self._data_lock.acquire()
 
-    def add(self, result, task_name, task_id):
+    def add(self, _pid, result, task_name, task_id):
         """Add a process to the queue.
 
         If the queue is full, it will wait for the first task to finish,
@@ -158,32 +191,34 @@ class TaskProcessQueue(UserList):
         :param task_id: Id of the task executed.
 
         """
+      
+        self._processes[_pid] = [result, task_name, task_id]
 
-        self.data.append([result, task_name, task_id])
-
-        if self.data and len(self.data) >= self.limit:
+        if self.full():
             self.wait_for_result()
-        else:
-            self.reap()
+
+    def full(self):
+        return len(self._processes.values()) >= self.limit
 
 
     def wait_for_result(self):
         """Collect results from processes that are ready."""
+        assert self.full()
         while True:
             if self.reap():
                 break
 
     def reap(self):
         processes_reaped = 0
-        for process_no, process_info in enumerate(self.data):
+        for process_no, entry in enumerate(self._processes.items()):
+            _pid, process_info = entry
             result, task_name, task_id = process_info
             try:
                 ret_value = result.get(timeout=0.1)
             except multiprocessing.TimeoutError:
                 continue
             else:
-                del(self[process_no])
-                self.on_ready(ret_value, task_name, task_id)
+                self.on_return(_pid, ret_value, task_name, task_id)
                 processes_reaped += 1
         return processes_reaped
 

+ 6 - 7
celery/worker.py

@@ -162,8 +162,10 @@ class TaskWrapper(object):
 
         """
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return pool.apply_async(jail, [self.task_id, self.task_func,
-                                       self.args, task_func_kwargs])
+        jail_args = [self.task_id, self.task_func,
+                     self.args, task_func_kwargs]
+        return pool.apply_async(jail, jail_args, {},
+                                self.task_name, self.task_id)
 
 
 class TaskDaemon(object):
@@ -231,7 +233,8 @@ class TaskDaemon(object):
         self.queue_wakeup_after = queue_wakeup_after or \
                                     self.queue_wakeup_after
         self.logger = setup_logger(loglevel, logfile)
-        self.pool = multiprocessing.Pool(self.concurrency)
+        self.pool = TaskProcessQueue(self.concurrency, logger=self.logger,
+                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         self.task_consumer = None
         self.reset_connection()
 
@@ -329,8 +332,6 @@ class TaskDaemon(object):
 
     def run(self):
         """Starts the workers main loop."""
-        results = TaskProcessQueue(self.concurrency, logger=self.logger,
-                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         log_wait = lambda: self.logger.info("Waiting for queue...")
         ev_msg_waiting = EventTimer(log_wait, self.empty_msg_emit_every)
         events = [
@@ -357,5 +358,3 @@ class TaskDaemon(object):
                 self.logger.critical("Message queue raised %s: %s\n%s" % (
                              e.__class__, e, traceback.format_exc()))
                 continue
-
-            results.add(result, task_name, task_id)