Skip to content

Commit

Permalink
add TimestampNTZType as equivalents and add parameters to test case
Browse files Browse the repository at this point in the history
Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck committed Oct 18, 2023
1 parent ceeae10 commit 32f8c41
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
12 changes: 9 additions & 3 deletions pandera/engines/pyspark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Iterable, Union, Optional
import sys

import pyspark
import pyspark.sql.types as pst

from pandera import dtypes, errors
Expand Down Expand Up @@ -341,10 +342,15 @@ class Date(DataType, dtypes.Date): # type: ignore
# timestamp
###############################################################################

# Default timestamp equivalents
equivalents = ["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType] # type: ignore

@Engine.register_dtype(
equivalents=["datetime", "timestamp", "TimestampType", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore
)
# Include new Spark 3.4 TimestampNTZType as equivalents
if pyspark.__version__ >= "3.4":
equivalents += ["TimestampNTZType", "TimestampNTZType()", pst.TimestampNTZType, pst.TimestampNTZType()] # type: ignore


@Engine.register_dtype(equivalents=equivalents) # type: ignore
@immutable
class Timestamp(DataType, dtypes.Timestamp): # type: ignore
"""Semantic representation of a :class:`pyspark.sql.types.TimestampType`."""
Expand Down
16 changes: 15 additions & 1 deletion tests/pyspark/test_pyspark_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for pyspark container."""

from typing import Any
import pyspark
import pyspark.sql.types as T
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -239,6 +240,18 @@ def test_pyspark_all_bytetint_types(
class TestAllDatetimeTestClass(BaseClass):
"""This class is to test all the datetime types"""

# Include new Spark 3.4 TimestampNTZType as equivalents
ntz_equivalents = (
[
{"pandera_equivalent": "TimestampNTZType"},
{"pandera_equivalent": "TimestampNTZType()"},
{"pandera_equivalent": T.TimestampNTZType},
{"pandera_equivalent": T.TimestampNTZType()},
]
if pyspark.__version__ >= "3.4"
else []
)

# a map specifying multiple argument sets for a test method
params = {
"test_pyspark_all_date_types": [
Expand All @@ -253,7 +266,8 @@ class TestAllDatetimeTestClass(BaseClass):
{"pandera_equivalent": T.TimestampType()},
{"pandera_equivalent": "datetime"},
{"pandera_equivalent": "timestamp"},
],
]
+ ntz_equivalents,
}

def test_pyspark_all_date_types(
Expand Down

0 comments on commit 32f8c41

Please sign in to comment.