Przeglądaj źródła

Refactored celery.bin.celeryd so it's possible to test it

Ask Solem 15 lat temu
rodzic
commit
db5e073893

+ 98 - 76
celery/bin/celeryd.py

@@ -47,10 +47,12 @@ import multiprocessing
 
 import celery
 from celery import conf
+from celery import signals
 from celery import platform
 from celery.log import emergency_error
 from celery.task import discard_all
 from celery.utils import info
+from celery.utils import get_full_cls_name
 from celery.worker import WorkController
 
 STARTUP_INFO_FMT = """
@@ -97,98 +99,114 @@ OPTION_LIST = (
 )
 
 
-def run_worker(concurrency=conf.CELERYD_CONCURRENCY,
-        loglevel=conf.CELERYD_LOG_LEVEL, logfile=conf.CELERYD_LOG_FILE,
-        hostname=None,
-        discard=False, run_clockservice=False, events=False, **kwargs):
-    """Starts the celery worker server."""
-
-    hostname = hostname or socket.gethostname()
-
-    print("celery@%s v%s is starting." % (hostname, celery.__version__))
-
-    from celery.loaders import current_loader, load_settings
-    loader = current_loader()
-    settings = load_settings()
-
-    if not concurrency:
-        concurrency = multiprocessing.cpu_count()
-
-    if conf.CELERY_BACKEND == "database" \
-            and settings.DATABASE_ENGINE == "sqlite3" and \
-            concurrency > 1:
-        import warnings
-        warnings.warn("The sqlite3 database engine doesn't support "
-                "concurrency. We'll be using a single process only.",
-                UserWarning)
-        concurrency = 1
-
-    # Setup logging
-    if not isinstance(loglevel, int):
-        loglevel = conf.LOG_LEVELS[loglevel.upper()]
-
-    if discard:
+class Worker(object):
+
+    def __init__(self, concurrency=conf.CELERYD_CONCURRENCY,
+            loglevel=conf.CELERYD_LOG_LEVEL, logfile=conf.CELERYD_LOG_FILE,
+            hostname=None, discard=False, run_clockservice=False,
+            events=False, **kwargs):
+        self.concurrency = concurrency or multiprocessing.cpu_count()
+        self.loglevel = loglevel
+        self.logfile = logfile
+        self.hostname = hostname or socket.gethostname()
+        self.discard = discard
+        self.run_clockservice = run_clockservice
+        self.events = events
+        if not isinstance(self.loglevel, int):
+            self.loglevel = conf.LOG_LEVELS[self.loglevel.upper()]
+
+    def run(self):
+        print("celery@%s v%s is starting." % (self.hostname,
+                                              celery.__version__))
+
+        self.init_loader()
+
+        if conf.CELERY_BACKEND == "database" \
+                and self.settings.DATABASE_ENGINE == "sqlite3" and \
+                self.concurrency > 1:
+            import warnings
+            warnings.warn("The sqlite3 database engine doesn't handle "
+                          "concurrency well. Will use a single process only.",
+                          UserWarning)
+            self.concurrency = 1
+
+        if self.discard:
+            self.purge_messages()
+        self.worker_init()
+
+        # Dump configuration to screen so we have some basic information
+        # for when users sends bug reports.
+        print(self.startup_info())
+        set_process_status("Running...")
+
+        self.run_worker()
+
+    def on_listener_ready(self, listener):
+        signals.worker_ready.send(sender=listener)
+        print("celery@%s has started." % self.hostname)
+
+    def init_loader(self):
+        from celery.loaders import current_loader, load_settings
+        self.loader = current_loader()
+        self.settings = load_settings()
+
+    def purge_messages(self):
         discarded_count = discard_all()
         what = discarded_count > 1 and "messages" or "message"
         print("discard: Erased %d %s from the queue.\n" % (
-                discarded_count, what))
-
-    # Run the worker init handler.
-    # (Usually imports task modules and such.)
-    loader.on_worker_init()
+            discarded_count, what))
 
-    # Dump configuration to screen so we have some basic information
-    # when users sends e-mails.
+    def worker_init(self):
+        # Run the worker init handler.
+        # (Usually imports task modules and such.)
+        self.loader.on_worker_init()
 
-    tasklist = ""
-    if loglevel <= logging.INFO:
+    def tasklist(self, include_builtins=True):
         from celery.registry import tasks
         tasklist = tasks.keys()
-        if not loglevel <= logging.DEBUG:
-            tasklist = filter(lambda s: not s.startswith("celery."), tasklist)
-        tasklist = TASK_LIST_FMT % "\n".join("        . %s" % task
-                                                for task in sorted(tasklist))
-
-    print(STARTUP_INFO_FMT % {
+        if not include_builtins:
+            tasklist = filter(lambda s: not s.startswith("celery."),
+                              tasklist)
+        return TASK_LIST_FMT % "\n".join("\t. %s" % task
+                                            for task in sorted(tasklist))
+
+    def startup_info(self):
+        tasklist = ""
+        if self.loglevel <= logging.INFO:
+            include_builtins = self.loglevel <= logging.DEBUG
+            tasklist = self.tasklist(include_builtins=include_builtins)
+
+        return STARTUP_INFO_FMT % {
             "conninfo": info.format_broker_info(),
             "queues": info.format_routing_table(indent=8),
-            "concurrency": concurrency,
-            "loglevel": conf.LOG_LEVELS[loglevel],
-            "logfile": logfile or "[stderr]",
-            "celerybeat": run_clockservice and "ON" or "OFF",
-            "events": events and "ON" or "OFF",
+            "concurrency": self.concurrency,
+            "loglevel": conf.LOG_LEVELS[self.loglevel],
+            "logfile": self.logfile or "[stderr]",
+            "celerybeat": self.run_clockservice and "ON" or "OFF",
+            "events": self.events and "ON" or "OFF",
             "tasks": tasklist,
-            "loader": loader.__class__.__module__,
-    })
-
-    print("Celery has started.")
-    set_process_status("Running...")
-
-    def run_worker():
-        worker = WorkController(concurrency=concurrency,
-                                loglevel=loglevel,
-                                logfile=logfile,
-                                hostname=hostname,
-                                embed_clockservice=run_clockservice,
-                                send_events=events)
+            "loader": get_full_cls_name(self.loader.__class__),
+        }
+
+    def run_worker(self):
+        worker = WorkController(concurrency=self.concurrency,
+                                loglevel=self.loglevel,
+                                logfile=self.logfile,
+                                hostname=self.hostname,
+                                ready_callback=self.on_listener_ready,
+                                embed_clockservice=self.run_clockservice,
+                                send_events=self.events)
 
         # Install signal handler so SIGHUP restarts the worker.
         install_worker_restart_handler(worker)
 
-        from celery import signals
         signals.worker_init.send(sender=worker)
-
         try:
             worker.start()
-        except Exception, e:
-            emergency_error(logfile, "celeryd raised exception %s: %s\n%s" % (
-                            e.__class__, e, traceback.format_exc()))
-
-    try:
-        run_worker()
-    except:
-        set_process_status("Exiting...")
-        raise
+        except Exception, exc:
+            emergency_error(self.logfile,
+                    "celeryd raised exception %s: %s\n%s" % (
+                        exc.__class__, exc, traceback.format_exc()))
 
 
 def install_worker_restart_handler(worker):
@@ -217,9 +235,13 @@ def set_process_status(info):
     platform.set_mp_process_title("celeryd", info=info)
 
 
+def run_worker(**options):
+    return Worker(**options).run()
+
+
 def main():
     options = parse_options(sys.argv[1:])
-    run_worker(**vars(options))
+    return run_worker(**vars(options))
 
 if __name__ == "__main__":
     main()

+ 14 - 0
celery/tests/test_bin_celeryd.py

@@ -0,0 +1,14 @@
+import unittest
+
+from celery.bin import celeryd
+
+
+class TestWorker(unittest.TestCase):
+
+    def test_init_loader(self):
+
+        w = celeryd.Worker()
+        w.init_loader()
+        self.assertTrue(w.loader)
+        self.assertTrue(w.settings)
+

+ 6 - 3
celery/worker/__init__.py

@@ -14,6 +14,7 @@ from celery import platform
 from celery import signals
 from celery.log import setup_logger, _hijack_multiprocessing_logger
 from celery.beat import EmbeddedClockService
+from celery.utils import noop
 
 from celery.worker.pool import TaskPool
 from celery.worker.buckets import TaskBucket
@@ -103,7 +104,7 @@ class WorkController(object):
 
     def __init__(self, concurrency=None, logfile=None, loglevel=None,
             send_events=conf.SEND_EVENTS, hostname=None,
-            embed_clockservice=False):
+            ready_callback=noop, embed_clockservice=False):
 
         # Options
         self.loglevel = loglevel or self.loglevel
@@ -112,6 +113,7 @@ class WorkController(object):
         self.logger = setup_logger(loglevel, logfile)
         self.hostname = hostname or socket.gethostname()
         self.embed_clockservice = embed_clockservice
+        self.ready_callback = ready_callback
         self.send_events = send_events
 
         # Queues
@@ -137,12 +139,13 @@ class WorkController(object):
         if self.embed_clockservice:
             self.clockservice = EmbeddedClockService(logger=self.logger)
 
-        prefetch_count = concurrency * conf.CELERYD_PREFETCH_MULTIPLIER
+        prefetch_count = self.concurrency * conf.CELERYD_PREFETCH_MULTIPLIER
         self.listener = CarrotListener(self.ready_queue,
                                        self.eta_schedule,
                                        logger=self.logger,
                                        hostname=self.hostname,
-                                       send_events=send_events,
+                                       send_events=self.send_events,
+                                       init_callback=self.ready_callback,
                                        initial_prefetch_count=prefetch_count)
 
         # The order is important here;

+ 5 - 3
celery/worker/listener.py

@@ -7,7 +7,7 @@ from carrot.connection import AMQPConnectionException
 
 from celery import conf
 from celery import signals
-from celery.utils import retry_over_time
+from celery.utils import noop, retry_over_time
 from celery.worker.job import TaskWrapper, InvalidTaskError
 from celery.worker.revoke import revoked
 from celery.worker.control import ControlDispatch
@@ -45,12 +45,14 @@ class CarrotListener(object):
     """
 
     def __init__(self, ready_queue, eta_schedule, logger,
-            send_events=False, hostname=None, initial_prefetch_count=2):
+            init_callback=noop, send_events=False, hostname=None,
+            initial_prefetch_count=2):
         self.connection = None
         self.task_consumer = None
         self.ready_queue = ready_queue
         self.eta_schedule = eta_schedule
         self.send_events = send_events
+        self.init_callback = init_callback
         self.logger = logger
         self.hostname = hostname or socket.gethostname()
         self.control_dispatch = ControlDispatch(logger=logger,
@@ -68,7 +70,7 @@ class CarrotListener(object):
 
         """
 
-        signals.worker_ready.send(sender=self)
+        self.init_callback(self)
 
         while 1:
             self.reset_connection()

+ 2 - 1
testproj/settings.py

@@ -25,7 +25,8 @@ COVERAGE_EXCLUDE_MODULES = ("celery.__init__",
                             "celery.tests.*",
                             "celery.management.*",
                             "celery.contrib.*",
-                            "celery.bin.*",
+                            "celery.bin.celeryinit",
+                            "celery.bin.celerybeat"
                             "celery.utils.patch",
                             "celery.utils.compat",
                             "celery.task.rest",