diff --git a/pandera/engines/pyspark_engine.py b/pandera/engines/pyspark_engine.py index a9668402b..c682b1b63 100644 --- a/pandera/engines/pyspark_engine.py +++ b/pandera/engines/pyspark_engine.py @@ -13,7 +13,9 @@ import warnings from typing import Any, Iterable, Union, Optional import sys +from packaging import version +import pyspark import pyspark.sql.types as pst from pandera import dtypes, errors @@ -32,6 +34,8 @@ DEFAULT_PYSPARK_PREC = pst.DecimalType().precision DEFAULT_PYSPARK_SCALE = pst.DecimalType().scale +CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__) + @immutable(init=True) class DataType(dtypes.DataType): @@ -341,10 +345,15 @@ class Date(DataType, dtypes.Date): # type: ignore # timestamp ############################################################################### +# Default timestamp equivalents +equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore -@Engine.register_dtype( - equivalents=["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore -) +# Include new Spark 3.4 TimestampNTZType as equivalents +if CURRENT_PYSPARK_VERSION >= version.parse("3.4"): + equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore + + +@Engine.register_dtype(equivalents=equivalents) # type: ignore @immutable class Timestamp(DataType, dtypes.Timestamp): # type: ignore """Semantic representation of a :class:`pyspark.sql.types.TimestampType`.""" diff --git a/tests/pyspark/test_pyspark_dtypes.py b/tests/pyspark/test_pyspark_dtypes.py index 0b00536cf..702ce4e5e 100644 --- a/tests/pyspark/test_pyspark_dtypes.py +++ b/tests/pyspark/test_pyspark_dtypes.py @@ -1,6 +1,7 @@ """Unit tests for pyspark container.""" from typing import Any +import pyspark import pyspark.sql.types as T from pyspark.sql import DataFrame @@ -239,6 +240,18 @@ def test_pyspark_all_bytetint_types( class TestAllDatetimeTestClass(BaseClass): """This class is to test all the datetime types""" + # Include new Spark 3.4 TimestampNTZType as equivalents + ntz_equivalents = ( + [ + {"pandera_equivalent": "TimestampNTZType"}, + {"pandera_equivalent": "TimestampNTZType()"}, + {"pandera_equivalent": T.TimestampNTZType}, + {"pandera_equivalent": T.TimestampNTZType()}, + ] + if pyspark.__version__ >= "3.4" + else [] + ) + # a map specifying multiple argument sets for a test method params = { "test_pyspark_all_date_types": [ @@ -253,7 +266,8 @@ class TestAllDatetimeTestClass(BaseClass): {"pandera_equivalent": T.TimestampType()}, {"pandera_equivalent": "datetime"}, {"pandera_equivalent": "timestamp"}, - ], + ] + + ntz_equivalents, } def test_pyspark_all_date_types(