diff --git a/src/django_watchfiles/__init__.py b/src/django_watchfiles/__init__.py index a0d49cc..7e7bd28 100644 --- a/src/django_watchfiles/__init__.py +++ b/src/django_watchfiles/__init__.py @@ -1,14 +1,15 @@ from __future__ import annotations -import fnmatch import threading from collections.abc import Generator +from collections.abc import Iterable +from fnmatch import fnmatch from pathlib import Path -from typing import Any from typing import Callable -import watchfiles from django.utils import autoreload +from watchfiles import Change +from watchfiles import watch class MutableWatcher: @@ -20,7 +21,7 @@ class MutableWatcher: underlying watchfiles iterator when roots are added or removed. """ - def __init__(self, filter: Callable[[watchfiles.Change, str], bool]) -> None: + def __init__(self, filter: Callable[[Change, str], bool]) -> None: self.change_event = threading.Event() self.stop_event = threading.Event() self.roots: set[Path] = set() @@ -34,10 +35,10 @@ def set_roots(self, roots: set[Path]) -> None: def stop(self) -> None: self.stop_event.set() - def __iter__(self) -> Generator[Any]: # TODO: better type + def __iter__(self) -> Generator[set[tuple[Change, str]]]: while True: self.change_event.clear() - for changes in watchfiles.watch( + for changes in watch( *self.roots, watch_filter=self.filter, stop_event=self.stop_event, @@ -53,11 +54,12 @@ def __iter__(self) -> Generator[Any]: # TODO: better type class WatchfilesReloader(autoreload.BaseReloader): def __init__(self) -> None: self.watcher = MutableWatcher(self.file_filter) + self.watched_files_set: set[Path] = set() super().__init__() - def file_filter(self, change: watchfiles.Change, filename: str) -> bool: + def file_filter(self, change: Change, filename: str) -> bool: path = Path(filename) - if path in set(self.watched_files(include_globs=False)): + if path in self.watched_files_set: return True for directory, globs in self.directory_globs.items(): try: @@ -65,20 +67,26 @@ def file_filter(self, change: watchfiles.Change, filename: str) -> bool: except ValueError: pass else: + relative_path_str = str(relative_path) for glob in globs: - if fnmatch.fnmatch(str(relative_path), glob): + if fnmatch(relative_path_str, glob): return True return False - def watched_roots(self, watched_files: list[Path]) -> frozenset[Path]: + def watched_roots(self, watched_files: Iterable[Path]) -> frozenset[Path]: + # Adapted from WatchmanReloader extra_directories = self.directory_globs.keys() watched_file_dirs = {f.parent for f in watched_files} sys_paths = set(autoreload.sys_path_directories()) return frozenset((*extra_directories, *watched_file_dirs, *sys_paths)) def tick(self) -> Generator[None]: - watched_files = list(self.watched_files(include_globs=False)) - roots = set(autoreload.common_roots(self.watched_roots(watched_files))) + self.watched_files_set = set(self.watched_files(include_globs=False)) + roots = set( + autoreload.common_roots( + self.watched_roots(self.watched_files_set), + ) + ) self.watcher.set_roots(roots) for changes in self.watcher: diff --git a/tests/test_django_watchfiles.py b/tests/test_django_watchfiles.py index 718a35f..a5293a0 100644 --- a/tests/test_django_watchfiles.py +++ b/tests/test_django_watchfiles.py @@ -5,6 +5,7 @@ from pathlib import Path from django.utils import autoreload +from watchfiles import Change from django_watchfiles import MutableWatcher from django_watchfiles import WatchfilesReloader @@ -76,6 +77,51 @@ def test_iter_respects_change_event(self): assert len(changes) == 0 +class WatchfilesReloaderTests(SimpleTestCase): + def setUp(self): + temp_dir = self.enterContext(tempfile.TemporaryDirectory()) + self.temp_path = Path(temp_dir) + + self.reloader = WatchfilesReloader() + + def test_file_filter_watched_file(self): + test_txt = self.temp_path / "test.txt" + test_txt.touch() + self.reloader.watched_files_set = {test_txt} + + result = self.reloader.file_filter(Change.modified, str(test_txt)) + + assert result is True + + def test_file_filter_unwatched_file(self): + test_txt = self.temp_path / "test.txt" + test_txt.touch() + + result = self.reloader.file_filter(Change.modified, str(test_txt)) + + assert result is False + + def test_file_filter_glob_matched(self): + self.reloader.watch_dir(self.temp_path, "*.txt") + + result = self.reloader.file_filter( + Change.modified, str(self.temp_path / "test.txt") + ) + + assert result is True + + def test_file_filter_glob_relative_path_impossible(self): + temp_dir2 = self.enterContext(tempfile.TemporaryDirectory()) + + self.reloader.watch_dir(Path(temp_dir2), "*.txt") + + result = self.reloader.file_filter( + Change.modified, str(self.temp_path / "test.txt") + ) + + assert result is False + + class ReplacedGetReloaderTests(SimpleTestCase): def test_replaced_get_reloader(self): reloader = autoreload.get_reloader()