Explorar el Código

99% coverage for .worker.autoreload

Ask Solem hace 13 años
padre
commit
ed29fa3a5d

+ 5 - 21
celery/tests/test_concurrency/test_threads.py

@@ -1,15 +1,11 @@
 from __future__ import absolute_import
 from __future__ import with_statement
 
-import sys
-
-from contextlib import contextmanager
 from mock import Mock
-from types import ModuleType
 
 from celery.concurrency.threads import NullDict, TaskPool, apply_target
 
-from celery.tests.utils import Case, mask_modules
+from celery.tests.utils import Case, mask_modules, mock_module
 
 
 class test_NullDict(Case):
@@ -21,18 +17,6 @@ class test_NullDict(Case):
             x["foo"]
 
 
-@contextmanager
-def threadpool_module():
-
-    prev = sys.modules.get("threadpool")
-    tp = sys.modules["threadpool"] = ModuleType("threadpool")
-    tp.WorkRequest = Mock()
-    tp.ThreadPool = Mock()
-    yield tp
-    if prev:
-        sys.modules["threadpool"] = prev
-
-
 class test_TaskPool(Case):
 
     def test_without_threadpool(self):
@@ -42,27 +26,27 @@ class test_TaskPool(Case):
                 TaskPool()
 
     def test_with_threadpool(self):
-        with threadpool_module():
+        with mock_module("threadpool"):
             x = TaskPool()
             self.assertTrue(x.ThreadPool)
             self.assertTrue(x.WorkRequest)
 
     def test_on_start(self):
-        with threadpool_module():
+        with mock_module("threadpool"):
             x = TaskPool()
             x.on_start()
             self.assertTrue(x._pool)
             self.assertIsInstance(x._pool.workRequests, NullDict)
 
     def test_on_stop(self):
-        with threadpool_module():
+        with mock_module("threadpool"):
             x = TaskPool()
             x.on_start()
             x.on_stop()
             x._pool.dismissWorkers.assert_called_with(x.limit, do_join=True)
 
     def test_on_apply(self):
-        with threadpool_module():
+        with mock_module("threadpool"):
             x = TaskPool()
             x.on_start()
             callback = Mock()

+ 243 - 0
celery/tests/test_worker/test_worker_autoreload.py

@@ -0,0 +1,243 @@
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import errno
+import select
+import sys
+
+from mock import Mock, patch
+from time import time
+
+from celery.worker import autoreload
+from celery.worker.autoreload import (
+    WorkerComponent,
+    file_hash,
+    BaseMonitor,
+    StatMonitor,
+    KQueueMonitor,
+    InotifyMonitor,
+    default_implementation,
+    Autoreloader,
+)
+
+from celery.tests.utils import AppCase, Case, WhateverIO
+
+
+class test_WorkerComponent(AppCase):
+
+    def test_create(self):
+        w = Mock()
+        x = WorkerComponent(w)
+        x.instantiate = Mock()
+        r = x.create(w)
+        x.instantiate.assert_called_with(w.autoreloader_cls,
+                                         controller=w)
+        self.assertIs(r, w.autoreloader)
+
+
+class test_file_hash(Case):
+
+    @patch("__builtin__.open")
+    def test_hash(self, open_):
+        context = open_.return_value = Mock()
+        context.__enter__ = Mock()
+        context.__exit__ = Mock()
+        a = context.__enter__.return_value = WhateverIO()
+        a.write("the quick brown fox\n")
+        a.seek(0)
+        A = file_hash("foo")
+        b = context.__enter__.return_value = WhateverIO()
+        b.write("the quick brown bar\n")
+        b.seek(0)
+        B = file_hash("bar")
+        self.assertNotEqual(A, B)
+
+
+class test_BaseMonitor(Case):
+
+    def test_start_stop_on_change(self):
+        x = BaseMonitor(["a", "b"])
+
+        with self.assertRaises(NotImplementedError):
+            x.start()
+        x.stop()
+        x.on_change([])
+        x._on_change = Mock()
+        x.on_change("foo")
+        x._on_change.assert_called_with("foo")
+
+
+class test_StatMonitor(Case):
+
+    @patch("os.stat")
+    def test_start(self, stat):
+
+        class st(object):
+            st_mtime = time()
+        stat.return_value = st()
+        x = StatMonitor(["a", "b"])
+        calls = [0]
+
+        def on_is_set():
+            calls[0] += 1
+            if calls[0] > 2:
+                return True
+            return False
+        x.shutdown_event = Mock()
+        x.shutdown_event.is_set.side_effect = on_is_set
+
+        x.start()
+        calls[0] = 0
+        stat.side_effect = OSError()
+        x.start()
+
+
+class test_KQueueMontior(Case):
+
+    @patch("select.kqueue", create=True)
+    @patch("os.close")
+    def test_stop(self, close, kqueue):
+        x = KQueueMonitor(["a", "b"])
+        x._kq = Mock()
+        x.filemap["a"] = 10
+        x.stop()
+        x._kq.close.assert_called_with()
+        close.assert_called_with(10)
+
+        close.side_effect = OSError()
+        close.side_effect.errno = errno.EBADF
+        x.stop()
+
+    @patch("select.kqueue", create=True)
+    @patch("select.kevent", create=True)
+    @patch("os.open")
+    def test_start(self, osopen, kevent, kqueue):
+        prev = {}
+        flags = ["KQ_FILTER_VNODE", "KQ_EV_ADD", "KQ_EV_ENABLE",
+                 "KQ_EV_CLEAR", "KW_NOTE_WRITE", "KQ_NOTE_EXTEND"]
+        for i, flag in enumerate(flags):
+            prev[flag] = getattr(select, flag, None)
+            if not prev[flag]:
+                setattr(select, flag, i)
+        try:
+            kq = kqueue.return_value = Mock()
+
+            class ev(object):
+                ident = 10
+            kq.control.return_value = [ev()]
+            x = KQueueMonitor(["a"])
+            osopen.return_value = 10
+            calls = [0]
+
+            def on_is_set():
+                calls[0] += 1
+                if calls[0] > 2:
+                    return True
+                return False
+            x.shutdown_event = Mock()
+            x.shutdown_event.is_set.side_effect = on_is_set
+            x.start()
+        finally:
+            for flag in flags:
+                if not prev[flag]:
+                    delattr(select, flag)
+
+
+class test_InotifyMonitor(Case):
+
+    @patch("celery.worker.autoreload.pyinotify")
+    def test_start(self, inotify):
+            x = InotifyMonitor(["a"])
+            inotify.IN_MODIFY = 1
+            inotify.IN_ATTRIB = 2
+            x.start()
+
+            inotify.WatchManager.side_effect = ValueError()
+            with self.assertRaises(ValueError):
+                x.start()
+            x.stop()
+
+            x._on_change = None
+            x.process_(Mock())
+            x._on_change = Mock()
+            x.process_(Mock())
+            self.assertTrue(x._on_change.called)
+
+
+class test_default_implementation(Case):
+
+    @patch("select.kqueue", create=True)
+    def test_kqueue(self, kqueue):
+        self.assertEqual(default_implementation(), "kqueue")
+
+    @patch("celery.worker.autoreload.pyinotify")
+    def test_inotify(self, pyinotify):
+        kq = getattr(select, "kqueue", None)
+        delattr(select, "kqueue")
+        platform, sys.platform = sys.platform, "linux"
+        try:
+            self.assertEqual(default_implementation(), "inotify")
+            ino, autoreload.pyinotify = autoreload.pyinotify, None
+            try:
+                self.assertEqual(default_implementation(), "stat")
+            finally:
+                autoreload.pyinotify = ino
+        finally:
+            if kq:
+                select.kqueue = kq
+            sys.platform = platform
+
+
+class test_Autoreloader(AppCase):
+
+    @patch("celery.worker.autoreload.file_hash")
+    def test_start(self, fhash):
+        x = Autoreloader(Mock(), modules=[__name__])
+        x.Monitor = Mock()
+        mon = x.Monitor.return_value = Mock()
+        mon.start.side_effect = OSError()
+        mon.start.side_effect.errno = errno.EINTR
+        x.body()
+        mon.start.side_effect.errno = errno.ENOENT
+        with self.assertRaises(OSError):
+            x.body()
+        mon.start.side_effect = None
+        x.body()
+
+    @patch("celery.worker.autoreload.file_hash")
+    def test_maybe_modified(self, fhash):
+        fhash.return_value = "abcd"
+        x = Autoreloader(Mock(), modules=[__name__])
+        x._hashes = {}
+        x._hashes[__name__] = "dcba"
+        self.assertTrue(x._maybe_modified(__name__))
+        x._hashes[__name__] = "abcd"
+        self.assertFalse(x._maybe_modified(__name__))
+
+    def test_on_change(self):
+        x = Autoreloader(Mock(), modules=[__name__])
+        mm = x._maybe_modified = Mock(0)
+        mm.return_value = True
+        x._reload = Mock()
+        x._module_name = Mock()
+        x.on_change([__name__])
+        self.assertTrue(x._reload.called)
+        mm.return_value = False
+        x.on_change([__name__])
+
+    def test_reload(self):
+        x = Autoreloader(Mock(), modules=[__name__])
+        x._reload([__name__])
+        x.controller.reload.assert_called_with([__name__], reload=True)
+
+    def test_stop(self):
+        x = Autoreloader(Mock(), modules=[__name__])
+        x._monitor = None
+        x.stop()
+        x._monitor = Mock()
+        x.stop()
+        x._monitor.stop.assert_called_with()
+
+    def test_module_name(self):
+        x = Autoreloader(Mock(), modules=[__name__])
+        self.assertEqual(x._module_name("foo/bar/baz.py"), "baz")

+ 19 - 3
celery/tests/utils.py

@@ -21,8 +21,9 @@ try:
 except ImportError:  # py3k
     import builtins  # noqa
 
-from functools import partial, wraps
 from contextlib import contextmanager
+from functools import partial, wraps
+from types import ModuleType
 
 import mock
 from nose import SkipTest
@@ -450,8 +451,6 @@ def reset_modules(*modules):
 
 @contextmanager
 def patch_modules(*modules):
-    from types import ModuleType
-
     prev = {}
     for mod in modules:
         prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
@@ -478,3 +477,20 @@ class create_pidlock(object):
                 pass
 
         return Object()
+
+
+@contextmanager
+def mock_module(name):
+
+    prev = sys.modules.get(name)
+
+    class MockModule(ModuleType):
+
+        def __getattr__(self, attr):
+            setattr(self, attr, Mock())
+            return ModuleType.__getattribute__(self, attr)
+
+    mod = sys.modules[name] = MockModule(name)
+    yield mod
+    if prev:
+        sys.modules[name] = prev

+ 2 - 2
celery/worker/autoreload.py

@@ -27,7 +27,7 @@ from .abstract import StartStopComponent
 try:
     import pyinotify
     _ProcessEvent = pyinotify.ProcessEvent
-except ImportError:
+except ImportError:         # pragma: no cover
     pyinotify = None        # noqa
     _ProcessEvent = object  # noqa
 
@@ -137,7 +137,7 @@ class KQueueMonitor(BaseMonitor):
     def stop(self):
         self._kq.close()
         for fd in filter(None, self.filemap.values()):
-            with ignore_EBADF:
+            with ignore_EBADF():
                 os.close(fd)
             self.filemap[fd] = None
         self.filemap.clear()

+ 0 - 1
setup.cfg

@@ -14,7 +14,6 @@ cover3-exclude = celery
                  celery.backends.mongodb
                  celery.backends.cassandra
                  celery.events.cursesmon
-                 celery.worker.autoreload
 
 [build_sphinx]
 source-dir = docs/