diff --git a/mlpp_lib/model_selection.py b/mlpp_lib/model_selection.py index 06d1c56..b9a279a 100644 --- a/mlpp_lib/model_selection.py +++ b/mlpp_lib/model_selection.py @@ -102,7 +102,6 @@ def __init__( seed: Optional[int] = 10, time_dim_name: str = "forecast_reference_time", ): - if not time_split.keys() == station_split.keys(): raise ValueError( "Time split and station split must be defined " @@ -362,7 +361,7 @@ def random_split( seed: int = 10, ) -> dict[str, np.ndarray]: """Split an input index array randomly""" - np.random.seed(seed) + rng = np.random.default_rng(np.random.PCG64(seed)) assert np.isclose(sum(split_fractions.values()), 1.0) @@ -370,7 +369,7 @@ def random_split( partitions = list(split_fractions.keys()) fractions = np.array(list(split_fractions.values())) - shuffled_index = np.random.permutation(index) + shuffled_index = rng.permutation(index) indices = np.floor(np.cumsum(fractions)[:-1] * n_samples).astype(int) sub_arrays = np.split(shuffled_index, indices) return dict(zip(partitions, sub_arrays))