Browse Source

Restart pool if a process is dead

Ask Solem 16 years ago
parent
commit
cc199089ac
5 changed files with 50 additions and 85 deletions
  1. 1 1
      README.rst
  2. 1 1
      celery/__init__.py
  3. 3 0
      celery/conf.py
  4. 42 81
      celery/datastructures.py
  5. 3 2
      celery/timer.py

+ 1 - 1
README.rst

@@ -2,7 +2,7 @@
 celery - Distributed Task Queue for Django.
 celery - Distributed Task Queue for Django.
 ============================================
 ============================================
 
 
-:Version: 0.2.3
+:Version: 0.2.4
 
 
 Introduction
 Introduction
 ============
 ============

+ 1 - 1
celery/__init__.py

@@ -1,5 +1,5 @@
 """Distributed Task Queue for Django"""
 """Distributed Task Queue for Django"""
-VERSION = (0, 2, 3)
+VERSION = (0, 2, 4)
 __version__ = ".".join(map(str, VERSION))
 __version__ = ".".join(map(str, VERSION))
 __author__ = "Ask Solem"
 __author__ = "Ask Solem"
 __contact__ = "askh@opera.com"
 __contact__ = "askh@opera.com"

+ 3 - 0
celery/conf.py

@@ -12,6 +12,7 @@ DEFAULT_DAEMON_PID_FILE = "celeryd.pid"
 DEFAULT_LOG_FMT = '[%(asctime)s: %(levelname)s/%(processName)s] %(message)s'
 DEFAULT_LOG_FMT = '[%(asctime)s: %(levelname)s/%(processName)s] %(message)s'
 DEFAULT_DAEMON_LOG_LEVEL = "INFO"
 DEFAULT_DAEMON_LOG_LEVEL = "INFO"
 DEFAULT_DAEMON_LOG_FILE = "celeryd.log"
 DEFAULT_DAEMON_LOG_FILE = "celeryd.log"
+DEFAULT_REAP_TIMEOUT = 30
 
 
 """
 """
 .. data:: LOG_LEVELS
 .. data:: LOG_LEVELS
@@ -125,3 +126,5 @@ AMQP_ROUTING_KEY = getattr(settings, "CELERY_AMQP_ROUTING_KEY",
 """
 """
 AMQP_CONSUMER_QUEUE = getattr(settings, "CELERY_AMQP_CONSUMER_QUEUE",
 AMQP_CONSUMER_QUEUE = getattr(settings, "CELERY_AMQP_CONSUMER_QUEUE",
                               DEFAULT_AMQP_CONSUMER_QUEUE)
                               DEFAULT_AMQP_CONSUMER_QUEUE)
+
+REAP_TIMEOUT = DEFAULT_REAP_TIMEOUT

+ 42 - 81
celery/datastructures.py

@@ -7,7 +7,10 @@ import multiprocessing
 import itertools
 import itertools
 import threading
 import threading
 import time
 import time
+import os
 from UserList import UserList
 from UserList import UserList
+from celery.timer import TimeoutTimer, TimeoutError
+from celery.conf import REAP_TIMEOUT
 
 
 
 
 class PositionQueue(UserList):
 class PositionQueue(UserList):
@@ -49,66 +52,6 @@ class PositionQueue(UserList):
                       self.data)
                       self.data)
 
 
 
 
-class TaskWorkerPool(object):
-
-    Process = multiprocessing.Process
-
-    class TaskWorker(object):
-        def __init__(self, process, task_name, task_id):
-            self.process = process
-            self.task_name = task_name
-            self.task_id = task_id
-
-    def __init__(self, limit, logger=None, done_msg=None):
-        self.limit = limit
-        self.logger = logger
-        self.done_msg = done_msg
-        self._pool = []
-        self.task_counter = itertools.count(1)
-        self.total_tasks_run = 0
-
-    def add(self, target, args, kwargs=None, task_name=None, task_id=None):
-        self.total_tasks_run = self.task_counter.next()
-        if self._pool and len(self._pool) >= self.limit:
-            self.wait_for_result()
-        else:
-            self.reap()
-
-        current_worker_no = len(self._pool) + 1
-        process_name = "TaskWorker-%d" % current_worker_no
-        process = self.Process(target=target, args=args, kwargs=kwargs,
-                               name=process_name)
-        process.start()
-        task = self.TaskWorker(process, task_name, task_id)
-        self._pool.append(task)
-
-    def wait_for_result(self):
-        """Collect results from processes that are ready."""
-        while True:
-            if self.reap():
-                break
-            time.sleep(0.1)
-            
-    def reap(self):
-        processed_reaped = 0
-        for worker_no, worker in enumerate(self._pool):
-            process = worker.process
-            if not process.is_alive():
-                ret_value = process.join()
-                self.on_finished(ret_value, worker.task_name,
-                        worker.task_id)
-                del(self._pool[worker_no])
-                processed_reaped += 1
-        return processed_reaped
-
-    def on_finished(self, ret_value, task_name, task_id):
-        if self.done_msg and self.logger:
-            self.logger.info(self.done_msg % {
-                "name": task_name,
-                "id": task_id,
-                "return_value": ret_value})
-
-
 class TaskProcessQueue(object):
 class TaskProcessQueue(object):
     """Queue of running child processes, which starts waiting for the
     """Queue of running child processes, which starts waiting for the
     processes to finish when the queue limit has been reached.
     processes to finish when the queue limit has been reached.
@@ -136,45 +79,46 @@ class TaskProcessQueue(object):
 
 
     """
     """
 
 
-    def __init__(self, limit, process_timeout=None,
+    def __init__(self, limit, reap_timeout=None,
             logger=None, done_msg=None):
             logger=None, done_msg=None):
         self.limit = limit
         self.limit = limit
-        self.logger = logger
+        self.logger = logger or multiprocessing.get_logger()
         self.done_msg = done_msg
         self.done_msg = done_msg
-        self.process_timeout = process_timeout
+        self.reap_timeout = reap_timeout
         self._processes = {}
         self._processes = {}
         self._process_counter = itertools.count(1)
         self._process_counter = itertools.count(1)
         self._data_lock = threading.Condition(threading.Lock())
         self._data_lock = threading.Condition(threading.Lock())
-        self.pool = multiprocessing.Pool(limit)
+        self._start()
+
+    def _start(self):
+        assert int(self.limit)
+        self._pool = multiprocessing.Pool(processes=self.limit)
+
+    def _restart(self):
+        self.logger.info("Closing and restarting the pool...")
+        self._pool.close()
+        self._pool.join()
+        self._start()
 
 
     def apply_async(self, target, args, kwargs, task_name, task_id):
     def apply_async(self, target, args, kwargs, task_name, task_id):
-        #self._data_lock.acquire()
-        try:
-            _pid = self._process_counter.next()
+        _pid = self._process_counter.next()
 
 
-            on_return = lambda ret_val: self.on_return(_pid, ret_val,
-                                                       task_name, task_id)
+        on_return = lambda ret_val: self.on_return(_pid, ret_val,
+                                                   task_name, task_id)
 
 
-            result = self.pool.apply_async(target, args, kwargs,
+        result = self._pool.apply_async(target, args, kwargs,
                                            callback=on_return)
                                            callback=on_return)
-            self.add(_pid, result, task_name, task_id)
-        finally:
-            pass
-            #self._data_lock.release()
+        self.add(_pid, result, task_name, task_id)
 
 
         return result
         return result
 
 
     def on_return(self, _pid, ret_val, task_name, task_id):
     def on_return(self, _pid, ret_val, task_name, task_id):
-        #self._data_lock.acquire()
         try:
         try:
             del(self._processes[_pid])
             del(self._processes[_pid])
         except KeyError:
         except KeyError:
             pass
             pass
         else:
         else:
             self.on_ready(ret_val, task_name, task_id)
             self.on_ready(ret_val, task_name, task_id)
-        finally:
-            pass
-            #self._data_lock.acquire()
 
 
     def add(self, _pid, result, task_name, task_id):
     def add(self, _pid, result, task_name, task_id):
         """Add a process to the queue.
         """Add a process to the queue.
@@ -197,16 +141,33 @@ class TaskProcessQueue(object):
         if self.full():
         if self.full():
             self.wait_for_result()
             self.wait_for_result()
 
 
+    def _is_alive(self, pid):
+        try:
+            is_alive = os.waitpid(pid, os.WNOHANG) == (0, 0)
+        except OSError, e:
+            if e.errno != errno.ECHILD:
+                raise
+        return is_alive
+
+    def reap_zombies(self):
+        assert hasattr(self._pool, "_pool")
+        self.logger.debug("Trying to find zombies...")
+        for process in self._pool._pool:
+            pid = process.pid
+            if not self._is_alive(pid):
+                self.logger.error(
+                        "Process with pid %d is dead? Restarting pool" % pid)
+                self._restart()
+
     def full(self):
     def full(self):
         return len(self._processes.values()) >= self.limit
         return len(self._processes.values()) >= self.limit
 
 
-
     def wait_for_result(self):
     def wait_for_result(self):
         """Collect results from processes that are ready."""
         """Collect results from processes that are ready."""
-        assert self.full()
         while True:
         while True:
             if self.reap():
             if self.reap():
                 break
                 break
+            self.reap_zombies()
 
 
     def reap(self):
     def reap(self):
         processes_reaped = 0
         processes_reaped = 0
@@ -224,7 +185,7 @@ class TaskProcessQueue(object):
 
 
 
 
     def on_ready(self, ret_value, task_name, task_id):
     def on_ready(self, ret_value, task_name, task_id):
-        if self.done_msg and self.logger:
+        if self.done_msg:
             self.logger.info(self.done_msg % {
             self.logger.info(self.done_msg % {
                 "name": task_name,
                 "name": task_name,
                 "id": task_id,
                 "id": task_id,

+ 3 - 2
celery/timer.py

@@ -61,8 +61,9 @@ class TimeoutTimer(object):
 
 
     """
     """
 
 
-    def __init__(self, timeout):
+    def __init__(self, timeout, timeout_msg="The operation timed out"):
         self.timeout = timeout
         self.timeout = timeout
+        self.timeout_msg = timeout_msg
         self.time_start = time.time()
         self.time_start = time.time()
 
 
     def tick(self):
     def tick(self):
@@ -75,4 +76,4 @@ class TimeoutTimer(object):
         if not self.timeout:
         if not self.timeout:
             return
             return
         if time.time() > self.time_start + self.timeout:
         if time.time() > self.time_start + self.timeout:
-            raise TimeoutError("The operation timed out.")
+            raise TimeoutError(self.timeout_msg)