From 32f8c41c2fa9d048b7f66ad3ae68445a097b4950 Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Wed, 18 Oct 2023 09:55:21 -0300 Subject: [PATCH 1/2] add TimestampNTZType as equivalents and add parameters to test case Signed-off-by: Filipe Oliveira --- pandera/engines/pyspark_engine.py | 12 +++++++++--- tests/pyspark/test_pyspark_dtypes.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pandera/engines/pyspark_engine.py b/pandera/engines/pyspark_engine.py index a9668402b..4e48d28da 100644 --- a/pandera/engines/pyspark_engine.py +++ b/pandera/engines/pyspark_engine.py @@ -14,6 +14,7 @@ from typing import Any, Iterable, Union, Optional import sys +import pyspark import pyspark.sql.types as pst from pandera import dtypes, errors @@ -341,10 +342,15 @@ class Date(DataType, dtypes.Date): # type: ignore # timestamp ############################################################################### +# Default timestamp equivalents +equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore -@Engine.register_dtype( - equivalents=["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore -) +# Include new Spark 3.4 TimestampNTZType as equivalents +if pyspark.__version__ >= "3.4": + equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore + + +@Engine.register_dtype(equivalents=equivalents) # type: ignore @immutable class Timestamp(DataType, dtypes.Timestamp): # type: ignore """Semantic representation of a :class:`pyspark.sql.types.TimestampType`.""" diff --git a/tests/pyspark/test_pyspark_dtypes.py b/tests/pyspark/test_pyspark_dtypes.py index 0b00536cf..702ce4e5e 100644 --- a/tests/pyspark/test_pyspark_dtypes.py +++ b/tests/pyspark/test_pyspark_dtypes.py @@ -1,6 +1,7 @@ """Unit tests for pyspark container.""" from typing import Any +import pyspark import pyspark.sql.types as T from pyspark.sql import DataFrame @@ -239,6 +240,18 @@ def test_pyspark_all_bytetint_types( class TestAllDatetimeTestClass(BaseClass): """This class is to test all the datetime types""" + # Include new Spark 3.4 TimestampNTZType as equivalents + ntz_equivalents = ( + [ + {"pandera_equivalent": "TimestampNTZType"}, + {"pandera_equivalent": "TimestampNTZType()"}, + {"pandera_equivalent": T.TimestampNTZType}, + {"pandera_equivalent": T.TimestampNTZType()}, + ] + if pyspark.__version__ >= "3.4" + else [] + ) + # a map specifying multiple argument sets for a test method params = { "test_pyspark_all_date_types": [ @@ -253,7 +266,8 @@ class TestAllDatetimeTestClass(BaseClass): {"pandera_equivalent": T.TimestampType()}, {"pandera_equivalent": "datetime"}, {"pandera_equivalent": "timestamp"}, - ], + ] + + ntz_equivalents, } def test_pyspark_all_date_types( From 0dce486ff2743d472b6e9038d7469f6c482ce5cf Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Wed, 18 Oct 2023 16:30:04 -0300 Subject: [PATCH 2/2] parse version to improve robustness Signed-off-by: Filipe Oliveira --- pandera/engines/pyspark_engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pandera/engines/pyspark_engine.py b/pandera/engines/pyspark_engine.py index 4e48d28da..c682b1b63 100644 --- a/pandera/engines/pyspark_engine.py +++ b/pandera/engines/pyspark_engine.py @@ -13,6 +13,7 @@ import warnings from typing import Any, Iterable, Union, Optional import sys +from packaging import version import pyspark import pyspark.sql.types as pst @@ -33,6 +34,8 @@ DEFAULT_PYSPARK_PREC = pst.DecimalType().precision DEFAULT_PYSPARK_SCALE = pst.DecimalType().scale +CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__) + @immutable(init=True) class DataType(dtypes.DataType): @@ -346,7 +349,7 @@ class Date(DataType, dtypes.Date): # type: ignore equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore # Include new Spark 3.4 TimestampNTZType as equivalents -if pyspark.__version__ >= "3.4": +if CURRENT_PYSPARK_VERSION >= version.parse("3.4"): equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore