Skip to content

Commit

Permalink
[BUGFIX] [PYSPARK] Avoid running nullable checks if nullable=True (#…
Browse files Browse the repository at this point in the history
…1403)

* avoid running nullable checks if nullable=true

Signed-off-by: Filipe Oliveira <[email protected]>

* add corresponding test cases for nullable fields

Signed-off-by: Filipe Oliveira <[email protected]>

* check schema-level information from both pyspark df and pandera shcema before applying nullable check

Signed-off-by: Filipe Oliveira <[email protected]>

---------

Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck authored Nov 10, 2023
1 parent 2a5257b commit 58a3309
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
16 changes: 12 additions & 4 deletions pandera/backends/pyspark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,18 @@ def coerce_dtype(

@validate_scope(scope=ValidationScope.SCHEMA)
def check_nullable(self, check_obj: DataFrame, schema):
isna = (
check_obj.filter(col(schema.name).isNull()).limit(1).count() == 0
)
passed = schema.nullable or isna
passed = True

# Use schema level information to optimize execution of the `nullable` check:
# ignore this check if Pandera Field's `nullable` property is True
# (check not necessary) or if df column's `nullable` property is False
# (PySpark's nullable ensures the presence of values when creating the df)
if (not schema.nullable) and (check_obj.schema[schema.name].nullable):
passed = (
check_obj.filter(col(schema.name).isNull()).limit(1).count()
== 0
)

return CoreCheckResult(
check="not_nullable",
reason_code=SchemaErrorReason.SERIES_CONTAINS_NULLS,
Expand Down
48 changes: 47 additions & 1 deletion tests/pyspark/test_pyspark_container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for pyspark container."""

from contextlib import nullcontext as does_not_raise
from pyspark.sql import DataFrame, SparkSession
import pyspark.sql.types as T
import pytest
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_pyspark_sample():
("Butter", 15),
("Ice Cream", 10),
("Cola", 12),
("Choclate", 7),
("Chocolate", 7),
]

spark_schema = T.StructType(
Expand Down Expand Up @@ -185,3 +186,48 @@ def test_pyspark_regex_column():
df_out = schema.validate(df2)

assert not df_out.pandera.errors


def test_pyspark_nullable():
"""
Test the nullable functionality of pyspark
"""

data = [
("Bread", 9),
("Butter", 15),
("Ice Cream", None),
("Cola", 12),
("Chocolate", None),
]
spark_schema = T.StructType(
[
T.StructField("product", T.StringType(), False),
T.StructField("price", T.IntegerType(), True),
],
)
df = spark.createDataFrame(data=data, schema=spark_schema)

# Check for `nullable=False`
schema_nullable_false = DataFrameSchema(
columns={
"product": Column("str"),
"price": Column("int", nullable=False),
},
)
with does_not_raise():
df_out = schema_nullable_false.validate(df)
assert isinstance(df_out, DataFrame)
assert "SERIES_CONTAINS_NULLS" in str(dict(df_out.pandera.errors))

# Check for `nullable=True`
schema_nullable_true = DataFrameSchema(
columns={
"product": Column("str"),
"price": Column("int", nullable=True),
},
)
with does_not_raise():
df_out = schema_nullable_true.validate(df)
assert isinstance(df_out, DataFrame)
assert df_out.pandera.errors == {}

0 comments on commit 58a3309

Please sign in to comment.