Skip to content

Commit

Permalink
Fix validating pyspark dataframes with regex columns (#1397)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Anthony <[email protected]>
  • Loading branch information
lexanth authored Nov 1, 2023
1 parent be1e1ae commit 37c24d9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
4 changes: 3 additions & 1 deletion pandera/api/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def get_dtypes(self, dataframe: DataFrame) -> Dict[str, DataType]:
regex_dtype.update(
{
c: column.dtype
for c in column.BACKEND.get_regex_columns(
for c in column.get_backend(
dataframe
).get_regex_columns(
column,
dataframe.columns,
)
Expand Down
8 changes: 4 additions & 4 deletions pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def collect_column_info(
if col_schema.regex:
try:
column_names.extend(
col_schema.BACKEND.get_regex_columns(
col_schema.get_backend(check_obj).get_regex_columns(
col_schema, check_obj.columns
)
)
Expand Down Expand Up @@ -462,9 +462,9 @@ def _try_coercion(obj, colname, col_schema):
for colname, col_schema in schema.columns.items():
if col_schema.regex:
try:
matched_columns = col_schema.BACKEND.get_regex_columns(
col_schema, obj.columns
)
matched_columns = col_schema.get_backend(
obj
).get_regex_columns(col_schema, obj.columns)
except SchemaError:
matched_columns = None

Expand Down
28 changes: 28 additions & 0 deletions tests/pyspark/test_pyspark_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,31 @@ def test_pyspark_sample():
df_out = schema.validate(df, sample=0.5)

assert isinstance(df_out, DataFrame)


def test_pyspark_regex_column():
"""
Test creating a pyspark DataFrameSchema object with regex columns
"""

schema = DataFrameSchema(
{
# Columns with all caps names must have string values
"[A-Z]+": Column(T.StringType(), regex=True),
}
)

data = [("Neeraj", 35), ("Jask", 30)]

df = spark.createDataFrame(data=data, schema=["NAME", "AGE"])
df_out = schema.validate(df)

assert df_out.pandera.errors is not None

data = [("Neeraj", "35"), ("Jask", "a")]

df2 = spark.createDataFrame(data=data, schema=["NAME", "AGE"])

df_out = schema.validate(df2)

assert not df_out.pandera.errors

0 comments on commit 37c24d9

Please sign in to comment.