diff --git a/pandera/engines/pyspark_engine.py b/pandera/engines/pyspark_engine.py index a9668402b..4e48d28da 100644 --- a/pandera/engines/pyspark_engine.py +++ b/pandera/engines/pyspark_engine.py @@ -14,6 +14,7 @@ from typing import Any, Iterable, Union, Optional import sys +import pyspark import pyspark.sql.types as pst from pandera import dtypes, errors @@ -341,10 +342,15 @@ class Date(DataType, dtypes.Date): # type: ignore # timestamp ############################################################################### +# Default timestamp equivalents +equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore -@Engine.register_dtype( - equivalents=["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore -) +# Include new Spark 3.4 TimestampNTZType as equivalents +if pyspark.__version__ >= "3.4": + equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore + + +@Engine.register_dtype(equivalents=equivalents) # type: ignore @immutable class Timestamp(DataType, dtypes.Timestamp): # type: ignore """Semantic representation of a :class:`pyspark.sql.types.TimestampType`.""" diff --git a/tests/pyspark/test_pyspark_dtypes.py b/tests/pyspark/test_pyspark_dtypes.py index 0b00536cf..702ce4e5e 100644 --- a/tests/pyspark/test_pyspark_dtypes.py +++ b/tests/pyspark/test_pyspark_dtypes.py @@ -1,6 +1,7 @@ """Unit tests for pyspark container.""" from typing import Any +import pyspark import pyspark.sql.types as T from pyspark.sql import DataFrame @@ -239,6 +240,18 @@ def test_pyspark_all_bytetint_types( class TestAllDatetimeTestClass(BaseClass): """This class is to test all the datetime types""" + # Include new Spark 3.4 TimestampNTZType as equivalents + ntz_equivalents = ( + [ + {"pandera_equivalent": "TimestampNTZType"}, + {"pandera_equivalent": "TimestampNTZType()"}, + {"pandera_equivalent": T.TimestampNTZType}, + {"pandera_equivalent": T.TimestampNTZType()}, + ] + if pyspark.__version__ >= "3.4" + else [] + ) + # a map specifying multiple argument sets for a test method params = { "test_pyspark_all_date_types": [ @@ -253,7 +266,8 @@ class TestAllDatetimeTestClass(BaseClass): {"pandera_equivalent": T.TimestampType()}, {"pandera_equivalent": "datetime"}, {"pandera_equivalent": "timestamp"}, - ], + ] + + ntz_equivalents, } def test_pyspark_all_date_types(