From 4a78f5339408c4adf587d7661060b040e950695c Mon Sep 17 00:00:00 2001 From: ned Date: Wed, 28 Aug 2024 16:23:32 +0200 Subject: [PATCH] Deprecate module standardizers --- mlpp_lib/datasets.py | 6 +- mlpp_lib/normalizers.py | 444 ++++++++++++++++++++++++++++++++ mlpp_lib/standardizers.py | 450 +-------------------------------- tests/conftest.py | 16 +- tests/test_data_transformer.py | 97 ------- tests/test_datasets.py | 2 +- tests/test_normalizers.py | 184 ++++++++++++++ tests/test_standardizers.py | 33 --- tests/test_train.py | 2 +- 9 files changed, 649 insertions(+), 585 deletions(-) create mode 100644 mlpp_lib/normalizers.py delete mode 100644 tests/test_data_transformer.py create mode 100644 tests/test_normalizers.py delete mode 100644 tests/test_standardizers.py diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index cacfbe9..1de18ae 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -11,7 +11,7 @@ from typing_extensions import Self from .model_selection import DataSplitter -from .standardizers import DataTransformer +from .normalizers import DataTransformer LOGGER = logging.getLogger(__name__) @@ -153,11 +153,11 @@ def apply_filter(self): self.x, self.y = self.filter.apply(self.x, self.y) def normalize(self, stage=None): - LOGGER.info("Standardizing data.") + LOGGER.info("Normalizing data.") if self.normalizer is None: if stage == "test": - raise ValueError("Must provide standardizer for `test` stage.") + raise ValueError("Must provide normalizer for `test` stage.") else: self.normalizer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})}) diff --git a/mlpp_lib/normalizers.py b/mlpp_lib/normalizers.py new file mode 100644 index 0000000..c90889a --- /dev/null +++ b/mlpp_lib/normalizers.py @@ -0,0 +1,444 @@ +import json +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional +from abc import abstractmethod + +import numpy as np +import xarray as xr +from typing_extensions import Self + +LOGGER = logging.getLogger(__name__) + + + +def create_transformation_from_str(class_name: str, inputs: Optional[dict] = None): + + cls = globals()[class_name] + + if issubclass(cls, DataTransformation): + if inputs is None: + return cls(fillvalue=-5) + else: + if "fillvalue" not in inputs.keys(): + inputs["fillvalue"] = -5 + return cls(**inputs) + else: + raise ValueError(f"{class_name} is not a subclass of DataTransformation") + + +@dataclass +class DataTransformer: + """ + Class to handle the transformation of data in a xarray.Dataset object with different techniques. + """ + name = "DataTransformer" + + def __init__(self, method_var_dict: dict[str, tuple[list[str], dict[str, float]]] = None, + default_norma: Optional[str] = None, fillvalue: float = -5): + + self.all_vars = [] + self.parameters = [] + self.fillvalue = fillvalue + self.default_norma = default_norma if default_norma is not None else "Standardizer" + self.default_norma_index = None + + if method_var_dict is not None: + for i, (method, params) in enumerate(method_var_dict.items()): + variables, input_params = params + # handle the case of user passing the default norma with some features and forgetting to assign all features to a norma + if method == default_norma: + self.default_norma_index = i + + if input_params is None: + input_params = {} + if "fillvalue" not in input_params.keys(): + input_params["fillvalue"] = self.fillvalue + vars_to_remove = [var for var in variables if var in self.all_vars] + + method_cls = create_transformation_from_str(method, inputs=input_params) + + if len(vars_to_remove) > 0: + LOGGER.error(f"Variable(s) {[var for var in vars_to_remove]} are already assigned to another transformation method.\nRemoving them from this transformation.") + variables = [var for var in variables if var not in vars_to_remove] + + self.parameters.append((method_cls, variables, input_params)) + self.all_vars.extend(variables) + + + def fit(self, dataset: xr.Dataset, dims: Optional[list] = None): + + datavars = list(dataset.data_vars) + remaining_var = list(set(datavars) - set(self.all_vars)) + if len(remaining_var) > 0: + LOGGER.info(f"Variables {[var for var in remaining_var]} are not assigned to any transformation method. They will be assigned to {self.default_norma}") + if self.default_norma_index is not None: + self.parameters[self.default_norma_index][1].extend(remaining_var) + else: + self.parameters.append((create_transformation_from_str(self.default_norma), remaining_var, {"fillvalue": self.fillvalue})) + self.all_vars.extend(remaining_var) + + for i in range(len(self.parameters)): + transformation, variables, inputs = self.parameters[i] + transformation.fit(dataset=dataset, variables = variables, dims = dims) + self.parameters[i] = (transformation, variables, inputs) + + + def transform(self, *datasets: xr.Dataset) -> tuple[xr.Dataset, ...]: + + for parameter in self.parameters: + method, variables, _ = parameter + variables = list(set(variables) & set(datasets[0].data_vars)) # ensure that only variables that are in the dataset are transformed + datasets = method.transform(*datasets, variables=variables) + return datasets + + + def inverse_transform(self, *datasets: xr.Dataset) -> xr.Dataset: + + for parameter in self.parameters: + method, variables, _ = parameter + variables = list(set(variables) & set(datasets[0].data_vars)) # ensure that only variables that are in the dataset are inv-transformed + datasets = method.inverse_transform(*datasets, variables=variables) + + return datasets + + + @classmethod + def from_dict(cls, in_dict: dict) -> Self: + method_var_dict = {} + + # check whether dict corresponds to old Standardizer format + first_key = list(in_dict.keys())[0] + if first_key not in [cls.__name__ for cls in DataTransformation.__subclasses__()]: + subclass = Standardizer().from_dict(in_dict) + inputs = {"mean": subclass.mean, "std": subclass.std, "fillvalue": subclass.fillvalue} + method_var_dict[subclass.name] = (list(subclass.mean.data_vars), inputs) + else: + for method_name, inner_dict in in_dict.items(): + tmp_class = create_transformation_from_str(method_name).from_dict(inner_dict) + inputs = {key: getattr(tmp_class, key) for key in inner_dict if getattr(tmp_class, key, None) is not None} + method_var_dict[tmp_class.name] = (inner_dict["channels"], inputs) + return cls(method_var_dict) + + + def to_dict(self): + out_dict = {} + for parameter in self.parameters: + method, variables, _ = parameter + out_dict_tmp = method.to_dict() + out_dict_tmp["channels"] = variables + out_dict[method.name] = out_dict_tmp + return out_dict + + @classmethod + def from_json(cls, in_fn: str) -> Self: + with open(in_fn, "r") as f: + in_dict = json.load(f) + return cls.from_dict(in_dict) + + def save_json(self, out_fn: str) -> None: + out_dict = self.to_dict() + if len(out_dict) == 0 or out_dict[list(out_dict.keys())[0]] is None: + raise ValueError(f"{self.name} wasn't fit to data") + with open(out_fn, "w") as outfile: + json.dump(out_dict, outfile, indent=4) + +class DataTransformation: + """ + Abstract class for nromalization techniques in a xarray.Dataset object. + """ + + @abstractmethod + def fit(): + pass + + @abstractmethod + def transform(): + pass + + @abstractmethod + def inverse_transform(): + pass + + @abstractmethod + def from_dict(): + pass + + @abstractmethod + def to_dict(): + pass + +@dataclass +class Identity(DataTransformation): + """ + Identity transformation, returns the input data without any transformation. + """ + + fillvalue: float = field(default=-5) + name = "Identity" + + def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): + self.fillvalue = self.fillvalue + + def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + if variables is None: + variables = list(ds.data_vars) + for var in variables: + identity_value = ds[var].astype("float32") + if self.fillvalue is not None: + identity_value = identity_value.fillna(self.fillvalue) + ds[var] = identity_value + return ds + + return tuple(f(ds, variables) for ds in datasets) + + def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + return tuple(xr.where(ds > self.fillvalue, ds, np.nan) for ds in datasets) + + @classmethod + def from_dict(cls, in_dict: dict) -> Self: + fillvalue = in_dict["fillvalue"] + return cls(fillvalue) + + def to_dict(self): + out_dict = { + "fillvalue": self.fillvalue, + } + return out_dict + + +@dataclass +class Standardizer(DataTransformation): + """ + Tranforms data using a z-normalization in a xarray.Dataset object. + """ + + mean: xr.Dataset = field(default=None) + std: xr.Dataset = field(default=None) + fillvalue: dict[str, float] = field(init=True, default=-5) + name = "Standardizer" + + def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): + + if variables is None: + variables = list(dataset.data_vars) + if not all(var in dataset.data_vars for var in variables): + raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") + + self.mean = dataset[variables].mean(dims).compute().copy() + self.std = dataset[variables].std(dims).compute().copy() + self.fillvalue = self.fillvalue + # Check for near-zero standard deviations and set them equal to one + self.std = xr.where(self.std < 1e-6, 1, self.std) + + def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.mean is None: + raise ValueError("Standardizer wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None): + + ds_copy = ds.copy() + if variables is None: + variables = list(ds_copy.data_vars) + for var in variables: + assert var in self.mean.data_vars, f"{var} not in Standardizer" + + standardized_var = ((ds_copy[var] - self.mean[var]) / self.std[var]).astype("float32") + if self.fillvalue is not None: + standardized_var = standardized_var.fillna(self.fillvalue) + + ds_copy[var] = standardized_var + + return ds_copy + + return tuple(f(ds, variables) for ds in datasets) + + + def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.mean is None: + raise ValueError("Standardizer wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + if variables is None: + variables = ds.data_vars + + ds = xr.where(ds > self.fillvalue, ds, np.nan) + for var in variables: + assert var in self.mean.data_vars, f"{var} not in Standardizer" + + unstandardized_var = (ds[var] * self.std[var] + self.mean[var]).astype("float32") + ds[var] = unstandardized_var + + return ds + + return tuple(f(ds, variables) for ds in datasets) + + @classmethod + def from_dict(cls, in_dict: dict) -> Self: + mean = xr.Dataset.from_dict(in_dict["mean"]) + std = xr.Dataset.from_dict(in_dict["std"]) + fillvalue = in_dict["fillvalue"] + return cls(mean, std, fillvalue=fillvalue) + + def to_dict(self): + out_dict = { + "mean": self.mean.to_dict(), + "std": self.std.to_dict(), + "fillvalue": self.fillvalue, + } + return out_dict + + +@dataclass +class MinMaxScaler(DataTransformation): + """ + Tranforms data using a min/max scaling in a xarray.Dataset object. + """ + + minimum: xr.Dataset = field(default=None) + maximum: xr.Dataset = field(default=None) + fillvalue: dict[str, float] = field(init=True, default=-5) + name = "MinMaxScaler" + + def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): + if variables is None: + variables = list(dataset.data_vars) + if not all(var in dataset.data_vars for var in variables): + raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") + + self.minimum = dataset[variables].min(dims).compute().copy() + self.maximum = dataset[variables].max(dims).compute().copy() + self.fillvalue = self.fillvalue + + def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.minimum is None: + raise ValueError("MinMaxScaler wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + + ds_copy = ds.copy() + if variables is None: + variables = ds_copy.data_vars + for var in variables: + assert var in self.minimum.data_vars, f"{var} not in MinMaxScaler" + + scaled_var = ((ds_copy[var] - self.minimum[var]) / (self.maximum[var] - self.minimum[var])).astype("float32") + + if self.fillvalue is not None: + scaled_var = scaled_var.fillna(self.fillvalue) + + ds_copy[var] = scaled_var + return ds_copy + + return tuple(f(ds, variables) for ds in datasets) + + + def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.minimum is None: + raise ValueError("MinMaxScaler wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + if variables is None: + variables = ds.data_vars + + ds = xr.where(ds > self.fillvalue, ds, np.nan) + for var in variables: + assert var in self.minimum.data_vars, f"{var} not in MinMaxScaler" + + unscaled_var = (ds[var] * (self.maximum[var] - self.minimum[var]) + self.minimum[var]).astype("float32") + ds[var] = unscaled_var + + return ds + + return tuple(f(ds, variables) for ds in datasets) + + + @classmethod + def from_dict(cls, in_dict: dict) -> Self: + minimum = xr.Dataset.from_dict(in_dict["minimum"]) + maximum = xr.Dataset.from_dict(in_dict["maximum"]) + return cls(minimum, maximum) + + + def to_dict(self): + out_dict = { + "minimum": self.minimum.to_dict(), + "maximum": self.maximum.to_dict(), + "fillvalue": self.fillvalue, + } + return out_dict + + +@dataclass +class MaxAbsScaler(DataTransformation): + """ + Tranforms data using a max absolute scaling in a xarray.Dataset object. + """ + + absmax: xr.Dataset = field(default=None) + fillvalue: dict[str, float] = field(init=True, default=-5) + name = "MaxAbsScaler" + + def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): + if variables is None: + variables = list(dataset.data_vars) + if not all(var in dataset.data_vars for var in variables): + raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") + + self.absmax = abs(dataset[variables]).max(dims).compute().copy() + self.fillvalue = self.fillvalue + + + def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.absmax is None: + raise ValueError("MaxAbsScaler wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + + ds_copy = ds.copy() + if variables is None: + variables = ds_copy.data_vars + for var in variables: + assert var in self.absmax.data_vars, f"{var} not in MaxAbsScaler" + + scaled_var = (ds_copy[var] / self.absmax[var]).astype("float32") + + if self.fillvalue is not None: + scaled_var = scaled_var.fillna(self.fillvalue) + + ds_copy[var] = scaled_var + return ds_copy + + return tuple(f(ds, variables) for ds in datasets) + + def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: + if self.absmax is None: + raise ValueError("MaxAbsScaler wasn't fit to data") + + def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: + if variables is None: + variables = ds.data_vars + ds = xr.where(ds > self.fillvalue, ds, np.nan) + for var in variables: + assert var in self.absmax.data_vars, f"{var} not in MaxAbsScaler" + + unscaled_var = (ds[var] * self.absmax[var]).astype("float32") + ds[var] = unscaled_var + return ds + + return tuple(f(ds, variables) for ds in datasets) + + @classmethod + def from_dict(cls, in_dict: dict) -> Self: + absmax = xr.Dataset.from_dict(in_dict["absmax"]) + return cls(absmax) + + def to_dict(self): + out_dict = { + "absmax": self.absmax.to_dict(), + "fillvalue": self.fillvalue, + } + return out_dict diff --git a/mlpp_lib/standardizers.py b/mlpp_lib/standardizers.py index 06283e6..bee3f25 100644 --- a/mlpp_lib/standardizers.py +++ b/mlpp_lib/standardizers.py @@ -1,449 +1,15 @@ -import json -import logging -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional -from abc import abstractmethod +import warnings +from .normalizers import * -import numpy as np -import xarray as xr -from typing_extensions import Self -LOGGER = logging.getLogger(__name__) +warnings.warn( + "Module 'standardizers' is deprecated and will be removed in a future version. " + "Please use 'normalizers' instead.", + DeprecationWarning, + stacklevel=2, +) - -def create_transformation_from_str(class_name: str, inputs: Optional[dict] = None): - - cls = globals()[class_name] - - if issubclass(cls, DataTransformation): - if inputs is None: - return cls(fillvalue=-5) - else: - if "fillvalue" not in inputs.keys(): - inputs["fillvalue"] = -5 - return cls(**inputs) - else: - raise ValueError(f"{class_name} is not a subclass of DataTransformation") - - -@dataclass -class DataTransformer: - """ - Class to handle the transformation of data in a xarray.Dataset object with different techniques. - """ - name = "DataTransformer" - - def __init__(self, method_var_dict: dict[str, tuple[list[str], dict[str, float]]] = None, - default_norma: Optional[str] = None, fillvalue: float = -5): - - self.all_vars = [] - self.parameters = [] - self.fillvalue = fillvalue - self.default_norma = default_norma if default_norma is not None else "Standardizer" - self.default_norma_index = None - - if method_var_dict is not None: - for i, (method, params) in enumerate(method_var_dict.items()): - variables, input_params = params - # handle the case of user passing the default norma with some features and forgetting to assign all features to a norma - if method == default_norma: - self.default_norma_index = i - - if input_params is None: - input_params = {} - if "fillvalue" not in input_params.keys(): - input_params["fillvalue"] = self.fillvalue - vars_to_remove = [var for var in variables if var in self.all_vars] - - method_cls = create_transformation_from_str(method, inputs=input_params) - - if len(vars_to_remove) > 0: - LOGGER.error(f"Variable(s) {[var for var in vars_to_remove]} are already assigned to another transformation method.\nRemoving them from this transformation.") - variables = [var for var in variables if var not in vars_to_remove] - - self.parameters.append((method_cls, variables, input_params)) - self.all_vars.extend(variables) - - - def fit(self, dataset: xr.Dataset, dims: Optional[list] = None): - - datavars = list(dataset.data_vars) - remaining_var = list(set(datavars) - set(self.all_vars)) - if len(remaining_var) > 0: - LOGGER.info(f"Variables {[var for var in remaining_var]} are not assigned to any transformation method. They will be assigned to {self.default_norma}") - if self.default_norma_index is not None: - self.parameters[self.default_norma_index][1].extend(remaining_var) - else: - self.parameters.append((create_transformation_from_str(self.default_norma), remaining_var, {"fillvalue": self.fillvalue})) - self.all_vars.extend(remaining_var) - - for i in range(len(self.parameters)): - transformation, variables, inputs = self.parameters[i] - transformation.fit(dataset=dataset, variables = variables, dims = dims) - self.parameters[i] = (transformation, variables, inputs) - - - def transform(self, *datasets: xr.Dataset) -> tuple[xr.Dataset, ...]: - - for parameter in self.parameters: - method, variables, _ = parameter - variables = list(set(variables) & set(datasets[0].data_vars)) # ensure that only variables that are in the dataset are transformed - datasets = method.transform(*datasets, variables=variables) - return datasets - - - def inverse_transform(self, *datasets: xr.Dataset) -> xr.Dataset: - - for parameter in self.parameters: - method, variables, _ = parameter - variables = list(set(variables) & set(datasets[0].data_vars)) # ensure that only variables that are in the dataset are inv-transformed - datasets = method.inverse_transform(*datasets, variables=variables) - - return datasets - - - @classmethod - def from_dict(cls, in_dict: dict) -> Self: - method_var_dict = {} - - # check whether dict corresponds to old Standardizer format - first_key = list(in_dict.keys())[0] - if first_key not in [cls.__name__ for cls in DataTransformation.__subclasses__()]: - subclass = Standardizer().from_dict(in_dict) - inputs = {"mean": subclass.mean, "std": subclass.std, "fillvalue": subclass.fillvalue} - method_var_dict[subclass.name] = (list(subclass.mean.data_vars), inputs) - else: - for method_name, inner_dict in in_dict.items(): - tmp_class = create_transformation_from_str(method_name).from_dict(inner_dict) - inputs = {key: getattr(tmp_class, key) for key in inner_dict if getattr(tmp_class, key, None) is not None} - method_var_dict[tmp_class.name] = (inner_dict["channels"], inputs) - return cls(method_var_dict) - - - def to_dict(self): - out_dict = {} - for parameter in self.parameters: - method, variables, _ = parameter - out_dict_tmp = method.to_dict() - out_dict_tmp["channels"] = variables - out_dict[method.name] = out_dict_tmp - return out_dict - - @classmethod - def from_json(cls, in_fn: str) -> Self: - with open(in_fn, "r") as f: - in_dict = json.load(f) - return cls.from_dict(in_dict) - - def save_json(self, out_fn: str) -> None: - out_dict = self.to_dict() - if len(out_dict) == 0 or out_dict[list(out_dict.keys())[0]] is None: - raise ValueError(f"{self.name} wasn't fit to data") - with open(out_fn, "w") as outfile: - json.dump(out_dict, outfile, indent=4) - -class DataTransformation: - """ - Abstract class for nromalization techniques in a xarray.Dataset object. - """ - - @abstractmethod - def fit(): - pass - - @abstractmethod - def transform(): - pass - - @abstractmethod - def inverse_transform(): - pass - - @abstractmethod - def from_dict(): - pass - - @abstractmethod - def to_dict(): - pass - -@dataclass -class Identity(DataTransformation): - """ - Identity transformation, returns the input data without any transformation. - """ - - fillvalue: float = field(default=-5) - name = "Identity" - - def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): - self.fillvalue = self.fillvalue - - def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - if variables is None: - variables = list(ds.data_vars) - for var in variables: - identity_value = ds[var].astype("float32") - if self.fillvalue is not None: - identity_value = identity_value.fillna(self.fillvalue) - ds[var] = identity_value - return ds - - return tuple(f(ds, variables) for ds in datasets) - - def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - return tuple(xr.where(ds > self.fillvalue, ds, np.nan) for ds in datasets) - - @classmethod - def from_dict(cls, in_dict: dict) -> Self: - fillvalue = in_dict["fillvalue"] - return cls(fillvalue) - - def to_dict(self): - out_dict = { - "fillvalue": self.fillvalue, - } - return out_dict - - -@dataclass -class Standardizer(DataTransformation): - """ - Tranforms data using a z-normalization in a xarray.Dataset object. - """ - - mean: xr.Dataset = field(default=None) - std: xr.Dataset = field(default=None) - fillvalue: dict[str, float] = field(init=True, default=-5) - name = "Standardizer" - - def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): - - if variables is None: - variables = list(dataset.data_vars) - if not all(var in dataset.data_vars for var in variables): - raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") - - self.mean = dataset[variables].mean(dims).compute().copy() - self.std = dataset[variables].std(dims).compute().copy() - self.fillvalue = self.fillvalue - # Check for near-zero standard deviations and set them equal to one - self.std = xr.where(self.std < 1e-6, 1, self.std) - - def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.mean is None: - raise ValueError("Standardizer wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None): - - ds_copy = ds.copy() - if variables is None: - variables = list(ds_copy.data_vars) - for var in variables: - assert var in self.mean.data_vars, f"{var} not in Standardizer" - - standardized_var = ((ds_copy[var] - self.mean[var]) / self.std[var]).astype("float32") - if self.fillvalue is not None: - standardized_var = standardized_var.fillna(self.fillvalue) - - ds_copy[var] = standardized_var - - return ds_copy - - return tuple(f(ds, variables) for ds in datasets) - - - def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.mean is None: - raise ValueError("Standardizer wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - if variables is None: - variables = ds.data_vars - - ds = xr.where(ds > self.fillvalue, ds, np.nan) - for var in variables: - assert var in self.mean.data_vars, f"{var} not in Standardizer" - - unstandardized_var = (ds[var] * self.std[var] + self.mean[var]).astype("float32") - ds[var] = unstandardized_var - - return ds - - return tuple(f(ds, variables) for ds in datasets) - - @classmethod - def from_dict(cls, in_dict: dict) -> Self: - mean = xr.Dataset.from_dict(in_dict["mean"]) - std = xr.Dataset.from_dict(in_dict["std"]) - fillvalue = in_dict["fillvalue"] - return cls(mean, std, fillvalue=fillvalue) - - def to_dict(self): - out_dict = { - "mean": self.mean.to_dict(), - "std": self.std.to_dict(), - "fillvalue": self.fillvalue, - } - return out_dict - - -@dataclass -class MinMaxScaler(DataTransformation): - """ - Tranforms data using a min/max scaling in a xarray.Dataset object. - """ - - minimum: xr.Dataset = field(default=None) - maximum: xr.Dataset = field(default=None) - fillvalue: dict[str, float] = field(init=True, default=-5) - name = "MinMaxScaler" - - def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): - if variables is None: - variables = list(dataset.data_vars) - if not all(var in dataset.data_vars for var in variables): - raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") - - self.minimum = dataset[variables].min(dims).compute().copy() - self.maximum = dataset[variables].max(dims).compute().copy() - self.fillvalue = self.fillvalue - - def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.minimum is None: - raise ValueError("MinMaxScaler wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - - ds_copy = ds.copy() - if variables is None: - variables = ds_copy.data_vars - for var in variables: - assert var in self.minimum.data_vars, f"{var} not in MinMaxScaler" - - scaled_var = ((ds_copy[var] - self.minimum[var]) / (self.maximum[var] - self.minimum[var])).astype("float32") - - if self.fillvalue is not None: - scaled_var = scaled_var.fillna(self.fillvalue) - - ds_copy[var] = scaled_var - return ds_copy - - return tuple(f(ds, variables) for ds in datasets) - - - def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.minimum is None: - raise ValueError("MinMaxScaler wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - if variables is None: - variables = ds.data_vars - - ds = xr.where(ds > self.fillvalue, ds, np.nan) - for var in variables: - assert var in self.minimum.data_vars, f"{var} not in MinMaxScaler" - - unscaled_var = (ds[var] * (self.maximum[var] - self.minimum[var]) + self.minimum[var]).astype("float32") - ds[var] = unscaled_var - - return ds - - return tuple(f(ds, variables) for ds in datasets) - - - @classmethod - def from_dict(cls, in_dict: dict) -> Self: - minimum = xr.Dataset.from_dict(in_dict["minimum"]) - maximum = xr.Dataset.from_dict(in_dict["maximum"]) - return cls(minimum, maximum) - - - def to_dict(self): - out_dict = { - "minimum": self.minimum.to_dict(), - "maximum": self.maximum.to_dict(), - "fillvalue": self.fillvalue, - } - return out_dict - - -@dataclass -class MaxAbsScaler(DataTransformation): - """ - Tranforms data using a max absolute scaling in a xarray.Dataset object. - """ - - absmax: xr.Dataset = field(default=None) - fillvalue: dict[str, float] = field(init=True, default=-5) - name = "MaxAbsScaler" - - def fit(self, dataset: xr.Dataset, variables: Optional[list] = None, dims: Optional[list] = None): - if variables is None: - variables = list(dataset.data_vars) - if not all(var in dataset.data_vars for var in variables): - raise KeyError(f"There are variables not in dataset: {[var for var in variables if var not in dataset.data_vars]}") - - self.absmax = abs(dataset[variables]).max(dims).compute().copy() - self.fillvalue = self.fillvalue - - - def transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.absmax is None: - raise ValueError("MaxAbsScaler wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - - ds_copy = ds.copy() - if variables is None: - variables = ds_copy.data_vars - for var in variables: - assert var in self.absmax.data_vars, f"{var} not in MaxAbsScaler" - - scaled_var = (ds_copy[var] / self.absmax[var]).astype("float32") - - if self.fillvalue is not None: - scaled_var = scaled_var.fillna(self.fillvalue) - - ds_copy[var] = scaled_var - return ds_copy - - return tuple(f(ds, variables) for ds in datasets) - - def inverse_transform(self, *datasets: xr.Dataset, variables: Optional[list] = None) -> tuple[xr.Dataset, ...]: - if self.absmax is None: - raise ValueError("MaxAbsScaler wasn't fit to data") - - def f(ds: xr.Dataset, variables: Optional[list] = None) -> xr.Dataset: - if variables is None: - variables = ds.data_vars - ds = xr.where(ds > self.fillvalue, ds, np.nan) - for var in variables: - assert var in self.absmax.data_vars, f"{var} not in MaxAbsScaler" - - unscaled_var = (ds[var] * self.absmax[var]).astype("float32") - ds[var] = unscaled_var - return ds - - return tuple(f(ds, variables) for ds in datasets) - - @classmethod - def from_dict(cls, in_dict: dict) -> Self: - absmax = xr.Dataset.from_dict(in_dict["absmax"]) - return cls(absmax) - - def to_dict(self): - out_dict = { - "absmax": self.absmax.to_dict(), - "fillvalue": self.fillvalue, - } - return out_dict - - def standardize_split_dataset( split_dataset: dict[str, xr.Dataset], save_to_json: Optional[Path] = None, diff --git a/tests/conftest.py b/tests/conftest.py index 6d3ead5..02dc715 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,12 +45,12 @@ def features_dataset() -> xr.Dataset: def features_multi() -> xr.Dataset: """ Create a dataset as if it was loaded from `features.zarr`. - Coherent with the number of data transformations defined in the standardizers file. + Coherent with the number of data transformations defined in the normalizers module. """ - import mlpp_lib.standardizers as st + import mlpp_lib.normalizers as no rng = np.random.default_rng(1) - X = rng.standard_normal(size=(*SHAPE, len([n.name for n in st.DataTransformation.__subclasses__()]))) + X = rng.standard_normal(size=(*SHAPE, len([n.name for n in no.DataTransformation.__subclasses__()]))) X = np.float64(X) X[(X > 4.5) | (X < -4.5)] = np.nan features = xr.Dataset( @@ -73,9 +73,9 @@ def datatransformations() -> list: Create a list of data transformations. The list consists of all available data transformations. """ - import mlpp_lib.standardizers as st + import mlpp_lib.normalizers as no - datatransformations = [st.create_transformation_from_str(n.name) for n in st.DataTransformation.__subclasses__()] + datatransformations = [no.create_transformation_from_str(n.name) for n in no.DataTransformation.__subclasses__()] return datatransformations @@ -85,11 +85,11 @@ def data_transformer() -> xr.Dataset: """ Create a datatransformer. """ - import mlpp_lib.standardizers as st + import mlpp_lib.normalizers as no - transformations_list = [n.name for n in st.DataTransformation.__subclasses__()] + transformations_list = [n.name for n in no.DataTransformation.__subclasses__()] method_var_dict = {transformation: ([f"var{i}"],{}) for i, transformation in enumerate(transformations_list)} - data_transformer = st.DataTransformer(method_var_dict) + data_transformer = no.DataTransformer(method_var_dict) return data_transformer diff --git a/tests/test_data_transformer.py b/tests/test_data_transformer.py deleted file mode 100644 index ebcd144..0000000 --- a/tests/test_data_transformer.py +++ /dev/null @@ -1,97 +0,0 @@ -from mlpp_lib.standardizers import DataTransformer, Standardizer -import numpy as np -import xarray as xr - - -def get_class_attributes(cls): - class_attrs = {name: field.default for name, field in cls.__dataclass_fields__.items()} - return class_attrs - -def test_fit(datatransformations, data_transformer, features_multi): - - for i, datatransform in enumerate(datatransformations): - datatransform.fit(features_multi, variables=[f"var{i}"]) - - data_transformer.fit(features_multi) - - assert all( - (np.allclose(getattr(datatransform, attr), getattr(data_transformer.parameters[i][0], attr), equal_nan=True) - for attr in get_class_attributes(datatransform)) - for datatransform in datatransformations - ) - - -def test_transform(datatransformations, data_transformer, features_multi): - - features_individual = features_multi.copy() - for i, datatransform in enumerate(datatransformations): - datatransform.fit(features_multi, variables=[f"var{i}"]) - features_individual = datatransform.transform(features_individual, variables=[f"var{i}"])[0] - - data_transformer.fit(features_multi) - features_multi = data_transformer.transform(features_multi)[0] - - assert all( - np.allclose(features_individual[f"var{i}"].values, features_multi[f"var{i}"].values, equal_nan=True) - for i in range(len(datatransformations)) - ) - - -def test_inverse_transform(datatransformations, data_transformer, features_multi): - - original_data = features_multi.copy().astype("float32") - features_individual = features_multi.copy() - for i, datatransform in enumerate(datatransformations): - datatransform.fit(features_multi, variables=[f"var{i}"]) - features_individual = datatransform.transform(features_individual, variables=[f"var{i}"])[0] - inv_ds_individual = features_individual.copy() - for i, datatransform in enumerate(datatransformations): - inv_ds_individual = datatransform.inverse_transform(inv_ds_individual, variables=[f"var{i}"])[0] - - data_transformer.fit(features_multi) - ds_multi = data_transformer.transform(features_multi)[0] - inv_ds_multi = data_transformer.inverse_transform(ds_multi)[0] - - assert all( - np.allclose(inv_ds_individual[f"var{i}"].values, inv_ds_multi[f"var{i}"].values, equal_nan=True) - for i in range(len(datatransformations)) - ), "Inverse transform is not equal between individual data transformations and data_transformer" - - assert all( - np.allclose(original_data[f"var{i}"].values, inv_ds_individual[f"var{i}"].values, equal_nan=True, atol=1e-6) - for i in range(len(datatransformations)) - ), "Inverse transform is not equal between transformed individual data transformations and original features" - - assert all( - np.allclose(original_data[f"var{i}"].values, inv_ds_multi[f"var{i}"].values, equal_nan=True, atol=1e-6) - for i in range(len(datatransformations)) - ), "Inverse transform is not equal between transformed data_transformer and original features" - - -def test_serialization(data_transformer, features_multi, tmp_path): - - fn_multi = f"{tmp_path}/data_transformer.json" - - data_transformer.fit(features_multi) - data_transformer.save_json(fn_multi) - new_datatransformer = DataTransformer.from_json(fn_multi) - - assert all( - np.allclose(getattr(data_transformer, attr), getattr(new_datatransformer, attr), equal_nan=True) - for attr in get_class_attributes(data_transformer) - ) - - -def test_retro_compatibility(features_multi): - - standardizer = Standardizer() - standardizer.fit(features_multi) - dict_stand = standardizer.to_dict() - data_transformer = DataTransformer.from_dict(dict_stand) - - assert all( - [np.allclose(getattr(data_transformer.parameters[0][0], attr)[var].values, getattr(standardizer, attr)[var].values, equal_nan=True) - for var in getattr(standardizer, attr).data_vars] if type(getattr(standardizer, attr))==xr.Dataset - else np.allclose(getattr(data_transformer.parameters[0][0], attr), getattr(standardizer, attr)) - for attr in get_class_attributes(standardizer) - ) \ No newline at end of file diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 1ae0e05..e60d91a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -7,7 +7,7 @@ from mlpp_lib.datasets import Dataset, DataModule from mlpp_lib.model_selection import DataSplitter -from mlpp_lib.standardizers import DataTransformer +from mlpp_lib.normalizers import DataTransformer from .test_model_selection import ValidDataSplitterOptions ZARR_MISSING = "zarr" not in xr.backends.list_engines() diff --git a/tests/test_normalizers.py b/tests/test_normalizers.py new file mode 100644 index 0000000..8120598 --- /dev/null +++ b/tests/test_normalizers.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest +import xarray as xr + +from mlpp_lib.normalizers import DataTransformer + + +def get_class_attributes(cls): + class_attrs = { + name: field.default for name, field in cls.__dataclass_fields__.items() + } + return class_attrs + + +def test_fit(datatransformations, data_transformer, features_multi): + + for i, datatransform in enumerate(datatransformations): + datatransform.fit(features_multi, variables=[f"var{i}"]) + + data_transformer.fit(features_multi) + + assert all( + ( + np.allclose( + getattr(datatransform, attr), + getattr(data_transformer.parameters[i][0], attr), + equal_nan=True, + ) + for attr in get_class_attributes(datatransform) + ) + for datatransform in datatransformations + ) + + +def test_transform(datatransformations, data_transformer, features_multi): + + features_individual = features_multi.copy() + for i, datatransform in enumerate(datatransformations): + datatransform.fit(features_multi, variables=[f"var{i}"]) + features_individual = datatransform.transform( + features_individual, variables=[f"var{i}"] + )[0] + + data_transformer.fit(features_multi) + features_multi = data_transformer.transform(features_multi)[0] + + assert all( + np.allclose( + features_individual[f"var{i}"].values, + features_multi[f"var{i}"].values, + equal_nan=True, + ) + for i in range(len(datatransformations)) + ) + + +def test_inverse_transform(datatransformations, data_transformer, features_multi): + + original_data = features_multi.copy().astype("float32") + features_individual = features_multi.copy() + for i, datatransform in enumerate(datatransformations): + datatransform.fit(features_multi, variables=[f"var{i}"]) + features_individual = datatransform.transform( + features_individual, variables=[f"var{i}"] + )[0] + inv_ds_individual = features_individual.copy() + for i, datatransform in enumerate(datatransformations): + inv_ds_individual = datatransform.inverse_transform( + inv_ds_individual, variables=[f"var{i}"] + )[0] + + data_transformer.fit(features_multi) + ds_multi = data_transformer.transform(features_multi)[0] + inv_ds_multi = data_transformer.inverse_transform(ds_multi)[0] + + assert all( + np.allclose( + inv_ds_individual[f"var{i}"].values, + inv_ds_multi[f"var{i}"].values, + equal_nan=True, + ) + for i in range(len(datatransformations)) + ), "Inverse transform is not equal between individual data transformations and data_transformer" + + assert all( + np.allclose( + original_data[f"var{i}"].values, + inv_ds_individual[f"var{i}"].values, + equal_nan=True, + atol=1e-6, + ) + for i in range(len(datatransformations)) + ), "Inverse transform is not equal between transformed individual data transformations and original features" + + assert all( + np.allclose( + original_data[f"var{i}"].values, + inv_ds_multi[f"var{i}"].values, + equal_nan=True, + atol=1e-6, + ) + for i in range(len(datatransformations)) + ), "Inverse transform is not equal between transformed data_transformer and original features" + + +def test_serialization(data_transformer, features_multi, tmp_path): + + fn_multi = f"{tmp_path}/data_transformer.json" + + data_transformer.fit(features_multi) + data_transformer.save_json(fn_multi) + new_datatransformer = DataTransformer.from_json(fn_multi) + + assert all( + np.allclose( + getattr(data_transformer, attr), + getattr(new_datatransformer, attr), + equal_nan=True, + ) + for attr in get_class_attributes(data_transformer) + ) + + +class TestLegacyStandardizer: + @pytest.fixture + def standardizer(self): + from mlpp_lib.standardizers import Standardizer + + return Standardizer(fillvalue=-5) + + def test_fit(self, standardizer, features_dataset): + standardizer.fit(features_dataset) + assert all( + var in standardizer.mean.data_vars for var in features_dataset.data_vars + ) + assert all( + var in standardizer.std.data_vars for var in features_dataset.data_vars + ) + assert standardizer.fillvalue == -5 + + def test_transform(self, standardizer, features_dataset): + standardizer.fit(features_dataset) + ds = standardizer.transform(features_dataset)[0] + assert all(var in ds.data_vars for var in features_dataset.data_vars) + assert all(np.isclose(ds[var].mean().values, 0) for var in ds.data_vars) + assert all(np.isclose(ds[var].std().values, 1) for var in ds.data_vars) + + def test_inverse_transform(self, standardizer, features_dataset): + standardizer.fit(features_dataset) + ds = standardizer.transform(features_dataset)[0] + inv_ds = standardizer.inverse_transform(ds)[0] + + assert all( + np.allclose( + inv_ds[var].values, + features_dataset[var].values, + equal_nan=True, + atol=1e-6, + ) + for var in features_dataset.data_vars + ) + assert all(var in inv_ds.data_vars for var in features_dataset.data_vars) + + def test_retro_compatibility(self, standardizer, features_multi): + standardizer.fit(features_multi) + dict_stand = standardizer.to_dict() + data_transformer = DataTransformer.from_dict(dict_stand) + + assert all( + [ + np.allclose( + getattr(data_transformer.parameters[0][0], attr)[var].values, + getattr(standardizer, attr)[var].values, + equal_nan=True, + ) + for var in getattr(standardizer, attr).data_vars + ] + if isinstance(getattr(standardizer, attr), xr.Dataset) + else np.allclose( + getattr(data_transformer.parameters[0][0], attr), + getattr(standardizer, attr), + ) + for attr in get_class_attributes(standardizer) + ) diff --git a/tests/test_standardizers.py b/tests/test_standardizers.py deleted file mode 100644 index 4789246..0000000 --- a/tests/test_standardizers.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np - -from mlpp_lib.standardizers import Standardizer - - -def test_fit(features_dataset): - standardizer = Standardizer(fillvalue=-5) - standardizer.fit(features_dataset) - assert all(var in standardizer.mean.data_vars for var in features_dataset.data_vars) - assert all(var in standardizer.std.data_vars for var in features_dataset.data_vars) - assert standardizer.fillvalue == -5 - - -def test_transform(features_dataset): - standardizer = Standardizer(fillvalue=-5) - standardizer.fit(features_dataset) - ds = standardizer.transform(features_dataset)[0] - assert all(var in ds.data_vars for var in features_dataset.data_vars) - assert all(np.isclose(ds[var].mean().values, 0) for var in ds.data_vars) - assert all(np.isclose(ds[var].std().values, 1) for var in ds.data_vars) - - -def test_inverse_transform(features_dataset): - standardizer = Standardizer(fillvalue=-5) - standardizer.fit(features_dataset) - ds = standardizer.transform(features_dataset)[0] - inv_ds = standardizer.inverse_transform(ds)[0] - - assert all( - np.allclose(inv_ds[var].values, features_dataset[var].values, equal_nan=True, atol=1e-6) - for var in features_dataset.data_vars - ) - assert all(var in inv_ds.data_vars for var in features_dataset.data_vars) diff --git a/tests/test_train.py b/tests/test_train.py index 7b7e7b7..7b8e26c 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,7 +7,7 @@ import xarray as xr from mlpp_lib import train -from mlpp_lib.standardizers import DataTransformer +from mlpp_lib.normalizers import DataTransformer from mlpp_lib.datasets import DataModule, DataSplitter from .test_model_selection import ValidDataSplitterOptions