Skip to content

Commit

Permalink
gpc
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Nov 26, 2024
1 parent 57f9026 commit 039c16f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def datamodule(self) -> AnemoiDatasetsDataModule:
"""DataModule instance and DataSets."""
datamodule = AnemoiDatasetsDataModule(self.config)
self.config.data.num_features = len(datamodule.ds_train.data.variables)
LOGGER.info("Data has ", len(datamodule.ds_train.data.variables), " variables: ", datamodule.ds_train.data.variables)
LOGGER.info(
"Data has ",
len(datamodule.ds_train.data.variables),
" variables: ",
datamodule.ds_train.data.variables,
)
return datamodule

@cached_property
Expand Down
7 changes: 4 additions & 3 deletions src/anemoi/training/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path:

for key in state_dict.copy():
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format(
key, state_dict[key].shape, model_state_dict[key].shape
)
LOGGER.debug(
"Skipping loading parameter: ", key,
", checkpoint shape: ", state_dict[key].shape,
", model shape: ", model_state_dict[key].shape,
)
del state_dict[key] # Remove the mismatched key

Expand Down

0 comments on commit 039c16f

Please sign in to comment.