From be735bf7447ccdbf44d5e340a37aa17f693049d2 Mon Sep 17 00:00:00 2001 From: ned Date: Wed, 1 Nov 2023 13:46:41 +0100 Subject: [PATCH] Add option to use batch normalization --- mlpp_lib/models.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/mlpp_lib/models.py b/mlpp_lib/models.py index 5dd09e9..7bb3142 100644 --- a/mlpp_lib/models.py +++ b/mlpp_lib/models.py @@ -32,6 +32,7 @@ def call(self, inputs): def _build_fcn_block( inputs, hidden_layers, + batchnorm, activations, dropout, mc_dropout, @@ -40,7 +41,10 @@ def _build_fcn_block( ): x = inputs for i, units in enumerate(hidden_layers): - x = Dense(units, activation=activations[i], name=f"dense_{idx}_{i}")(x) + x = Dense(units, name=f"dense_{idx}_{i}")(x) + if batchnorm: + x = BatchNormalization()(x) + x = Activation(activations[i])(x) if i < len(dropout) and 0.0 < dropout[i] < 1.0: if mc_dropout: x = MonteCarloDropout(dropout[i], name=f"mc_dropout_{idx}_{i}")(x) @@ -82,6 +86,7 @@ def fully_connected_network( input_shape: tuple[int], output_size: int, hidden_layers: list, + batchnorm: bool = False, activations: Optional[Union[str, list[str]]] = "relu", dropout: Optional[Union[float, list[float]]] = None, mc_dropout: bool = False, @@ -101,6 +106,8 @@ def fully_connected_network( hidden_layers: list[int] List that is used to define the fully connected block. Each element creates a Dense layer with the corresponding units. + batchnorm: bool + Use batch normalization. Default is False. activations: str or list[str] (Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function. If a string is passed, the same activation is used for all layers. Default is `relu`. @@ -150,6 +157,7 @@ def fully_connected_network( x = _build_fcn_block( inputs, hidden_layers, + batchnorm, activations, dropout, mc_dropout, @@ -165,6 +173,7 @@ def fully_connected_multibranch_network( input_shape: tuple[int], output_size: int, hidden_layers: list, + batchnorm: bool = False, activations: Optional[Union[str, list[str]]] = "relu", dropout: Optional[Union[float, list[float]]] = None, mc_dropout: bool = False, @@ -184,6 +193,8 @@ def fully_connected_multibranch_network( hidden_layers: list[int] List that is used to define the fully connected block. Each element creates a Dense layer with the corresponding units. + batchnorm: bool + Use batch normalization. Default is False. activations: str or list[str] (Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function. If a string is passed, the same activation is used for all layers. Default is `relu`. @@ -242,6 +253,7 @@ def fully_connected_multibranch_network( x = _build_fcn_block( inputs, hidden_layers, + batchnorm, activations, dropout, mc_dropout, @@ -263,6 +275,7 @@ def deep_cross_network( input_shape: tuple[int], output_size: int, hidden_layers: list, + batchnorm: bool = True, activations: Optional[Union[str, list[str]]] = "relu", dropout: Optional[Union[float, list[float]]] = None, mc_dropout: bool = False, @@ -282,6 +295,8 @@ def deep_cross_network( hidden_layers: list[int] List that is used to define the fully connected block. Each element creates a Dense layer with the corresponding units. + batchnorm: bool + Use batch normalization. Default is True. activations: str or list[str] (Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function. If a string is passed, the same activation is used for all layers. Default is `relu`. @@ -338,16 +353,15 @@ def deep_cross_network( # deep part deep = inputs - for i, u in enumerate(hidden_layers): - deep = Dense(u)(deep) - deep = BatchNormalization()(deep) - deep = Activation(activations[i])(deep) - if i < len(dropout) and 0.0 < dropout[i] < 1.0: - if mc_dropout: - deep = MonteCarloDropout(dropout[i])(deep) - else: - deep = Dropout(dropout[i])(deep) - # deep = tf.keras.Model(inputs=inputs, outputs=deep, name="deepblock") + deep = _build_fcn_block( + deep, + hidden_layers, + batchnorm, + activations, + dropout, + mc_dropout, + skip_connection=False, + ) # merge merge = tf.keras.layers.Concatenate()([cross, deep])