Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement autofix for TRIO100 #149

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions flake8_trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def main():
cwd=root,
).stdout.splitlines()
except (subprocess.SubprocessError, FileNotFoundError):
print("Doesn't seem to be a git repo; pass filenames to format.")
print(
"Doesn't seem to be a git repo; pass filenames to format.",
file=sys.stderr,
)
sys.exit(1)
all_filenames = [
os.path.join(root, f) for f in all_filenames if _should_format(f)
Expand All @@ -110,6 +113,9 @@ def main():
plugin = Plugin.from_filename(file)
for error in sorted(plugin.run()):
print(f"{file}:{error}")
if plugin.options.autofix:
with open(file, "w") as file:
file.write(plugin.module.code)


class Plugin:
Expand All @@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
self._tree = tree
source = "".join(lines)

self._module: cst.Module = cst_parse_module_native(source)
self.module: cst.Module = cst_parse_module_native(source)

@classmethod
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
Expand All @@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
plugin = Plugin.__new__(cls)
super(Plugin, plugin).__init__()
plugin._tree = ast.parse(source)
plugin._module = cst_parse_module_native(source)
plugin.module = cst_parse_module_native(source)
return plugin

def run(self) -> Iterable[Error]:
yield from Flake8TrioRunner.run(self._tree, self.options)
yield from Flake8TrioRunner_cst(self.options).run(self._module)
cst_runner = Flake8TrioRunner_cst(self.options, self.module)
yield from cst_runner.run()
self.module = cst_runner.module

@staticmethod
def add_options(option_manager: OptionManager | ArgumentParser):
Expand All @@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
add_argument = functools.partial(
option_manager.add_option, parse_from_config=True
)
add_argument("--autofix", action="store_true", required=False)

add_argument(
"--no-checkpoint-warning-decorators",
Expand Down
8 changes: 4 additions & 4 deletions flake8_trio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def visit(self, node: ast.AST):


class Flake8TrioRunner_cst:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the runner is kind of .. not necessary and clunky, esp as currently __init__ does a bunch of logic itself. Might warrant some restructuring.

def __init__(self, options: Namespace):
def __init__(self, options: Namespace, module: Module):
super().__init__()
self.state = SharedState(options)
self.options = options
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
)
self.module = module

def run(self, module: Module) -> Iterable[Error]:
def run(self) -> Iterable[Error]:
if not self.visitors:
return
wrapper = cst.MetadataWrapper(module)
for v in self.visitors:
_ = wrapper.visit(v)
self.module = cst.MetadataWrapper(self.module).visit(v)
yield from self.state.problems
Comment on lines -111 to 117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now saves and updates self.module after visiting.


def selected(self, error_codes: dict[str, str]) -> bool:
Expand Down
61 changes: 60 additions & 1 deletion flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar
from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast

import libcst as cst
import libcst.matchers as m
Expand Down Expand Up @@ -341,3 +341,62 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
),
)
)


def get_comments(node: cst.CSTNode | Iterable[cst.CSTNode]) -> Iterator[cst.EmptyLine]:
# pyright can't use hasattr to narrow the type, so need a bunch of casts
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(node, "__iter__"):
for n in cast("Iterable[cst.CSTNode]", node):
yield from get_comments(n)
return
yield from (
cst.EmptyLine(comment=ensure_type(c, cst.Comment))
for c in m.findall(cast("cst.CSTNode", node), m.Comment())
)
return


# used in TRIO100
def flatten_preserving_comments(node: cst.BaseCompoundStatement):
# add leading lines (comments and empty lines) for the node to be removed
new_leading_lines = list(node.leading_lines)

# add other comments belonging to the node as empty lines with comments
for attr in "lpar", "items", "rpar":
# pragma, since this is currently only used to flatten `With` statements
if comment_nodes := getattr(node, attr, None): # pragma: no cover
new_leading_lines.extend(get_comments(comment_nodes))

# node.body is a BaseSuite, whose subclasses are SimpleStatementSuite
# and IndentedBlock
if isinstance(node.body, cst.SimpleStatementSuite):
# `with ...: pass;pass;pass` -> pass;pass;pass
return cst.SimpleStatementLine(node.body.body, leading_lines=new_leading_lines)

assert isinstance(node.body, cst.IndentedBlock)
nodes = list(node.body.body)

# nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
# and BaseCompoundStatement - both of which has leading_lines
assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))

# add body header comment - i.e. comments on the same/last line of the statement
if node.body.header and node.body.header.comment:
new_leading_lines.append(
cst.EmptyLine(indent=True, comment=node.body.header.comment)
)
# add the leading lines of the first node
new_leading_lines.extend(nodes[0].leading_lines)
# update the first node with all the above constructed lines
nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)

# if there's comments in the footer of the indented block, add a pass
# statement with the comments as leading lines
if node.body.footer:
nodes.append(
cst.SimpleStatementLine(
[cst.Pass()],
node.body.footer,
)
)
return cst.FlattenSentinel(nodes)
17 changes: 13 additions & 4 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
import libcst.matchers as m

from .flake8triovisitor import Flake8TrioVisitor_cst
from .helpers import AttributeCall, error_class_cst, with_has_call
from .helpers import (
AttributeCall,
error_class_cst,
flatten_preserving_comments,
with_has_call,
)


@error_class_cst
Expand Down Expand Up @@ -46,12 +51,16 @@ def visit_With(self, node: cst.With) -> None:
else:
self.has_checkpoint_stack.append(True)

def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
def leave_With(
self, original_node: cst.With, updated_node: cst.With
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
if not self.has_checkpoint_stack.pop():
for res in self.node_dict[original_node]:
self.error(res.node, res.base, res.function)
# if: autofixing is enabled for this code
# then: remove the with and pop out it's body

if self.options.autofix and len(updated_node.items) == 1:
return flatten_preserving_comments(updated_node)
Comment on lines 58 to +62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that we don't actually emit the error for autofixed code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently always does, I'll introduce ways of configuring that soon :)


return updated_node

def visit_For(self, node: cst.For):
Expand Down
2 changes: 1 addition & 1 deletion flake8_trio/visitors/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
nursery = get_matching_call(item.context_expr, "open_nursery")

# `isinstance(..., ast.Call)` is done in get_matching_call
body_call = cast(ast.Call, node.body[0].value)
body_call = cast("ast.Call", node.body[0].value)

if (
nursery is not None
Expand Down
84 changes: 84 additions & 0 deletions tests/autofix_files/trio100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's not a ton to see here ("trio100_simple_autofix.py" is easier to read), but if you want to review it you should view this and tests/eval_files/trio100.py with a difftool of your choice.


import trio

# error: 5, "trio", "move_on_after"
...


async def function_name():
# fmt: off
...; ...; ...
# fmt: on
# error: 15, "trio", "fail_after"
...
# error: 15, "trio", "fail_at"
...
# error: 15, "trio", "move_on_after"
...
# error: 15, "trio", "move_on_at"
...
# error: 15, "trio", "CancelScope"
...

with trio.move_on_after(10):
await trio.sleep(1)

with trio.move_on_after(10):
await trio.sleep(1)
print("hello")

with trio.move_on_after(10):
while True:
await trio.sleep(1)
print("hello")

with open("filename") as _:
...

# error: 9, "trio", "fail_after"
...

send_channel, receive_channel = trio.open_memory_channel(0)
async with trio.fail_after(10):
async with send_channel:
...

async with trio.fail_after(10):
async for _ in receive_channel:
...

# error: 15, "trio", "fail_after"
for _ in receive_channel:
...

# fix missed alarm when function is defined inside the with scope
# error: 9, "trio", "move_on_after"

async def foo():
await trio.sleep(1)

# error: 9, "trio", "move_on_after"
if ...:

async def foo():
if ...:
await trio.sleep(1)

async with random_ignored_library.fail_after(10):
...


async def function_name2():
with (
open("") as _,
trio.fail_after(10), # error: 8, "trio", "fail_after"
):
...

with (
trio.fail_after(5), # error: 8, "trio", "fail_after"
open("") as _,
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
):
...
59 changes: 59 additions & 0 deletions tests/autofix_files/trio100_simple_autofix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import trio

# a
# b
# error: 5, "trio", "move_on_after"
# c
# d
print(1) # e
# f
# g
print(2) # h
# i
# j
print(3) # k
# l
# m
pass
# n

# error: 5, "trio", "move_on_after"
...


# a
# b
# fmt: off
...;...;...
# fmt: on
# c
# d

# Doesn't autofix With's with multiple withitems
with (
trio.move_on_after(10), # error: 4, "trio", "move_on_after"
open("") as f,
):
...


# multiline with, despite only being one statement
# a
# b
# c
# error: 4, "trio", "move_on_after"
# d
# e
# f
# g
# h
# this comment is kept
...

# fmt: off
# a
# b
# error: 4, "trio", "move_on_after"
# c
...; ...; ...
# fmt: on
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
)
parser.addoption(
"--generate-autofix",
action="store_true",
default=False,
help="generate autofix file content",
)
parser.addoption(
"--enable-visitor-codes-regex",
default=".*",
Expand All @@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
item.add_marker(skip_fuzz)


@pytest.fixture()
def generate_autofix(request: pytest.FixtureRequest):
return request.config.getoption("generate_autofix")


@pytest.fixture()
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
return request.config.getoption("--enable-visitor-codes-regex")
Loading