diff --git a/pandera/backends/pyspark/column.py b/pandera/backends/pyspark/column.py index 7c0ac168e..52c3081c1 100644 --- a/pandera/backends/pyspark/column.py +++ b/pandera/backends/pyspark/column.py @@ -125,10 +125,18 @@ def coerce_dtype( @validate_scope(scope=ValidationScope.SCHEMA) def check_nullable(self, check_obj: DataFrame, schema): - isna = ( - check_obj.filter(col(schema.name).isNull()).limit(1).count() == 0 - ) - passed = schema.nullable or isna + passed = True + + # Use schema level information to optimize execution of the `nullable` check: + # ignore this check if Pandera Field's `nullable` property is True + # (check not necessary) or if df column's `nullable` property is False + # (PySpark's nullable ensures the presence of values when creating the df) + if (not schema.nullable) and (check_obj.schema[schema.name].nullable): + passed = ( + check_obj.filter(col(schema.name).isNull()).limit(1).count() + == 0 + ) + return CoreCheckResult( check="not_nullable", reason_code=SchemaErrorReason.SERIES_CONTAINS_NULLS, diff --git a/tests/pyspark/test_pyspark_container.py b/tests/pyspark/test_pyspark_container.py index 6d9f65d89..87243f8d4 100644 --- a/tests/pyspark/test_pyspark_container.py +++ b/tests/pyspark/test_pyspark_container.py @@ -1,5 +1,6 @@ """Unit tests for pyspark container.""" +from contextlib import nullcontext as does_not_raise from pyspark.sql import DataFrame, SparkSession import pyspark.sql.types as T import pytest @@ -142,7 +143,7 @@ def test_pyspark_sample(): ("Butter", 15), ("Ice Cream", 10), ("Cola", 12), - ("Choclate", 7), + ("Chocolate", 7), ] spark_schema = T.StructType( @@ -185,3 +186,48 @@ def test_pyspark_regex_column(): df_out = schema.validate(df2) assert not df_out.pandera.errors + + +def test_pyspark_nullable(): + """ + Test the nullable functionality of pyspark + """ + + data = [ + ("Bread", 9), + ("Butter", 15), + ("Ice Cream", None), + ("Cola", 12), + ("Chocolate", None), + ] + spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), True), + ], + ) + df = spark.createDataFrame(data=data, schema=spark_schema) + + # Check for `nullable=False` + schema_nullable_false = DataFrameSchema( + columns={ + "product": Column("str"), + "price": Column("int", nullable=False), + }, + ) + with does_not_raise(): + df_out = schema_nullable_false.validate(df) + assert isinstance(df_out, DataFrame) + assert "SERIES_CONTAINS_NULLS" in str(dict(df_out.pandera.errors)) + + # Check for `nullable=True` + schema_nullable_true = DataFrameSchema( + columns={ + "product": Column("str"), + "price": Column("int", nullable=True), + }, + ) + with does_not_raise(): + df_out = schema_nullable_true.validate(df) + assert isinstance(df_out, DataFrame) + assert df_out.pandera.errors == {}