Skip to content

Commit

Permalink
Merge pull request #42 from MeteoSwiss/refactor-data-splitter
Browse files Browse the repository at this point in the history
Add .fit() method to DataSplitter
  • Loading branch information
dnerini authored May 27, 2024
2 parents 23b5847 + 6e53fea commit c7f9685
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions mlpp_lib/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ def __init__(
self.seed = seed
self.time_dim_name = time_dim_name

def fit(self, *args: xr.Dataset) -> Self:
"""Compute splits based on the input datasets.
Parameters
----------
*args: `xr.Dataset`
The datasets defining the time and station indices.
Returns
-------
self: `DataSplitter`
The fitted DataSplitter instance.
"""

self.time_index = args[0][self.time_dim_name].values.copy()
self.station_index = args[0].station.values.copy()

self._time_partitioning()
self._station_partitioning()

return self

def get_partition(
self, *args: xr.Dataset, partition=None, thinning: Optional[Mapping] = None
) -> tuple[xr.Dataset, ...]:
Expand Down Expand Up @@ -141,11 +163,8 @@ def get_partition(
if partition is None:
raise ValueError("Keyword argument `partition` must be provided.")

self.time_index = args[0][self.time_dim_name].values.copy()
self.station_index = args[0].station.values.copy()

self._time_partitioning()
self._station_partitioning()
if not hasattr(self, "time_index") or not hasattr(self, "station_index"):
self = self.fit(*args)

# avoid out-of-order indexing (leads to bad performance with xarray/dask)
station_idx = self.partitions[partition]["station"]
Expand Down Expand Up @@ -289,9 +308,9 @@ def from_dict(cls, splits: dict) -> Self:
splitter._station_partitioning()
return splitter

def to_dict(self):
def to_dict(self, sort_values=False):
if not hasattr(self, "time_index") or not hasattr(self, "station_index"):
raise ValueError("DataSplitter wasn't applied on data yet")
raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.")
if not hasattr(self, "partitions"):
self._time_partitioning()
self._station_partitioning()
Expand All @@ -302,6 +321,8 @@ def to_dict(self):
partitions[split_key][dim] = [str(value.start), str(value.stop)]
elif hasattr(value, "tolist"):
partitions[split_key][dim] = value.astype(str).tolist()
if sort_values:
partitions[split_key][dim] = sorted(partitions[split_key][dim])
return partitions

@classmethod
Expand Down

0 comments on commit c7f9685

Please sign in to comment.