Browse Source

Refactor the TaskProcessQueue code (now renamed to celery.pool.TaskPool)

Ask Solem 15 years ago
parent
commit
1da3f8cad7
2 changed files with 207 additions and 9 deletions
  1. 175 0
      celery/pool.py
  2. 32 9
      celery/worker.py

+ 175 - 0
celery/pool.py

@@ -0,0 +1,175 @@
+import multiprocessing
+import itertools
+import threading
+import uuid
+import time
+import os
+
+from multiprocessing.pool import RUN as POOL_STATE_RUN
+from celery.timer import TimeoutTimer, TimeoutError
+from celery.conf import REAP_TIMEOUT
+
+
+class TaskPool(object):
+    """Queue of running child processes, which starts waiting for the
+    processes to finish when the queue limit has been reached.
+
+    :param limit: see :attr:`limit` attribute.
+
+    :param logger: see :attr:`logger` attribute.
+
+    :param done_msg: see :attr:`done_msg` attribute.
+
+
+    .. attribute:: limit
+
+        The number of processes that can run simultaneously until
+        we start collecting results.
+
+    .. attribute:: logger
+
+        The logger used to print the :attr:`done_msg`.
+
+    .. attribute:: done_msg
+
+        Message logged when a tasks result has been collected.
+        The message is logged with loglevel :const:`logging.INFO`.
+
+    """
+
+    def __init__(self, limit, reap_timeout=None):
+        self.limit = limit
+        self.reap_timeout = reap_timeout
+        self._process_counter = itertools.count(1)
+        self._processed_total = 0
+
+    def run(self):
+        self._start()
+
+    def _start(self):
+        self._processes = {}
+        self._pool = multiprocessing.Pool(processes=self.limit)
+
+    def _terminate_and_restart(self):
+        try:
+            self._pool.terminate()
+        except OSError:
+            pass
+        self._start()
+
+    def _restart(self):
+        self.logger.info("Closing and restarting the pool...")
+        self._pool.close()
+        timeout_thread = threading.Timer(30.0, self._terminate_and_restart)
+        timeout_thread.start()
+        self._pool.join()
+        timeout_thread.cancel()
+        self._start()
+
+    def _pool_is_running(self):
+        return self._pool._state == POOL_STATE_RUN
+
+    def apply_async(self, target, args=None, kwargs=None, callbacks=None,
+            errbacks=None, meta=None):
+        args = args or []
+        kwargs = kwargs or {}
+        callbacks = callbacks or []
+        errbacks = errbacks or []
+        meta = meta or {}
+        id = str(uuid.uuid4())
+
+        if not self._pool_is_running():
+            self._start()
+
+        self._processed_total = self._process_counter.next()
+
+        on_return = lambda r: self.on_return(r, id, callbacks, errbacks, meta)
+
+        result = self._pool.apply_async(target, args, kwargs,
+                                           callback=on_return)
+
+        self.add(result, callbacks, errbacks, id, meta)
+
+        return result
+
+    def on_return(self, ret_val, id, callbacks, errbacks, meta):
+        try:
+            del(self._processes[id])
+        except KeyError:
+            pass
+        else:
+            self.on_ready(ret_val, callbacks, errbacks, meta)
+
+    def add(self, result, callbacks, errbacks, id, meta):
+        """Add a process to the queue.
+
+        If the queue is full, it will wait for the first task to finish,
+        collects its result and remove it from the queue, so it's ready
+        to accept new processes.
+
+        :param result: A :class:`multiprocessing.AsyncResult` instance, as
+            returned by :meth:`multiprocessing.Pool.apply_async`.
+
+        :option callbacks: List of callbacks to execute if the task was
+            successful. Must have the function signature:
+                ``mycallback(result, meta)``
+
+        :option errbacks: List of errbacks to execute if the task raised
+            and exception. Must have the function signature:
+                ``myerrback(exc, meta)``.
+
+        :option id: Explicitly set the id for this task.
+st
+        """
+
+        self._processes[id] = [result, callbacks, errbacks, meta]
+
+        if self.full():
+            self.wait_for_result()
+
+    def full(self):
+        return len(self._processes.values()) >= self.limit
+
+    def wait_for_result(self):
+        """Waits for the first process in the pool to finish.
+
+        This operation is blocking.
+
+        """
+        while True:
+            if self.reap():
+                break
+
+    def reap(self):
+        self.logger.debug("Reaping processes...")
+        processes_reaped = 0
+        for process_no, entry in enumerate(self._processes.items()):
+            id, process_info = entry
+            result, callbacks, errbacks, meta = process_info
+            try:
+                ret_value = result.get(timeout=0.3)
+            except multiprocessing.TimeoutError:
+                continue
+            else:
+                self.on_return(ret_value, id, callbacks, errbacks, meta)
+                processes_reaped += 1
+        return processes_reaped
+
+    def get_worker_pids(self):
+        """Returns the process id's of all the pool workers.
+
+        :rtype: list
+
+        """
+        return [process.pid for process in self._pool._pool]
+
+    def on_ready(self, ret_value, callbacks, errbacks, meta):
+        """What to do when a worker task is ready and its return value has
+        been collected."""
+
+        if isinstance(ret_value, Exception):
+            for errback in errbacks:
+                errback(ret_value, meta)
+        else:
+            for callback in callbacks:
+                callback(ret_value, meta)

+ 32 - 9
celery/worker.py

@@ -5,7 +5,7 @@ from celery.conf import DAEMON_CONCURRENCY, DAEMON_LOG_FILE
 from celery.conf import QUEUE_WAKEUP_AFTER, EMPTY_MSG_EMIT_EVERY
 from celery.log import setup_logger
 from celery.registry import tasks
-from celery.datastructures import TaskProcessQueue
+from celery.pool import TaskPool
 from celery.models import PeriodicTaskMeta
 from celery.backends import default_backend, default_periodic_status_backend
 from celery.timer import EventTimer
@@ -92,13 +92,16 @@ class TaskWrapper(object):
         Mapping of keyword arguments to apply to the task.
 
     """
+    done_msg = "Task %(name)s[%(id)s] processed: %(return_value)s"
 
-    def __init__(self, task_name, task_id, task_func, args, kwargs):
+    def __init__(self, task_name, task_id, task_func, args, kwargs, **opts):
         self.task_name = task_name
         self.task_id = task_id
         self.task_func = task_func
         self.args = args
         self.kwargs = kwargs
+        self.done_msg = opts.get("done_msg", self.done_msg)
+        self.logger = opts.get("logger", multiprocessing.get_logger())
 
     def __repr__(self):
         return '<%s: {name:"%s", id:"%s", args:"%s", kwargs:"%s"}>' % (
@@ -107,7 +110,7 @@ class TaskWrapper(object):
                 self.args, self.kwargs)
 
     @classmethod
-    def from_message(cls, message):
+    def from_message(cls, message, logger):
         """Create a :class:`TaskWrapper` from a task message sent by
         :class:`celery.messaging.TaskPublisher`.
 
@@ -125,7 +128,7 @@ class TaskWrapper(object):
         if task_name not in tasks:
             raise UnknownTask(task_name)
         task_func = tasks[task_name]
-        return cls(task_name, task_id, task_func, args, kwargs)
+        return cls(task_name, task_id, task_func, args, kwargs, logger=logger)
 
     def extend_with_default_kwargs(self, loglevel, logfile):
         """Extend the tasks keyword arguments with standard task arguments.
@@ -153,6 +156,24 @@ class TaskWrapper(object):
         return jail(self.task_id, [
                         self.task_func, self.args, task_func_kwargs])
 
+    def on_success(self, ret_value, meta):
+        task_id = meta.get("task_id")
+        task_name = meta.get("task_name")
+        msg = self.done_msg % {
+                "id": task_id,
+                "name": task_name,
+                "return_value": ret_value}
+        self.logger.info(msg)
+
+    def on_failure(self, ret_value, meta):
+        task_id = meta.get("task_id")
+        task_name = meta.get("task_name")
+        msg = self.done_msg % {
+                "id": task_id,
+                "name": task_name,
+                "return_value": ret_value}
+        self.logger.error(msg)
+
     def execute_using_pool(self, pool, loglevel=None, logfile=None):
         """Like :meth:`execute`, but using the :mod:`multiprocessing` pool.
 
@@ -168,8 +189,9 @@ class TaskWrapper(object):
         task_func_kwargs = self.extend_with_default_kwargs(loglevel, logfile)
         jail_args = [self.task_id, self.task_func,
                      self.args, task_func_kwargs]
-        return pool.apply_async(jail, jail_args, {},
-                                self.task_name, self.task_id)
+        return pool.apply_async(jail, args=jail_args,
+                callbacks=[self.on_success], errbacks=[self.on_failure],
+                meta={"task_id": self.task_id, "task_name": self.task_name})
 
 
 class WorkController(object):
@@ -237,8 +259,7 @@ class WorkController(object):
         self.queue_wakeup_after = queue_wakeup_after or \
                                     self.queue_wakeup_after
         self.logger = setup_logger(loglevel, logfile)
-        self.pool = TaskProcessQueue(self.concurrency, logger=self.logger,
-                done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
+        self.pool = TaskPool(self.concurrency)
         self.task_consumer = None
         self.is_detached = is_detached
         self.reset_connection()
@@ -298,7 +319,7 @@ class WorkController(object):
         if message is None: # No messages waiting.
             raise EmptyQueue()
 
-        task = TaskWrapper.from_message(message)
+        task = TaskWrapper.from_message(message, logger=self.logger)
         self.logger.info("Got task from broker: %s[%s]" % (
                             task.task_name, task.task_id))
 
@@ -342,6 +363,8 @@ class WorkController(object):
             EventTimer(self.schedule_retry_tasks, 2),
         ]
 
+        self.pool.run()
+
         # If not running as daemon, and DEBUG logging level is enabled,
         # print pool PIDs and sleep for a second before we start.
         if self.logger.isEnabledFor(logging.DEBUG):