diff --git a/tests/pyspark/test_pyspark_container.py b/tests/pyspark/test_pyspark_container.py index 6d9f65d89..87243f8d4 100644 --- a/tests/pyspark/test_pyspark_container.py +++ b/tests/pyspark/test_pyspark_container.py @@ -1,5 +1,6 @@ """Unit tests for pyspark container.""" +from contextlib import nullcontext as does_not_raise from pyspark.sql import DataFrame, SparkSession import pyspark.sql.types as T import pytest @@ -142,7 +143,7 @@ def test_pyspark_sample(): ("Butter", 15), ("Ice Cream", 10), ("Cola", 12), - ("Choclate", 7), + ("Chocolate", 7), ] spark_schema = T.StructType( @@ -185,3 +186,48 @@ def test_pyspark_regex_column(): df_out = schema.validate(df2) assert not df_out.pandera.errors + + +def test_pyspark_nullable(): + """ + Test the nullable functionality of pyspark + """ + + data = [ + ("Bread", 9), + ("Butter", 15), + ("Ice Cream", None), + ("Cola", 12), + ("Chocolate", None), + ] + spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), True), + ], + ) + df = spark.createDataFrame(data=data, schema=spark_schema) + + # Check for `nullable=False` + schema_nullable_false = DataFrameSchema( + columns={ + "product": Column("str"), + "price": Column("int", nullable=False), + }, + ) + with does_not_raise(): + df_out = schema_nullable_false.validate(df) + assert isinstance(df_out, DataFrame) + assert "SERIES_CONTAINS_NULLS" in str(dict(df_out.pandera.errors)) + + # Check for `nullable=True` + schema_nullable_true = DataFrameSchema( + columns={ + "product": Column("str"), + "price": Column("int", nullable=True), + }, + ) + with does_not_raise(): + df_out = schema_nullable_true.validate(df) + assert isinstance(df_out, DataFrame) + assert df_out.pandera.errors == {}