diff --git a/mlpp_lib/normalizers.py b/mlpp_lib/normalizers.py index 4b87cb4..d501b3c 100644 --- a/mlpp_lib/normalizers.py +++ b/mlpp_lib/normalizers.py @@ -30,8 +30,8 @@ class DataTransformer: def __init__( self, method_vars_dict: Optional[dict[str, list[str]]] = None, - default: Optional[str] = "Standardizer", - fillvalue: float = -5, + default: str = "Standardizer", + fillvalue: Optional[float] = -5, ): self.method_vars_dict = method_vars_dict diff --git a/tests/test_train.py b/tests/test_train.py index c32b7c9..d6b99f7 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -156,7 +156,7 @@ def test_train_fromfile(tmp_path, cfg): splitter_options = ValidDataSplitterOptions(time="lists", station="lists") datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split) - datanormalizer = DataTransformer(cfg["normalizer"]) + datanormalizer = DataTransformer(**["normalizer"]) batch_dims = ["forecast_reference_time", "t", "station"] datamodule = DataModule( features=cfg["features"],