Skip to content

Commit

Permalink
Fix formatting (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnerini authored and Louis Poulain--Auzéau committed Nov 5, 2024
1 parent 139f860 commit f3663aa
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 56 deletions.
1 change: 1 addition & 0 deletions mlpp_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__ = "0.1.0"

import os

os.environ["TF_USE_LEGACY_KERAS"] = "1"
4 changes: 3 additions & 1 deletion mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def drop_nans(self, group_size: int = 1):
event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"]
mask = da.any(~da.isfinite(da.from_array(x, name="x")), axis=event_axes)
if y is not None:
mask = mask | da.any(~da.isfinite(da.from_array(y, name="y")), axis=event_axes)
mask = mask | da.any(
~da.isfinite(da.from_array(y, name="y")), axis=event_axes
)
mask = (~mask).compute()

# with grouped samples, nans have to be removed in blocks:
Expand Down
4 changes: 1 addition & 3 deletions mlpp_lib/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,7 @@ class MultivariateLoss(tf.keras.losses.Loss):
"""

def mse_metric(y_true, y_pred):
return tf.reduce_mean(
tf.square(y_true - y_pred), axis=0
)
return tf.reduce_mean(tf.square(y_true - y_pred), axis=0)

def mae_metric(y_true, y_pred):
return tf.reduce_mean(tf.abs(y_true - y_pred), axis=0)
Expand Down
4 changes: 3 additions & 1 deletion mlpp_lib/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def f(ds: xr.Dataset) -> xr.Dataset:
ds = ds.fillna(self.fillvalue)
else:
if ds.isnull().any():
raise ValueError("Missing values found in the data. Please provide a fill value.")
raise ValueError(
"Missing values found in the data. Please provide a fill value."
)
return ds.astype("float32")

return tuple(f(ds) for ds in datasets)
Expand Down
48 changes: 12 additions & 36 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def new_from_t(t):
return IndependentBeta.new(t, event_shape, validate_args)

super(IndependentBeta, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -200,9 +198,7 @@ def new_from_t(t):
return Independent4ParamsBeta.new(t, event_shape, validate_args)

super(Independent4ParamsBeta, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -313,14 +309,10 @@ def __init__(
kwargs.pop("make_distribution_fn", None)

def new_from_t(t):
return IndependentDoublyCensoredNormal.new(
t, event_shape, validate_args
)
return IndependentDoublyCensoredNormal.new(t, event_shape, validate_args)

super(IndependentDoublyCensoredNormal, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -505,9 +497,7 @@ def new_from_t(t):
return IndependentConcaveBeta.new(t, event_shape, validate_args)

super(IndependentConcaveBeta, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -621,9 +611,7 @@ def new_from_t(t):
return IndependentGamma.new(t, event_shape, validate_args)

super(IndependentGamma, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -734,9 +722,7 @@ def new_from_t(t):
return IndependentLogNormal.new(t, event_shape, validate_args)

super(IndependentLogNormal, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -845,9 +831,7 @@ def new_from_t(t):
return IndependentLogitNormal.new(t, event_shape, validate_args)

super(IndependentLogitNormal, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -959,9 +943,7 @@ def new_from_t(t):
return IndependentMixtureNormal.new(t, event_shape, validate_args)

super(IndependentMixtureNormal, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -1135,9 +1117,7 @@ def new_from_t(t):
return IndependentTruncatedNormal.new(t, event_shape, validate_args)

super(IndependentTruncatedNormal, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -1248,9 +1228,7 @@ def new_from_t(t):
return IndependentWeibull.new(t, event_shape, validate_args)

super(IndependentWeibull, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_shape = event_shape
Expand Down Expand Up @@ -1365,9 +1343,7 @@ def new_from_t(t):
return MultivariateNormalDiag.new(t, event_size, validate_args)

super(MultivariateNormalDiag, self).__init__(
new_from_t,
convert_to_tensor_fn,
**kwargs
new_from_t, convert_to_tensor_fn, **kwargs
)

self._event_size = event_size
Expand Down
2 changes: 2 additions & 0 deletions mlpp_lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def get_log_params(param_run: dict) -> dict:

def get_lr(optimizer: tf.keras.optimizers.Optimizer) -> float:
"""Get the learning rate of the optimizer"""

def lr(y_true, y_pred):
return optimizer.lr

return lr


Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import os

os.environ["TF_USE_LEGACY_KERAS"] = "1"
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def datatransformations() -> list:
import mlpp_lib.normalizers as no

datatransformations = [
no.create_transformation_from_str(n.name, inputs={"fillvalue": -5} if n.name == "Identity" else {}) # temporary fix, do we want to let the user define different fillvalue for each transformation ?
no.create_transformation_from_str(
n.name, inputs={"fillvalue": -5} if n.name == "Identity" else {}
) # temporary fix, do we want to let the user define different fillvalue for each transformation ?
for n in no.DataTransformation.__subclasses__()
]

Expand All @@ -96,7 +98,11 @@ def data_transformer() -> xr.Dataset:
for i, transformation in enumerate(transformations_list)
}
data_transformer = no.DataTransformer(method_var_dict)
data_transformer.transformers['Identity'][0].fillvalue = -5 # temporary fix, do we want to let the user define different fillvalue for each transformation ?
data_transformer.transformers["Identity"][
0
].fillvalue = (
-5
) # temporary fix, do we want to let the user define different fillvalue for each transformation ?

return data_transformer

Expand Down
28 changes: 15 additions & 13 deletions tests/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,22 @@ def test_retro_compatibility(self, standardizer, features_multi):
data_transformer = DataTransformer.from_dict(dict_stand)

assert all(
[
np.allclose(
getattr(data_transformer.transformers["Standardizer"][0], attr)[
var
].values,
getattr(standardizer, attr)[var].values,
equal_nan=True,
(
[
np.allclose(
getattr(data_transformer.transformers["Standardizer"][0], attr)[
var
].values,
getattr(standardizer, attr)[var].values,
equal_nan=True,
)
for var in getattr(standardizer, attr).data_vars
]
if isinstance(getattr(standardizer, attr), xr.Dataset)
else np.allclose(
getattr(data_transformer.transformers["Standardizer"][0], attr),
getattr(standardizer, attr),
)
for var in getattr(standardizer, attr).data_vars
]
if isinstance(getattr(standardizer, attr), xr.Dataset)
else np.allclose(
getattr(data_transformer.transformers["Standardizer"][0], attr),
getattr(standardizer, attr),
)
for attr in get_class_attributes(standardizer)
)

0 comments on commit f3663aa

Please sign in to comment.