Browse Source

Cleanup some code in TaskPool/WorkController

Ask Solem 16 years ago
parent
commit
b368345a1d
2 changed files with 24 additions and 79 deletions
  1. 12 68
      celery/pool.py
  2. 12 11
      celery/worker/__init__.py

+ 12 - 68
celery/pool.py

@@ -10,6 +10,7 @@ import uuid
 
 
 from multiprocessing.pool import RUN as POOL_STATE_RUN
 from multiprocessing.pool import RUN as POOL_STATE_RUN
 from celery.datastructures import ExceptionInfo
 from celery.datastructures import ExceptionInfo
+from functools import partial as curry
 
 
 
 
 class TaskPool(object):
 class TaskPool(object):
@@ -39,49 +40,19 @@ class TaskPool(object):
         self._pool = None
         self._pool = None
         self._processes = None
         self._processes = None
 
 
-    def run(self):
+    def start(self):
         """Run the task pool.
         """Run the task pool.
 
 
         Will pre-fork all workers so they're ready to accept tasks.
         Will pre-fork all workers so they're ready to accept tasks.
 
 
         """
         """
-        self._start()
-
-    def _start(self):
-        """INTERNAL: Starts the pool. Used by :meth:`run`."""
         self._processes = {}
         self._processes = {}
         self._pool = multiprocessing.Pool(processes=self.limit)
         self._pool = multiprocessing.Pool(processes=self.limit)
 
 
-    def terminate(self):
+    def stop(self):
         """Terminate the pool."""
         """Terminate the pool."""
         self._pool.terminate()
         self._pool.terminate()
 
 
-    def _terminate_and_restart(self):
-        """INTERNAL: Terminate and restart the pool."""
-        try:
-            self.terminate()
-        except OSError:
-            pass
-        self._start()
-
-    def _restart(self):
-        """INTERNAL: Close and restart the pool."""
-        self.logger.info("Closing and restarting the pool...")
-        self._pool.close()
-        timeout_thread = threading.Timer(30.0, self._terminate_and_restart)
-        timeout_thread.start()
-        self._pool.join()
-        timeout_thread.cancel()
-        self._start()
-
-    def _pool_is_running(self):
-        """Check if the pool is in the run state.
-
-        :returns: ``True`` if the pool is running.
-
-        """
-        return self._pool._state == POOL_STATE_RUN
-
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
             errbacks=None, on_acknowledge=None, meta=None):
             errbacks=None, on_acknowledge=None, meta=None):
         """Equivalent of the :func:``apply`` built-in function.
         """Equivalent of the :func:``apply`` built-in function.
@@ -97,56 +68,30 @@ class TaskPool(object):
         meta = meta or {}
         meta = meta or {}
         tid = str(uuid.uuid4())
         tid = str(uuid.uuid4())
 
 
-        if not self._pool_is_running():
-            self._start()
-
         self._processed_total = self._process_counter.next()
         self._processed_total = self._process_counter.next()
 
 
-        on_return = lambda r: self.on_return(r, tid, callbacks, errbacks, meta)
-
+        on_return = curry(self.on_return, tid, callbacks, errbacks, meta)
 
 
         if self.full():
         if self.full():
             self.wait_for_result()
             self.wait_for_result()
+
         result = self._pool.apply_async(target, args, kwargs,
         result = self._pool.apply_async(target, args, kwargs,
-                                           callback=on_return)
+                                        callback=on_return)
         if on_acknowledge:
         if on_acknowledge:
             on_acknowledge()
             on_acknowledge()
-        self.add(result, callbacks, errbacks, tid, meta)
+        
+        self._processes[tid] = [result, callbacks, errbacks, meta]
 
 
         return result
         return result
 
 
-    def on_return(self, ret_val, tid, callbacks, errbacks, meta):
+    def on_return(self, tid, callbacks, errbacks, meta, ret_value):
         """What to do when the process returns."""
         """What to do when the process returns."""
         try:
         try:
             del(self._processes[tid])
             del(self._processes[tid])
         except KeyError:
         except KeyError:
             pass
             pass
         else:
         else:
-            self.on_ready(ret_val, callbacks, errbacks, meta)
-
-    def add(self, result, callbacks, errbacks, tid, meta):
-        """Add a process to the queue.
-
-        If the queue is full, it will wait for the first task to finish,
-        collects its result and remove it from the queue, so it's ready
-        to accept new processes.
-
-        :param result: A :class:`multiprocessing.AsyncResult` instance, as
-            returned by :meth:`multiprocessing.Pool.apply_async`.
-
-        :option callbacks: List of callbacks to execute if the task was
-            successful. Must have the function signature:
-            ``mycallback(result, meta)``
-
-        :option errbacks: List of errbacks to execute if the task raised
-            and exception. Must have the function signature:
-            ``myerrback(exc, meta)``.
-
-        :option tid: The tid for this task (unqiue pool id).
-
-        """
-
-        self._processes[tid] = [result, callbacks, errbacks, meta]
+            self.on_ready(callbacks, errbacks, meta, ret_value)
 
 
     def full(self):
     def full(self):
         """Is the pool full?
         """Is the pool full?
@@ -179,7 +124,7 @@ class TaskPool(object):
             except multiprocessing.TimeoutError:
             except multiprocessing.TimeoutError:
                 continue
                 continue
             else:
             else:
-                self.on_return(ret_value, tid, callbacks, errbacks, meta)
+                self.on_return(tid, callbacks, errbacks, meta, ret_value)
                 processes_reaped += 1
                 processes_reaped += 1
         return processes_reaped
         return processes_reaped
 
 
@@ -187,14 +132,13 @@ class TaskPool(object):
         """Returns the process id's of all the pool workers."""
         """Returns the process id's of all the pool workers."""
         return [process.pid for process in self._pool._pool]
         return [process.pid for process in self._pool._pool]
 
 
-    def on_ready(self, ret_value, callbacks, errbacks, meta):
+    def on_ready(self, callbacks, errbacks, meta, ret_value):
         """What to do when a worker task is ready and its return value has
         """What to do when a worker task is ready and its return value has
         been collected."""
         been collected."""
 
 
         if isinstance(ret_value, ExceptionInfo):
         if isinstance(ret_value, ExceptionInfo):
             if isinstance(ret_value.exception, KeyboardInterrupt) or \
             if isinstance(ret_value.exception, KeyboardInterrupt) or \
                     isinstance(ret_value.exception, SystemExit):
                     isinstance(ret_value.exception, SystemExit):
-                self.terminate()
                 raise ret_value.exception
                 raise ret_value.exception
             for errback in errbacks:
             for errback in errbacks:
                 errback(ret_value, meta)
                 errback(ret_value, meta)

+ 12 - 11
celery/worker/__init__.py

@@ -180,21 +180,26 @@ class WorkController(object):
         self.amqp_listener = AMQPListener(self.bucket_queue, self.hold_queue,
         self.amqp_listener = AMQPListener(self.bucket_queue, self.hold_queue,
                                           logger=self.logger)
                                           logger=self.logger)
         self.mediator = Mediator(self.bucket_queue, self.safe_process_task)
         self.mediator = Mediator(self.bucket_queue, self.safe_process_task)
+        
+        # The order is important here;
+        #   the first in the list is the first to start,
+        # and they must be stopped in reverse order.  
+        self.components = [self.pool,
+                           self.mediator,
+                           self.periodic_work_controller,
+                           self.amqp_listener]
 
 
     def start(self):
     def start(self):
         """Starts the workers main loop."""
         """Starts the workers main loop."""
         self._state = "RUN"
         self._state = "RUN"
 
 
         try:
         try:
-            self.pool.run()
-            self.mediator.start()
-            self.periodic_work_controller.start()
-            self.amqp_listener.start()
+            [component.start() for component in self.components]
         finally:
         finally:
             self.stop()
             self.stop()
 
 
     def safe_process_task(self, task):
     def safe_process_task(self, task):
-        """Same as :meth:`process_task`, but catch all exceptions
+        """Same as :meth:`process_task`, but catches all exceptions
         the task raises and log them as errors, to make sure the
         the task raises and log them as errors, to make sure the
         worker doesn't die."""
         worker doesn't die."""
         try:
         try:
@@ -224,9 +229,5 @@ class WorkController(object):
         # shut down the periodic work controller thread
         # shut down the periodic work controller thread
         if self._state != "RUN":
         if self._state != "RUN":
             return
             return
-        self._state = "TERMINATE"
-        self.amqp_listener.stop()
-        self.mediator.stop()
-        self.periodic_work_controller.stop()
-        self.pool.terminate()
-        self._state = "STOP"
+
+        [component.stop() for component in reversed(self.components)]