diff --git a/pandera/mypy.py b/pandera/mypy.py index fc044edc0..613e8283b 100644 --- a/pandera/mypy.py +++ b/pandera/mypy.py @@ -133,7 +133,13 @@ class Schema(BaseSchema): x: pa.typing.Series[str] # mypy assignment error, cannot override types """ for def_ in self.ctx.cls.defs.body: - if not hasattr(def_, "type") or def_.type is None: + if ( + not hasattr(def_, "type") + or def_.type is None + # e.g. UnionType does not have module_name or name + or not hasattr(def_.type, "module_name") + or not hasattr(def_.type, "name") + ): continue type_ = def_.type if str(def_.type) in FIELD_GENERICS_FULLNAMES: diff --git a/tests/mypy/modules/pandas_dataframe.py b/tests/mypy/modules/pandas_dataframe.py index 533a344a0..e82a1e8fd 100644 --- a/tests/mypy/modules/pandas_dataframe.py +++ b/tests/mypy/modules/pandas_dataframe.py @@ -5,7 +5,7 @@ run statically check the functions marked pytest.mark.mypy_testing """ -from typing import cast +from typing import Optional, cast import pandas as pd @@ -24,7 +24,7 @@ class SchemaOut(pa.DataFrameModel): class AnotherSchema(pa.DataFrameModel): id: Series[int] - first_name: Series[str] + first_name: Optional[Series[str]] def fn(df: DataFrame[Schema]) -> DataFrame[SchemaOut]: diff --git a/tests/mypy/test_static_type_checking.py b/tests/mypy/test_static_type_checking.py index 7368bf7ce..0d82405ee 100644 --- a/tests/mypy/test_static_type_checking.py +++ b/tests/mypy/test_static_type_checking.py @@ -54,7 +54,11 @@ def _get_mypy_errors( ] -def test_mypy_pandas_dataframe(capfd) -> None: +@pytest.mark.parametrize( + ["config_file", "expected_errors"], + [("no_plugin.ini", PANDAS_DATAFRAME_ERRORS), ("plugin_mypy.ini", [])], +) +def test_mypy_pandas_dataframe(capfd, config_file, expected_errors) -> None: """Test that mypy raises expected errors on pandera-decorated functions.""" # pylint: disable=subprocess-run-check pytest.xfail( @@ -71,13 +75,13 @@ def test_mypy_pandas_dataframe(capfd) -> None: "--cache-dir", cache_dir, "--config-file", - str(test_module_dir / "config" / "no_plugin.ini"), + str(test_module_dir / "config" / config_file), ], text=True, ) errors = _get_mypy_errors("pandas_dataframe.py", capfd.readouterr().out) - assert len(PANDAS_DATAFRAME_ERRORS) == len(errors) - for expected, error in zip(PANDAS_DATAFRAME_ERRORS, errors): + assert len(expected_errors) == len(errors) + for expected, error in zip(expected_errors, errors): assert error["errcode"] == expected["errcode"] assert expected["msg"] == error["msg"] or re.match( expected["msg"], error["msg"]