diff --git a/mlpp_lib/model_selection.py b/mlpp_lib/model_selection.py index 8d32fea..9d2020b 100644 --- a/mlpp_lib/model_selection.py +++ b/mlpp_lib/model_selection.py @@ -114,6 +114,28 @@ def __init__( self.seed = seed self.time_dim_name = time_dim_name + def fit(self, *args: xr.Dataset) -> Self: + """Compute splits based on the input datasets. + + Parameters + ---------- + *args: `xr.Dataset` + The datasets defining the time and station indices. + + Returns + ------- + self: `DataSplitter` + The fitted DataSplitter instance. + """ + + self.time_index = args[0][self.time_dim_name].values.copy() + self.station_index = args[0].station.values.copy() + + self._time_partitioning() + self._station_partitioning() + + return self + def get_partition( self, *args: xr.Dataset, partition=None, thinning: Optional[Mapping] = None ) -> tuple[xr.Dataset, ...]: @@ -141,11 +163,8 @@ def get_partition( if partition is None: raise ValueError("Keyword argument `partition` must be provided.") - self.time_index = args[0][self.time_dim_name].values.copy() - self.station_index = args[0].station.values.copy() - - self._time_partitioning() - self._station_partitioning() + if not hasattr(self, "time_index") or not hasattr(self, "station_index"): + self = self.fit(*args) # avoid out-of-order indexing (leads to bad performance with xarray/dask) station_idx = self.partitions[partition]["station"] @@ -289,9 +308,9 @@ def from_dict(cls, splits: dict) -> Self: splitter._station_partitioning() return splitter - def to_dict(self): + def to_dict(self, sort_values=False): if not hasattr(self, "time_index") or not hasattr(self, "station_index"): - raise ValueError("DataSplitter wasn't applied on data yet") + raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.") if not hasattr(self, "partitions"): self._time_partitioning() self._station_partitioning() @@ -302,6 +321,8 @@ def to_dict(self): partitions[split_key][dim] = [str(value.start), str(value.stop)] elif hasattr(value, "tolist"): partitions[split_key][dim] = value.astype(str).tolist() + if sort_values: + partitions[split_key][dim] = sorted(partitions[split_key][dim]) return partitions @classmethod