From 77c61dd655cb3eb872afc161254f6c66d851ff8a Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 6 Mar 2023 18:45:59 +0100 Subject: [PATCH] Implement basic autofix infrastructure and autofixer for TRIO100 --- flake8_trio/__init__.py | 17 +++- flake8_trio/runner.py | 8 +- flake8_trio/visitors/helpers.py | 61 +++++++++++++- flake8_trio/visitors/visitor100.py | 17 +++- flake8_trio/visitors/visitors.py | 2 +- tests/autofix_files/trio100.py | 84 +++++++++++++++++++ tests/autofix_files/trio100_simple_autofix.py | 59 +++++++++++++ tests/conftest.py | 11 +++ tests/eval_files/trio100.py | 31 +++---- tests/eval_files/trio100_simple_autofix.py | 58 +++++++++++++ tests/test_config_and_args.py | 18 +++- tests/test_flake8_trio.py | 30 ++++++- tox.ini | 4 +- 13 files changed, 368 insertions(+), 32 deletions(-) create mode 100644 tests/autofix_files/trio100.py create mode 100644 tests/autofix_files/trio100_simple_autofix.py create mode 100644 tests/eval_files/trio100_simple_autofix.py diff --git a/flake8_trio/__init__.py b/flake8_trio/__init__.py index d235c17..820f661 100644 --- a/flake8_trio/__init__.py +++ b/flake8_trio/__init__.py @@ -101,7 +101,10 @@ def main(): cwd=root, ).stdout.splitlines() except (subprocess.SubprocessError, FileNotFoundError): - print("Doesn't seem to be a git repo; pass filenames to format.") + print( + "Doesn't seem to be a git repo; pass filenames to format.", + file=sys.stderr, + ) sys.exit(1) all_filenames = [ os.path.join(root, f) for f in all_filenames if _should_format(f) @@ -110,6 +113,9 @@ def main(): plugin = Plugin.from_filename(file) for error in sorted(plugin.run()): print(f"{file}:{error}") + if plugin.options.autofix: + with open(file, "w") as file: + file.write(plugin.module.code) class Plugin: @@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]): self._tree = tree source = "".join(lines) - self._module: cst.Module = cst_parse_module_native(source) + self.module: cst.Module = cst_parse_module_native(source) @classmethod def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover @@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin: plugin = Plugin.__new__(cls) super(Plugin, plugin).__init__() plugin._tree = ast.parse(source) - plugin._module = cst_parse_module_native(source) + plugin.module = cst_parse_module_native(source) return plugin def run(self) -> Iterable[Error]: yield from Flake8TrioRunner.run(self._tree, self.options) - yield from Flake8TrioRunner_cst(self.options).run(self._module) + cst_runner = Flake8TrioRunner_cst(self.options, self.module) + yield from cst_runner.run() + self.module = cst_runner.module @staticmethod def add_options(option_manager: OptionManager | ArgumentParser): @@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser): add_argument = functools.partial( option_manager.add_option, parse_from_config=True ) + add_argument("--autofix", action="store_true", required=False) add_argument( "--no-checkpoint-warning-decorators", diff --git a/flake8_trio/runner.py b/flake8_trio/runner.py index ddc907e..945281e 100644 --- a/flake8_trio/runner.py +++ b/flake8_trio/runner.py @@ -100,20 +100,20 @@ def visit(self, node: ast.AST): class Flake8TrioRunner_cst: - def __init__(self, options: Namespace): + def __init__(self, options: Namespace, module: Module): super().__init__() self.state = SharedState(options) self.options = options self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple( v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes) ) + self.module = module - def run(self, module: Module) -> Iterable[Error]: + def run(self) -> Iterable[Error]: if not self.visitors: return - wrapper = cst.MetadataWrapper(module) for v in self.visitors: - _ = wrapper.visit(v) + self.module = cst.MetadataWrapper(self.module).visit(v) yield from self.state.problems def selected(self, error_codes: dict[str, str]) -> bool: diff --git a/flake8_trio/visitors/helpers.py b/flake8_trio/visitors/helpers.py index 3296673..689f3d5 100644 --- a/flake8_trio/visitors/helpers.py +++ b/flake8_trio/visitors/helpers.py @@ -7,7 +7,7 @@ import ast from fnmatch import fnmatch -from typing import TYPE_CHECKING, NamedTuple, TypeVar +from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast import libcst as cst import libcst.matchers as m @@ -341,3 +341,62 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool: ), ) ) + + +def get_comments(node: cst.CSTNode | Iterable[cst.CSTNode]) -> Iterator[cst.EmptyLine]: + # pyright can't use hasattr to narrow the type, so need a bunch of casts + if hasattr(node, "__iter__"): + for n in cast("Iterable[cst.CSTNode]", node): + yield from get_comments(n) + return + yield from ( + cst.EmptyLine(comment=ensure_type(c, cst.Comment)) + for c in m.findall(cast("cst.CSTNode", node), m.Comment()) + ) + return + + +# used in TRIO100 +def flatten_preserving_comments(node: cst.BaseCompoundStatement): + # add leading lines (comments and empty lines) for the node to be removed + new_leading_lines = list(node.leading_lines) + + # add other comments belonging to the node as empty lines with comments + for attr in "lpar", "items", "rpar": + # pragma, since this is currently only used to flatten `With` statements + if comment_nodes := getattr(node, attr, None): # pragma: no cover + new_leading_lines.extend(get_comments(comment_nodes)) + + # node.body is a BaseSuite, whose subclasses are SimpleStatementSuite + # and IndentedBlock + if isinstance(node.body, cst.SimpleStatementSuite): + # `with ...: pass;pass;pass` -> pass;pass;pass + return cst.SimpleStatementLine(node.body.body, leading_lines=new_leading_lines) + + assert isinstance(node.body, cst.IndentedBlock) + nodes = list(node.body.body) + + # nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine + # and BaseCompoundStatement - both of which has leading_lines + assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement)) + + # add body header comment - i.e. comments on the same/last line of the statement + if node.body.header and node.body.header.comment: + new_leading_lines.append( + cst.EmptyLine(indent=True, comment=node.body.header.comment) + ) + # add the leading lines of the first node + new_leading_lines.extend(nodes[0].leading_lines) + # update the first node with all the above constructed lines + nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines) + + # if there's comments in the footer of the indented block, add a pass + # statement with the comments as leading lines + if node.body.footer: + nodes.append( + cst.SimpleStatementLine( + [cst.Pass()], + node.body.footer, + ) + ) + return cst.FlattenSentinel(nodes) diff --git a/flake8_trio/visitors/visitor100.py b/flake8_trio/visitors/visitor100.py index 0c0da13..0f621cd 100644 --- a/flake8_trio/visitors/visitor100.py +++ b/flake8_trio/visitors/visitor100.py @@ -13,7 +13,12 @@ import libcst.matchers as m from .flake8triovisitor import Flake8TrioVisitor_cst -from .helpers import AttributeCall, error_class_cst, with_has_call +from .helpers import ( + AttributeCall, + error_class_cst, + flatten_preserving_comments, + with_has_call, +) @error_class_cst @@ -46,12 +51,16 @@ def visit_With(self, node: cst.With) -> None: else: self.has_checkpoint_stack.append(True) - def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With: + def leave_With( + self, original_node: cst.With, updated_node: cst.With + ) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]: if not self.has_checkpoint_stack.pop(): for res in self.node_dict[original_node]: self.error(res.node, res.base, res.function) - # if: autofixing is enabled for this code - # then: remove the with and pop out it's body + + if self.options.autofix and len(updated_node.items) == 1: + return flatten_preserving_comments(updated_node) + return updated_node def visit_For(self, node: cst.For): diff --git a/flake8_trio/visitors/visitors.py b/flake8_trio/visitors/visitors.py index 4fe292d..dc9c99b 100644 --- a/flake8_trio/visitors/visitors.py +++ b/flake8_trio/visitors/visitors.py @@ -91,7 +91,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): nursery = get_matching_call(item.context_expr, "open_nursery") # `isinstance(..., ast.Call)` is done in get_matching_call - body_call = cast(ast.Call, node.body[0].value) + body_call = cast("ast.Call", node.body[0].value) if ( nursery is not None diff --git a/tests/autofix_files/trio100.py b/tests/autofix_files/trio100.py new file mode 100644 index 0000000..df7ea15 --- /dev/null +++ b/tests/autofix_files/trio100.py @@ -0,0 +1,84 @@ +# type: ignore + +import trio + +# error: 5, "trio", "move_on_after" +... + + +async def function_name(): + # fmt: off + ...; ...; ... + # fmt: on + # error: 15, "trio", "fail_after" + ... + # error: 15, "trio", "fail_at" + ... + # error: 15, "trio", "move_on_after" + ... + # error: 15, "trio", "move_on_at" + ... + # error: 15, "trio", "CancelScope" + ... + + with trio.move_on_after(10): + await trio.sleep(1) + + with trio.move_on_after(10): + await trio.sleep(1) + print("hello") + + with trio.move_on_after(10): + while True: + await trio.sleep(1) + print("hello") + + with open("filename") as _: + ... + + # error: 9, "trio", "fail_after" + ... + + send_channel, receive_channel = trio.open_memory_channel(0) + async with trio.fail_after(10): + async with send_channel: + ... + + async with trio.fail_after(10): + async for _ in receive_channel: + ... + + # error: 15, "trio", "fail_after" + for _ in receive_channel: + ... + + # fix missed alarm when function is defined inside the with scope + # error: 9, "trio", "move_on_after" + + async def foo(): + await trio.sleep(1) + + # error: 9, "trio", "move_on_after" + if ...: + + async def foo(): + if ...: + await trio.sleep(1) + + async with random_ignored_library.fail_after(10): + ... + + +async def function_name2(): + with ( + open("") as _, + trio.fail_after(10), # error: 8, "trio", "fail_after" + ): + ... + + with ( + trio.fail_after(5), # error: 8, "trio", "fail_after" + open("") as _, + trio.move_on_after(5), # error: 8, "trio", "move_on_after" + ): + ... diff --git a/tests/autofix_files/trio100_simple_autofix.py b/tests/autofix_files/trio100_simple_autofix.py new file mode 100644 index 0000000..27dd18b --- /dev/null +++ b/tests/autofix_files/trio100_simple_autofix.py @@ -0,0 +1,59 @@ +import trio + +# a +# b +# error: 5, "trio", "move_on_after" +# c +# d +print(1) # e +# f +# g +print(2) # h +# i +# j +print(3) # k +# l +# m +pass +# n + +# error: 5, "trio", "move_on_after" +... + + +# a +# b +# fmt: off +...;...;... +# fmt: on +# c +# d + +# Doesn't autofix With's with multiple withitems +with ( + trio.move_on_after(10), # error: 4, "trio", "move_on_after" + open("") as f, +): + ... + + +# multiline with, despite only being one statement +# a +# b +# c +# error: 4, "trio", "move_on_after" +# d +# e +# f +# g +# h +# this comment is kept +... + +# fmt: off +# a +# b +# error: 4, "trio", "move_on_after" +# c +...; ...; ... +# fmt: on diff --git a/tests/conftest.py b/tests/conftest.py index 0dd074e..a3ebae2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser): parser.addoption( "--runfuzz", action="store_true", default=False, help="run fuzz tests" ) + parser.addoption( + "--generate-autofix", + action="store_true", + default=False, + help="generate autofix file content", + ) parser.addoption( "--enable-visitor-codes-regex", default=".*", @@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item item.add_marker(skip_fuzz) +@pytest.fixture() +def generate_autofix(request: pytest.FixtureRequest): + return request.config.getoption("generate_autofix") + + @pytest.fixture() def enable_visitor_codes_regex(request: pytest.FixtureRequest): return request.config.getoption("--enable-visitor-codes-regex") diff --git a/tests/eval_files/trio100.py b/tests/eval_files/trio100.py index 04d942a..b63cf5e 100644 --- a/tests/eval_files/trio100.py +++ b/tests/eval_files/trio100.py @@ -3,20 +3,23 @@ import trio with trio.move_on_after(10): # error: 5, "trio", "move_on_after" - pass + ... async def function_name(): + # fmt: off + async with trio.fail_after(10): ...; ...; ... # error: 15, "trio", "fail_after" + # fmt: on async with trio.fail_after(10): # error: 15, "trio", "fail_after" - pass + ... async with trio.fail_at(10): # error: 15, "trio", "fail_at" - pass + ... async with trio.move_on_after(10): # error: 15, "trio", "move_on_after" - pass + ... async with trio.move_on_at(10): # error: 15, "trio", "move_on_at" - pass + ... async with trio.CancelScope(...): # error: 15, "trio", "CancelScope" - pass + ... with trio.move_on_after(10): await trio.sleep(1) @@ -31,23 +34,23 @@ async def function_name(): print("hello") with open("filename") as _: - pass + ... with trio.fail_after(10): # error: 9, "trio", "fail_after" - pass + ... send_channel, receive_channel = trio.open_memory_channel(0) async with trio.fail_after(10): async with send_channel: - pass + ... async with trio.fail_after(10): async for _ in receive_channel: - pass + ... async with trio.fail_after(10): # error: 15, "trio", "fail_after" for _ in receive_channel: - pass + ... # fix missed alarm when function is defined inside the with scope with trio.move_on_after(10): # error: 9, "trio", "move_on_after" @@ -63,7 +66,7 @@ async def foo(): await trio.sleep(1) async with random_ignored_library.fail_after(10): - pass + ... async def function_name2(): @@ -71,11 +74,11 @@ async def function_name2(): open("") as _, trio.fail_after(10), # error: 8, "trio", "fail_after" ): - pass + ... with ( trio.fail_after(5), # error: 8, "trio", "fail_after" open("") as _, trio.move_on_after(5), # error: 8, "trio", "move_on_after" ): - pass + ... diff --git a/tests/eval_files/trio100_simple_autofix.py b/tests/eval_files/trio100_simple_autofix.py new file mode 100644 index 0000000..46655c3 --- /dev/null +++ b/tests/eval_files/trio100_simple_autofix.py @@ -0,0 +1,58 @@ +import trio + +# a +# b +with trio.move_on_after(10): # error: 5, "trio", "move_on_after" + # c + # d + print(1) # e + # f + # g + print(2) # h + # i + # j + print(3) # k + # l + # m +# n + +with trio.move_on_after(10): # error: 5, "trio", "move_on_after" + ... + + +# a +# b +# fmt: off +with trio.move_on_after(10): ...;...;... # error: 5, "trio", "move_on_after" +# fmt: on +# c +# d + +# Doesn't autofix With's with multiple withitems +with ( + trio.move_on_after(10), # error: 4, "trio", "move_on_after" + open("") as f, +): + ... + + +# multiline with, despite only being one statement +with ( # a + # b + # c + trio.move_on_after( # error: 4, "trio", "move_on_after" + # d + 9999999999999999999999999999999999999999999999999999999 # e + # f + ) # g + # h +): # this comment is kept + ... + +# fmt: off +with ( # a + # b + trio.move_on_after(10) # error: 4, "trio", "move_on_after" + # c +): ...; ...; ... +# fmt: on diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py index d6cd361..4f3c32d 100644 --- a/tests/test_config_and_args.py +++ b/tests/test_config_and_args.py @@ -63,8 +63,24 @@ def test_run_no_git_repo( with pytest.raises(SystemExit): from flake8_trio import __main__ # noqa out, err = capsys.readouterr() - assert out == "Doesn't seem to be a git repo; pass filenames to format.\n" + assert err == "Doesn't seem to be a git repo; pass filenames to format.\n" + assert not out + + +def test_run_100_autofix( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +): + err_msg = _common_error_setup(tmp_path) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr( + sys, "argv", [tmp_path / "flake8_trio", "--autofix", "./example.py"] + ) + from flake8_trio import __main__ # noqa + + out, err = capsys.readouterr() + assert out == err_msg assert not err + assert tmp_path.joinpath("example.py").read_text() == "import trio\n...\n" def test_114_raises_on_invalid_parameter(capsys: pytest.CaptureFixture[str]): diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index eef016d..6d7867f 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -4,6 +4,7 @@ import ast import copy +import difflib import itertools import os import re @@ -33,6 +34,11 @@ test_files: list[tuple[str, Path]] = sorted( (f.stem.upper(), f) for f in (Path(__file__).parent / "eval_files").iterdir() ) +autofix_files: dict[str, Path] = { + f.stem.upper(): f for f in (Path(__file__).parent / "autofix_files").iterdir() +} +# check that there's an eval file for each autofix file +assert set(autofix_files.keys()) - {f[0] for f in test_files} == set() class ParseError(Exception): @@ -65,19 +71,41 @@ def check_version(test: str): } +def check_autofix(test: str, plugin: Plugin, generate_autofix: bool): + if test not in autofix_files: + return + visited_code = plugin.module.code + current_generated = autofix_files[test].read_text() + if generate_autofix: + if visited_code != current_generated: + print(f"\nregenerating {test}") + sys.stdout.writelines( + difflib.unified_diff( + current_generated.splitlines(keepends=True), + visited_code.splitlines(keepends=True), + ) + ) + autofix_files[test].write_text(visited_code) + return + assert visited_code == current_generated + + @pytest.mark.parametrize(("test", "path"), test_files) -def test_eval(test: str, path: Path): +def test_eval(test: str, path: Path, generate_autofix: bool): content = path.read_text() if "# NOTRIO" in content: pytest.skip("file marked with NOTRIO") expected, parsed_args = _parse_eval_file(test, content) + parsed_args.append("--autofix") if "# TRIO_NO_ERROR" in content: expected = [] plugin = Plugin.from_source(content) _ = assert_expected_errors(plugin, *expected, args=parsed_args) + check_autofix(test, plugin, generate_autofix) + @pytest.mark.parametrize(("test", "path"), test_files) def test_eval_anyio(test: str, path: Path): diff --git a/tox.ini b/tox.ini index 79239ee..e084670 100644 --- a/tox.ini +++ b/tox.ini @@ -48,9 +48,9 @@ exclude_lines = [flake8] max-line-length = 90 -extend-ignore = S101, D101, D102, D103, D105, D106, D107, TC006 +extend-ignore = S101, D101, D102, D103, D105, D106, D107 extend-enable = TC10 -exclude = .*, tests/eval_files/* +exclude = .*, tests/eval_files/*, tests/autofix_files/* per-file-ignores = flake8_trio/visitors/__init__.py: F401, E402 # (E301, E302) black formats stub files without excessive blank lines