From 15b6d09bb6153f5205f45e09b513e92c9167c56e Mon Sep 17 00:00:00 2001 From: ned Date: Sat, 31 Aug 2024 17:53:17 +0200 Subject: [PATCH] Minor fixes --- mlpp_lib/normalizers.py | 4 ++-- tests/test_train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlpp_lib/normalizers.py b/mlpp_lib/normalizers.py index 4b87cb4..d501b3c 100644 --- a/mlpp_lib/normalizers.py +++ b/mlpp_lib/normalizers.py @@ -30,8 +30,8 @@ class DataTransformer: def __init__( self, method_vars_dict: Optional[dict[str, list[str]]] = None, - default: Optional[str] = "Standardizer", - fillvalue: float = -5, + default: str = "Standardizer", + fillvalue: Optional[float] = -5, ): self.method_vars_dict = method_vars_dict diff --git a/tests/test_train.py b/tests/test_train.py index c32b7c9..d6b99f7 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -156,7 +156,7 @@ def test_train_fromfile(tmp_path, cfg): splitter_options = ValidDataSplitterOptions(time="lists", station="lists") datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split) - datanormalizer = DataTransformer(cfg["normalizer"]) + datanormalizer = DataTransformer(**["normalizer"]) batch_dims = ["forecast_reference_time", "t", "station"] datamodule = DataModule( features=cfg["features"],