Skip to content

Commit

Permalink
Merge pull request #51 from MeteoSwiss/fix/mixed-splits
Browse files Browse the repository at this point in the history
Fix mixed splits
  • Loading branch information
dnerini authored Jul 12, 2024
2 parents b31f878 + c09afd3 commit 4580984
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
34 changes: 23 additions & 11 deletions mlpp_lib/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ def _time_partitioning(self) -> None:
self._time_indexers = self.time_split
else:
self._time_indexers = {}
if all(
[isinstance(v, float) for v in self.time_split.values()]
): # only fractions
if all([isinstance(v, float) for v in self.time_split.values()]):
res = self._time_partition_method(self.time_split)
self._time_indexers.update(res)
else: # mixed
else: # mixed fractions and labels
_time_split = self.time_split.copy()
self._time_indexers.update({"test": _time_split.pop("test")})
test_indexers = _time_split.pop("test")
test_indexers = [t for t in test_indexers if t in self.time_index]
self._time_indexers.update({"test": test_indexers})
res = self._time_partition_method(_time_split)
self._time_indexers.update(res)

Expand All @@ -222,8 +222,11 @@ def _time_partitioning(self) -> None:
self.partitions[partition].update(indexer)

def _time_partition_method(self, fractions: Mapping[str, float]):
time_index = [
t for t in self.time_index if t not in self._time_indexers.get("test", [])
]
if self.time_split_method == "sequential":
return sequential_split(self.time_index, fractions)
return sequential_split(time_index, fractions)

def _station_partitioning(self):
"""
Expand All @@ -235,9 +238,11 @@ def _station_partitioning(self):
if all([isinstance(v, float) for v in self.station_split.values()]):
res = self._station_partition_method(self.station_split)
self._station_indexers.update(res)
else:
else: # mixed fractions and labels
_station_split = self.station_split.copy()
self._station_indexers.update({"test": _station_split.pop("test")})
test_indexers = _station_split.pop("test")
test_indexers = [s for s in test_indexers if s in self.station_index]
self._station_indexers.update({"test": test_indexers})
res = self._station_partition_method(_station_split)
self._station_indexers.update(res)
else:
Expand All @@ -253,10 +258,15 @@ def _station_partition_method(
self, fractions: Mapping[str, float]
) -> Mapping[str, np.ndarray]:

station_index = [
sta
for sta in self.station_index
if sta not in self._station_indexers.get("test", [])
]
if self.station_split_method == "random":
out = random_split(self.station_index, fractions, seed=self.seed)
out = random_split(station_index, fractions, seed=self.seed)
elif self.station_split_method == "sequential":
out = sequential_split(self.station_index, fractions)
out = sequential_split(station_index, fractions)
return out

def _check_time(self, time_split: dict, time_split_method: str):
Expand Down Expand Up @@ -312,7 +322,9 @@ def from_dict(cls, splits: dict) -> Self:

def to_dict(self, sort_values=False):
if not hasattr(self, "time_index") or not hasattr(self, "station_index"):
raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.")
raise ValueError(
"DataSplitter wasn't applied on any data yet, run `fit` first."
)
if not hasattr(self, "partitions"):
self._time_partitioning()
self._station_partitioning()
Expand Down
52 changes: 50 additions & 2 deletions tests/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,39 @@
import mlpp_lib.model_selection as ms


def check_splits(splits: dict):
"""
Check if the data splits are valid.
Args:
splits (dict): A dictionary containing train, val, and test splits.
Each split should contain 'station' and 'forecast_reference_time' keys.
Raises:
AssertionError: If there are overlapping stations or forecast reference times between splits.
"""
train_stations = set(splits["train"]["station"])
val_stations = set(splits["val"]["station"])
test_stations = set(splits["test"]["station"])

train_reftimes = set(splits["train"]["forecast_reference_time"])
val_reftimes = set(splits["val"]["forecast_reference_time"])
test_reftimes = set(splits["test"]["forecast_reference_time"])

assert len(train_stations & test_stations) == 0, "Train and test stations overlap."
assert len(train_stations & val_stations) == 0, "Train and val stations overlap."
assert len(val_stations & test_stations) == 0, "Val and test stations overlap."
assert (
len(train_reftimes & test_reftimes) == 0
), "Train and test forecast reference times overlap."
assert (
len(train_reftimes & val_reftimes) == 0
), "Train and val forecast reference times overlap."
assert (
len(val_reftimes & test_reftimes) == 0
), "Val and test forecast reference times overlap."


@dataclass
class ValidDataSplitterOptions:

Expand Down Expand Up @@ -49,8 +82,8 @@ def time_split_lists(self):
def time_split_slices(self):
out = {
"train": [self.reftimes[0], self.reftimes[30]],
"val": [self.reftimes[30], self.reftimes[40]],
"test": [self.reftimes[40], self.reftimes[50]],
"val": [self.reftimes[31], self.reftimes[40]],
"test": [self.reftimes[41], self.reftimes[50]],
}
return out

Expand All @@ -67,10 +100,25 @@ class TestDataSplitter:
scenarios = [
ValidDataSplitterOptions(time="fractions", station="lists"),
ValidDataSplitterOptions(time="slices", station="fractions"),
ValidDataSplitterOptions(time="lists", station="fractions"),
ValidDataSplitterOptions(time="lists", station="mixed"),
ValidDataSplitterOptions(time="mixed", station="fractions"),
]

@pytest.mark.parametrize(
"options", scenarios, ids=ValidDataSplitterOptions.pytest_id
)
def test_valid_split(self, options, features_dataset, targets_dataset):
splitter = ms.DataSplitter(
options.time_split,
options.station_split,
options.time_split_method,
options.station_split_method,
)
splits = splitter.fit(features_dataset).to_dict()

check_splits(splits)

@pytest.mark.parametrize(
"options", scenarios, ids=ValidDataSplitterOptions.pytest_id
)
Expand Down

0 comments on commit 4580984

Please sign in to comment.