Skip to content

Commit

Permalink
Small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dnerini committed Aug 21, 2024
1 parent dd735a6 commit caf4365
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
9 changes: 2 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from mlpp_lib import models, probabilistic_layers


PROB_LAYERS = [obj[0] for obj in getmembers(probabilistic_layers, isclass)]


FCN_OPTIONS = dict(
input_shape=[(5,)],
output_size=[1, 2],
Expand All @@ -21,7 +18,7 @@
dropout=[None, 0.1, [0.1, 0.0]],
mc_dropout=[True, False],
out_bias_init=["zeros", np.array([0.2]), np.array([0.2, 2.1])],
probabilistic_layer=[None] + PROB_LAYERS,
probabilistic_layer=[None] + ["IndependentNormal", "IndependentGamma"],
skip_connection=[False, True],
)

Expand All @@ -46,6 +43,7 @@ def _test_model(model):
or isinstance(model, tf.keras.Model)
)
assert moodel_is_keras
assert isinstance(model, Functional)
assert len(model.layers[-1]._inbound_nodes) > 0
model_output = model.layers[-1].output
assert not isinstance(
Expand Down Expand Up @@ -98,7 +96,6 @@ def test_fully_connected_network(scenario_kwargs):
model = models.fully_connected_network(
input_shape, output_size, **scenario_kwargs
)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)
Expand Down Expand Up @@ -131,7 +128,6 @@ def test_fully_connected_multibranch_network(scenario_kwargs):
model = models.fully_connected_multibranch_network(
input_shape, output_size, **scenario_kwargs
)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)
Expand All @@ -157,7 +153,6 @@ def test_deep_cross_network(scenario_kwargs):

else:
model = models.deep_cross_network(input_shape, output_size, **scenario_kwargs)
assert isinstance(model, Functional)

_test_model(model)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)
7 changes: 7 additions & 0 deletions tests/test_probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def test_probabilistic_model(layer):
encoder.summary()
encoder.compile()
assert isinstance(encoder, tf.keras.Sequential)
model_output = encoder.layers[-1].output
assert not isinstance(
model_output, list
), "The model output must be a single tensor!"
assert (
len(model_output.shape) < 3
), "The model output must be a vector or a single value!"


@pytest.mark.parametrize("layer", LAYERS)
Expand Down

0 comments on commit caf4365

Please sign in to comment.