diff --git a/mlpp_lib/model_selection.py b/mlpp_lib/model_selection.py index b9a279a..181577d 100644 --- a/mlpp_lib/model_selection.py +++ b/mlpp_lib/model_selection.py @@ -197,12 +197,10 @@ def _time_partitioning(self) -> None: self._time_indexers = self.time_split else: self._time_indexers = {} - if all( - [isinstance(v, float) for v in self.time_split.values()] - ): # only fractions + if all([isinstance(v, float) for v in self.time_split.values()]): res = self._time_partition_method(self.time_split) self._time_indexers.update(res) - else: # mixed + else: # mixed fractions and labels _time_split = self.time_split.copy() self._time_indexers.update({"test": _time_split.pop("test")}) res = self._time_partition_method(_time_split) @@ -222,8 +220,11 @@ def _time_partitioning(self) -> None: self.partitions[partition].update(indexer) def _time_partition_method(self, fractions: Mapping[str, float]): + time_index = [ + t for t in self.time_index if t not in self._time_indexers.get("test", []) + ] if self.time_split_method == "sequential": - return sequential_split(self.time_index, fractions) + return sequential_split(time_index, fractions) def _station_partitioning(self): """ @@ -235,7 +236,7 @@ def _station_partitioning(self): if all([isinstance(v, float) for v in self.station_split.values()]): res = self._station_partition_method(self.station_split) self._station_indexers.update(res) - else: + else: # mixed fractions and labels _station_split = self.station_split.copy() self._station_indexers.update({"test": _station_split.pop("test")}) res = self._station_partition_method(_station_split) @@ -253,10 +254,15 @@ def _station_partition_method( self, fractions: Mapping[str, float] ) -> Mapping[str, np.ndarray]: + station_index = [ + sta + for sta in self.station_index + if sta not in self._station_indexers.get("test", []) + ] if self.station_split_method == "random": - out = random_split(self.station_index, fractions, seed=self.seed) + out = random_split(station_index, fractions, seed=self.seed) elif self.station_split_method == "sequential": - out = sequential_split(self.station_index, fractions) + out = sequential_split(station_index, fractions) return out def _check_time(self, time_split: dict, time_split_method: str): @@ -312,7 +318,9 @@ def from_dict(cls, splits: dict) -> Self: def to_dict(self, sort_values=False): if not hasattr(self, "time_index") or not hasattr(self, "station_index"): - raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.") + raise ValueError( + "DataSplitter wasn't applied on any data yet, run `fit` first." + ) if not hasattr(self, "partitions"): self._time_partitioning() self._station_partitioning() diff --git a/tests/test_model_selection.py b/tests/test_model_selection.py index 298aa7b..18c23f2 100644 --- a/tests/test_model_selection.py +++ b/tests/test_model_selection.py @@ -8,6 +8,39 @@ import mlpp_lib.model_selection as ms +def check_splits(splits: dict): + """ + Check if the data splits are valid. + + Args: + splits (dict): A dictionary containing train, val, and test splits. + Each split should contain 'station' and 'forecast_reference_time' keys. + + Raises: + AssertionError: If there are overlapping stations or forecast reference times between splits. + """ + train_stations = set(splits["train"]["station"]) + val_stations = set(splits["val"]["station"]) + test_stations = set(splits["test"]["station"]) + + train_reftimes = set(splits["train"]["forecast_reference_time"]) + val_reftimes = set(splits["val"]["forecast_reference_time"]) + test_reftimes = set(splits["test"]["forecast_reference_time"]) + + assert len(train_stations & test_stations) == 0, "Train and test stations overlap." + assert len(train_stations & val_stations) == 0, "Train and val stations overlap." + assert len(val_stations & test_stations) == 0, "Val and test stations overlap." + assert ( + len(train_reftimes & test_reftimes) == 0 + ), "Train and test forecast reference times overlap." + assert ( + len(train_reftimes & val_reftimes) == 0 + ), "Train and val forecast reference times overlap." + assert ( + len(val_reftimes & test_reftimes) == 0 + ), "Val and test forecast reference times overlap." + + @dataclass class ValidDataSplitterOptions: @@ -49,8 +82,8 @@ def time_split_lists(self): def time_split_slices(self): out = { "train": [self.reftimes[0], self.reftimes[30]], - "val": [self.reftimes[30], self.reftimes[40]], - "test": [self.reftimes[40], self.reftimes[50]], + "val": [self.reftimes[31], self.reftimes[40]], + "test": [self.reftimes[41], self.reftimes[50]], } return out @@ -67,10 +100,25 @@ class TestDataSplitter: scenarios = [ ValidDataSplitterOptions(time="fractions", station="lists"), ValidDataSplitterOptions(time="slices", station="fractions"), + ValidDataSplitterOptions(time="lists", station="fractions"), ValidDataSplitterOptions(time="lists", station="mixed"), ValidDataSplitterOptions(time="mixed", station="fractions"), ] + @pytest.mark.parametrize( + "options", scenarios, ids=ValidDataSplitterOptions.pytest_id + ) + def test_valid_split(self, options, features_dataset, targets_dataset): + splitter = ms.DataSplitter( + options.time_split, + options.station_split, + options.time_split_method, + options.station_split_method, + ) + splits = splitter.fit(features_dataset).to_dict() + + check_splits(splits) + @pytest.mark.parametrize( "options", scenarios, ids=ValidDataSplitterOptions.pytest_id )