From 9312dfac4c51bfef89df9de3bdbabf78c5deeb4e Mon Sep 17 00:00:00 2001 From: ned Date: Mon, 27 May 2024 09:19:49 +0200 Subject: [PATCH 1/2] Add .fit() method --- mlpp_lib/model_selection.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/mlpp_lib/model_selection.py b/mlpp_lib/model_selection.py index 8d32fea..88a2381 100644 --- a/mlpp_lib/model_selection.py +++ b/mlpp_lib/model_selection.py @@ -114,6 +114,28 @@ def __init__( self.seed = seed self.time_dim_name = time_dim_name + def fit(self, *args: xr.Dataset) -> Self: + """Compute splits based on the input datasets. + + Parameters + ---------- + *args: `xr.Dataset` + The datasets defining the time and station indices. + + Returns + ------- + self: `DataSplitter` + The fitted DataSplitter instance. + """ + + self.time_index = args[0][self.time_dim_name].values.copy() + self.station_index = args[0].station.values.copy() + + self._time_partitioning() + self._station_partitioning() + + return self + def get_partition( self, *args: xr.Dataset, partition=None, thinning: Optional[Mapping] = None ) -> tuple[xr.Dataset, ...]: @@ -141,11 +163,8 @@ def get_partition( if partition is None: raise ValueError("Keyword argument `partition` must be provided.") - self.time_index = args[0][self.time_dim_name].values.copy() - self.station_index = args[0].station.values.copy() - - self._time_partitioning() - self._station_partitioning() + if not hasattr(self, "time_index") or not hasattr(self, "station_index"): + self = self.fit(*args) # avoid out-of-order indexing (leads to bad performance with xarray/dask) station_idx = self.partitions[partition]["station"] @@ -291,7 +310,7 @@ def from_dict(cls, splits: dict) -> Self: def to_dict(self): if not hasattr(self, "time_index") or not hasattr(self, "station_index"): - raise ValueError("DataSplitter wasn't applied on data yet") + raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.") if not hasattr(self, "partitions"): self._time_partitioning() self._station_partitioning() From 6e53fea3f29d54531f675dcbecb342999fcf56e2 Mon Sep 17 00:00:00 2001 From: ned Date: Mon, 27 May 2024 09:52:48 +0200 Subject: [PATCH 2/2] Add option to sort values --- mlpp_lib/model_selection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlpp_lib/model_selection.py b/mlpp_lib/model_selection.py index 88a2381..9d2020b 100644 --- a/mlpp_lib/model_selection.py +++ b/mlpp_lib/model_selection.py @@ -308,7 +308,7 @@ def from_dict(cls, splits: dict) -> Self: splitter._station_partitioning() return splitter - def to_dict(self): + def to_dict(self, sort_values=False): if not hasattr(self, "time_index") or not hasattr(self, "station_index"): raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.") if not hasattr(self, "partitions"): @@ -321,6 +321,8 @@ def to_dict(self): partitions[split_key][dim] = [str(value.start), str(value.stop)] elif hasattr(value, "tolist"): partitions[split_key][dim] = value.astype(str).tolist() + if sort_values: + partitions[split_key][dim] = sorted(partitions[split_key][dim]) return partitions @classmethod