Skip to content

Commit

Permalink
Move to another script to avoid circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
louisPoulain authored Nov 8, 2024
1 parent 85d0536 commit 6f0568f
Showing 1 changed file with 0 additions and 27 deletions.
27 changes: 0 additions & 27 deletions mlpp_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,6 @@ def get_callback(callback: Union[str, dict]) -> Callable:
return callback


def get_probabilistic_layer(
output_size,
probabilistic_layer: Union[str, dict]
) -> Callable:
"""Get the probabilistic layer."""

if isinstance(probabilistic_layer, dict):
probabilistic_layer_name = list(probabilistic_layer.keys())[0]
probabilistic_layer_options = probabilistic_layer[probabilistic_layer_name]
else:
probabilistic_layer_name = metric
probabilistic_layer_options = {}

if hasattr(probabilistic_layers, probabilistic_layer_name):
LOGGER.info(f"Using custom probabilistic layer: {probabiistic_layer_name}")
probabilistic_layer_obj = getattr(probabilistic_layers, probabilistic_layer_name)
probabilistic_layer = (
probabilistic_layer_obj(output_size, name="output", **probabilistic_layer_options) if isinstance(probabilistic_layer_obj, type)
else probabilistic_layer_obj(output_size, name="output")
)
else:
raise KeyError(f"The probabilistic layer {probabilistic_layer_name} is not available.")

return probabilistic_layer



def get_model(
input_shape: tuple[int],
output_shape: Union[int, tuple[int]],
Expand Down

0 comments on commit 6f0568f

Please sign in to comment.