Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@check_types now properly passes in *args **kwargs and checks their types #1336

Merged
merged 6 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 86 additions & 12 deletions pandera/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,29 +692,103 @@ def _check_arg(arg_name: str, arg_value: Any) -> Any:

sig = inspect.signature(wrapped)

def validate_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
return {
arg_name: _check_arg(arg_name, arg_value)
for arg_name, arg_value in arguments.items()
}
def validate_args(
named_arguments: Dict[str, Any], arguments: Tuple[Any, ...]
) -> List[Any]:
"""
Validates schemas of both explicit and *args-like function arguments.

:param named_arguments: Bundled function arguments. Organized as key-value pairs of the
argument name and value. *args-like arguments are bundled into a single tuple.
Example: OrderedDict({'arg1': 1, 'arg2': 2, 'star_args': (3, 4, 5)})
:param arguments: Unpacked function arguments, as written in the function call.
Example: (1, 2, 3, 4, 5)
:return: List of validated function arguments.
"""

# Check for an '*args'-like argument
if len(arguments) > len(named_arguments):
(
star_args_name,
star_args_values,
) = named_arguments.popitem() # *args is the last item

star_args_tuple = (
_check_arg(star_args_name, arg_value)
for arg_value in star_args_values
)

explicit_args_tuple = (
_check_arg(arg_name, arg_value)
for arg_name, arg_value in named_arguments.items()
)

return list((*explicit_args_tuple, *star_args_tuple))

else:
return list(
_check_arg(arg_name, arg_value)
for arg_name, arg_value in named_arguments.items()
)

def validate_kwargs(
named_kwargs: Dict[str, Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""
Validates schemas of both explicit and **kwargs-like function arguments.

:param named_kwargs: Bundled function keyword arguments. Organized as key-value pairs of
the keyword argument name and value. **kwargs-like arguments are bundled into a single
dictionary.
Example: OrderedDict({'kwarg1': 1, 'kwarg2': 2, 'star_kwargs': {'kwarg3': 3, 'kwarg4': 4}})
:param kwargs: Unpacked function keyword arguments, as written in the function call.
Example: {'kwarg1': 1, 'kwarg2': 2, 'kwarg3': 3, 'kwarg4': 4}
:return: list of validated function keyword arguments.
"""

# Check for an '**kwargs'-like argument
if kwargs.keys() != named_kwargs.keys():
(
star_kwargs_name,
star_kwargs_dict,
) = named_kwargs.popitem() # **kwargs is the last item

explicit_kwargs_dict = {
arg_name: _check_arg(arg_name, arg_value)
for arg_name, arg_value in named_kwargs.items()
}

star_kwargs_dict = {
arg_name: _check_arg(star_kwargs_name, arg_value)
for arg_name, arg_value in star_kwargs_dict.items()
}

return {**explicit_kwargs_dict, **star_kwargs_dict}

else:
return {
arg_name: _check_arg(arg_name, arg_value)
for arg_name, arg_value in named_kwargs.items()
}

def validate_inputs(
instance: Optional[Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
) -> Tuple[List[Any], Dict[str, Any]]:
if instance is not None:
# If the wrapped function is a method -> add "self" as the first positional arg
args = (instance, *args)

validated_pos = validate_args(sig.bind_partial(*args).arguments)
validated_kwd = validate_args(sig.bind_partial(**kwargs).arguments)
validated_pos = validate_args(sig.bind_partial(*args).arguments, args)
validated_kwd = validate_kwargs(
sig.bind_partial(**kwargs).arguments, kwargs
)

if instance is not None:
# If the decorated func is a method, "wrapped" is a bound method
# -> remove "self" before passing positional args through
first_pos_arg = list(sig.parameters)[0]
del validated_pos[first_pos_arg]
del validated_pos[0]

return validated_pos, validated_kwd

Expand All @@ -733,7 +807,7 @@ async def _wrapper(
validated_pos, validated_kwd = validate_inputs(
instance, args, kwargs
)
out = await wrapped_(*validated_pos.values(), **validated_kwd)
out = await wrapped_(*validated_pos, **validated_kwd)
return _check_arg("return", out)

else:
Expand All @@ -751,7 +825,7 @@ def _wrapper(
validated_pos, validated_kwd = validate_inputs(
instance, args, kwargs
)
out = wrapped_(*validated_pos.values(), **validated_kwd)
out = wrapped_(*validated_pos, **validated_kwd)
return _check_arg("return", out)

wrapped_fn = _wrapper(wrapped) # pylint:disable=no-value-for-parameter
Expand Down
116 changes: 116 additions & 0 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,122 @@ def union_df_int_types_pydantic_check(
assert isinstance(str_val_pydantic, int)


def test_check_types_star_args() -> None:
"""Test to check_types for functions with *args arguments"""

@check_types
def get_len_star_args__int(
# pylint: disable=unused-argument
arg1: int,
*args: int,
) -> int:
return len(args)

@check_types
def get_len_star_args__dataframe(
# pylint: disable=unused-argument
arg1: DataFrame[InSchema],
*args: DataFrame[InSchema],
) -> int:
return len(args)

in_1 = pd.DataFrame({"a": [1]}, index=["1"])
in_2 = pd.DataFrame({"a": [1]}, index=["1"])
in_3 = pd.DataFrame({"a": [1]}, index=["1"])
in_4_error = pd.DataFrame({"b": [1]}, index=["1"])

assert get_len_star_args__int(1, 2, 3) == 2
assert get_len_star_args__dataframe(in_1, in_2) == 1
assert get_len_star_args__dataframe(in_1, in_2, in_3) == 2

with pytest.raises(
errors.SchemaError, match="column 'a' not in dataframe"
):
get_len_star_args__dataframe(in_1, in_2, in_4_error)


def test_check_types_star_kwargs() -> None:
"""Test to check_types for functions with **kwargs arguments"""

@check_types
def get_star_kwargs_keys_int(
# pylint: disable=unused-argument
kwarg1: int = 1,
**kwargs: int,
) -> typing.List[str]:
return list(kwargs.keys())

@check_types
def get_star_kwargs_keys_dataframe(
# pylint: disable=unused-argument
kwarg1: DataFrame[InSchema] = None,
**kwargs: DataFrame[InSchema],
) -> typing.List[str]:
return list(kwargs.keys())

in_1 = pd.DataFrame({"a": [1]}, index=["1"])
in_2 = pd.DataFrame({"a": [1]}, index=["1"])
in_3 = pd.DataFrame({"a": [1]}, index=["1"])
in_4_error = pd.DataFrame({"b": [1]}, index=["1"])

int_kwargs_keys = get_star_kwargs_keys_int(kwarg1=1, kwarg2=2, kwarg3=3)
df_kwargs_keys_1 = get_star_kwargs_keys_dataframe(
kwarg1=in_1,
kwarg2=in_2,
)
df_kwargs_keys_2 = get_star_kwargs_keys_dataframe(
kwarg1=in_1, kwarg2=in_2, kwarg3=in_3
)

assert int_kwargs_keys == ["kwarg2", "kwarg3"]
assert df_kwargs_keys_1 == ["kwarg2"]
assert df_kwargs_keys_2 == ["kwarg2", "kwarg3"]

with pytest.raises(
errors.SchemaError, match="column 'a' not in dataframe"
):
get_star_kwargs_keys_dataframe(
kwarg1=in_1, kwarg2=in_2, kwarg3=in_4_error
)


def test_check_types_star_args_kwargs() -> None:
"""Test to check_types for functions with both *args and **kwargs"""

@check_types
def star_args_kwargs(
arg1: DataFrame[InSchema],
*args: DataFrame[InSchema],
kwarg1: DataFrame[InSchema],
**kwargs: DataFrame[InSchema],
):
return arg1, args, kwarg1, kwargs

in_1 = pd.DataFrame({"a": [1]}, index=["1"])
in_2 = pd.DataFrame({"a": [1]}, index=["1"])
in_3 = pd.DataFrame({"a": [1]}, index=["1"])

expected_arg = in_1
expected_star_args = (in_2, in_3)
expected_kwarg = in_1
expected_star_kwargs = {"kwarg2": in_2, "kwarg3": in_3}

arg, star_args, kwarg, star_kwargs = star_args_kwargs(
in_1, in_2, in_3, kwarg1=in_1, kwarg2=in_2, kwarg3=in_3
)

pd.testing.assert_frame_equal(expected_arg, arg)
pd.testing.assert_frame_equal(expected_kwarg, kwarg)

for expected, actual in zip(expected_star_args, star_args):
pd.testing.assert_frame_equal(expected, actual)

for expected, actual in zip(
expected_star_kwargs.values(), star_kwargs.values()
):
pd.testing.assert_frame_equal(expected, actual)


def test_coroutines(event_loop: AbstractEventLoop) -> None:
# pylint: disable=missing-class-docstring,too-few-public-methods,missing-function-docstring
class Schema(DataFrameModel):
Expand Down
Loading