Skip to content

Commit

Permalink
Merge pull request #53 from MeteoSwiss/fix-model-ouput
Browse files Browse the repository at this point in the history
Fix model output
  • Loading branch information
dnerini authored Aug 21, 2024
2 parents aabc7ae + ad5b13a commit 03b589d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
29 changes: 28 additions & 1 deletion mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,41 @@
)


# these work out of the box
# these almost work out of the box
from tensorflow_probability.python.layers import (
IndependentNormal,
IndependentLogistic,
IndependentBernoulli,
IndependentPoisson,
)

@tf.keras.saving.register_keras_serializable()
class IndependentNormal(IndependentNormal):
@property
def output(self): # this is necessary to use the layer within shap
return super().output[0]


@tf.keras.saving.register_keras_serializable()
class IndependentLogistic(IndependentLogistic):
@property
def output(self):
return super().output[0]


@tf.keras.saving.register_keras_serializable()
class IndependentBernoulli(IndependentBernoulli):
@property
def output(self):
return super().output[0]


@tf.keras.saving.register_keras_serializable()
class IndependentPoisson(IndependentPoisson):
@property
def output(self):
return super().output[0]


@tf.keras.saving.register_keras_serializable()
class IndependentBeta(tfpl.DistributionLambda):
Expand Down
27 changes: 23 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
dropout=[None, 0.1, [0.1, 0.0]],
mc_dropout=[True, False],
out_bias_init=["zeros", np.array([0.2]), np.array([0.2, 2.1])],
probabilistic_layer=[None, "IndependentNormal", "MultivariateNormalTriL"],
probabilistic_layer=[None] + ["IndependentNormal", "IndependentGamma"],
skip_connection=[False, True],
)

Expand All @@ -34,6 +34,25 @@
]


def _test_model(model):
moodel_is_keras = (
str(type(model)).endswith("keras.engine.sequential.Sequential'>")
or str(type(model)).endswith("keras.models.Sequential'>")
or str(type(model)).endswith("keras.engine.training.Model'>")
or isinstance(model, tf.keras.Model)
)
assert moodel_is_keras
assert isinstance(model, Functional)
assert len(model.layers[-1]._inbound_nodes) > 0
model_output = model.layers[-1].output
assert not isinstance(
model_output, list
), "The model output must be a single tensor!"
assert (
len(model_output.shape) < 3
), "The model output must be a vector or a single value!"


def _test_prediction(model, scenario_kwargs, dummy_input, output_size):
pred = model(dummy_input)
assert pred.shape == (32, output_size)
Expand Down Expand Up @@ -76,8 +95,8 @@ def test_fully_connected_network(scenario_kwargs):
model = models.fully_connected_network(
input_shape, output_size, **scenario_kwargs
)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)


Expand Down Expand Up @@ -108,8 +127,8 @@ def test_fully_connected_multibranch_network(scenario_kwargs):
model = models.fully_connected_multibranch_network(
input_shape, output_size, **scenario_kwargs
)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)


Expand All @@ -133,6 +152,6 @@ def test_deep_cross_network(scenario_kwargs):

else:
model = models.deep_cross_network(input_shape, output_size, **scenario_kwargs)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)
7 changes: 7 additions & 0 deletions tests/test_probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def test_probabilistic_model(layer):
encoder.summary()
encoder.compile()
assert isinstance(encoder, tf.keras.Sequential)
model_output = encoder.layers[-1].output
assert not isinstance(
model_output, list
), "The model output must be a single tensor!"
assert (
len(model_output.shape) < 3
), "The model output must be a vector or a single value!"


@pytest.mark.parametrize("layer", LAYERS)
Expand Down

0 comments on commit 03b589d

Please sign in to comment.