Browse Source

Shutdown properly on TERM, and ack() early instead of late, this is important as we can't ack it if the broker connection is lost.

Ask Solem 15 years ago
parent
commit
af6d45f0d9

+ 0 - 1
celery/bin/celeryd.py

@@ -229,7 +229,6 @@ class Worker(object):
 def install_worker_term_handler(worker):
 def install_worker_term_handler(worker):
 
 
     def _stop(signum, frame):
     def _stop(signum, frame):
-        worker.stop()
         raise SystemExit()
         raise SystemExit()
 
 
     platform.install_signal_handler("SIGTERM", _stop)
     platform.install_signal_handler("SIGTERM", _stop)

+ 3 - 2
celery/tests/test_worker.py

@@ -177,7 +177,7 @@ class TestCarrotListener(unittest.TestCase):
         l.reset_connection()
         l.reset_connection()
         self.assertTrue(isinstance(l.connection, BrokerConnection))
         self.assertTrue(isinstance(l.connection, BrokerConnection))
 
 
-        l.close_connection()
+        l.stop_consumers()
         self.assertTrue(l.connection is None)
         self.assertTrue(l.connection is None)
         self.assertTrue(l.task_consumer is None)
         self.assertTrue(l.task_consumer is None)
 
 
@@ -185,6 +185,7 @@ class TestCarrotListener(unittest.TestCase):
         self.assertTrue(isinstance(l.connection, BrokerConnection))
         self.assertTrue(isinstance(l.connection, BrokerConnection))
 
 
         l.stop()
         l.stop()
+        l.close_connection()
         self.assertTrue(l.connection is None)
         self.assertTrue(l.connection is None)
         self.assertTrue(l.task_consumer is None)
         self.assertTrue(l.task_consumer is None)
 
 
@@ -209,7 +210,7 @@ class TestCarrotListener(unittest.TestCase):
         eventer = l.event_dispatcher = MockEventDispatcher()
         eventer = l.event_dispatcher = MockEventDispatcher()
         heart = l.heart = MockHeart()
         heart = l.heart = MockHeart()
         l._state = RUN
         l._state = RUN
-        l.close_connection()
+        l.stop_consumers()
         self.assertTrue(eventer.closed)
         self.assertTrue(eventer.closed)
         self.assertTrue(heart.closed)
         self.assertTrue(heart.closed)
 
 

+ 4 - 1
celery/worker/__init__.py

@@ -7,6 +7,7 @@ import socket
 import logging
 import logging
 import traceback
 import traceback
 from Queue import Queue
 from Queue import Queue
+from multiprocessing.util import Finalize
 
 
 from celery import conf
 from celery import conf
 from celery import registry
 from celery import registry
@@ -27,6 +28,7 @@ def process_initializer():
     # There seems to a bug in multiprocessing (backport?)
     # There seems to a bug in multiprocessing (backport?)
     # when detached, where the worker gets EOFErrors from time to time
     # when detached, where the worker gets EOFErrors from time to time
     # and the logger is left from the parent process causing a crash.
     # and the logger is left from the parent process causing a crash.
+    platform.reset_signal("SIGTERM")
     _hijack_multiprocessing_logger()
     _hijack_multiprocessing_logger()
     platform.set_mp_process_title("celeryd")
     platform.set_mp_process_title("celeryd")
 
 
@@ -116,6 +118,7 @@ class WorkController(object):
         self.embed_clockservice = embed_clockservice
         self.embed_clockservice = embed_clockservice
         self.ready_callback = ready_callback
         self.ready_callback = ready_callback
         self.send_events = send_events
         self.send_events = send_events
+        self._finalize = Finalize(self, self.stop, exitpriority=20)
 
 
         # Queues
         # Queues
         if conf.DISABLE_RATE_LIMITS:
         if conf.DISABLE_RATE_LIMITS:
@@ -190,5 +193,5 @@ class WorkController(object):
 
 
         signals.worker_shutdown.send(sender=self)
         signals.worker_shutdown.send(sender=self)
         [component.stop() for component in reversed(self.components)]
         [component.stop() for component in reversed(self.components)]
-
+        self.listener.close_connection()
         self._state = "STOP"
         self._state = "STOP"

+ 4 - 3
celery/worker/job.py

@@ -303,9 +303,10 @@ class TaskWrapper(object):
 
 
         args = self._get_tracer_args(loglevel, logfile)
         args = self._get_tracer_args(loglevel, logfile)
         self.time_start = time.time()
         self.time_start = time.time()
-        return pool.apply_async(execute_and_trace, args=args,
-                callbacks=[self.on_success], errbacks=[self.on_failure],
-                on_ack=self.on_ack)
+        result = pool.apply_async(execute_and_trace, args=args,
+                    callbacks=[self.on_success], errbacks=[self.on_failure])
+        self.on_ack()
+        return result
 
 
     def on_success(self, ret_value):
     def on_success(self, ret_value):
         """The handler used if the task was successfully processed (
         """The handler used if the task was successfully processed (

+ 12 - 7
celery/worker/listener.py

@@ -160,6 +160,11 @@ class CarrotListener(object):
         message.ack()
         message.ack()
 
 
     def close_connection(self):
     def close_connection(self):
+        self.logger.debug("CarrotListener: "
+                          "Closing connection to broker...")
+        self.connection = self.connection and self.connection.close()
+
+    def stop_consumers(self, close=True):
         if not self._state == RUN:
         if not self._state == RUN:
             return
             return
         self._state = CLOSE
         self._state = CLOSE
@@ -175,14 +180,13 @@ class CarrotListener(object):
             self.logger.debug("EventDispatcher: Shutting down...")
             self.logger.debug("EventDispatcher: Shutting down...")
             self.event_dispatcher = self.event_dispatcher.close()
             self.event_dispatcher = self.event_dispatcher.close()
 
 
-        self.logger.debug("CarrotListener: "
-                          "Closing connection to broker...")
-        self.connection = self.connection and self.connection.close()
+        if close:
+            self.close_connection()
 
 
     def reset_connection(self):
     def reset_connection(self):
         self.logger.debug(
         self.logger.debug(
                 "CarrotListener: Re-establishing connection to the broker...")
                 "CarrotListener: Re-establishing connection to the broker...")
-        self.close_connection()
+        self.stop_consumers()
         self.connection = self._open_connection()
         self.connection = self._open_connection()
         self.logger.debug("CarrotListener: Connection Established.")
         self.logger.debug("CarrotListener: Connection Established.")
         self.task_consumer = get_consumer_set(connection=self.connection)
         self.task_consumer = get_consumer_set(connection=self.connection)
@@ -219,11 +223,11 @@ class CarrotListener(object):
 
 
         def _connection_error_handler(exc, interval):
         def _connection_error_handler(exc, interval):
             """Callback handler for connection errors."""
             """Callback handler for connection errors."""
-            self.logger.error("AMQP Listener: Connection Error: %s. " % exc
+            self.logger.error("CarrotListener: Connection Error: %s. " % exc
                      + "Trying again in %d seconds..." % interval)
                      + "Trying again in %d seconds..." % interval)
 
 
         def _establish_connection():
         def _establish_connection():
-            """Establish a connection to the AMQP broker."""
+            """Establish a connection to the broker."""
             conn = establish_connection()
             conn = establish_connection()
             conn.connect() # Connection is established lazily, so connect.
             conn.connect() # Connection is established lazily, so connect.
             return conn
             return conn
@@ -237,4 +241,5 @@ class CarrotListener(object):
         return conn
         return conn
 
 
     def stop(self):
     def stop(self):
-        self.close_connection()
+        self.logger.debug("CarrotListener: Stopping consumers...")
+        self.stop_consumers(close=False)

+ 4 - 6
celery/worker/pool.py

@@ -45,7 +45,7 @@ class TaskPool(object):
 
 
     def stop(self):
     def stop(self):
         """Terminate the pool."""
         """Terminate the pool."""
-        self._pool.terminate()
+        self._pool.close()
         self._pool.join()
         self._pool.join()
         self._pool = None
         self._pool = None
 
 
@@ -58,7 +58,7 @@ class TaskPool(object):
                     dead_count))
                     dead_count))
 
 
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
     def apply_async(self, target, args=None, kwargs=None, callbacks=None,
-            errbacks=None, on_ack=noop):
+            errbacks=None, **compat):
         """Equivalent of the :func:``apply`` built-in function.
         """Equivalent of the :func:``apply`` built-in function.
 
 
         All ``callbacks`` and ``errbacks`` should complete immediately since
         All ``callbacks`` and ``errbacks`` should complete immediately since
@@ -70,7 +70,7 @@ class TaskPool(object):
         callbacks = callbacks or []
         callbacks = callbacks or []
         errbacks = errbacks or []
         errbacks = errbacks or []
 
 
-        on_ready = curry(self.on_ready, callbacks, errbacks, on_ack)
+        on_ready = curry(self.on_ready, callbacks, errbacks)
 
 
         self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
         self.logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)" % (
             target, args, kwargs))
             target, args, kwargs))
@@ -80,11 +80,9 @@ class TaskPool(object):
         return self._pool.apply_async(target, args, kwargs,
         return self._pool.apply_async(target, args, kwargs,
                                         callback=on_ready)
                                         callback=on_ready)
 
 
-    def on_ready(self, callbacks, errbacks, on_ack, ret_value):
+    def on_ready(self, callbacks, errbacks, ret_value):
         """What to do when a worker task is ready and its return value has
         """What to do when a worker task is ready and its return value has
         been collected."""
         been collected."""
-        # Acknowledge the task as being processed.
-        on_ack()
 
 
         if isinstance(ret_value, ExceptionInfo):
         if isinstance(ret_value, ExceptionInfo):
             if isinstance(ret_value.exception, (
             if isinstance(ret_value.exception, (