Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use Monte Carlo dropout #37

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions mlpp_lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,28 @@
TCN_IMPORTED = True


class MonteCarloDropout(Dropout):
def call(self, inputs):
return super().call(inputs, training=True)


def _build_fcn_block(
inputs, hidden_layers, activations, dropout, skip_connection, idx=0
inputs,
hidden_layers,
activations,
dropout,
mc_dropout,
skip_connection,
idx=0,
):
x = inputs
for i, units in enumerate(hidden_layers):
x = Dense(units, activation=activations[i], name=f"dense_{idx}_{i}")(x)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
x = Dropout(dropout[i], name=f"dropout_{idx}_{i}")(x)
if mc_dropout:
x = MonteCarloDropout(dropout[i], name=f"mc_dropout_{idx}_{i}")(x)
else:
x = Dropout(dropout[i], name=f"dropout_{idx}_{i}")(x)

if skip_connection:
x = Dense(inputs.shape[1], name=f"skip_dense_{idx}")(x)
Expand Down Expand Up @@ -70,6 +84,7 @@ def fully_connected_network(
hidden_layers: list,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
out_bias_init: Optional[Union[str, np.ndarray[Any, float]]] = "zeros",
probabilistic_layer: Optional[str] = None,
skip_connection: bool = False,
Expand All @@ -93,6 +108,9 @@ def fully_connected_network(
(Optional) Dropout rate for the optional dropout layers. If a `float` is passed,
dropout layers with the given rate are created after each Dense layer, except before the output layer.
Default is None.
mc_dropout: bool
Enable Monte Carlo dropout during inference. It has no effect during training.
It has no effect if `dropout=None`. Default is false.
out_bias_init: str or np.ndarray
(Optional) Specifies the initialization of the output layer bias. If a string is passed,
it must be a valid Keras built-in initializer (see https://keras.io/api/layers/initializers/).
Expand Down Expand Up @@ -129,7 +147,14 @@ def fully_connected_network(
)

inputs = tf.keras.Input(shape=input_shape)
x = _build_fcn_block(inputs, hidden_layers, activations, dropout, skip_connection)
x = _build_fcn_block(
inputs,
hidden_layers,
activations,
dropout,
mc_dropout,
skip_connection,
)
outputs = _build_fcn_output(x, output_size, probabilistic_layer, out_bias_init)
model = Model(inputs=inputs, outputs=outputs)

Expand All @@ -142,6 +167,7 @@ def fully_connected_multibranch_network(
hidden_layers: list,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
out_bias_init: Optional[Union[str, np.ndarray[Any, float]]] = "zeros",
probabilistic_layer: Optional[str] = None,
skip_connection: bool = False,
Expand All @@ -165,6 +191,9 @@ def fully_connected_multibranch_network(
(Optional) Dropout rate for the optional dropout layers. If a `float` is passed,
dropout layers with the given rate are created after each Dense layer, except before the output layer.
Default is None.
mc_dropout: bool
Enable Monte Carlo dropout during inference. It has no effect during training.
It has no effect if `dropout=None`. Default is false.
out_bias_init: str or np.ndarray
(Optional) Specifies the initialization of the output layer bias. If a string is passed,
it must be a valid Keras built-in initializer (see https://keras.io/api/layers/initializers/).
Expand Down Expand Up @@ -211,7 +240,13 @@ def fully_connected_multibranch_network(

for idx in range(n_branches):
x = _build_fcn_block(
inputs, hidden_layers, activations, dropout, skip_connection, idx
inputs,
hidden_layers,
activations,
dropout,
mc_dropout,
skip_connection,
idx,
)
all_branch_outputs.append(x)

Expand All @@ -230,6 +265,7 @@ def deep_cross_network(
hidden_layers: list,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
out_bias_init: Optional[Union[str, np.ndarray[Any, float]]] = "zeros",
probabilistic_layer: Optional[str] = None,
skip_connection: bool = False,
Expand All @@ -253,6 +289,9 @@ def deep_cross_network(
(Optional) Dropout rate for the optional dropout layers. If a `float` is passed,
dropout layers with the given rate are created after each Dense layer, except before the output layer.
Default is None.
mc_dropout: bool
Enable Monte Carlo dropout during inference. It has no effect during training.
It has no effect if `dropout=None`. Default is false.
out_bias_init: str or np.ndarray
(Optional) Specifies the initialization of the output layer bias. If a string is passed,
it must be a valid Keras built-in initializer (see https://keras.io/api/layers/initializers/).
Expand Down Expand Up @@ -304,7 +343,10 @@ def deep_cross_network(
deep = BatchNormalization()(deep)
deep = Activation(activations[i])(deep)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
deep = Dropout(dropout[i])(deep)
if mc_dropout:
deep = MonteCarloDropout(dropout[i])(deep)
else:
deep = Dropout(dropout[i])(deep)
# deep = tf.keras.Model(inputs=inputs, outputs=deep, name="deepblock")

# merge
Expand Down
23 changes: 20 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import tensorflow as tf
from keras.engine.functional import Functional
from numpy.testing import assert_array_equal

from mlpp_lib import models

Expand All @@ -14,6 +15,7 @@
hidden_layers=[[8, 8]],
activations=["relu", ["relu", "elu"]],
dropout=[None, 0.1, [0.1, 0.0]],
mc_dropout=[True, False],
out_bias_init=["zeros", np.array([0.2]), np.array([0.2, 2.1])],
probabilistic_layer=[None, "IndependentNormal", "MultivariateNormalTriL"],
skip_connection=[False, True],
Expand All @@ -32,6 +34,21 @@
]


def _test_prediction(model, scenario_kwargs, dummy_input, output_size):
pred = model(dummy_input)
assert pred.shape == (32, output_size)
pred2 = model(dummy_input)
if scenario_kwargs["probabilistic_layer"] is not None:
pred = pred.mean()
pred2 = pred2.mean()

if scenario_kwargs["dropout"] is not None and scenario_kwargs["mc_dropout"]:
with pytest.raises(AssertionError):
assert_array_equal(pred, pred2)
else:
assert_array_equal(pred, pred2)


@pytest.mark.parametrize("scenario_kwargs", FCN_SCENARIOS)
def test_fully_connected_network(scenario_kwargs):

Expand Down Expand Up @@ -61,7 +78,7 @@ def test_fully_connected_network(scenario_kwargs):
)
assert isinstance(model, Functional)

assert model(dummy_input).shape == (32, output_size)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)


@pytest.mark.parametrize("scenario_kwargs", FCN_SCENARIOS)
Expand Down Expand Up @@ -93,7 +110,7 @@ def test_fully_connected_multibranch_network(scenario_kwargs):
)
assert isinstance(model, Functional)

assert model(dummy_input).shape == (32, output_size)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)


@pytest.mark.parametrize("scenario_kwargs", DCN_SCENARIOS)
Expand All @@ -118,4 +135,4 @@ def test_deep_cross_network(scenario_kwargs):
model = models.deep_cross_network(input_shape, output_size, **scenario_kwargs)
assert isinstance(model, Functional)

assert model(dummy_input).shape == (32, output_size)
_test_prediction(model, scenario_kwargs, dummy_input, output_size)