diff --git a/src/unasync/__init__.py b/src/unasync/__init__.py index 52af030..17c4259 100644 --- a/src/unasync/__init__.py +++ b/src/unasync/__init__.py @@ -100,7 +100,57 @@ def _unasync_file(self, filepath): def _unasync_tokens(self, tokens): # TODO __await__, ...? used_space = None + context = None # Can be `None`, `"func_decl"`, `"func_name"`, `"arg_list"`, `"arg_list_end"`, `"return_type"` + brace_depth = 0 + typing_ctx = False + for space, toknum, tokval in tokens: + # Update context state tracker + if context is None and toknum == std_tokenize.NAME and tokval == "def": + context = "func_decl" + elif context == "func_decl" and toknum == std_tokenize.NAME: + context = "func_name" + elif context == "func_name" and toknum == std_tokenize.OP and tokval == "(": + context = "arg_list" + elif context == "arg_list": + if toknum == std_tokenize.OP and tokval in ("(", "["): + brace_depth += 1 + elif ( + toknum == std_tokenize.OP + and tokval in (")", "]") + and brace_depth >= 1 + ): + brace_depth -= 1 + elif toknum == std_tokenize.OP and tokval == ")": + context = "arg_list_end" + elif toknum == std_tokenize.OP and tokval == ":" and brace_depth < 1: + typing_ctx = True + elif toknum == std_tokenize.OP and tokval == "," and brace_depth < 1: + typing_ctx = False + elif ( + context == "arg_list_end" + and toknum == std_tokenize.OP + and tokval == "->" + ): + context = "return_type" + typing_ctx = True + elif context == "return_type": + if toknum == std_tokenize.OP and tokval in ("(", "["): + brace_depth += 1 + elif ( + toknum == std_tokenize.OP + and tokval in (")", "]") + and brace_depth >= 1 + ): + brace_depth -= 1 + elif toknum == std_tokenize.OP and tokval == ":": + context = None + typing_ctx = False + else: # Something unexpected happend - reset state + context = None + brace_depth = 0 + typing_ctx = False + if tokval in ["async", "await"]: # When removing async or await, we want to use the whitespace that # was before async/await before the next token so that @@ -111,8 +161,34 @@ def _unasync_tokens(self, tokens): if toknum == std_tokenize.NAME: tokval = self._unasync_name(tokval) elif toknum == std_tokenize.STRING: - left_quote, name, right_quote = tokval[0], tokval[1:-1], tokval[-1] - tokval = left_quote + self._unasync_name(name) + right_quote + # Strings in typing context are forward-references and should be unasyncified + quote = "" + prefix = "" + while ord(tokval[0]) in range(ord("a"), ord("z") + 1): + prefix += tokval[0] + tokval = tokval[1:] + + if tokval.startswith('"""') and tokval.endswith('"""'): + quote = '"""' # Broken syntax highlighters workaround: """ + elif tokval.startswith("'''") and tokval.endswith("'''"): + quote = "'''" # Broken syntax highlighters wokraround: ''' + elif tokval.startswith('"') and tokval.endswith('"'): + quote = '"' + elif tokval.startswith("'") and tokval.endswith( + "'" + ): # pragma: no branch + quote = "'" + assert ( + len(quote) > 0 + ), "Quoting style of string {0!r} unknown".format(tokval) + stringval = tokval[len(quote) : -len(quote)] + if typing_ctx: + stringval = _untokenize( + self._unasync_tokens(_tokenize(StringIO(stringval))) + ) + else: + stringval = self._unasync_name(stringval) + tokval = prefix + quote + stringval + quote elif toknum == std_tokenize.COMMENT and tokval.startswith( _TYPE_COMMENT_PREFIX ): @@ -193,7 +269,7 @@ def _tokenize(f): # Somehow Python 3.5 and below produce the ENDMARKER in a way that # causes superfluous continuation lines to be generated if tok.type != std_tokenize.ENDMARKER: - yield ("", std_tokenize.STRING, " \\\n") + yield (" ", std_tokenize.NEWLINE, "\\\n") last_end = (tok.start[0], 0) space = "" diff --git a/tests/data/async/typing.py b/tests/data/async/typing.py index 64bcfb6..b5a0e83 100644 --- a/tests/data/async/typing.py +++ b/tests/data/async/typing.py @@ -18,6 +18,21 @@ async def func2(a): # type: (typing.AsyncIterable[int]) -> str return str(b) +# fmt: off +# A forward-reference typed function that returns an iterator for an (a)sync iterable +async def aiter1(a: "typing.AsyncIterable[int]") -> 'typing.AsyncIterable[int]': + return a.__aiter__() + +# Same as the above but using tripple-quoted strings +async def aiter2(a: """typing.AsyncIterable[int]""") -> r'''typing.AsyncIterable[int]''': + return a.__aiter__() + +# Same as the above but without forward-references +async def aiter3(a: typing.AsyncIterable[int]) -> typing.AsyncIterable[int]: + return a.__aiter__() +# fmt: on + + # And some funky edge cases to at least cover the relevant at all in this test a: int = 5 b: str = a # type: ignore # This is the actual comment and the type declaration silences the warning that would otherwise happen diff --git a/tests/data/sync/typing.py b/tests/data/sync/typing.py index 213b048..11eedc7 100644 --- a/tests/data/sync/typing.py +++ b/tests/data/sync/typing.py @@ -18,6 +18,21 @@ def func2(a): # type: (typing.Iterable[int]) -> str return str(b) +# fmt: off +# A forward-reference typed function that returns an iterator for an (a)sync iterable +def aiter1(a: "typing.Iterable[int]") -> 'typing.Iterable[int]': + return a.__iter__() + +# Same as the above but using tripple-quoted strings +def aiter2(a: """typing.Iterable[int]""") -> r'''typing.Iterable[int]''': + return a.__iter__() + +# Same as the above but without forward-references +def aiter3(a: typing.Iterable[int]) -> typing.Iterable[int]: + return a.__iter__() +# fmt: on + + # And some funky edge cases to at least cover the relevant at all in this test a: int = 5 b: str = a # type: ignore # This is the actual comment and the type declaration silences the warning that would otherwise happen diff --git a/tests/test_unasync.py b/tests/test_unasync.py index f4b08eb..f8c4ed7 100644 --- a/tests/test_unasync.py +++ b/tests/test_unasync.py @@ -15,6 +15,9 @@ SYNC_DIR = os.path.join(TEST_DIR, "sync") TEST_FILES = sorted([f for f in os.listdir(ASYNC_DIR) if f.endswith(".py")]) +if sys.version_info[0] == 2: + TEST_FILES.remove("typing.py") + def list_files(startpath): output = ""