Browse Source

Restart pool if a process is dead

Ask Solem 16 years ago
parent
commit
cc199089ac
5 changed files with 50 additions and 85 deletions
  1. 1 1
      README.rst
  2. 1 1
      celery/__init__.py
  3. 3 0
      celery/conf.py
  4. 42 81
      celery/datastructures.py
  5. 3 2
      celery/timer.py

+ 1 - 1
README.rst

@@ -2,7 +2,7 @@
 celery - Distributed Task Queue for Django.
 ============================================
 
-:Version: 0.2.3
+:Version: 0.2.4
 
 Introduction
 ============

+ 1 - 1
celery/__init__.py

@@ -1,5 +1,5 @@
 """Distributed Task Queue for Django"""
-VERSION = (0, 2, 3)
+VERSION = (0, 2, 4)
 __version__ = ".".join(map(str, VERSION))
 __author__ = "Ask Solem"
 __contact__ = "askh@opera.com"

+ 3 - 0
celery/conf.py

@@ -12,6 +12,7 @@ DEFAULT_DAEMON_PID_FILE = "celeryd.pid"
 DEFAULT_LOG_FMT = '[%(asctime)s: %(levelname)s/%(processName)s] %(message)s'
 DEFAULT_DAEMON_LOG_LEVEL = "INFO"
 DEFAULT_DAEMON_LOG_FILE = "celeryd.log"
+DEFAULT_REAP_TIMEOUT = 30
 
 """
 .. data:: LOG_LEVELS
@@ -125,3 +126,5 @@ AMQP_ROUTING_KEY = getattr(settings, "CELERY_AMQP_ROUTING_KEY",
 """
 AMQP_CONSUMER_QUEUE = getattr(settings, "CELERY_AMQP_CONSUMER_QUEUE",
                               DEFAULT_AMQP_CONSUMER_QUEUE)
+
+REAP_TIMEOUT = DEFAULT_REAP_TIMEOUT

+ 42 - 81
celery/datastructures.py

@@ -7,7 +7,10 @@ import multiprocessing
 import itertools
 import threading
 import time
+import os
 from UserList import UserList
+from celery.timer import TimeoutTimer, TimeoutError
+from celery.conf import REAP_TIMEOUT
 
 
 class PositionQueue(UserList):
@@ -49,66 +52,6 @@ class PositionQueue(UserList):
                       self.data)
 
 
-class TaskWorkerPool(object):
-
-    Process = multiprocessing.Process
-
-    class TaskWorker(object):
-        def __init__(self, process, task_name, task_id):
-            self.process = process
-            self.task_name = task_name
-            self.task_id = task_id
-
-    def __init__(self, limit, logger=None, done_msg=None):
-        self.limit = limit
-        self.logger = logger
-        self.done_msg = done_msg
-        self._pool = []
-        self.task_counter = itertools.count(1)
-        self.total_tasks_run = 0
-
-    def add(self, target, args, kwargs=None, task_name=None, task_id=None):
-        self.total_tasks_run = self.task_counter.next()
-        if self._pool and len(self._pool) >= self.limit:
-            self.wait_for_result()
-        else:
-            self.reap()
-
-        current_worker_no = len(self._pool) + 1
-        process_name = "TaskWorker-%d" % current_worker_no
-        process = self.Process(target=target, args=args, kwargs=kwargs,
-                               name=process_name)
-        process.start()
-        task = self.TaskWorker(process, task_name, task_id)
-        self._pool.append(task)
-
-    def wait_for_result(self):
-        """Collect results from processes that are ready."""
-        while True:
-            if self.reap():
-                break
-            time.sleep(0.1)
-            
-    def reap(self):
-        processed_reaped = 0
-        for worker_no, worker in enumerate(self._pool):
-            process = worker.process
-            if not process.is_alive():
-                ret_value = process.join()
-                self.on_finished(ret_value, worker.task_name,
-                        worker.task_id)
-                del(self._pool[worker_no])
-                processed_reaped += 1
-        return processed_reaped
-
-    def on_finished(self, ret_value, task_name, task_id):
-        if self.done_msg and self.logger:
-            self.logger.info(self.done_msg % {
-                "name": task_name,
-                "id": task_id,
-                "return_value": ret_value})
-
-
 class TaskProcessQueue(object):
     """Queue of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.
@@ -136,45 +79,46 @@ class TaskProcessQueue(object):
 
     """
 
-    def __init__(self, limit, process_timeout=None,
+    def __init__(self, limit, reap_timeout=None,
             logger=None, done_msg=None):
         self.limit = limit
-        self.logger = logger
+        self.logger = logger or multiprocessing.get_logger()
         self.done_msg = done_msg
-        self.process_timeout = process_timeout
+        self.reap_timeout = reap_timeout
         self._processes = {}
         self._process_counter = itertools.count(1)
         self._data_lock = threading.Condition(threading.Lock())
-        self.pool = multiprocessing.Pool(limit)
+        self._start()
+
+    def _start(self):
+        assert int(self.limit)
+        self._pool = multiprocessing.Pool(processes=self.limit)
+
+    def _restart(self):
+        self.logger.info("Closing and restarting the pool...")
+        self._pool.close()
+        self._pool.join()
+        self._start()
 
     def apply_async(self, target, args, kwargs, task_name, task_id):
-        #self._data_lock.acquire()
-        try:
-            _pid = self._process_counter.next()
+        _pid = self._process_counter.next()
 
-            on_return = lambda ret_val: self.on_return(_pid, ret_val,
-                                                       task_name, task_id)
+        on_return = lambda ret_val: self.on_return(_pid, ret_val,
+                                                   task_name, task_id)
 
-            result = self.pool.apply_async(target, args, kwargs,
+        result = self._pool.apply_async(target, args, kwargs,
                                            callback=on_return)
-            self.add(_pid, result, task_name, task_id)
-        finally:
-            pass
-            #self._data_lock.release()
+        self.add(_pid, result, task_name, task_id)
 
         return result
 
     def on_return(self, _pid, ret_val, task_name, task_id):
-        #self._data_lock.acquire()
         try:
             del(self._processes[_pid])
         except KeyError:
             pass
         else:
             self.on_ready(ret_val, task_name, task_id)
-        finally:
-            pass
-            #self._data_lock.acquire()
 
     def add(self, _pid, result, task_name, task_id):
         """Add a process to the queue.
@@ -197,16 +141,33 @@ class TaskProcessQueue(object):
         if self.full():
             self.wait_for_result()
 
+    def _is_alive(self, pid):
+        try:
+            is_alive = os.waitpid(pid, os.WNOHANG) == (0, 0)
+        except OSError, e:
+            if e.errno != errno.ECHILD:
+                raise
+        return is_alive
+
+    def reap_zombies(self):
+        assert hasattr(self._pool, "_pool")
+        self.logger.debug("Trying to find zombies...")
+        for process in self._pool._pool:
+            pid = process.pid
+            if not self._is_alive(pid):
+                self.logger.error(
+                        "Process with pid %d is dead? Restarting pool" % pid)
+                self._restart()
+
     def full(self):
         return len(self._processes.values()) >= self.limit
 
-
     def wait_for_result(self):
         """Collect results from processes that are ready."""
-        assert self.full()
         while True:
             if self.reap():
                 break
+            self.reap_zombies()
 
     def reap(self):
         processes_reaped = 0
@@ -224,7 +185,7 @@ class TaskProcessQueue(object):
 
 
     def on_ready(self, ret_value, task_name, task_id):
-        if self.done_msg and self.logger:
+        if self.done_msg:
             self.logger.info(self.done_msg % {
                 "name": task_name,
                 "id": task_id,

+ 3 - 2
celery/timer.py

@@ -61,8 +61,9 @@ class TimeoutTimer(object):
 
     """
 
-    def __init__(self, timeout):
+    def __init__(self, timeout, timeout_msg="The operation timed out"):
         self.timeout = timeout
+        self.timeout_msg = timeout_msg
         self.time_start = time.time()
 
     def tick(self):
@@ -75,4 +76,4 @@ class TimeoutTimer(object):
         if not self.timeout:
             return
         if time.time() > self.time_start + self.timeout:
-            raise TimeoutError("The operation timed out.")
+            raise TimeoutError(self.timeout_msg)