瀏覽代碼

Not using multiprocessing.Pool anymore, but our own algorithm.

Ask Solem 16 年之前
父節點
當前提交
a404426113
共有 3 個文件被更改,包括 72 次插入11 次删除
  1. 1 1
      celery/__init__.py
  2. 62 0
      celery/datastructures.py
  3. 9 10
      celery/worker.py

+ 1 - 1
celery/__init__.py

@@ -1,5 +1,5 @@
 """Distributed Task Queue for Django"""
-VERSION = (0, 2, 0)
+VERSION = (0, 2, 1)
 __version__ = ".".join(map(str, VERSION))
 __author__ = "Ask Solem"
 __contact__ = "askh@opera.com"

+ 62 - 0
celery/datastructures.py

@@ -4,6 +4,8 @@ Custom Datastructures
 
 """
 import multiprocessing
+import itertools
+import time
 from UserList import UserList
 
 
@@ -46,6 +48,66 @@ class PositionQueue(UserList):
                       self.data)
 
 
+class TaskWorkerPool(object):
+
+    Process = multiprocessing.Process
+
+    class TaskWorker(object):
+        def __init__(self, process, task_name, task_id):
+            self.process = process
+            self.task_name = task_name
+            self.task_id = task_id
+
+    def __init__(self, limit, logger=None, done_msg=None):
+        self.limit = limit
+        self.logger = logger
+        self.done_msg = done_msg
+        self._pool = []
+        self.task_counter = itertools.count(1)
+        self.total_tasks_run = 0
+
+    def add(self, target, args, kwargs=None, task_name=None, task_id=None):
+        self.total_tasks_run = self.task_counter.next()
+        if self._pool and len(self._pool) >= self.limit:
+            self.wait_for_result()
+        else:
+            self.reap()
+
+        current_worker_no = len(self._pool) + 1
+        process_name = "TaskWorker-%d" % current_worker_no
+        process = self.Process(target=target, args=args, kwargs=kwargs,
+                               name=process_name)
+        process.start()
+        task = self.TaskWorker(process, task_name, task_id)
+        self._pool.append(task)
+
+    def wait_for_result(self):
+        """Collect results from processes that are ready."""
+        while True:
+            if self.reap():
+                break
+            time.sleep(0.1)
+            
+    def reap(self):
+        processed_reaped = 0
+        for worker_no, worker in enumerate(self._pool):
+            process = worker.process
+            if not process.is_alive():
+                ret_value = process.join()
+                self.on_finished(ret_value, worker.task_name,
+                        worker.task_id)
+                del(self._pool[worker_no])
+                processed_reaped += 1
+        return processed_reaped
+
+    def on_finished(self, ret_value, task_name, task_id):
+        if self.done_msg and self.logger:
+            self.logger.info(self.done_msg % {
+                "name": task_name,
+                "id": task_id,
+                "return_value": ret_value})
+
+
 class TaskProcessQueue(UserList):
     """Queue of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.

+ 9 - 10
celery/worker.py

@@ -5,7 +5,7 @@ from celery.conf import DAEMON_CONCURRENCY, DAEMON_LOG_FILE
 from celery.conf import QUEUE_WAKEUP_AFTER, EMPTY_MSG_EMIT_EVERY
 from celery.log import setup_logger
 from celery.registry import tasks
-from celery.datastructures import TaskProcessQueue
+from celery.datastructures import TaskWorkerPool
 from celery.models import PeriodicTaskMeta
 from celery.backends import default_backend
 from celery.timer import EventTimer
@@ -162,8 +162,9 @@ class TaskWrapper(object):
 
         """
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
-        return pool.apply_async(jail, [self.task_id, self.task_func,
-                                       self.args, task_func_kwargs])
+        jail_args = [self.task_id, self.task_func, self.args,
+                     task_func_kwargs]
+        return pool.add(jail, jail_args, self.task_name, self.task_id)
 
 
 class TaskDaemon(object):
@@ -231,7 +232,8 @@ class TaskDaemon(object):
         self.queue_wakeup_after = queue_wakeup_after or \
                                     self.queue_wakeup_after
         self.logger = setup_logger(loglevel, logfile)
-        self.pool = multiprocessing.Pool(self.concurrency)
+        self.pool = TaskWorkerPool(self.concurrency, logger=self.logger,
+                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         self.task_consumer = None
         self.reset_connection()
 
@@ -329,13 +331,12 @@ class TaskDaemon(object):
 
     def run(self):
         """Starts the workers main loop."""
-        results = TaskProcessQueue(self.concurrency, logger=self.logger,
-                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         log_wait = lambda: self.logger.info("Waiting for queue...")
         ev_msg_waiting = EventTimer(log_wait, self.empty_msg_emit_every)
         events = [
-            EventTimer(self.run_periodic_tasks, 1),
-            EventTimer(self.schedule_retry_tasks, 2),
+            EventTimer(self.run_periodic_tasks, 2),
+            EventTimer(self.schedule_retry_tasks, 4),
+            EventTimer(self.pool.reap, 2),
         ]
 
         while True:
@@ -357,5 +358,3 @@ class TaskDaemon(object):
                 self.logger.critical("Message queue raised %s: %s\n%s" % (
                              e.__class__, e, traceback.format_exc()))
                 continue
-
-            results.add(result, task_name, task_id)