Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify file_filter and test it #123

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions src/django_watchfiles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import fnmatch
import threading
from collections.abc import Generator
from collections.abc import Iterable
from fnmatch import fnmatch
from pathlib import Path
from typing import Any
from typing import Callable

import watchfiles
from django.utils import autoreload
from watchfiles import Change
from watchfiles import watch


class MutableWatcher:
Expand All @@ -20,7 +21,7 @@ class MutableWatcher:
underlying watchfiles iterator when roots are added or removed.
"""

def __init__(self, filter: Callable[[watchfiles.Change, str], bool]) -> None:
def __init__(self, filter: Callable[[Change, str], bool]) -> None:
self.change_event = threading.Event()
self.stop_event = threading.Event()
self.roots: set[Path] = set()
Expand All @@ -34,10 +35,10 @@ def set_roots(self, roots: set[Path]) -> None:
def stop(self) -> None:
self.stop_event.set()

def __iter__(self) -> Generator[Any]: # TODO: better type
def __iter__(self) -> Generator[set[tuple[Change, str]]]:
while True:
self.change_event.clear()
for changes in watchfiles.watch(
for changes in watch(
*self.roots,
watch_filter=self.filter,
stop_event=self.stop_event,
Expand All @@ -53,32 +54,39 @@ def __iter__(self) -> Generator[Any]: # TODO: better type
class WatchfilesReloader(autoreload.BaseReloader):
def __init__(self) -> None:
self.watcher = MutableWatcher(self.file_filter)
self.watched_files_set: set[Path] = set()
super().__init__()

def file_filter(self, change: watchfiles.Change, filename: str) -> bool:
def file_filter(self, change: Change, filename: str) -> bool:
path = Path(filename)
if path in set(self.watched_files(include_globs=False)):
if path in self.watched_files_set:
return True
for directory, globs in self.directory_globs.items():
try:
relative_path = path.relative_to(directory)
except ValueError:
pass
else:
relative_path_str = str(relative_path)
for glob in globs:
if fnmatch.fnmatch(str(relative_path), glob):
if fnmatch(relative_path_str, glob):
return True
return False

def watched_roots(self, watched_files: list[Path]) -> frozenset[Path]:
def watched_roots(self, watched_files: Iterable[Path]) -> frozenset[Path]:
# Adapted from WatchmanReloader
extra_directories = self.directory_globs.keys()
watched_file_dirs = {f.parent for f in watched_files}
sys_paths = set(autoreload.sys_path_directories())
return frozenset((*extra_directories, *watched_file_dirs, *sys_paths))

def tick(self) -> Generator[None]:
watched_files = list(self.watched_files(include_globs=False))
roots = set(autoreload.common_roots(self.watched_roots(watched_files)))
self.watched_files_set = set(self.watched_files(include_globs=False))
roots = set(
autoreload.common_roots(
self.watched_roots(self.watched_files_set),
)
)
self.watcher.set_roots(roots)

for changes in self.watcher:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_django_watchfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

from django.utils import autoreload
from watchfiles import Change

from django_watchfiles import MutableWatcher
from django_watchfiles import WatchfilesReloader
Expand Down Expand Up @@ -76,6 +77,51 @@ def test_iter_respects_change_event(self):
assert len(changes) == 0


class WatchfilesReloaderTests(SimpleTestCase):
def setUp(self):
temp_dir = self.enterContext(tempfile.TemporaryDirectory())
self.temp_path = Path(temp_dir)

self.reloader = WatchfilesReloader()

def test_file_filter_watched_file(self):
test_txt = self.temp_path / "test.txt"
test_txt.touch()
self.reloader.watched_files_set = {test_txt}

result = self.reloader.file_filter(Change.modified, str(test_txt))

assert result is True

def test_file_filter_unwatched_file(self):
test_txt = self.temp_path / "test.txt"
test_txt.touch()

result = self.reloader.file_filter(Change.modified, str(test_txt))

assert result is False

def test_file_filter_glob_matched(self):
self.reloader.watch_dir(self.temp_path, "*.txt")

result = self.reloader.file_filter(
Change.modified, str(self.temp_path / "test.txt")
)

assert result is True

def test_file_filter_glob_relative_path_impossible(self):
temp_dir2 = self.enterContext(tempfile.TemporaryDirectory())

self.reloader.watch_dir(Path(temp_dir2), "*.txt")

result = self.reloader.file_filter(
Change.modified, str(self.temp_path / "test.txt")
)

assert result is False


class ReplacedGetReloaderTests(SimpleTestCase):
def test_replaced_get_reloader(self):
reloader = autoreload.get_reloader()
Expand Down
Loading