Parcourir la source

Worker cleanup. The TaskDaemon class is now called WorkController

Ask Solem il y a 16 ans
Parent
commit
473bc8ad0c
3 fichiers modifiés avec 112 ajouts et 41 suppressions
  1. 4 3
      celery/bin/celeryd.py
  2. 75 17
      celery/datastructures.py
  3. 33 21
      celery/worker.py

+ 4 - 3
celery/bin/celeryd.py

@@ -45,7 +45,7 @@ from celery.conf import LOG_LEVELS, DAEMON_LOG_FILE, DAEMON_LOG_LEVEL
 from celery.conf import DAEMON_CONCURRENCY, DAEMON_PID_FILE
 from celery.conf import DAEMON_CONCURRENCY, DAEMON_PID_FILE
 from celery.conf import QUEUE_WAKEUP_AFTER
 from celery.conf import QUEUE_WAKEUP_AFTER
 from celery import discovery
 from celery import discovery
-from celery.worker import TaskDaemon
+from celery.worker import WorkController
 import traceback
 import traceback
 import optparse
 import optparse
 import atexit
 import atexit
@@ -70,10 +70,11 @@ def main(concurrency=DAEMON_CONCURRENCY, daemon=False,
         logfile = None # log to stderr when not running as daemon.
         logfile = None # log to stderr when not running as daemon.
 
 
     discovery.autodiscover()
     discovery.autodiscover()
-    celeryd = TaskDaemon(concurrency=concurrency,
+    celeryd = WorkController(concurrency=concurrency,
                                loglevel=loglevel,
                                loglevel=loglevel,
                                logfile=logfile,
                                logfile=logfile,
-                               queue_wakeup_after=queue_wakeup_after)
+                               queue_wakeup_after=queue_wakeup_after,
+                               is_detached=daemon)
     try:
     try:
         celeryd.run()
         celeryd.run()
     except Exception, e:
     except Exception, e:

+ 75 - 17
celery/datastructures.py

@@ -4,6 +4,7 @@ Custom Datastructures
 
 
 """
 """
 import multiprocessing
 import multiprocessing
+from multiprocessing.pool import RUN as POOL_STATE_RUN
 import itertools
 import itertools
 import threading
 import threading
 import time
 import time
@@ -85,42 +86,58 @@ class TaskProcessQueue(object):
         self.logger = logger or multiprocessing.get_logger()
         self.logger = logger or multiprocessing.get_logger()
         self.done_msg = done_msg
         self.done_msg = done_msg
         self.reap_timeout = reap_timeout
         self.reap_timeout = reap_timeout
-        self._processes = {}
         self._process_counter = itertools.count(1)
         self._process_counter = itertools.count(1)
+        self._processed_total = 0
         self._data_lock = threading.Condition(threading.Lock())
         self._data_lock = threading.Condition(threading.Lock())
         self._start()
         self._start()
 
 
     def _start(self):
     def _start(self):
-        assert int(self.limit)
+        self._processes = {}
         self._pool = multiprocessing.Pool(processes=self.limit)
         self._pool = multiprocessing.Pool(processes=self.limit)
 
 
+    def _terminate_and_restart(self):
+        try:
+            self._pool.terminate()
+        except OSError:
+            pass
+        self._start()
+
     def _restart(self):
     def _restart(self):
         self.logger.info("Closing and restarting the pool...")
         self.logger.info("Closing and restarting the pool...")
         self._pool.close()
         self._pool.close()
+        timeout_thread = threading.Timer(30.0, self._terminate_and_restart)
+        timeout_thread.start()
         self._pool.join()
         self._pool.join()
+        timeout_thread.cancel()
         self._start()
         self._start()
 
 
+    def _pool_is_running(self):
+        return self._pool._state == POOL_STATE_RUN
+
     def apply_async(self, target, args, kwargs, task_name, task_id):
     def apply_async(self, target, args, kwargs, task_name, task_id):
-        _pid = self._process_counter.next()
 
 
-        on_return = lambda ret_val: self.on_return(_pid, ret_val,
-                                                   task_name, task_id)
+        if not self._pool_is_running():
+            self._start()
+
+        self._processed_total = self._process_counter.next()
+        
+        on_return = lambda r: self.on_return(r, task_name, task_id)
 
 
         result = self._pool.apply_async(target, args, kwargs,
         result = self._pool.apply_async(target, args, kwargs,
                                            callback=on_return)
                                            callback=on_return)
-        self.add(_pid, result, task_name, task_id)
+        self.add(result, task_name, task_id)
 
 
         return result
         return result
 
 
-    def on_return(self, _pid, ret_val, task_name, task_id):
+    def on_return(self, ret_val, task_name, task_id):
         try:
         try:
-            del(self._processes[_pid])
+            del(self._processes[task_id])
         except KeyError:
         except KeyError:
             pass
             pass
         else:
         else:
             self.on_ready(ret_val, task_name, task_id)
             self.on_ready(ret_val, task_name, task_id)
 
 
-    def add(self, _pid, result, task_name, task_id):
+    def add(self, result, task_name, task_id):
         """Add a process to the queue.
         """Add a process to the queue.
 
 
         If the queue is full, it will wait for the first task to finish,
         If the queue is full, it will wait for the first task to finish,
@@ -136,12 +153,22 @@ class TaskProcessQueue(object):
 
 
         """
         """
       
       
-        self._processes[_pid] = [result, task_name, task_id]
+        self._processes[task_id] = [result, task_name]
 
 
         if self.full():
         if self.full():
             self.wait_for_result()
             self.wait_for_result()
 
 
     def _is_alive(self, pid):
     def _is_alive(self, pid):
+        """Uses non-blocking ``waitpid`` to see if a process is still alive.
+
+        :param pid: The process id of the process.
+
+        :returns: ``True`` if the process is still running, ``False``
+            otherwise.
+
+        :rtype: bool
+
+        """
         try:
         try:
             is_alive = os.waitpid(pid, os.WNOHANG) == (0, 0)
             is_alive = os.waitpid(pid, os.WNOHANG) == (0, 0)
         except OSError, e:
         except OSError, e:
@@ -149,7 +176,7 @@ class TaskProcessQueue(object):
                 raise
                 raise
         return is_alive
         return is_alive
 
 
-    def reap_zombies(self):
+    def _reap_zombies(self):
         assert hasattr(self._pool, "_pool")
         assert hasattr(self._pool, "_pool")
         self.logger.debug("Trying to find zombies...")
         self.logger.debug("Trying to find zombies...")
         for process in self._pool._pool:
         for process in self._pool._pool:
@@ -163,28 +190,59 @@ class TaskProcessQueue(object):
         return len(self._processes.values()) >= self.limit
         return len(self._processes.values()) >= self.limit
 
 
     def wait_for_result(self):
     def wait_for_result(self):
-        """Collect results from processes that are ready."""
+        """Waits for the first process in the pool to finish.
+
+        This operation is blocking.
+
+        """
         while True:
         while True:
             if self.reap():
             if self.reap():
                 break
                 break
-            self.reap_zombies()
+            #self._reap_zombies()
 
 
     def reap(self):
     def reap(self):
+        self.logger.debug("Reaping processes...")
         processes_reaped = 0
         processes_reaped = 0
         for process_no, entry in enumerate(self._processes.items()):
         for process_no, entry in enumerate(self._processes.items()):
-            _pid, process_info = entry
-            result, task_name, task_id = process_info
+            task_id, process_info = entry
+            result, task_name = process_info
             try:
             try:
-                ret_value = result.get(timeout=0.1)
+                ret_value = result.get(timeout=0.3)
             except multiprocessing.TimeoutError:
             except multiprocessing.TimeoutError:
                 continue
                 continue
             else:
             else:
-                self.on_return(_pid, ret_value, task_name, task_id)
+                self.on_return(ret_value, task_name, task_id)
                 processes_reaped += 1
                 processes_reaped += 1
         return processes_reaped
         return processes_reaped
 
 
+    def get_worker_pids(self):
+        """Returns the process id's of all the pool workers.
+
+        :rtype: list
+
+        """
+        return [process.pid for process in self._pool._pool]
 
 
     def on_ready(self, ret_value, task_name, task_id):
     def on_ready(self, ret_value, task_name, task_id):
+        """What to do when a worker returns with a result.
+
+        If :attr:`done_msg` is defined, it will log this
+        format string, with level ``logging.INFO``,
+        using these format variables:
+
+            * %(name)
+
+                The name of the task completed
+
+            * %(id)
+
+                The UUID of the task completed.
+
+            * %(return_value)
+
+                Return value of the task function.
+
+        """
         if self.done_msg:
         if self.done_msg:
             self.logger.info(self.done_msg % {
             self.logger.info(self.done_msg % {
                 "name": task_name,
                 "name": task_name,

+ 33 - 21
celery/worker.py

@@ -7,7 +7,7 @@ from celery.log import setup_logger
 from celery.registry import tasks
 from celery.registry import tasks
 from celery.datastructures import TaskProcessQueue
 from celery.datastructures import TaskProcessQueue
 from celery.models import PeriodicTaskMeta
 from celery.models import PeriodicTaskMeta
-from celery.backends import default_backend
+from celery.backends import default_backend, default_periodic_status_backend
 from celery.timer import EventTimer
 from celery.timer import EventTimer
 import multiprocessing
 import multiprocessing
 import simplejson
 import simplejson
@@ -37,16 +37,12 @@ def jail(task_id, func, args, kwargs):
     result, and sets the task status to ``"FAILURE"``.
     result, and sets the task status to ``"FAILURE"``.
 
 
     :param task_id: The id of the task.
     :param task_id: The id of the task.
-
     :param func: Callable object to execute.
     :param func: Callable object to execute.
-
     :param args: List of positional args to pass on to the function.
     :param args: List of positional args to pass on to the function.
-
     :param kwargs: Keyword arguments mapping to pass on to the function.
     :param kwargs: Keyword arguments mapping to pass on to the function.
 
 
-    :returns: the function return value on success.
-
-    :returns: the exception instance on failure.
+    :returns: the function return value on success, or
+        the exception instance on failure.
 
 
     """
     """
     try:
     try:
@@ -101,6 +97,12 @@ class TaskWrapper(object):
         self.args = args
         self.args = args
         self.kwargs = kwargs
         self.kwargs = kwargs
 
 
+    def __repr__(self):
+        return '<%s: {name:"%s", id:"%s", args:"%s", kwargs:"%s"}>' % (
+                self.__class__.__name__,
+                self.task_name, self.task_id,
+                self.args, self.kwargs)
+
     @classmethod
     @classmethod
     def from_message(cls, message):
     def from_message(cls, message):
         """Create a :class:`TaskWrapper` from a task message sent by
         """Create a :class:`TaskWrapper` from a task message sent by
@@ -168,7 +170,7 @@ class TaskWrapper(object):
                                 self.task_name, self.task_id)
                                 self.task_name, self.task_id)
 
 
 
 
-class TaskDaemon(object):
+class WorkController(object):
     """Executes tasks waiting in the task queue.
     """Executes tasks waiting in the task queue.
     
     
     :param concurrency: see :attr:`concurrency`.
     :param concurrency: see :attr:`concurrency`.
@@ -226,7 +228,7 @@ class TaskDaemon(object):
     empty_msg_emit_every = EMPTY_MSG_EMIT_EVERY
     empty_msg_emit_every = EMPTY_MSG_EMIT_EVERY
 
 
     def __init__(self, concurrency=None, logfile=None, loglevel=None,
     def __init__(self, concurrency=None, logfile=None, loglevel=None,
-            queue_wakeup_after=None):
+            queue_wakeup_after=None, is_detached=False):
         self.loglevel = loglevel or self.loglevel
         self.loglevel = loglevel or self.loglevel
         self.concurrency = concurrency or self.concurrency
         self.concurrency = concurrency or self.concurrency
         self.logfile = logfile or self.logfile
         self.logfile = logfile or self.logfile
@@ -236,6 +238,7 @@ class TaskDaemon(object):
         self.pool = TaskProcessQueue(self.concurrency, logger=self.logger,
         self.pool = TaskProcessQueue(self.concurrency, logger=self.logger,
                 done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
                 done_msg="Task %(name)s[%(id)s] processed: %(return_value)s")
         self.task_consumer = None
         self.task_consumer = None
+        self.is_detached = is_detached
         self.reset_connection()
         self.reset_connection()
 
 
     def reset_connection(self):
     def reset_connection(self):
@@ -273,8 +276,11 @@ class TaskDaemon(object):
 
 
         """
         """
         #self.connection_diagnostics()
         #self.connection_diagnostics()
+        self.logger.debug("Trying to fetch message from broker...")
         message = self.task_consumer.fetch()
         message = self.task_consumer.fetch()
         if message is not None:
         if message is not None:
+            self.logger.debug("Acknowledging message with delivery tag %s" % (
+                message.delivery_tag))
             message.ack()
             message.ack()
         return message
         return message
 
 
@@ -304,27 +310,23 @@ class TaskDaemon(object):
         :const:`logging.CRITICAL`.
         :const:`logging.CRITICAL`.
 
 
         """
         """
+        self.logger.debug("Trying to fetch a task.")
         task, message = self.fetch_next_task()
         task, message = self.fetch_next_task()
+        self.logger.debug("Got a task: %s. Trying to execute it..." % task)
+        
+        result = task.execute_using_pool(self.pool, self.loglevel,
+                                         self.logfile)
 
 
-        try:
-            result = task.execute_using_pool(self.pool, self.loglevel,
-                                             self.logfile)
-        except Exception, error:
-            self.logger.critical("Worker got exception %s: %s\n%s" % (
-                error.__class__, error, traceback.format_exc()))
-            return
+        self.logger.debug("Task %s has been executed asynchronously." % task)
 
 
         return result, task.task_name, task.task_id
         return result, task.task_name, task.task_id
 
 
     def run_periodic_tasks(self):
     def run_periodic_tasks(self):
         """Schedule all waiting periodic tasks for execution.
         """Schedule all waiting periodic tasks for execution.
 
 
-        :rtype: list of :class:`celery.models.PeriodicTaskMeta` objects.
         """
         """
-        waiting_tasks = PeriodicTaskMeta.objects.get_waiting_tasks()
-        [waiting_task.delay()
-                for waiting_task in waiting_tasks]
-        return waiting_tasks
+        self.logger.debug("Looking for periodic tasks ready for execution...")
+        default_periodic_status_backend.run_periodic_tasks()
 
 
     def schedule_retry_tasks(self):
     def schedule_retry_tasks(self):
         """Reschedule all requeued tasks waiting for retry."""
         """Reschedule all requeued tasks waiting for retry."""
@@ -339,9 +341,19 @@ class TaskDaemon(object):
             EventTimer(self.schedule_retry_tasks, 2),
             EventTimer(self.schedule_retry_tasks, 2),
         ]
         ]
 
 
+        # If not running as daemon, and DEBUG logging level is enabled,
+        # print pool PIDs and sleep for a second before we start.
+        if self.logger.isEnabledFor(logging.DEBUG):
+            self.logger.debug("Pool child processes: [%s]" % (
+                "|".join(map(str, self.pool.get_worker_pids()))))
+            if not self.is_detached:
+                time.sleep(1)
+
         while True:
         while True:
+            print("!!!!! Running tick...")
             [event.tick() for event in events]
             [event.tick() for event in events]
             try:
             try:
+                print("Trying to execute task.")
                 result, task_name, task_id = self.execute_next_task()
                 result, task_name, task_id = self.execute_next_task()
             except ValueError:
             except ValueError:
                 # execute_next_task didn't return a r/name/id tuple,
                 # execute_next_task didn't return a r/name/id tuple,