Browse Source

Dynamic module loading

mher 14 years ago
parent
commit
97369467b0

+ 4 - 0
celery/concurrency/base.py

@@ -52,6 +52,10 @@ class BasePool(object):
         raise NotImplementedError(
         raise NotImplementedError(
                 "%s does not implement kill_job" % (self.__class__, ))
                 "%s does not implement kill_job" % (self.__class__, ))
 
 
+    def restart(self):
+        raise NotImplementedError(
+                "%s does not implement restart" % (self.__class__, ))
+
     def stop(self):
     def stop(self):
         self._state = self.CLOSE
         self._state = self.CLOSE
         self.on_stop()
         self.on_stop()

+ 3 - 0
celery/concurrency/processes/__init__.py

@@ -68,6 +68,9 @@ class TaskPool(BasePool):
     def shrink(self, n=1):
     def shrink(self, n=1):
         return self._pool.shrink(n)
         return self._pool.shrink(n)
 
 
+    def restart(self):
+        self._pool.restart()
+
     def _get_info(self):
     def _get_info(self):
         return {"max-concurrency": self.limit,
         return {"max-concurrency": self.limit,
                 "processes": [p.pid for p in self._pool._pool],
                 "processes": [p.pid for p in self._pool._pool],

+ 17 - 3
celery/concurrency/processes/pool.py

@@ -23,7 +23,7 @@ import time
 import signal
 import signal
 import warnings
 import warnings
 
 
-from multiprocessing import Process, cpu_count, TimeoutError
+from multiprocessing import Process, cpu_count, TimeoutError, Event
 from multiprocessing import util
 from multiprocessing import util
 from multiprocessing.util import Finalize, debug
 from multiprocessing.util import Finalize, debug
 
 
@@ -118,7 +118,8 @@ def soft_timeout_sighandler(signum, frame):
 #
 #
 
 
 
 
-def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
+def worker(inqueue, outqueue, initializer=None, initargs=(), 
+           maxtasks=None, sentinel=None):
     pid = os.getpid()
     pid = os.getpid()
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
     put = outqueue.put
@@ -150,6 +151,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
 
 
     completed = 0
     completed = 0
     while maxtasks is None or (maxtasks and completed < maxtasks):
     while maxtasks is None or (maxtasks and completed < maxtasks):
+        if sentinel is not None and sentinel.is_set():
+            debug('worker got sentinel -- exiting')
+            break
+
         try:
         try:
             ready, task = poll(1.0)
             ready, task = poll(1.0)
             if not ready:
             if not ready:
@@ -509,6 +514,7 @@ class Pool(object):
             raise TypeError('initializer must be a callable')
             raise TypeError('initializer must be a callable')
 
 
         self._pool = []
         self._pool = []
+        self._poolctrl = {}
         for i in range(processes):
         for i in range(processes):
             self._create_worker_process()
             self._create_worker_process()
 
 
@@ -546,16 +552,19 @@ class Pool(object):
             )
             )
 
 
     def _create_worker_process(self):
     def _create_worker_process(self):
+        sentinel = Event()
         w = self.Process(
         w = self.Process(
             target=worker,
             target=worker,
             args=(self._inqueue, self._outqueue,
             args=(self._inqueue, self._outqueue,
                     self._initializer, self._initargs,
                     self._initializer, self._initargs,
-                    self._maxtasksperchild),
+                    self._maxtasksperchild,
+                    sentinel),
             )
             )
         self._pool.append(w)
         self._pool.append(w)
         w.name = w.name.replace('Process', 'PoolWorker')
         w.name = w.name.replace('Process', 'PoolWorker')
         w.daemon = True
         w.daemon = True
         w.start()
         w.start()
+        self._poolctrl[w.pid] = sentinel
         return w
         return w
 
 
     def _join_exited_workers(self, shutdown=False, lost_worker_timeout=10.0):
     def _join_exited_workers(self, shutdown=False, lost_worker_timeout=10.0):
@@ -587,6 +596,7 @@ class Pool(object):
                 debug('Supervisor: worked %d joined' % i)
                 debug('Supervisor: worked %d joined' % i)
                 cleaned.append(worker.pid)
                 cleaned.append(worker.pid)
                 del self._pool[i]
                 del self._pool[i]
+                del self._poolctrl[worker.pid]
         if cleaned:
         if cleaned:
             for job in self._cache.values():
             for job in self._cache.values():
                 for worker_pid in job.worker_pids():
                 for worker_pid in job.worker_pids():
@@ -830,6 +840,10 @@ class Pool(object):
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             p.join()
             p.join()
 
 
+    def restart(self):
+        for e in self._poolctrl.itervalues():
+            e.set()
+
     @staticmethod
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
     def _help_stuff_finish(inqueue, task_handler, size):
         # task_handler may be blocked trying to put items on inqueue
         # task_handler may be blocked trying to put items on inqueue

+ 18 - 0
celery/worker/control/builtins.py

@@ -207,6 +207,24 @@ def pool_shrink(panel, n=1, **kwargs):
     return {"ok": "terminated worker processes"}
     return {"ok": "terminated worker processes"}
 
 
 
 
+@Panel.register
+def pool_restart(panel, imports=[], reload_imports=False, **kwargs):
+    imports = set(imports)
+    celery_imports = set(panel.app.conf.CELERY_IMPORTS)
+    for m in imports - celery_imports:
+        panel.app.loader.import_from_cwd(m)
+        panel.logger.debug("imported %s module" % m)
+    if reload_imports:
+        for m in celery_imports:
+            reload(sys.modules[m])
+            panel.logger.debug("reloaded %s module" % m)
+    else:
+        for m in imports & celery_imports:
+            reload(sys.modules[m])
+            panel.logger.debug("reloaded %s module" % m)
+    panel.consumer.pool.restart()
+    return {"ok": "started restarting worker processes"}
+
 @Panel.register
 @Panel.register
 def shutdown(panel, **kwargs):
 def shutdown(panel, **kwargs):
     panel.logger.warning("Got shutdown from remote.")
     panel.logger.warning("Got shutdown from remote.")