Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use batch normalization #38

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions mlpp_lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def call(self, inputs):
def _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -40,7 +41,10 @@ def _build_fcn_block(
):
x = inputs
for i, units in enumerate(hidden_layers):
x = Dense(units, activation=activations[i], name=f"dense_{idx}_{i}")(x)
x = Dense(units, name=f"dense_{idx}_{i}")(x)
if batchnorm:
x = BatchNormalization()(x)
x = Activation(activations[i])(x)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
if mc_dropout:
x = MonteCarloDropout(dropout[i], name=f"mc_dropout_{idx}_{i}")(x)
Expand Down Expand Up @@ -82,6 +86,7 @@ def fully_connected_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = False,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -101,6 +106,8 @@ def fully_connected_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is False.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -150,6 +157,7 @@ def fully_connected_network(
x = _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -165,6 +173,7 @@ def fully_connected_multibranch_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = False,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -184,6 +193,8 @@ def fully_connected_multibranch_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is False.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -242,6 +253,7 @@ def fully_connected_multibranch_network(
x = _build_fcn_block(
inputs,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
Expand All @@ -263,6 +275,7 @@ def deep_cross_network(
input_shape: tuple[int],
output_size: int,
hidden_layers: list,
batchnorm: bool = True,
activations: Optional[Union[str, list[str]]] = "relu",
dropout: Optional[Union[float, list[float]]] = None,
mc_dropout: bool = False,
Expand All @@ -282,6 +295,8 @@ def deep_cross_network(
hidden_layers: list[int]
List that is used to define the fully connected block. Each element creates
a Dense layer with the corresponding units.
batchnorm: bool
Use batch normalization. Default is True.
activations: str or list[str]
(Optional) Activation function(s) for the Dense layer(s). See https://keras.io/api/layers/activations/#relu-function.
If a string is passed, the same activation is used for all layers. Default is `relu`.
Expand Down Expand Up @@ -338,16 +353,15 @@ def deep_cross_network(

# deep part
deep = inputs
for i, u in enumerate(hidden_layers):
deep = Dense(u)(deep)
deep = BatchNormalization()(deep)
deep = Activation(activations[i])(deep)
if i < len(dropout) and 0.0 < dropout[i] < 1.0:
if mc_dropout:
deep = MonteCarloDropout(dropout[i])(deep)
else:
deep = Dropout(dropout[i])(deep)
# deep = tf.keras.Model(inputs=inputs, outputs=deep, name="deepblock")
deep = _build_fcn_block(
deep,
hidden_layers,
batchnorm,
activations,
dropout,
mc_dropout,
skip_connection=False,
)

# merge
merge = tf.keras.layers.Concatenate()([cross, deep])
Expand Down