Skip to content

Commit

Permalink
improve code coverage through new test file for decorators
Browse files Browse the repository at this point in the history
Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck committed Nov 10, 2023
1 parent 74f6c33 commit dc652e8
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 40 deletions.
Binary file modified docs/.DS_Store
Binary file not shown.
42 changes: 2 additions & 40 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""This module is to test the behaviour change based on defined config in pandera"""
# pylint:disable=import-outside-toplevel,abstract-method

import logging
import pyspark.sql.types as T
from pyspark.sql import DataFrame
import pytest

from pandera.config import CONFIG, ValidationDepth
Expand Down Expand Up @@ -338,27 +336,13 @@ class TestSchema(DataFrameModel):
== expected_dataframemodel["SCHEMA"]
)

@pytest.mark.parametrize(
"cache_enabled,unpersist_enabled,"
"expected_caching_message,expected_unpersisting_message",
[
(True, True, "Caching dataframe...", "Unpersisting dataframe..."),
(True, False, "Caching dataframe...", ""),
(False, True, "", ""),
(False, False, "", ""),
],
scope="function",
)
@pytest.mark.parametrize("cache_enabled", [True, False])
@pytest.mark.parametrize("unpersist_enabled", [True, False])
# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
self,
spark,
sample_spark_schema,
cache_enabled,
unpersist_enabled,
expected_caching_message,
expected_unpersisting_message,
caplog,
):
"""This function validates that caching/unpersisting works as expected."""
# Set expected properties in Config object
Expand All @@ -373,25 +357,3 @@ def test_pyspark_cache_settings(
"pyspark_unpersist": unpersist_enabled,
}
assert CONFIG.dict() == expected

# Prepare test data
input_df = spark_df(spark, self.sample_data, sample_spark_schema)
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
}
)

# Capture log message
with caplog.at_level(logging.DEBUG, logger="pandera"):
df_out = pandera_schema.validate(input_df)

# Assertions
assert isinstance(df_out, DataFrame)
assert (
expected_caching_message in caplog.text
), "Debugging info has no information about caching the dataframe."
assert (
expected_unpersisting_message in caplog.text
), "Debugging info has no information about unpersisting the dataframe."
113 changes: 113 additions & 0 deletions tests/pyspark/test_pyspark_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""This module is to test the behaviour change based on defined config in pandera"""
# pylint:disable=import-outside-toplevel,abstract-method

from contextlib import nullcontext as does_not_raise
import logging
import pyspark.sql.types as T
from pyspark.sql import DataFrame
import pytest

from pandera.backends.pyspark.decorators import cache_check_obj
from pandera.config import CONFIG
from pandera.pyspark import (
Check,
DataFrameSchema,
Column,
)
from tests.pyspark.conftest import spark_df


class TestPanderaDecorators:
"""Class to test all the different configs types"""

sample_data = [("Bread", 9), ("Cutter", 15)]

def test_pyspark_cache_requirements(self, spark, sample_spark_schema):
"""Validates if decorator can only be applied in a proper function."""
# Set expected properties in Config object
CONFIG.pyspark_cache = True
input_df = spark_df(spark, self.sample_data, sample_spark_schema)

class FakeDataFrameSchemaBackend:
"""Class that simulates DataFrameSchemaBackend class."""

@cache_check_obj()
def func_w_check_obj(self, check_obj: DataFrame):
"""Right function to use this decorator."""
return check_obj.columns

@cache_check_obj()
def func_wo_check_obj(self, message: str):
"""Wrong function to use this decorator."""
return message

# Check for a function that does have a `check_obj`
with does_not_raise():
instance = FakeDataFrameSchemaBackend()
_ = instance.func_w_check_obj(check_obj=input_df)

# Check for a wrong function, that does not have a `check_obj`
with pytest.raises(KeyError):
instance = FakeDataFrameSchemaBackend()
_ = instance.func_wo_check_obj("wrong")

@pytest.mark.parametrize(
"cache_enabled,unpersist_enabled,"
"expected_caching_message,expected_unpersisting_message",
[
(True, True, True, True),
(True, False, True, None),
(False, True, None, None),
(False, False, None, None),
],
scope="function",
)

# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
self,
spark,
sample_spark_schema,
cache_enabled,
unpersist_enabled,
expected_caching_message,
expected_unpersisting_message,
caplog,
):
"""This function validates that caching/unpersisting works as expected."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_unpersist = unpersist_enabled

# Prepare test data
input_df = spark_df(spark, self.sample_data, sample_spark_schema)
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
}
)

# Capture log message
with caplog.at_level(logging.DEBUG, logger="pandera"):
df_out = pandera_schema.validate(input_df)

# Assertions
assert isinstance(df_out, DataFrame)
if expected_caching_message:
assert (
"Caching dataframe..." in caplog.text
), "Debugging info has no information about caching the dataframe."
else:
assert (
"Caching dataframe..." not in caplog.text
), "Debugging info has information about caching. It shouldn't."

if expected_unpersisting_message:
assert (
"Unpersisting dataframe..." in caplog.text
), "Debugging info has no information about unpersisting the dataframe."
else:
assert (
"Unpersisting dataframe..." not in caplog.text
), "Debugging info has information about unpersisting. It shouldn't."

0 comments on commit dc652e8

Please sign in to comment.