From d7d5610c10fa1f7019b86daf5eb793183c6a15ef Mon Sep 17 00:00:00 2001 From: Alex Anthony Date: Thu, 26 Oct 2023 16:09:55 +0100 Subject: [PATCH] Fix validating pyspark dataframes with regex columns --- pandera/api/pyspark/container.py | 4 +++- pandera/backends/pyspark/container.py | 8 +++---- tests/pyspark/test_pyspark_container.py | 28 +++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/pandera/api/pyspark/container.py b/pandera/api/pyspark/container.py index 6a272f7f1..cef8675c9 100644 --- a/pandera/api/pyspark/container.py +++ b/pandera/api/pyspark/container.py @@ -234,7 +234,9 @@ def get_dtypes(self, dataframe: DataFrame) -> Dict[str, DataType]: regex_dtype.update( { c: column.dtype - for c in column.BACKEND.get_regex_columns( + for c in column.get_backend( + dataframe + ).get_regex_columns( column, dataframe.columns, ) diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 6d2ef2683..9fde2bb2d 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -252,7 +252,7 @@ def collect_column_info( if col_schema.regex: try: column_names.extend( - col_schema.BACKEND.get_regex_columns( + col_schema.get_backend(check_obj).get_regex_columns( col_schema, check_obj.columns ) ) @@ -457,9 +457,9 @@ def _try_coercion(obj, colname, col_schema): for colname, col_schema in schema.columns.items(): if col_schema.regex: try: - matched_columns = col_schema.BACKEND.get_regex_columns( - col_schema, obj.columns - ) + matched_columns = col_schema.get_backend( + obj + ).get_regex_columns(col_schema, obj.columns) except SchemaError: matched_columns = None diff --git a/tests/pyspark/test_pyspark_container.py b/tests/pyspark/test_pyspark_container.py index e0306df40..6d9f65d89 100644 --- a/tests/pyspark/test_pyspark_container.py +++ b/tests/pyspark/test_pyspark_container.py @@ -157,3 +157,31 @@ def test_pyspark_sample(): df_out = schema.validate(df, sample=0.5) assert isinstance(df_out, DataFrame) + + +def test_pyspark_regex_column(): + """ + Test creating a pyspark DataFrameSchema object with regex columns + """ + + schema = DataFrameSchema( + { + # Columns with all caps names must have string values + "[A-Z]+": Column(T.StringType(), regex=True), + } + ) + + data = [("Neeraj", 35), ("Jask", 30)] + + df = spark.createDataFrame(data=data, schema=["NAME", "AGE"]) + df_out = schema.validate(df) + + assert df_out.pandera.errors is not None + + data = [("Neeraj", "35"), ("Jask", "a")] + + df2 = spark.createDataFrame(data=data, schema=["NAME", "AGE"]) + + df_out = schema.validate(df2) + + assert not df_out.pandera.errors