Skip to content

Commit

Permalink
Add option to use batch normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
dnerini committed Nov 1, 2023
1 parent 3399e66 commit be735bf
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions mlpp_lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def call(self, inputs):
def _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -40,7 +41,10 @@ def _build_fcn_block(
):
x = inputs
for i, units in enumerate(hidden_layers):
x = Dense(units, activation=activations[i], name=f"dense_{idx}_{i}")(x)
x = Dense(units, name=f"dense_{idx}_{i}")(x)
if batchnorm:
x = BatchNormalization()(x)
x = Activation(activations[i])(x)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
if mc_dropout:
x = MonteCarloDropout(dropout[i], name=f"mc_dropout_{idx}_{i}")(x)
Expand Down Expand Up @@ -82,6 +86,7 @@ def fully_connected_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = False,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -101,6 +106,8 @@ def fully_connected_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is False.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -150,6 +157,7 @@ def fully_connected_network(
x = _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -165,6 +173,7 @@ def fully_connected_multibranch_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = False,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -184,6 +193,8 @@ def fully_connected_multibranch_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is False.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -242,6 +253,7 @@ def fully_connected_multibranch_network(
x = _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -263,6 +275,7 @@ def deep_cross_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = True,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -282,6 +295,8 @@ def deep_cross_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is True.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -338,16 +353,15 @@ def deep_cross_network(

# deep part
deep = inputs
for i, u in enumerate(hidden_layers):
deep = Dense(u)(deep)
deep = BatchNormalization()(deep)
deep = Activation(activations[i])(deep)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
if mc_dropout:
deep = MonteCarloDropout(dropout[i])(deep)
else:
deep = Dropout(dropout[i])(deep)
# deep = tf.keras.Model(inputs=inputs, outputs=deep, name="deepblock")
deep = _build_fcn_block(
deep,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
skip_connection=False,
)

# merge
merge = tf.keras.layers.Concatenate()([cross, deep])
Expand Down

0 comments on commit be735bf

Please sign in to comment.