Browse Source

Merge branch 'code-reload'

Ask Solem 13 years ago
parent
commit
3b9ea2a840

+ 4 - 0
celery/concurrency/base.py

@@ -59,6 +59,10 @@ class BasePool(object):
         raise NotImplementedError(
                 "%s does not implement kill_job" % (self.__class__, ))
 
+    def restart(self):
+        raise NotImplementedError(
+                "%s does not implement restart" % (self.__class__, ))
+
     def stop(self):
         self._state = self.CLOSE
         self.on_stop()

+ 3 - 0
celery/concurrency/processes/__init__.py

@@ -65,6 +65,9 @@ class TaskPool(BasePool):
     def shrink(self, n=1):
         return self._pool.shrink(n)
 
+    def restart(self):
+        self._pool.restart()
+
     def _get_info(self):
         return {"max-concurrency": self.limit,
                 "processes": [p.pid for p in self._pool._pool],

+ 17 - 3
celery/concurrency/processes/pool.py

@@ -24,7 +24,7 @@ import signal
 import warnings
 import logging
 
-from multiprocessing import Process, cpu_count, TimeoutError
+from multiprocessing import Process, cpu_count, TimeoutError, Event
 from multiprocessing import util
 from multiprocessing.util import Finalize, debug
 
@@ -134,7 +134,8 @@ def soft_timeout_sighandler(signum, frame):
 #
 
 
-def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
+def worker(inqueue, outqueue, initializer=None, initargs=(),
+           maxtasks=None, sentinel=None):
     # Re-init logging system.
     # Workaround for http://bugs.python.org/issue6721#msg140215
     # Python logging module uses RLock() objects which are broken after
@@ -177,6 +178,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
 
     completed = 0
     while maxtasks is None or (maxtasks and completed < maxtasks):
+        if sentinel is not None and sentinel.is_set():
+            debug('worker got sentinel -- exiting')
+            break
+
         try:
             ready, task = poll(1.0)
             if not ready:
@@ -543,6 +548,7 @@ class Pool(object):
             raise TypeError('initializer must be a callable')
 
         self._pool = []
+        self._poolctrl = {}
         for i in range(processes):
             self._create_worker_process()
 
@@ -580,16 +586,19 @@ class Pool(object):
             )
 
     def _create_worker_process(self):
+        sentinel = Event()
         w = self.Process(
             target=worker,
             args=(self._inqueue, self._outqueue,
                     self._initializer, self._initargs,
-                    self._maxtasksperchild),
+                    self._maxtasksperchild,
+                    sentinel),
             )
         self._pool.append(w)
         w.name = w.name.replace('Process', 'PoolWorker')
         w.daemon = True
         w.start()
+        self._poolctrl[w.pid] = sentinel
         return w
 
     def _join_exited_workers(self, shutdown=False):
@@ -626,6 +635,7 @@ class Pool(object):
                 debug('Supervisor: worked %d joined' % i)
                 cleaned.append(worker.pid)
                 del self._pool[i]
+                del self._poolctrl[worker.pid]
         if cleaned:
             for job in self._cache.values():
                 for worker_pid in job.worker_pids():
@@ -874,6 +884,10 @@ class Pool(object):
             debug('joining worker %s/%s (%r)' % (i, len(self._pool), p, ))
             p.join()
 
+    def restart(self):
+        for e in self._poolctrl.itervalues():
+            e.set()
+
     @staticmethod
     def _help_stuff_finish(inqueue, task_handler, size):
         # task_handler may be blocked trying to put items on inqueue

+ 6 - 0
celery/tests/test_concurrency/__init__.py

@@ -1,4 +1,5 @@
 from __future__ import absolute_import
+from __future__ import with_statement
 
 import os
 
@@ -63,3 +64,8 @@ class test_BasePool(unittest.TestCase):
         self.assertFalse(p.active)
         p._state = p.RUN
         self.assertTrue(p.active)
+
+    def test_restart(self):
+        p = BasePool(10)
+        with self.assertRaises(NotImplementedError):
+            p.restart()

+ 13 - 0
celery/tests/test_concurrency/test_concurrency_processes.py

@@ -3,6 +3,7 @@ from __future__ import with_statement
 
 import signal
 import sys
+import time
 
 from itertools import cycle
 
@@ -226,3 +227,15 @@ class test_TaskPool(unittest.TestCase):
         self.assertEqual(info["max-concurrency"], pool.limit)
         self.assertIsNone(info["max-tasks-per-child"])
         self.assertEqual(info["timeouts"], (5, 10))
+
+    def test_restart(self):
+        def get_pids(pool):
+            return set([p.pid for p in pool._pool._pool])
+
+        tp = self.TaskPool(5)
+        time.sleep(0.5)
+        tp.start()
+        pids = get_pids(tp)
+        tp.restart()
+        time.sleep(0.5)
+        self.assertEqual(pids, get_pids(tp))

+ 60 - 1
celery/tests/test_worker/test_worker_control.py

@@ -1,12 +1,13 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
+import sys
 import socket
 
 from datetime import datetime, timedelta
 
 from kombu import pidbox
-from mock import Mock
+from mock import Mock, patch
 
 from celery import current_app
 from celery.datastructures import AttributeDict
@@ -376,3 +377,61 @@ class test_ControlPanel(unittest.TestCase):
                                              "routing_key": "x"})
         self.assertEqual(r, "pong")
         self.assertDictEqual(replies[0], {panel.hostname: "pong"})
+
+    def test_pool_restart(self):
+        consumer = Consumer()
+        consumer.pool.restart = Mock()
+        panel = self.create_panel(consumer=consumer)
+        panel.app = self.app
+        _import = panel.app.loader.import_from_cwd = Mock()
+        _reload = Mock()
+
+        panel.handle("pool_restart", {"reload": _reload})
+        self.assertTrue(consumer.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertFalse(_import.called)
+
+    def test_pool_restart_import_modules(self):
+        consumer = Consumer()
+        consumer.pool.restart = Mock()
+        panel = self.create_panel(consumer=consumer)
+        panel.app = self.app
+        _import = panel.app.loader.import_from_cwd = Mock()
+        _reload = Mock()
+
+        panel.handle("pool_restart", {"imports": ["foo", "bar"],
+                                      "reload": _reload})
+
+        self.assertTrue(consumer.pool.restart.called)
+        self.assertFalse(_reload.called)
+        self.assertEqual([(("foo",), {}), (("bar",), {})],
+                          _import.call_args_list)
+
+    def test_pool_restart_relaod_modules(self):
+        consumer = Consumer()
+        consumer.pool.restart = Mock()
+        panel = self.create_panel(consumer=consumer)
+        panel.app = self.app
+        _import = panel.app.loader.import_from_cwd = Mock()
+        _reload = Mock()
+
+        with patch.dict(sys.modules, {"foo": None}):
+            panel.handle("pool_restart", {"imports": ["foo"],
+                                          "reload_imports": False,
+                                          "reload": _reload})
+
+            self.assertTrue(consumer.pool.restart.called)
+            self.assertFalse(_reload.called)
+            self.assertFalse(_import.called)
+
+            _import.reset_mock()
+            _reload.reset_mock()
+            consumer.pool.restart.reset_mock()
+
+            panel.handle("pool_restart", {"imports": ["foo"],
+                                          "reload_imports": True,
+                                          "reload": _reload})
+
+            self.assertTrue(consumer.pool.restart.called)
+            self.assertTrue(_reload.called)
+            self.assertFalse(_import.called)

+ 179 - 0
celery/worker/autoreload.py

@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+"""
+    celery.worker.autoreload
+    ~~~~~~~~~~~~~~~~~~~~~~~~
+
+    This module implements automatic module reloading
+"""
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import os
+import sys
+import time
+import select
+import hashlib
+
+from collections import defaultdict
+
+from .. import current_app
+
+
+def file_hash(filename, algorithm='md5'):
+    hobj = hashlib.new(algorithm)
+    with open(filename, 'rb') as f:
+        for chunk in iter(lambda: f.read(2 ** 20), ''):
+            hobj.update(chunk)
+    return hobj.digest()
+
+
+class StatMonitor(object):
+    """File change monitor based on `stat` system call"""
+    def __init__(self, files, on_change=None, interval=0.5):
+        self._files = files
+        self._interval = interval
+        self._on_change = on_change
+        self._modify_times = defaultdict(int)
+
+    def start(self):
+        while True:
+            modified = {}
+            for m in self._files:
+                mt = self._mtime(m)
+                if mt is None:
+                    break
+                if self._modify_times[m] != mt:
+                    modified[m] = mt
+            else:
+                if modified:
+                    self.on_change(modified.keys())
+                    self._modify_times.update(modified)
+
+            time.sleep(self._interval)
+
+    def on_change(self, modified):
+        if self._on_change:
+            return self._on_change(modified)
+
+    @classmethod
+    def _mtime(cls, path):
+        try:
+            return os.stat(path).st_mtime
+        except:
+            return
+
+
+class KQueueMonitor(object):
+    """File change monitor based on BSD kernel event notifications"""
+    def __init__(self, files, on_change=None):
+        assert hasattr(select, 'kqueue')
+        self._files = dict([(f, None) for f in files])
+        self._on_change = on_change
+
+    def start(self):
+        try:
+            self._kq = select.kqueue()
+            kevents = []
+            for f in self._files:
+                self._files[f] = fd = os.open(f, os.O_RDONLY)
+
+                ev = select.kevent(fd,
+                        filter=select.KQ_FILTER_VNODE,
+                        flags=select.KQ_EV_ADD |
+                              select.KQ_EV_ENABLE |
+                              select.KQ_EV_CLEAR,
+                        fflags=select.KQ_NOTE_WRITE |
+                               select.KQ_NOTE_EXTEND)
+                kevents.append(ev)
+
+            events = self._kq.control(kevents, 0)
+            while True:
+                events = self._kq.control(kevents, 1)
+                fds = [e.ident for e in events]
+                modified = [k for k, v in self._files.iteritems()
+                                            if v in fds]
+                self.on_change(modified)
+        finally:
+            self.close()
+
+    def close(self):
+        self._kq.close()
+        for f in self._files:
+            if self._files[f] is not None:
+                os.close(self._files[f])
+                self._files[f] = None
+
+    def on_change(self, modified):
+        if self._on_change:
+            return self._on_change(modified)
+
+
+try:
+    import pyinotify
+except ImportError:
+    pyinotify = None    # noqa
+
+
+class InotifyMonitor(pyinotify and pyinotify.ProcessEvent or object):
+    """File change monitor based on  Linux kernel `inotify` subsystem"""
+    def __init__(self, modules, on_change=None):
+        assert pyinotify
+        self._modules = modules
+        self._on_change = on_change
+
+    def start(self):
+        try:
+            self._wm = pyinotify.WatchManager()
+            self._notifier = pyinotify.Notifier(self._wm)
+            for m in self._modules:
+                self._wm.add_watch(m, pyinotify.IN_MODIFY)
+            self._notifier.loop()
+        finally:
+            self.close()
+
+    def close(self):
+        self._notifier.stop()
+        self._wm.close()
+
+    def process_IN_MODIFY(self, event):
+        self.on_change(event.pathname)
+
+    def on_change(self, modified):
+        if self._on_change:
+            self._on_change(modified)
+
+
+if hasattr(select, 'kqueue'):
+    _monitor_cls = KQueueMonitor
+elif sys.platform.startswith('linux') and pyinotify:
+    _monitor_cls = InotifyMonitor
+else:
+    _monitor_cls = StatMonitor
+
+
+class AutoReloader(object):
+    """Tracks changes in modules and fires reload commands"""
+    def __init__(self, modules, monitor_cls=_monitor_cls, *args, **kwargs):
+        self._monitor = monitor_cls(modules, self.on_change, *args, **kwargs)
+        self._hashes = dict([(f, file_hash(f)) for f in modules])
+
+    def start(self):
+        self._monitor.start()
+
+    def on_change(self, files):
+        modified = []
+        for f in files:
+            fhash = file_hash(f)
+            if fhash != self._hashes[f]:
+                modified.append(f)
+                self._hashes[f] = fhash
+        if modified:
+            self._reload(map(self._module_name, modified))
+
+    def _reload(self, modules):
+        current_app.control.broadcast("pool_restart",
+                arguments={"imports": modules, "reload_modules": True})
+
+    @classmethod
+    def _module_name(cls, path):
+        return os.path.splitext(os.path.basename(path))[0]

+ 16 - 0
celery/worker/control.py

@@ -238,6 +238,22 @@ def pool_shrink(panel, n=1, **kwargs):
     return {"ok": "terminated worker processes"}
 
 
+@Panel.register
+def pool_restart(panel, imports=None, reload_imports=False,
+                 reload=reload, **kwargs):
+    imports = set(imports or [])
+    for m in imports:
+        if m not in sys.modules:
+            panel.app.loader.import_from_cwd(m)
+            panel.logger.debug("imported %s module" % m)
+        elif reload_imports:
+            reload(sys.modules[m])
+            panel.logger.debug("reloaded %s module" % m)
+
+    panel.consumer.pool.restart()
+    return {"ok": "started restarting worker processes"}
+
+
 @Panel.register
 def autoscale(panel, max=None, min=None):
     autoscaler = panel.consumer.controller.autoscaler

+ 50 - 0
docs/userguide/workers.rst

@@ -354,6 +354,56 @@ a worker using :program:`celeryev`/:program:`celerymon`.
     >>> broadcast("enable_events")
     >>> broadcast("disable_events")
 
+Adding/Reloading modules
+------------------------
+
+.. versionadded:: 2.5
+
+The remote control command ``pool_restart`` sends restart requests to
+the workers child processes.  It is particularly useful for forcing
+the worker to import new modules, or for reloading already imported
+modules.  This command does not interrupt executing tasks.
+
+Example
+~~~~~~~
+
+Runnig the following command will result in the `foo` and `bar` modules
+being imported by the worker processes:
+
+.. code-block:: python
+
+    >>> from celery.task.control import broadcast
+    >>> broadcast("pool_restart", arguments={"imports":["foo", "bar"]})
+
+If you want to reload all modules you can use:
+
+.. code-block:: python
+
+    >>> from celery.task.control import broadcast
+    >>> from celery import current_app
+    >>> modules = current_app.conf.CELERY_IMPORTS
+    >>> broadcast("pool_restart",
+                  arguments={"imports":modules, "reload_modules":True})
+
+`imports` argument is a list of modules to modify. `reload_modules`
+specifies whether to reload modules if they are previously imported.
+By default `reload_modules` is `False`. `pool_restart` command uses the
+`reload`_ built in function to reload modules, but you can provide custom
+reloader as well.
+
+.. note::
+
+Module reloading comes with some caveats that are documented in :fun:`reload`.
+Make sure your modules are suitable for reloading.
+
+.. seealso::
+
+http://pyunit.sourceforge.net/notes/reloading.html
+
+http://www.indelible.org/ink/python-reloading/
+
+http://docs.python.org/library/functions.html#reload
+
 .. _worker-custom-control-commands:
 
 Writing your own remote control commands