Skip to content

Commit

Permalink
Merge pull request #50 from MeteoSwiss/feature/scheduler
Browse files Browse the repository at this point in the history
Feature/scheduler
  • Loading branch information
dnerini authored Jul 9, 2024
2 parents cb56843 + bbb9dda commit b31f878
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mlpp_lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def get_log_params(param_run: dict) -> dict:
return log_params


def get_lr(optimizer: tf.keras.optimizers.Optimizer) -> float:
"""Get the learning rate of the optimizer"""
def lr(y_true, y_pred):
return optimizer.lr
return lr


def train(
cfg: dict,
datamodule: DataModule,
Expand All @@ -61,6 +68,7 @@ def train(
loss = get_loss(loss_config)
metrics = [get_metric(metric) for metric in cfg.get("metrics", [])]
optimizer = get_optimizer(cfg.get("optimizer", "Adam"))
metrics.append(get_lr(optimizer))
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.summary(print_fn=LOGGER.info)

Expand Down
38 changes: 38 additions & 0 deletions mlpp_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,50 @@ def get_metric(metric: Union[str, dict]) -> Callable:
return metric


def get_scheduler(
scheduler_config: Union[dict, None]
) -> Optional[tf.keras.optimizers.schedules.LearningRateSchedule]:
"""Create a learning rate scheduler from a config dictionary."""

if not isinstance(scheduler_config, dict):
LOGGER.info("Not using a scheduler.")
return None

if len(scheduler_config) != 1:
raise ValueError(
"Scheduler configuration should contain exactly one scheduler name with its options."
)

scheduler_name = next(
iter(scheduler_config)
) # first key is the name of the scheduler
scheduler_options = scheduler_config[scheduler_name]

if not isinstance(scheduler_options, dict):
raise ValueError(
f"Scheduler options for '{scheduler_name}' should be a dictionary."
)

if hasattr(tf.keras.optimizers.schedules, scheduler_name):
LOGGER.info(f"Using keras built-in learning rate scheduler: {scheduler_name}")
scheduler_cls = getattr(tf.keras.optimizers.schedules, scheduler_name)
scheduler = scheduler_cls(**scheduler_options)
else:
raise KeyError(
f"The scheduler '{scheduler_name}' is not available in tf.keras.optimizers.schedules."
)

return scheduler


def get_optimizer(optimizer: Union[str, dict]) -> Callable:
"""Get the optimizer, keras built-in only."""

if isinstance(optimizer, dict):
optimizer_name = list(optimizer.keys())[0]
optimizer_options = optimizer[optimizer_name]
if scheduler := get_scheduler(optimizer_options.pop("learning_rate", None)):
optimizer_options["learning_rate"] = scheduler
else:
optimizer_name = optimizer
optimizer_options = {}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,34 @@
"optimizer": {"Adam": {"learning_rate": 0.1, "beta_1": 0.95}},
"metrics": ["bias", "mean_absolute_error", {"MAEBusts": {"threshold": 0.5}}],
},
# use a learning rate scheduler
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"model": {
"fully_connected_network": {
"hidden_layers": [10],
"probabilistic_layer": "IndependentNormal",
}
},
"loss": "crps_energy",
"optimizer": {
"Adam": {
"learning_rate": {
"CosineDecayRestarts": {
"initial_learning_rate": 0.001,
"first_decay_steps": 20,
"t_mul": 1.5,
"m_mul": 1.1,
"alpha": 0,
}
}
}
},
"callbacks": [
{"EarlyStopping": {"patience": 10, "restore_best_weights": True}}
],
},
#
{
"features": ["coe:x1"],
Expand Down

0 comments on commit b31f878

Please sign in to comment.