Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mixed splits #51

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions mlpp_lib/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ def _time_partitioning(self) -> None:
self._time_indexers = self.time_split
else:
self._time_indexers = {}
if all(
[isinstance(v, float) for v in self.time_split.values()]
): # only fractions
if all([isinstance(v, float) for v in self.time_split.values()]):
res = self._time_partition_method(self.time_split)
self._time_indexers.update(res)
else: # mixed
else: # mixed fractions and labels
_time_split = self.time_split.copy()
self._time_indexers.update({"test": _time_split.pop("test")})
test_indexers = _time_split.pop("test")
test_indexers = [t for t in test_indexers if t in self.time_index]
self._time_indexers.update({"test": test_indexers})
res = self._time_partition_method(_time_split)
self._time_indexers.update(res)

Expand All @@ -222,8 +222,11 @@ def _time_partitioning(self) -> None:
self.partitions[partition].update(indexer)

def _time_partition_method(self, fractions: Mapping[str, float]):
time_index = [
t for t in self.time_index if t not in self._time_indexers.get("test", [])
]
if self.time_split_method == "sequential":
return sequential_split(self.time_index, fractions)
return sequential_split(time_index, fractions)

def _station_partitioning(self):
"""
Expand All @@ -235,9 +238,11 @@ def _station_partitioning(self):
if all([isinstance(v, float) for v in self.station_split.values()]):
res = self._station_partition_method(self.station_split)
self._station_indexers.update(res)
else:
else: # mixed fractions and labels
_station_split = self.station_split.copy()
self._station_indexers.update({"test": _station_split.pop("test")})
test_indexers = _station_split.pop("test")
test_indexers = [s for s in test_indexers if s in self.station_index]
self._station_indexers.update({"test": test_indexers})
res = self._station_partition_method(_station_split)
self._station_indexers.update(res)
else:
Expand All @@ -253,10 +258,15 @@ def _station_partition_method(
self, fractions: Mapping[str, float]
) -> Mapping[str, np.ndarray]:

station_index = [
sta
for sta in self.station_index
if sta not in self._station_indexers.get("test", [])
]
if self.station_split_method == "random":
out = random_split(self.station_index, fractions, seed=self.seed)
out = random_split(station_index, fractions, seed=self.seed)
elif self.station_split_method == "sequential":
out = sequential_split(self.station_index, fractions)
out = sequential_split(station_index, fractions)
return out

def _check_time(self, time_split: dict, time_split_method: str):
Expand Down Expand Up @@ -312,7 +322,9 @@ def from_dict(cls, splits: dict) -> Self:

def to_dict(self, sort_values=False):
if not hasattr(self, "time_index") or not hasattr(self, "station_index"):
raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.")
raise ValueError(
"DataSplitter wasn't applied on any data yet, run `fit` first."
)
if not hasattr(self, "partitions"):
self._time_partitioning()
self._station_partitioning()
Expand Down
52 changes: 50 additions & 2 deletions tests/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,39 @@
import mlpp_lib.model_selection as ms


def check_splits(splits: dict):
"""
Check if the data splits are valid.

Args:
splits (dict): A dictionary containing train, val, and test splits.
Each split should contain 'station' and 'forecast_reference_time' keys.

Raises:
AssertionError: If there are overlapping stations or forecast reference times between splits.
"""
train_stations = set(splits["train"]["station"])
val_stations = set(splits["val"]["station"])
test_stations = set(splits["test"]["station"])

train_reftimes = set(splits["train"]["forecast_reference_time"])
val_reftimes = set(splits["val"]["forecast_reference_time"])
test_reftimes = set(splits["test"]["forecast_reference_time"])

assert len(train_stations & test_stations) == 0, "Train and test stations overlap."
assert len(train_stations & val_stations) == 0, "Train and val stations overlap."
assert len(val_stations & test_stations) == 0, "Val and test stations overlap."
assert (
len(train_reftimes & test_reftimes) == 0
), "Train and test forecast reference times overlap."
assert (
len(train_reftimes & val_reftimes) == 0
), "Train and val forecast reference times overlap."
assert (
len(val_reftimes & test_reftimes) == 0
), "Val and test forecast reference times overlap."


@dataclass
class ValidDataSplitterOptions:

Expand Down Expand Up @@ -49,8 +82,8 @@ def time_split_lists(self):
def time_split_slices(self):
out = {
"train": [self.reftimes[0], self.reftimes[30]],
"val": [self.reftimes[30], self.reftimes[40]],
"test": [self.reftimes[40], self.reftimes[50]],
"val": [self.reftimes[31], self.reftimes[40]],
"test": [self.reftimes[41], self.reftimes[50]],
}
return out

Expand All @@ -67,10 +100,25 @@ class TestDataSplitter:
scenarios = [
ValidDataSplitterOptions(time="fractions", station="lists"),
ValidDataSplitterOptions(time="slices", station="fractions"),
ValidDataSplitterOptions(time="lists", station="fractions"),
ValidDataSplitterOptions(time="lists", station="mixed"),
ValidDataSplitterOptions(time="mixed", station="fractions"),
]

@pytest.mark.parametrize(
"options", scenarios, ids=ValidDataSplitterOptions.pytest_id
)
def test_valid_split(self, options, features_dataset, targets_dataset):
splitter = ms.DataSplitter(
options.time_split,
options.station_split,
options.time_split_method,
options.station_split_method,
)
splits = splitter.fit(features_dataset).to_dict()

check_splits(splits)

@pytest.mark.parametrize(
"options", scenarios, ids=ValidDataSplitterOptions.pytest_id
)
Expand Down
Loading