From 77c61dd655cb3eb872afc161254f6c66d851ff8a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 6 Mar 2023 18:45:59 +0100
Subject: [PATCH] Implement basic autofix infrastructure and autofixer for
TRIO100
---
flake8_trio/__init__.py | 17 +++-
flake8_trio/runner.py | 8 +-
flake8_trio/visitors/helpers.py | 61 +++++++++++++-
flake8_trio/visitors/visitor100.py | 17 +++-
flake8_trio/visitors/visitors.py | 2 +-
tests/autofix_files/trio100.py | 84 +++++++++++++++++++
tests/autofix_files/trio100_simple_autofix.py | 59 +++++++++++++
tests/conftest.py | 11 +++
tests/eval_files/trio100.py | 31 +++----
tests/eval_files/trio100_simple_autofix.py | 58 +++++++++++++
tests/test_config_and_args.py | 18 +++-
tests/test_flake8_trio.py | 30 ++++++-
tox.ini | 4 +-
13 files changed, 368 insertions(+), 32 deletions(-)
create mode 100644 tests/autofix_files/trio100.py
create mode 100644 tests/autofix_files/trio100_simple_autofix.py
create mode 100644 tests/eval_files/trio100_simple_autofix.py
diff --git a/flake8_trio/__init__.py b/flake8_trio/__init__.py
index d235c17..820f661 100644
--- a/flake8_trio/__init__.py
+++ b/flake8_trio/__init__.py
@@ -101,7 +101,10 @@ def main():
cwd=root,
).stdout.splitlines()
except (subprocess.SubprocessError, FileNotFoundError):
- print("Doesn't seem to be a git repo; pass filenames to format.")
+ print(
+ "Doesn't seem to be a git repo; pass filenames to format.",
+ file=sys.stderr,
+ )
sys.exit(1)
all_filenames = [
os.path.join(root, f) for f in all_filenames if _should_format(f)
@@ -110,6 +113,9 @@ def main():
plugin = Plugin.from_filename(file)
for error in sorted(plugin.run()):
print(f"{file}:{error}")
+ if plugin.options.autofix:
+ with open(file, "w") as file:
+ file.write(plugin.module.code)
class Plugin:
@@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
self._tree = tree
source = "".join(lines)
- self._module: cst.Module = cst_parse_module_native(source)
+ self.module: cst.Module = cst_parse_module_native(source)
@classmethod
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
@@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
plugin = Plugin.__new__(cls)
super(Plugin, plugin).__init__()
plugin._tree = ast.parse(source)
- plugin._module = cst_parse_module_native(source)
+ plugin.module = cst_parse_module_native(source)
return plugin
def run(self) -> Iterable[Error]:
yield from Flake8TrioRunner.run(self._tree, self.options)
- yield from Flake8TrioRunner_cst(self.options).run(self._module)
+ cst_runner = Flake8TrioRunner_cst(self.options, self.module)
+ yield from cst_runner.run()
+ self.module = cst_runner.module
@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
@@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
add_argument = functools.partial(
option_manager.add_option, parse_from_config=True
)
+ add_argument("--autofix", action="store_true", required=False)
add_argument(
"--no-checkpoint-warning-decorators",
diff --git a/flake8_trio/runner.py b/flake8_trio/runner.py
index ddc907e..945281e 100644
--- a/flake8_trio/runner.py
+++ b/flake8_trio/runner.py
@@ -100,20 +100,20 @@ def visit(self, node: ast.AST):
class Flake8TrioRunner_cst:
- def __init__(self, options: Namespace):
+ def __init__(self, options: Namespace, module: Module):
super().__init__()
self.state = SharedState(options)
self.options = options
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
)
+ self.module = module
- def run(self, module: Module) -> Iterable[Error]:
+ def run(self) -> Iterable[Error]:
if not self.visitors:
return
- wrapper = cst.MetadataWrapper(module)
for v in self.visitors:
- _ = wrapper.visit(v)
+ self.module = cst.MetadataWrapper(self.module).visit(v)
yield from self.state.problems
def selected(self, error_codes: dict[str, str]) -> bool:
diff --git a/flake8_trio/visitors/helpers.py b/flake8_trio/visitors/helpers.py
index 3296673..689f3d5 100644
--- a/flake8_trio/visitors/helpers.py
+++ b/flake8_trio/visitors/helpers.py
@@ -7,7 +7,7 @@
import ast
from fnmatch import fnmatch
-from typing import TYPE_CHECKING, NamedTuple, TypeVar
+from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast
import libcst as cst
import libcst.matchers as m
@@ -341,3 +341,62 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
),
)
)
+
+
+def get_comments(node: cst.CSTNode | Iterable[cst.CSTNode]) -> Iterator[cst.EmptyLine]:
+ # pyright can't use hasattr to narrow the type, so need a bunch of casts
+ if hasattr(node, "__iter__"):
+ for n in cast("Iterable[cst.CSTNode]", node):
+ yield from get_comments(n)
+ return
+ yield from (
+ cst.EmptyLine(comment=ensure_type(c, cst.Comment))
+ for c in m.findall(cast("cst.CSTNode", node), m.Comment())
+ )
+ return
+
+
+# used in TRIO100
+def flatten_preserving_comments(node: cst.BaseCompoundStatement):
+ # add leading lines (comments and empty lines) for the node to be removed
+ new_leading_lines = list(node.leading_lines)
+
+ # add other comments belonging to the node as empty lines with comments
+ for attr in "lpar", "items", "rpar":
+ # pragma, since this is currently only used to flatten `With` statements
+ if comment_nodes := getattr(node, attr, None): # pragma: no cover
+ new_leading_lines.extend(get_comments(comment_nodes))
+
+ # node.body is a BaseSuite, whose subclasses are SimpleStatementSuite
+ # and IndentedBlock
+ if isinstance(node.body, cst.SimpleStatementSuite):
+ # `with ...: pass;pass;pass` -> pass;pass;pass
+ return cst.SimpleStatementLine(node.body.body, leading_lines=new_leading_lines)
+
+ assert isinstance(node.body, cst.IndentedBlock)
+ nodes = list(node.body.body)
+
+ # nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
+ # and BaseCompoundStatement - both of which has leading_lines
+ assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))
+
+ # add body header comment - i.e. comments on the same/last line of the statement
+ if node.body.header and node.body.header.comment:
+ new_leading_lines.append(
+ cst.EmptyLine(indent=True, comment=node.body.header.comment)
+ )
+ # add the leading lines of the first node
+ new_leading_lines.extend(nodes[0].leading_lines)
+ # update the first node with all the above constructed lines
+ nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)
+
+ # if there's comments in the footer of the indented block, add a pass
+ # statement with the comments as leading lines
+ if node.body.footer:
+ nodes.append(
+ cst.SimpleStatementLine(
+ [cst.Pass()],
+ node.body.footer,
+ )
+ )
+ return cst.FlattenSentinel(nodes)
diff --git a/flake8_trio/visitors/visitor100.py b/flake8_trio/visitors/visitor100.py
index 0c0da13..0f621cd 100644
--- a/flake8_trio/visitors/visitor100.py
+++ b/flake8_trio/visitors/visitor100.py
@@ -13,7 +13,12 @@
import libcst.matchers as m
from .flake8triovisitor import Flake8TrioVisitor_cst
-from .helpers import AttributeCall, error_class_cst, with_has_call
+from .helpers import (
+ AttributeCall,
+ error_class_cst,
+ flatten_preserving_comments,
+ with_has_call,
+)
@error_class_cst
@@ -46,12 +51,16 @@ def visit_With(self, node: cst.With) -> None:
else:
self.has_checkpoint_stack.append(True)
- def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
+ def leave_With(
+ self, original_node: cst.With, updated_node: cst.With
+ ) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
if not self.has_checkpoint_stack.pop():
for res in self.node_dict[original_node]:
self.error(res.node, res.base, res.function)
- # if: autofixing is enabled for this code
- # then: remove the with and pop out it's body
+
+ if self.options.autofix and len(updated_node.items) == 1:
+ return flatten_preserving_comments(updated_node)
+
return updated_node
def visit_For(self, node: cst.For):
diff --git a/flake8_trio/visitors/visitors.py b/flake8_trio/visitors/visitors.py
index 4fe292d..dc9c99b 100644
--- a/flake8_trio/visitors/visitors.py
+++ b/flake8_trio/visitors/visitors.py
@@ -91,7 +91,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
nursery = get_matching_call(item.context_expr, "open_nursery")
# `isinstance(..., ast.Call)` is done in get_matching_call
- body_call = cast(ast.Call, node.body[0].value)
+ body_call = cast("ast.Call", node.body[0].value)
if (
nursery is not None
diff --git a/tests/autofix_files/trio100.py b/tests/autofix_files/trio100.py
new file mode 100644
index 0000000..df7ea15
--- /dev/null
+++ b/tests/autofix_files/trio100.py
@@ -0,0 +1,84 @@
+# type: ignore
+
+import trio
+
+# error: 5, "trio", "move_on_after"
+...
+
+
+async def function_name():
+ # fmt: off
+ ...; ...; ...
+ # fmt: on
+ # error: 15, "trio", "fail_after"
+ ...
+ # error: 15, "trio", "fail_at"
+ ...
+ # error: 15, "trio", "move_on_after"
+ ...
+ # error: 15, "trio", "move_on_at"
+ ...
+ # error: 15, "trio", "CancelScope"
+ ...
+
+ with trio.move_on_after(10):
+ await trio.sleep(1)
+
+ with trio.move_on_after(10):
+ await trio.sleep(1)
+ print("hello")
+
+ with trio.move_on_after(10):
+ while True:
+ await trio.sleep(1)
+ print("hello")
+
+ with open("filename") as _:
+ ...
+
+ # error: 9, "trio", "fail_after"
+ ...
+
+ send_channel, receive_channel = trio.open_memory_channel(0)
+ async with trio.fail_after(10):
+ async with send_channel:
+ ...
+
+ async with trio.fail_after(10):
+ async for _ in receive_channel:
+ ...
+
+ # error: 15, "trio", "fail_after"
+ for _ in receive_channel:
+ ...
+
+ # fix missed alarm when function is defined inside the with scope
+ # error: 9, "trio", "move_on_after"
+
+ async def foo():
+ await trio.sleep(1)
+
+ # error: 9, "trio", "move_on_after"
+ if ...:
+
+ async def foo():
+ if ...:
+ await trio.sleep(1)
+
+ async with random_ignored_library.fail_after(10):
+ ...
+
+
+async def function_name2():
+ with (
+ open("") as _,
+ trio.fail_after(10), # error: 8, "trio", "fail_after"
+ ):
+ ...
+
+ with (
+ trio.fail_after(5), # error: 8, "trio", "fail_after"
+ open("") as _,
+ trio.move_on_after(5), # error: 8, "trio", "move_on_after"
+ ):
+ ...
diff --git a/tests/autofix_files/trio100_simple_autofix.py b/tests/autofix_files/trio100_simple_autofix.py
new file mode 100644
index 0000000..27dd18b
--- /dev/null
+++ b/tests/autofix_files/trio100_simple_autofix.py
@@ -0,0 +1,59 @@
+import trio
+
+# a
+# b
+# error: 5, "trio", "move_on_after"
+# c
+# d
+print(1) # e
+# f
+# g
+print(2) # h
+# i
+# j
+print(3) # k
+# l
+# m
+pass
+# n
+
+# error: 5, "trio", "move_on_after"
+...
+
+
+# a
+# b
+# fmt: off
+...;...;...
+# fmt: on
+# c
+# d
+
+# Doesn't autofix With's with multiple withitems
+with (
+ trio.move_on_after(10), # error: 4, "trio", "move_on_after"
+ open("") as f,
+):
+ ...
+
+
+# multiline with, despite only being one statement
+# a
+# b
+# c
+# error: 4, "trio", "move_on_after"
+# d
+# e
+# f
+# g
+# h
+# this comment is kept
+...
+
+# fmt: off
+# a
+# b
+# error: 4, "trio", "move_on_after"
+# c
+...; ...; ...
+# fmt: on
diff --git a/tests/conftest.py b/tests/conftest.py
index 0dd074e..a3ebae2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
)
+ parser.addoption(
+ "--generate-autofix",
+ action="store_true",
+ default=False,
+ help="generate autofix file content",
+ )
parser.addoption(
"--enable-visitor-codes-regex",
default=".*",
@@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
item.add_marker(skip_fuzz)
+@pytest.fixture()
+def generate_autofix(request: pytest.FixtureRequest):
+ return request.config.getoption("generate_autofix")
+
+
@pytest.fixture()
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
return request.config.getoption("--enable-visitor-codes-regex")
diff --git a/tests/eval_files/trio100.py b/tests/eval_files/trio100.py
index 04d942a..b63cf5e 100644
--- a/tests/eval_files/trio100.py
+++ b/tests/eval_files/trio100.py
@@ -3,20 +3,23 @@
import trio
with trio.move_on_after(10): # error: 5, "trio", "move_on_after"
- pass
+ ...
async def function_name():
+ # fmt: off
+ async with trio.fail_after(10): ...; ...; ... # error: 15, "trio", "fail_after"
+ # fmt: on
async with trio.fail_after(10): # error: 15, "trio", "fail_after"
- pass
+ ...
async with trio.fail_at(10): # error: 15, "trio", "fail_at"
- pass
+ ...
async with trio.move_on_after(10): # error: 15, "trio", "move_on_after"
- pass
+ ...
async with trio.move_on_at(10): # error: 15, "trio", "move_on_at"
- pass
+ ...
async with trio.CancelScope(...): # error: 15, "trio", "CancelScope"
- pass
+ ...
with trio.move_on_after(10):
await trio.sleep(1)
@@ -31,23 +34,23 @@ async def function_name():
print("hello")
with open("filename") as _:
- pass
+ ...
with trio.fail_after(10): # error: 9, "trio", "fail_after"
- pass
+ ...
send_channel, receive_channel = trio.open_memory_channel(0)
async with trio.fail_after(10):
async with send_channel:
- pass
+ ...
async with trio.fail_after(10):
async for _ in receive_channel:
- pass
+ ...
async with trio.fail_after(10): # error: 15, "trio", "fail_after"
for _ in receive_channel:
- pass
+ ...
# fix missed alarm when function is defined inside the with scope
with trio.move_on_after(10): # error: 9, "trio", "move_on_after"
@@ -63,7 +66,7 @@ async def foo():
await trio.sleep(1)
async with random_ignored_library.fail_after(10):
- pass
+ ...
async def function_name2():
@@ -71,11 +74,11 @@ async def function_name2():
open("") as _,
trio.fail_after(10), # error: 8, "trio", "fail_after"
):
- pass
+ ...
with (
trio.fail_after(5), # error: 8, "trio", "fail_after"
open("") as _,
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
):
- pass
+ ...
diff --git a/tests/eval_files/trio100_simple_autofix.py b/tests/eval_files/trio100_simple_autofix.py
new file mode 100644
index 0000000..46655c3
--- /dev/null
+++ b/tests/eval_files/trio100_simple_autofix.py
@@ -0,0 +1,58 @@
+import trio
+
+# a
+# b
+with trio.move_on_after(10): # error: 5, "trio", "move_on_after"
+ # c
+ # d
+ print(1) # e
+ # f
+ # g
+ print(2) # h
+ # i
+ # j
+ print(3) # k
+ # l
+ # m
+# n
+
+with trio.move_on_after(10): # error: 5, "trio", "move_on_after"
+ ...
+
+
+# a
+# b
+# fmt: off
+with trio.move_on_after(10): ...;...;... # error: 5, "trio", "move_on_after"
+# fmt: on
+# c
+# d
+
+# Doesn't autofix With's with multiple withitems
+with (
+ trio.move_on_after(10), # error: 4, "trio", "move_on_after"
+ open("") as f,
+):
+ ...
+
+
+# multiline with, despite only being one statement
+with ( # a
+ # b
+ # c
+ trio.move_on_after( # error: 4, "trio", "move_on_after"
+ # d
+ 9999999999999999999999999999999999999999999999999999999 # e
+ # f
+ ) # g
+ # h
+): # this comment is kept
+ ...
+
+# fmt: off
+with ( # a
+ # b
+ trio.move_on_after(10) # error: 4, "trio", "move_on_after"
+ # c
+): ...; ...; ...
+# fmt: on
diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py
index d6cd361..4f3c32d 100644
--- a/tests/test_config_and_args.py
+++ b/tests/test_config_and_args.py
@@ -63,8 +63,24 @@ def test_run_no_git_repo(
with pytest.raises(SystemExit):
from flake8_trio import __main__ # noqa
out, err = capsys.readouterr()
- assert out == "Doesn't seem to be a git repo; pass filenames to format.\n"
+ assert err == "Doesn't seem to be a git repo; pass filenames to format.\n"
+ assert not out
+
+
+def test_run_100_autofix(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
+):
+ err_msg = _common_error_setup(tmp_path)
+ monkeypatch.chdir(tmp_path)
+ monkeypatch.setattr(
+ sys, "argv", [tmp_path / "flake8_trio", "--autofix", "./example.py"]
+ )
+ from flake8_trio import __main__ # noqa
+
+ out, err = capsys.readouterr()
+ assert out == err_msg
assert not err
+ assert tmp_path.joinpath("example.py").read_text() == "import trio\n...\n"
def test_114_raises_on_invalid_parameter(capsys: pytest.CaptureFixture[str]):
diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py
index eef016d..6d7867f 100644
--- a/tests/test_flake8_trio.py
+++ b/tests/test_flake8_trio.py
@@ -4,6 +4,7 @@
import ast
import copy
+import difflib
import itertools
import os
import re
@@ -33,6 +34,11 @@
test_files: list[tuple[str, Path]] = sorted(
(f.stem.upper(), f) for f in (Path(__file__).parent / "eval_files").iterdir()
)
+autofix_files: dict[str, Path] = {
+ f.stem.upper(): f for f in (Path(__file__).parent / "autofix_files").iterdir()
+}
+# check that there's an eval file for each autofix file
+assert set(autofix_files.keys()) - {f[0] for f in test_files} == set()
class ParseError(Exception):
@@ -65,19 +71,41 @@ def check_version(test: str):
}
+def check_autofix(test: str, plugin: Plugin, generate_autofix: bool):
+ if test not in autofix_files:
+ return
+ visited_code = plugin.module.code
+ current_generated = autofix_files[test].read_text()
+ if generate_autofix:
+ if visited_code != current_generated:
+ print(f"\nregenerating {test}")
+ sys.stdout.writelines(
+ difflib.unified_diff(
+ current_generated.splitlines(keepends=True),
+ visited_code.splitlines(keepends=True),
+ )
+ )
+ autofix_files[test].write_text(visited_code)
+ return
+ assert visited_code == current_generated
+
+
@pytest.mark.parametrize(("test", "path"), test_files)
-def test_eval(test: str, path: Path):
+def test_eval(test: str, path: Path, generate_autofix: bool):
content = path.read_text()
if "# NOTRIO" in content:
pytest.skip("file marked with NOTRIO")
expected, parsed_args = _parse_eval_file(test, content)
+ parsed_args.append("--autofix")
if "# TRIO_NO_ERROR" in content:
expected = []
plugin = Plugin.from_source(content)
_ = assert_expected_errors(plugin, *expected, args=parsed_args)
+ check_autofix(test, plugin, generate_autofix)
+
@pytest.mark.parametrize(("test", "path"), test_files)
def test_eval_anyio(test: str, path: Path):
diff --git a/tox.ini b/tox.ini
index 79239ee..e084670 100644
--- a/tox.ini
+++ b/tox.ini
@@ -48,9 +48,9 @@ exclude_lines =
[flake8]
max-line-length = 90
-extend-ignore = S101, D101, D102, D103, D105, D106, D107, TC006
+extend-ignore = S101, D101, D102, D103, D105, D106, D107
extend-enable = TC10
-exclude = .*, tests/eval_files/*
+exclude = .*, tests/eval_files/*, tests/autofix_files/*
per-file-ignores =
flake8_trio/visitors/__init__.py: F401, E402
# (E301, E302) black formats stub files without excessive blank lines