Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PySpark] Improve validation performance by enabling cache()/unpersist() toggles #1414

Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
dask-worker-space
spark-warehouse
docs/source/_contents
**.DS_Store
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignoring MacOS specific files


# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Binary file removed docs/.DS_Store
filipeo2-mck marked this conversation as resolved.
Outdated
Show resolved Hide resolved
Binary file not shown.
40 changes: 40 additions & 0 deletions docs/source/pyspark_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,46 @@ By default, validations are enabled and depth is set to ``SCHEMA_AND_DATA`` whic
can be changed to ``SCHEMA_ONLY`` or ``DATA_ONLY`` as required by the use case.


Caching control
---------------

*new in 0.17.3*

Given Spark's architecture and Pandera's internal implementation of PySpark integration
that relies on filtering conditions and *count* commands,
the PySpark DataFrame being validated by a Pandera schema may be reprocessed
multiple times, as each *count* command triggers a new underlying *Spark action*.
This processing overhead is directly related to the amount of *schema* and *data* checks
added to the Pandera schema.

To avoid such reprocessing time, Pandera allows you to cache the PySpark DataFrame
before validation starts, through the use of two environment variables:

.. code-block:: bash

export PANDERA_PYSPARK_CACHE=True # Defaults to False, do not `cache()` by default
export PANDERA_PYSPARK_UNPERSIST=False # Defaults to True, `unpersist()` by default

The first controls if current DataFrame state should be cached in your Spark Session
before the validation starts. The second controls if such cached state should still be
kept after the validation ends.

.. note::

To cache or not is a trade-off analysis: if you have enough memory to keep
the dataframe cached, it will speed up the validation timings as the validation
process will make use of this cached state.

Keeping the cached state and opting for not throwing it away when the
validation ends is important when the Pandera validation of a dataset is not
an individual process, but one step of the pipeline: if you have a pipeline that,
in a single Spark session, uses Pandera to evaluate all input dataframes before
transforming them in an result that will be written to disk, it may make sense
to not throw away the cached states of the inputs. In the end, the already
processed states of these dataframes will still be used after the validation ends
and storing them in memory may be beneficial.


Registering Custom Checks
-------------------------

Expand Down
7 changes: 6 additions & 1 deletion pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler
from pandera.api.pyspark.types import is_table
from pandera.backends.pyspark.base import ColumnInfo, PysparkSchemaBackend
from pandera.backends.pyspark.decorators import ValidationScope, validate_scope
from pandera.backends.pyspark.decorators import (

Check warning on line 14 in pandera/backends/pyspark/container.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/container.py#L14

Added line #L14 was not covered by tests
ValidationScope,
validate_scope,
cache_check_obj,
)
from pandera.backends.pyspark.error_formatters import scalar_failure_case
from pandera.config import CONFIG
from pandera.errors import (
Expand Down Expand Up @@ -102,6 +106,7 @@

return check_obj

@cache_check_obj()

Check warning on line 109 in pandera/backends/pyspark/container.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/container.py#L109

Added line #L109 was not covered by tests
filipeo2-mck marked this conversation as resolved.
Show resolved Hide resolved
def validate(
self,
check_obj: DataFrame,
Expand Down
77 changes: 74 additions & 3 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""This module holds the decorators only valid for pyspark"""

import functools
import logging

Check warning on line 4 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L4

Added line #L4 was not covered by tests
import warnings
from contextlib import contextmanager

Check warning on line 6 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L6

Added line #L6 was not covered by tests
from enum import Enum
from typing import List, Type

import pyspark.sql

from pyspark.sql import DataFrame

Check warning on line 10 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L10

Added line #L10 was not covered by tests
from pandera.api.pyspark.types import PysparkDefaultTypes
from pandera.config import CONFIG, ValidationDepth
from pandera.errors import SchemaError

logger = logging.getLogger(__name__)

Check warning on line 15 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L15

Added line #L15 was not covered by tests


class ValidationScope(Enum):
"""Indicates whether a check/validator operates at a schema of data level."""
Expand Down Expand Up @@ -90,7 +93,7 @@
"""
if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
if isinstance(value, DataFrame):

Check warning on line 96 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L96

Added line #L96 was not covered by tests
return value

if scope == ValidationScope.SCHEMA:
Expand Down Expand Up @@ -126,3 +129,71 @@
return wrapper

return _wrapper


def cache_check_obj():

Check warning on line 134 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L134

Added line #L134 was not covered by tests
"""This decorator evaluates if `check_obj` should be cached before validation.

As each new data check added to the Pandera schema by the user triggers a new
Spark action, Spark reprocesses the `check_obj` DataFrame multiple times.
To prevent this waste of processing resources and to reduce validation times in
complex scenarios, the decorator created by this factory caches the `check_obj`
DataFrame before validation and unpersists it afterwards.

This decorator is meant to be used primarily in the `validate()` function
entrypoint.

The behavior of the resulting decorator depends on the `PANDERA_PYSPARK_CACHING` and
`PANDERA_PYSPARK_UNPERSIST` (optional) environment variables.

Usage:
@cache_check_obj()
def validate(check_obj: DataFrame):
# ...
"""

def _wrapper(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):

Check warning on line 157 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L155-L157

Added lines #L155 - L157 were not covered by tests
# Skip if not enabled
if CONFIG.pyspark_cache is not True:
return func(self, *args, **kwargs)

Check warning on line 160 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L159-L160

Added lines #L159 - L160 were not covered by tests

check_obj: DataFrame = None

Check warning on line 162 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L162

Added line #L162 was not covered by tests

# Check if decorated function has a dataframe object as an positional arg
for arg in args:
if isinstance(arg, DataFrame):
check_obj = arg
break

Check warning on line 168 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L165-L168

Added lines #L165 - L168 were not covered by tests

# If it doesn't exist, fallback to kwargs and search for a `check_obj` key
if check_obj is None:
check_obj = kwargs.get("check_obj", None)

Check warning on line 172 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L171-L172

Added lines #L171 - L172 were not covered by tests

if not isinstance(check_obj, DataFrame):
raise ValueError(

Check warning on line 175 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L174-L175

Added lines #L174 - L175 were not covered by tests
"Expected to find a DataFrame object in a arg or a `check_obj` "
"kwarg in the decorated function "
f"`{func.__name__}`. Got {args=}/{kwargs=}"
)

@contextmanager
def cached_check_obj():

Check warning on line 182 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L181-L182

Added lines #L181 - L182 were not covered by tests
"""Cache the dataframe and unpersist it after function execution."""
logger.debug("Caching dataframe...")
check_obj.cache()

Check warning on line 185 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L184-L185

Added lines #L184 - L185 were not covered by tests

yield # Execute the decorated function

Check warning on line 187 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L187

Added line #L187 was not covered by tests

if CONFIG.pyspark_unpersist:

Check warning on line 189 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L189

Added line #L189 was not covered by tests
# If not cached, `.unpersist()` does nothing
logger.debug("Unpersisting dataframe...")
check_obj.unpersist()

Check warning on line 192 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L191-L192

Added lines #L191 - L192 were not covered by tests

with cached_check_obj():
return func(self, *args, **kwargs)

Check warning on line 195 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L194-L195

Added lines #L194 - L195 were not covered by tests

return wrapper

Check warning on line 197 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L197

Added line #L197 was not covered by tests

return _wrapper

Check warning on line 199 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L199

Added line #L199 was not covered by tests
12 changes: 12 additions & 0 deletions pandera/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
This should pick up environment variables automatically, e.g.:
export PANDERA_VALIDATION_ENABLED=False
export PANDERA_VALIDATION_DEPTH=DATA_ONLY
export PANDERA_PYSPARK_CACHE=True
export PANDERA_PYSPARK_UNPERSIST=False
"""

validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA
pyspark_cache: bool = False
pyspark_unpersist: bool = True

Check warning on line 30 in pandera/config.py

View check run for this annotation

Codecov / codecov/patch

pandera/config.py#L29-L30

Added lines #L29 - L30 were not covered by tests


# this config variable should be accessible globally
Expand All @@ -35,4 +39,12 @@
validation_depth=os.environ.get(
"PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA
),
pyspark_cache=os.environ.get(
"PANDERA_PYSPARK_CACHE",
False,
),
pyspark_unpersist=os.environ.get(
"PANDERA_PYSPARK_UNPERSIST",
True,
),
)
43 changes: 37 additions & 6 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint:disable=import-outside-toplevel,abstract-method

import pyspark.sql.types as T
import pytest

from pandera.config import CONFIG, ValidationDepth
from pandera.pyspark import (
Expand All @@ -24,7 +25,7 @@ def test_disable_validation(self, spark, sample_spark_schema):

CONFIG.validation_enabled = False

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
filipeo2-mck marked this conversation as resolved.
Show resolved Hide resolved
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -41,10 +42,12 @@ class TestSchema(DataFrameModel):
expected = {
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_unpersist": True,
}

assert CONFIG.dict() == expected
assert pandra_schema.validate(input_df)
assert pandera_schema.validate(input_df)
assert TestSchema.validate(input_df)

# pylint:disable=too-many-locals
Expand All @@ -63,6 +66,8 @@ def test_schema_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_ONLY,
"pyspark_cache": False,
"pyspark_unpersist": True,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -132,7 +137,7 @@ def test_data_only(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.DATA_ONLY

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -141,11 +146,13 @@ def test_data_only(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.DATA_ONLY,
"pyspark_cache": False,
"pyspark_unpersist": True,
}
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"DATA": {
"DATAFRAME_CHECK": [
Expand Down Expand Up @@ -217,7 +224,7 @@ def test_schema_and_data(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.SCHEMA_AND_DATA

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -226,11 +233,13 @@ def test_schema_and_data(self, spark, sample_spark_schema):
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_unpersist": True,
}
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"DATA": {
"DATAFRAME_CHECK": [
Expand Down Expand Up @@ -326,3 +335,25 @@ class TestSchema(DataFrameModel):
dict(output_dataframemodel_df.pandera.errors["SCHEMA"])
== expected_dataframemodel["SCHEMA"]
)

@pytest.mark.parametrize("cache_enabled", [True, False])
@pytest.mark.parametrize("unpersist_enabled", [True, False])
# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
filipeo2-mck marked this conversation as resolved.
Show resolved Hide resolved
self,
cache_enabled,
unpersist_enabled,
):
"""This function validates setter and getters of caching/unpersisting options."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_unpersist = unpersist_enabled

# Evaluate expected Config
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": cache_enabled,
"pyspark_unpersist": unpersist_enabled,
}
assert CONFIG.dict() == expected
Loading