Skip to content

Commit

Permalink
Add support for forward-reference typings (PEP-484#forward-references)
Browse files Browse the repository at this point in the history
Implementation: A context tracker monitors the incomping token stream and reports whether we are currently in a typing context or not, if we are in a typing context then the normal replacement code treats string values as a new token string to unasyncify recursively.
  • Loading branch information
ntninja committed Dec 18, 2020
1 parent 600807e commit 3a6ea0b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 3 deletions.
82 changes: 79 additions & 3 deletions src/unasync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,57 @@ def _unasync_file(self, filepath):
def _unasync_tokens(self, tokens):
# TODO __await__, ...?
used_space = None
context = None # Can be `None`, `"func_decl"`, `"func_name"`, `"arg_list"`, `"arg_list_end"`, `"return_type"`
brace_depth = 0
typing_ctx = False

for space, toknum, tokval in tokens:
# Update context state tracker
if context is None and toknum == std_tokenize.NAME and tokval == "def":
context = "func_decl"
elif context == "func_decl" and toknum == std_tokenize.NAME:
context = "func_name"
elif context == "func_name" and toknum == std_tokenize.OP and tokval == "(":
context = "arg_list"
elif context == "arg_list":
if toknum == std_tokenize.OP and tokval in ("(", "["):
brace_depth += 1
elif (
toknum == std_tokenize.OP
and tokval in (")", "]")
and brace_depth >= 1
):
brace_depth -= 1
elif toknum == std_tokenize.OP and tokval == ")":
context = "arg_list_end"
elif toknum == std_tokenize.OP and tokval == ":" and brace_depth < 1:
typing_ctx = True
elif toknum == std_tokenize.OP and tokval == "," and brace_depth < 1:
typing_ctx = False
elif (
context == "arg_list_end"
and toknum == std_tokenize.OP
and tokval == "->"
):
context = "return_type"
typing_ctx = True
elif context == "return_type":
if toknum == std_tokenize.OP and tokval in ("(", "["):
brace_depth += 1
elif (
toknum == std_tokenize.OP
and tokval in (")", "]")
and brace_depth >= 1
):
brace_depth -= 1
elif toknum == std_tokenize.OP and tokval == ":":
context = None
typing_ctx = False
else: # Something unexpected happend - reset state
context = None
brace_depth = 0
typing_ctx = False

if tokval in ["async", "await"]:
# When removing async or await, we want to use the whitespace that
# was before async/await before the next token so that
Expand All @@ -111,8 +161,34 @@ def _unasync_tokens(self, tokens):
if toknum == std_tokenize.NAME:
tokval = self._unasync_name(tokval)
elif toknum == std_tokenize.STRING:
left_quote, name, right_quote = tokval[0], tokval[1:-1], tokval[-1]
tokval = left_quote + self._unasync_name(name) + right_quote
# Strings in typing context are forward-references and should be unasyncified
quote = ""
prefix = ""
while ord(tokval[0]) in range(ord("a"), ord("z") + 1):
prefix += tokval[0]
tokval = tokval[1:]

if tokval.startswith('"""') and tokval.endswith('"""'):
quote = '"""' # Broken syntax highlighters workaround: """
elif tokval.startswith("'''") and tokval.endswith("'''"):
quote = "'''" # Broken syntax highlighters wokraround: '''
elif tokval.startswith('"') and tokval.endswith('"'):
quote = '"'
elif tokval.startswith("'") and tokval.endswith(
"'"
): # pragma: no branch
quote = "'"
assert (
len(quote) > 0
), "Quoting style of string {0!r} unknown".format(tokval)
stringval = tokval[len(quote) : -len(quote)]
if typing_ctx:
stringval = _untokenize(
self._unasync_tokens(_tokenize(StringIO(stringval)))
)
else:
stringval = self._unasync_name(stringval)
tokval = prefix + quote + stringval + quote
elif toknum == std_tokenize.COMMENT and tokval.startswith(
_TYPE_COMMENT_PREFIX
):
Expand Down Expand Up @@ -193,7 +269,7 @@ def _tokenize(f):
# Somehow Python 3.5 and below produce the ENDMARKER in a way that
# causes superfluous continuation lines to be generated
if tok.type != std_tokenize.ENDMARKER:
yield ("", std_tokenize.STRING, " \\\n")
yield (" ", std_tokenize.NEWLINE, "\\\n")
last_end = (tok.start[0], 0)

space = ""
Expand Down
15 changes: 15 additions & 0 deletions tests/data/async/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ async def func2(a): # type: (typing.AsyncIterable[int]) -> str
return str(b)


# fmt: off
# A forward-reference typed function that returns an iterator for an (a)sync iterable
async def aiter1(a: "typing.AsyncIterable[int]") -> 'typing.AsyncIterable[int]':
return a.__aiter__()

# Same as the above but using tripple-quoted strings
async def aiter2(a: """typing.AsyncIterable[int]""") -> r'''typing.AsyncIterable[int]''':
return a.__aiter__()

# Same as the above but without forward-references
async def aiter3(a: typing.AsyncIterable[int]) -> typing.AsyncIterable[int]:
return a.__aiter__()
# fmt: on


# And some funky edge cases to at least cover the relevant at all in this test
a: int = 5
b: str = a # type: ignore # This is the actual comment and the type declaration silences the warning that would otherwise happen
Expand Down
15 changes: 15 additions & 0 deletions tests/data/sync/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,21 @@ def func2(a): # type: (typing.Iterable[int]) -> str
return str(b)


# fmt: off
# A forward-reference typed function that returns an iterator for an (a)sync iterable
def aiter1(a: "typing.Iterable[int]") -> 'typing.Iterable[int]':
return a.__iter__()

# Same as the above but using tripple-quoted strings
def aiter2(a: """typing.Iterable[int]""") -> r'''typing.Iterable[int]''':
return a.__iter__()

# Same as the above but without forward-references
def aiter3(a: typing.Iterable[int]) -> typing.Iterable[int]:
return a.__iter__()
# fmt: on


# And some funky edge cases to at least cover the relevant at all in this test
a: int = 5
b: str = a # type: ignore # This is the actual comment and the type declaration silences the warning that would otherwise happen
Expand Down
3 changes: 3 additions & 0 deletions tests/test_unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
SYNC_DIR = os.path.join(TEST_DIR, "sync")
TEST_FILES = sorted([f for f in os.listdir(ASYNC_DIR) if f.endswith(".py")])

if sys.version_info[0] == 2:
TEST_FILES.remove("typing.py")


def list_files(startpath):
output = ""
Expand Down

0 comments on commit 3a6ea0b

Please sign in to comment.