Browse Source

Fixes autoreloader to work on non-top-level modules. Closes #823

Ask Solem 12 years ago
parent
commit
692c134633
2 changed files with 9 additions and 13 deletions
  1. 1 5
      celery/tests/worker/test_autoreload.py
  2. 8 8
      celery/worker/autoreload.py

+ 1 - 5
celery/tests/worker/test_autoreload.py

@@ -237,7 +237,7 @@ class test_Autoreloader(AppCase):
         mm = x._maybe_modified = Mock(0)
         mm.return_value = True
         x._reload = Mock()
-        x._module_name = Mock()
+        x.file_to_module[__name__] = __name__
         x.on_change([__name__])
         self.assertTrue(x._reload.called)
         mm.return_value = False
@@ -255,7 +255,3 @@ class test_Autoreloader(AppCase):
         x._monitor = Mock()
         x.stop()
         x._monitor.stop.assert_called_with()
-
-    def test_module_name(self):
-        x = Autoreloader(Mock(), modules=[__name__])
-        self.assertEqual(x._module_name('foo/bar/baz.py'), 'baz')

+ 8 - 8
celery/worker/autoreload.py

@@ -225,12 +225,16 @@ class Autoreloader(bgThread):
         self.options = options
         self._monitor = None
         self._hashes = None
+        self.file_to_module = {}
 
     def on_init(self):
-        files = [module_file(sys.modules[m]) for m in self.modules]
-        self._hashes = dict((f, file_hash(f)) for f in files)
-        self._monitor = self.Monitor(files, self.on_change,
+        files = self.file_to_module
+        files.update(dict((module_file(sys.modules[m]), m)
+                        for m in self.modules))
+
+        self._monitor = self.Monitor(files.keys(), self.on_change,
                 shutdown_event=self._is_shutdown, **self.options)
+        self._hashes = dict([(f, file_hash(f)) for f in files])
 
     def on_poll_init(self, hub):
         if self._monitor is None:
@@ -259,7 +263,7 @@ class Autoreloader(bgThread):
     def on_change(self, files):
         modified = [f for f in files if self._maybe_modified(f)]
         if modified:
-            names = [self._module_name(module) for module in modified]
+            names = [self.file_to_module[module] for module in modified]
             logger.info('Detected modified modules: %r', names)
             self._reload(names)
 
@@ -269,7 +273,3 @@ class Autoreloader(bgThread):
     def stop(self):
         if self._monitor:
             self._monitor.stop()
-
-    @staticmethod
-    def _module_name(path):
-        return os.path.splitext(os.path.basename(path))[0]