diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index 2429eef..d8543ba 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -44,8 +44,8 @@ class DataModule: and targets are lists of names. filter: `DataFilter`, optional The object that handles the data filtering. - data_transformer: `DataTransformer`, optional - The object to transform data. + normalizer: `DataTransformer`, optional + The object to normalize data. Must be provided if `.setup("test")` is called. sample_weighting: list of str or str or xr.Dataset, optional Name(s) of the variable(s) used for weighting dataset samples or an xr.Dataset @@ -63,7 +63,7 @@ def __init__( group_samples: Optional[dict[str:int]] = None, data_dir: Optional[str] = None, filter: Optional[DataFilter] = None, - data_transformer: Optional[DataTransformer] = None, + normalizer: Optional[DataTransformer] = None, sample_weighting: Optional[Sequence[Hashable] or Hashable or xr.Dataset] = None, thinning: Optional[Mapping[str, int]] = None, ): @@ -75,7 +75,7 @@ def __init__( self.splitter = splitter self.group_samples = group_samples self.filter = filter - self.data_transformer = data_transformer + self.normalizer = normalizer self.sample_weighting = ( list(sample_weighting) if isinstance(sample_weighting, str) @@ -155,21 +155,21 @@ def apply_filter(self): def standardize(self, stage=None): LOGGER.info("Standardizing data.") - if self.data_transformer is None: + if self.normalizer is None: if stage == "test": raise ValueError("Must provide standardizer for `test` stage.") else: - self.data_transformer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})}) + self.normalizer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})}) if stage == "fit" or stage is None: - self.data_transformer.fit(self.train[0]) + self.normalizer.fit(self.train[0]) self.train = ( - tuple(self.data_transformer.transform(self.train[0])) + self.train[1:] + tuple(self.normalizer.transform(self.train[0])) + self.train[1:] ) - self.val = tuple(self.data_transformer.transform(self.val[0])) + self.val[1:] + self.val = tuple(self.normalizer.transform(self.val[0])) + self.val[1:] if stage == "test" or stage is None: - self.test = tuple(self.data_transformer.transform(self.test[0])) + self.test[1:] + self.test = tuple(self.normalizer.transform(self.test[0])) + self.test[1:] def as_datasets(self, stage=None): batch_dims = self.batch_dims diff --git a/mlpp_lib/train.py b/mlpp_lib/train.py index 9de8c53..75a139a 100644 --- a/mlpp_lib/train.py +++ b/mlpp_lib/train.py @@ -124,4 +124,4 @@ def train( for k in history: history[k] = list(map(float, history[k])) - return model, custom_objects, datamodule.data_transformer, history + return model, custom_objects, datamodule.normalizer, history