Skip to content

Commit

Permalink
support pydantic v2 (#1253)
Browse files Browse the repository at this point in the history
* support pydantic v2

* fix tests for py 3.11

* fix tests for py 3.7

* bump cache

* fix fastapi bug

* fix pydantic test

* dont run fastapi ci for py3.7

* debug

* debugging

* debug

* skip fastapi tests with pydantic > v2

* fix BaseSettings

* dont check docstrings for pydantic v2

* [wip] need to figure out how to replace use of ModelField in the fastapi.UploadFile type

* fix fastapi

Signed-off-by: Niels Bantilan <[email protected]>

* update dependencies

Signed-off-by: Niels Bantilan <[email protected]>

* mypy

Signed-off-by: Niels Bantilan <[email protected]>

* ignore modin-ray in pydantic v2

Signed-off-by: Niels Bantilan <[email protected]>

* update

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update pydantic version

* update ci

---------

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Aug 26, 2023
1 parent d71d890 commit 850dcf8
Show file tree
Hide file tree
Showing 20 changed files with 334 additions and 146 deletions.
31 changes: 21 additions & 10 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:

tests:
name: >
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }})
CI Tests (${{ matrix.python-version }}, ${{ matrix.os }}, pandas-${{ matrix.pandas-version }}, pydantic-${{ matrix.pydantic-version }})
runs-on: ${{ matrix.os }}
defaults:
run:
Expand All @@ -101,10 +101,11 @@ jobs:
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
pandas-version: ["1.3.0", "1.5.2", "2.0.1"]
pandas-version: ["1.5.3", "2.0.3"]
pydantic-version: ["1.10.11", "2.3.0"]
exclude:
- python-version: "3.7"
pandas-version: "2.0.1"
pandas-version: "2.0.3"
- python-version: "3.7"
pandas-version: "1.5.2"
- python-version: "3.10"
Expand Down Expand Up @@ -163,19 +164,26 @@ jobs:

# need to install pandas via pip: conda installation is on the fritz
- name: Install Conda Deps [pandas 2]
if: ${{ matrix.pandas-version == '2.0.1' }}
if: ${{ matrix.pandas-version == '2.0.3' }}
run: |
mamba install -c conda-forge asv pandas geopandas bokeh
mamba env update -n pandera-dev -f environment.yml
pip install pandas==${{ matrix.pandas-version }}
pip install --user dask>=2023.3.2
- name: Install Conda Deps
if: ${{ matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.pandas-version != '2.0.3' }}
run: |
mamba install -c conda-forge asv pandas==${{ matrix.pandas-version }} geopandas bokeh
mamba env update -n pandera-dev -f environment.yml
- name: Install Pydantic Deps
run: pip install -U --upgrade-strategy only-if-needed pydantic==${{ matrix.pydantic-version }}

- name: Install Pydantic v2 Deps
if : ${{ matrix.pydantic-version == '2.3.0' }}
run: pip install fastapi>=0.100.0

- run: |
conda info
conda list
Expand All @@ -200,21 +208,24 @@ jobs:
run: pytest tests/strategies ${{ env.PYTEST_FLAGS }} ${{ env.HYPOTHESIS_FLAGS }}

- name: Unit Tests - FastAPI
# there's an issue with the fastapi tests in CI that's not reproducible locally
# when pydantic > v2
if: ${{ matrix.python-version != '3.7' }}
run: pytest tests/fastapi ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - GeoPandas
run: pytest tests/geopandas ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Dask
if: ${{ matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.pandas-version != '2.0.3' }}
run: pytest tests/dask ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Pyspark
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
run: pytest tests/pyspark ${{ env.PYTEST_FLAGS }}

- name: Unit Tests - Modin-Dask
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.1' }}
if: ${{ !contains(fromJson('["3.11"]'), matrix.python-version) && matrix.pandas-version != '2.0.3' }}
run: pytest tests/modin ${{ env.PYTEST_FLAGS }}
env:
CI_MODIN_ENGINES: dask
Expand All @@ -233,9 +244,9 @@ jobs:
uses: codecov/codecov-action@v3

- name: Check Docstrings
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
run: nox ${{ env.NOX_FLAGS }} --session doctests

- name: Check Docs
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) }}
if: ${{ matrix.os != 'windows-latest' && !contains(fromJson('["3.7", "3.10", "3.11"]'), matrix.python-version) && matrix.pydantic-version != '2.0.2' }}
run: nox ${{ env.NOX_FLAGS }} --session docs
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,5 @@ disable=
arguments-differ,
unnecessary-dunder-call,
use-dict-literal,
invalid-name
invalid-name,
import-outside-toplevel
2 changes: 1 addition & 1 deletion docs/source/dtype_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ For example:

from typing import Dict, List, Tuple, NamedTuple

if sys.version_info >= (3, 9):
if sys.version_info >= (3, 12):
from typing import TypedDict
# use typing_extensions.TypedDict for python < 3.9 in order to support
# run-time availability of optional/required fields
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- typing_extensions >= 3.7.4.3
- frictionless <= 4.40.8 # v5.* introduces breaking changes
- pyarrow
- pydantic < 2.0.0
- pydantic
- multimethod

# mypy extra
Expand Down
25 changes: 21 additions & 4 deletions pandera/api/pandas/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from pandera.api.hypotheses import Hypothesis
from pandera.api.pandas.types import CheckList, PandasDtypeInputTypes, is_field
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine
from pandera.engines import pandas_engine, PYDANTIC_V2

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler


TArraySchemaBase = TypeVar("TArraySchemaBase", bound="ArraySchema")

Expand Down Expand Up @@ -203,9 +208,21 @@ def __call__(
def __eq__(self, other):
return self.__dict__ == other.__dict__

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate
if PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls._pydantic_validate, # type: ignore[misc]
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

@classmethod
def _pydantic_validate( # type: ignore
Expand Down
24 changes: 20 additions & 4 deletions pandera/api/pandas/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
StrictType,
)
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine
from pandera.engines import pandas_engine, PYDANTIC_V2

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetCoreSchemaHandler

N_INDENT_SPACES = 4

Expand Down Expand Up @@ -517,9 +521,21 @@ def _compare_dict(obj):

return _compare_dict(self) == _compare_dict(other)

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate
if PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls._pydantic_validate,
)

else:

@classmethod
def __get_validators__(cls):
yield cls._pydantic_validate

@classmethod
def _pydantic_validate(cls, schema: Any) -> "DataFrameSchema":
Expand Down
73 changes: 47 additions & 26 deletions pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,21 @@
FieldInfo,
)
from pandera.api.pandas.model_config import BaseConfig
from pandera.engines import PYDANTIC_V2
from pandera.errors import SchemaInitError
from pandera.strategies import pandas_strategies as st
from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo
from pandera.typing.common import DataFrameBase

if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetJsonSchemaHandler, GetCoreSchemaHandler

try:
from typing_extensions import get_type_hints
except ImportError: # pragma: no cover
from typing import get_type_hints # type: ignore

try:
from pydantic.fields import ModelField # pylint:disable=unused-import

HAS_PYDANTIC = True
except ImportError:
HAS_PYDANTIC = False


SchemaIndex = Union[Index, MultiIndex]

Expand Down Expand Up @@ -538,8 +536,19 @@ def _extract_df_checks(cls, check_infos: List[CheckInfo]) -> List[Check]:
return [check_info.to_check(cls) for check_info in check_infos]

@classmethod
def __get_validators__(cls):
yield cls.pydantic_validate
def get_metadata(cls) -> Optional[dict]:
"""Provide metadata for columns and schema level"""
res: Dict[Any, Any] = {"columns": {}}
columns = cls._collect_fields()

for k, (_, v) in columns.items():
res["columns"][k] = v.properties["metadata"]

res["dataframe"] = cls.Config.metadata

meta = {}
meta[cls.Config.name] = res
return meta

@classmethod
def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":
Expand All @@ -562,25 +571,37 @@ def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":

return cast("DataFrameModel", schema_model)

@classmethod
def get_metadata(cls) -> Optional[dict]:
"""Provide metadata for columns and schema level"""
res: Dict[Any, Any] = {"columns": {}}
columns = cls._collect_fields()

for k, (_, v) in columns.items():
res["columns"][k] = v.properties["metadata"]
if PYDANTIC_V2:

res["dataframe"] = cls.Config.metadata

meta = {}
meta[cls.Config.name] = res
return meta
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.pydantic_validate,
)

@classmethod
def __modify_schema__(cls, field_schema):
"""Update pydantic field schema."""
field_schema.update(_to_json_schema(cls.to_schema()))
@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema,
_handler: GetJsonSchemaHandler,
):
"""Update pydantic field schema."""
json_schema = _handler(_core_schema)
json_schema = _handler.resolve_ref_schema(json_schema)
json_schema.update(_to_json_schema(cls.to_schema()))

else:

@classmethod
def __modify_schema__(cls, field_schema):
"""Update pydantic field schema."""
field_schema.update(_to_json_schema(cls.to_schema()))

@classmethod
def __get_validators__(cls):
yield cls.pydantic_validate


SchemaModel = DataFrameModel
Expand Down
7 changes: 0 additions & 7 deletions pandera/api/pyspark/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@
except ImportError: # pragma: no cover
from typing import get_type_hints # type: ignore

try:
from pydantic.fields import ModelField # pylint:disable=unused-import

HAS_PYDANTIC = True
except ImportError: # pragma: no cover
HAS_PYDANTIC = False


_CONFIG_KEY = "Config"
MODEL_CACHE: Dict[Type["DataFrameModel"], DataFrameSchema] = {}
Expand Down
21 changes: 13 additions & 8 deletions pandera/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Pandera configuration."""

import os
from enum import Enum
from pydantic import BaseSettings

from pydantic import BaseModel


class ValidationDepth(Enum):
Expand All @@ -12,7 +14,7 @@ class ValidationDepth(Enum):
SCHEMA_AND_DATA = "SCHEMA_AND_DATA"


class PanderaConfig(BaseSettings):
class PanderaConfig(BaseModel):
"""Pandera config base class.
This should pick up environment variables automatically, e.g.:
Expand All @@ -23,11 +25,14 @@ class PanderaConfig(BaseSettings):
validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA

class Config:
"""Pydantic configuration settings."""

env_prefix = "pandera_"


# this config variable should be accessible globally
CONFIG = PanderaConfig()
CONFIG = PanderaConfig(
validation_enabled=os.environ.get(
"PANDERA_VALIDATION_ENABLED",
True,
),
validation_depth=os.environ.get(
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
),
)
6 changes: 6 additions & 0 deletions pandera/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Pandera type engines."""

from pandera.engines.utils import pydantic_version


PYDANTIC_V2 = pydantic_version().release >= (2, 0, 0)
2 changes: 1 addition & 1 deletion pandera/engines/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


# register different TypedDict type depending on python version
if sys.version_info >= (3, 9):
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict # noqa
Expand Down
Loading

0 comments on commit 850dcf8

Please sign in to comment.