Skip to content

Commit

Permalink
Fix lack of support for new TimestampNTZType in Spark 3.4 datatypes (#…
Browse files Browse the repository at this point in the history
…1385)

* add TimestampNTZType as equivalents and add parameters to test case

Signed-off-by: Filipe Oliveira <[email protected]>

* parse version to improve robustness

Signed-off-by: Filipe Oliveira <[email protected]>

---------

Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck authored Oct 19, 2023
1 parent 4425ad8 commit 0c48778
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
15 changes: 12 additions & 3 deletions pandera/engines/pyspark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import warnings
from typing import Any, Iterable, Union, Optional
import sys
from packaging import version

import pyspark
import pyspark.sql.types as pst

from pandera import dtypes, errors
Expand All @@ -32,6 +34,8 @@
DEFAULT_PYSPARK_PREC = pst.DecimalType().precision
DEFAULT_PYSPARK_SCALE = pst.DecimalType().scale

CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__)


@immutable(init=True)
class DataType(dtypes.DataType):
Expand Down Expand Up @@ -341,10 +345,15 @@ class Date(DataType, dtypes.Date): # type: ignore
# timestamp
###############################################################################

# Default timestamp equivalents
equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore

@Engine.register_dtype(
equivalents=["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore
)
# Include new Spark 3.4 TimestampNTZType as equivalents
if CURRENT_PYSPARK_VERSION >= version.parse("3.4"):
equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore


@Engine.register_dtype(equivalents=equivalents) # type: ignore
@immutable
class Timestamp(DataType, dtypes.Timestamp): # type: ignore
"""Semantic representation of a :class:`pyspark.sql.types.TimestampType`."""
Expand Down
16 changes: 15 additions & 1 deletion tests/pyspark/test_pyspark_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for pyspark container."""

from typing import Any
import pyspark
import pyspark.sql.types as T
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -239,6 +240,18 @@ def test_pyspark_all_bytetint_types(
class TestAllDatetimeTestClass(BaseClass):
"""This class is to test all the datetime types"""

# Include new Spark 3.4 TimestampNTZType as equivalents
ntz_equivalents = (
[
{"pandera_equivalent": "TimestampNTZType"},
{"pandera_equivalent": "TimestampNTZType()"},
{"pandera_equivalent": T.TimestampNTZType},
{"pandera_equivalent": T.TimestampNTZType()},
]
if pyspark.__version__ >= "3.4"
else []
)

# a map specifying multiple argument sets for a test method
params = {
"test_pyspark_all_date_types": [
Expand All @@ -253,7 +266,8 @@ class TestAllDatetimeTestClass(BaseClass):
{"pandera_equivalent": T.TimestampType()},
{"pandera_equivalent": "datetime"},
{"pandera_equivalent": "timestamp"},
],
]
+ ntz_equivalents,
}

def test_pyspark_all_date_types(
Expand Down

0 comments on commit 0c48778

Please sign in to comment.