Skip to content

Commit

Permalink
Revert to use "normalizer" as input name
Browse files Browse the repository at this point in the history
  • Loading branch information
louisPoulain committed Jun 25, 2024
1 parent 1cfd1b7 commit 5f76a3e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class DataModule:
and targets are lists of names.
filter: `DataFilter`, optional
The object that handles the data filtering.
data_transformer: `DataTransformer`, optional
The object to transform data.
normalizer: `DataTransformer`, optional
The object to normalize data.
Must be provided if `.setup("test")` is called.
sample_weighting: list of str or str or xr.Dataset, optional
Name(s) of the variable(s) used for weighting dataset samples or an xr.Dataset
Expand All @@ -63,7 +63,7 @@ def __init__(
group_samples: Optional[dict[str:int]] = None,
data_dir: Optional[str] = None,
filter: Optional[DataFilter] = None,
data_transformer: Optional[DataTransformer] = None,
normalizer: Optional[DataTransformer] = None,
sample_weighting: Optional[Sequence[Hashable] or Hashable or xr.Dataset] = None,
thinning: Optional[Mapping[str, int]] = None,
):
Expand All @@ -75,7 +75,7 @@ def __init__(
self.splitter = splitter
self.group_samples = group_samples
self.filter = filter
self.data_transformer = data_transformer
self.normalizer = normalizer
self.sample_weighting = (
list(sample_weighting)
if isinstance(sample_weighting, str)
Expand Down Expand Up @@ -155,21 +155,21 @@ def apply_filter(self):
def standardize(self, stage=None):
LOGGER.info("Standardizing data.")

if self.data_transformer is None:
if self.normalizer is None:
if stage == "test":
raise ValueError("Must provide standardizer for `test` stage.")
else:
self.data_transformer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})})
self.normalizer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})})


if stage == "fit" or stage is None:
self.data_transformer.fit(self.train[0])
self.normalizer.fit(self.train[0])
self.train = (
tuple(self.data_transformer.transform(self.train[0])) + self.train[1:]
tuple(self.normalizer.transform(self.train[0])) + self.train[1:]
)
self.val = tuple(self.data_transformer.transform(self.val[0])) + self.val[1:]
self.val = tuple(self.normalizer.transform(self.val[0])) + self.val[1:]
if stage == "test" or stage is None:
self.test = tuple(self.data_transformer.transform(self.test[0])) + self.test[1:]
self.test = tuple(self.normalizer.transform(self.test[0])) + self.test[1:]

def as_datasets(self, stage=None):
batch_dims = self.batch_dims
Expand Down
2 changes: 1 addition & 1 deletion mlpp_lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,4 @@ def train(
for k in history:
history[k] = list(map(float, history[k]))

return model, custom_objects, datamodule.data_transformer, history
return model, custom_objects, datamodule.normalizer, history

0 comments on commit 5f76a3e

Please sign in to comment.