Skip to content

Commit

Permalink
Merge pull request #45 from MeteoSwiss/fix/station_partionning
Browse files Browse the repository at this point in the history
Fix/station partionning
  • Loading branch information
dnerini authored Jul 9, 2024
2 parents 5c2e52f + bffd616 commit cb56843
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mlpp_lib/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
seed: Optional[int] = 10,
time_dim_name: str = "forecast_reference_time",
):

if not time_split.keys() == station_split.keys():
raise ValueError(
"Time split and station split must be defined "
Expand Down Expand Up @@ -362,15 +361,15 @@ def random_split(
seed: int = 10,
) -> dict[str, np.ndarray]:
"""Split an input index array randomly"""
np.random.seed(seed)
rng = np.random.default_rng(np.random.PCG64(seed))

assert np.isclose(sum(split_fractions.values()), 1.0)

n_samples = len(index)
partitions = list(split_fractions.keys())
fractions = np.array(list(split_fractions.values()))

shuffled_index = np.random.permutation(index)
shuffled_index = rng.permutation(index)
indices = np.floor(np.cumsum(fractions)[:-1] * n_samples).astype(int)
sub_arrays = np.split(shuffled_index, indices)
return dict(zip(partitions, sub_arrays))
Expand Down

0 comments on commit cb56843

Please sign in to comment.