Skip to content

Commit

Permalink
Added sanification of checkpoint, effective batch size, git pre commit
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Nov 26, 2024
1 parent db2a14f commit 57f9026
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
run_id: null
fork_run_id: null
load_weights_only: null # only load model weights, do not restore optimiser states etc.
transfer_learning: null # activate to perform transfer learning

# run in deterministic mode ; slows down
deterministic: False
Expand Down
21 changes: 11 additions & 10 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(self, config: DictConfig) -> None:
self.config.dataloader.validation.start - 1,
)
self.config.dataloader.training.end = self.config.dataloader.validation.start - 1


if not self.config.dataloader.get("pin_memory", True):
LOGGER.info("Data loader memory pinning disabled.")
Expand Down Expand Up @@ -177,15 +176,17 @@ def _get_dataset(
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:

r = max(rollout, self.rollout)

# Compute effective batch size
effective_bs = self.config.dataloader.batch_size['training'] *\
self.config.hardware.num_gpus_per_node *\
self.config.hardware.num_nodes //\
self.config.hardware.num_gpus_per_model

# Compute effective batch size
effective_bs = (
self.config.dataloader.batch_size["training"]
* self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
)

data = NativeGridDataset(
data_reader=data_reader,
rollout=r,
Expand All @@ -196,9 +197,9 @@ def _get_dataset(
model_comm_num_groups=self.model_comm_num_groups,
shuffle=shuffle,
label=label,
effective_bs=effective_bs
effective_bs=effective_bs,
)
# self._check_resolution(data.resolution)
self._check_resolution(data.resolution)
return data

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import random
from functools import cached_property
from typing import Callable
from omegaconf import DictConfig

import numpy as np
import torch
Expand Down Expand Up @@ -42,7 +41,7 @@ def __init__(
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1
effective_bs: int = 1,
) -> None:
"""Initialize (part of) the dataset state.
Expand All @@ -66,7 +65,8 @@ def __init__(
Shuffle batches, by default True
label : str, optional
label for the dataset, by default "generic"
effective_bs : int, default 1
effective batch size useful to compute the lenght of the dataset
"""
self.label = label
self.effective_bs = effective_bs
Expand Down Expand Up @@ -250,7 +250,7 @@ def __repr__(self) -> str:
Multistep: {self.multi_step}
Timeincrement: {self.timeincrement}
"""

def __len__(self) -> int:
"""Estimate the total number of samples based on valid indices."""
return len(self.valid_date_indices) // self.effective_bs
Expand Down
21 changes: 12 additions & 9 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ def __init__(
self.loss.register_full_backward_hook(grad_scaler, prepend=False)

self.multi_step = config.training.multistep_input
self.lr = config.hardware.num_nodes * config.hardware.num_gpus_per_node * config.training.lr.rate / config.hardware.num_gpus_per_model

self.lr = (
config.hardware.num_nodes
* config.hardware.num_gpus_per_node
* config.training.lr.rate
/ config.hardware.num_gpus_per_model
)

self.lr_iterations = config.training.lr.iterations
self.lr_min = config.training.lr.min
self.rollout = config.training.rollout.start
Expand Down Expand Up @@ -376,8 +381,7 @@ def rollout_step(

y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full]
# y includes the auxiliary variables, so we must leave those out when computing the loss
tmp_loss = checkpoint(self.loss, y_pred, y, use_reentrant=False)
loss += tmp_loss
loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) if training_mode else None

x = self.advance_input(x, y_pred, batch, rollout_step)

Expand Down Expand Up @@ -437,11 +441,10 @@ def calculate_val_metrics(
validation metrics and predictions
"""
metrics = {}
y_preds = []


# Added to impute nans
nan_locations = torch.isnan(y[..., self.data_indices.internal_data.output.full])
self.model.post_processors.processors['imputer'].set_nan_locations(nan_locations)
self.model.post_processors.processors["imputer"].set_nan_locations(nan_locations)

y_postprocessed = self.model.post_processors(y, in_place=False)
y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False)
Expand Down Expand Up @@ -514,7 +517,7 @@ def on_train_epoch_end(self) -> None:
self.rollout = min(self.rollout, self.rollout_max)

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:

with torch.no_grad():
val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True)

Expand All @@ -529,7 +532,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
sync_dist=True,
)

for i, (mname, mvalue) in enumerate(metrics.items()):
for mname, mvalue in metrics.items():
self.log(
"val_" + mname,
mvalue,
Expand Down
41 changes: 28 additions & 13 deletions src/anemoi/training/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from anemoi.training.diagnostics.logger import get_wandb_logger
from anemoi.training.distributed.strategy import DDPGroupStrategy
from anemoi.training.train.forecaster import GraphForecaster
from anemoi.training.utils.checkpoint import sanify_checkpoint
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.seeding import get_base_seed

Expand All @@ -42,6 +43,7 @@

LOGGER = logging.getLogger(__name__)


class AnemoiTrainer:
"""Utility class for training the model."""

Expand All @@ -61,8 +63,13 @@ def __init__(self, config: DictConfig) -> None:
OmegaConf.resolve(config)
self.config = config

# Default to not warm-starting from a checkpoint
self.start_from_checkpoint = (bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)) and self.config.training.resume
# Set Transfer Learning based on the other if not provided
if self.config.training.transfer_learning is None:
self.config.training.transfer_learning = (
bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)
) and self.load_weights_only

self.start_from_checkpoint = bool(self.config.training.run_id) or bool(self.config.training.fork_run_id)
self.load_weights_only = config.training.load_weights_only
self.parent_uuid = None

Expand All @@ -82,9 +89,7 @@ def datamodule(self) -> AnemoiDatasetsDataModule:
"""DataModule instance and DataSets."""
datamodule = AnemoiDatasetsDataModule(self.config)
self.config.data.num_features = len(datamodule.ds_train.data.variables)
LOGGER.info(
f"Data has {len(datamodule.ds_train.data.variables)} variables: {datamodule.ds_train.data.variables}"
)
LOGGER.info("Data has ", len(datamodule.ds_train.data.variables), " variables: ", datamodule.ds_train.data.variables)
return datamodule

@cached_property
Expand Down Expand Up @@ -146,11 +151,22 @@ def model(self) -> GraphForecaster:
"metadata": self.metadata,
"statistics": self.datamodule.statistics,
}

model = GraphForecaster(**kwargs)

if self.load_weights_only:
# Sanify the checkpoint for transfer learning
if self.config.training.transfer_learning:
save_path = Path(
self.config.hardware.paths.checkpoints.parent,
(self.fork_run_server2server or self.config.training.fork_run_id) or self.lineage_run,
)
self.last_checkpoint = sanify_checkpoint(model, self.last_checkpoint, save_path)

LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)
return model.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False)

return GraphForecaster(**kwargs)
return model

@rank_zero_only
def _get_mlflow_run_id(self) -> str:
Expand Down Expand Up @@ -200,9 +216,9 @@ def last_checkpoint(self) -> str | None:
checkpoint = Path(
self.config.hardware.paths.checkpoints.parent,
fork_id or self.lineage_run,
self.config.hardware.files.warm_start or "transfer.ckpt" or "last.ckpt",
self.config.hardware.files.warm_start or "last.ckpt",
)

# Check if the last checkpoint exists
if Path(checkpoint).exists():
LOGGER.info("Resuming training from last checkpoint: %s", checkpoint)
Expand Down Expand Up @@ -297,7 +313,7 @@ def _log_information(self) -> None:
total_number_of_model_instances = int(
self.config.hardware.num_nodes
* self.config.hardware.num_gpus_per_node
/ self.config.hardware.num_gpus_per_model
/ self.config.hardware.num_gpus_per_model,
)

LOGGER.debug(
Expand Down Expand Up @@ -355,8 +371,7 @@ def strategy(self) -> DDPGroupStrategy:

def train(self) -> None:
"""Training entry point."""

print('Setting up trainer..')
LOGGER.debug("Setting up trainer..")

trainer = pl.Trainer(
accelerator=self.accelerator,
Expand Down Expand Up @@ -384,7 +399,7 @@ def train(self) -> None:
enable_progress_bar=self.config.diagnostics.enable_progress_bar,
)

print('Starting training..')
LOGGER.debug("Starting training..")

trainer.fit(
self.model,
Expand Down
28 changes: 28 additions & 0 deletions src/anemoi/training/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

from __future__ import annotations

import logging
from pathlib import Path

import torch
from anemoi.utils.checkpoints import save_metadata

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)


def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Module, dict]:
"""Load the lightning checkpoint and extract the pytorch model and its metadata.
Expand Down Expand Up @@ -65,3 +68,28 @@ def save_inference_checkpoint(model: torch.nn.Module, metadata: dict, save_path:
torch.save(model, inference_filepath)
save_metadata(inference_filepath, metadata)
return inference_filepath


def sanify_checkpoint(model: torch.nn.Module, ckpt_path: Path | str, save_path: Path | str) -> Path:

# Load the checkpoint
checkpoint = torch.load(ckpt_path, map_location=model.device)

# Filter out layers with size mismatch
state_dict = checkpoint["state_dict"]

model_state_dict = model.state_dict()

for key in state_dict.copy():
if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape:
LOGGER.debug("Skipping loading parameter: {}, checkpoint shape: {}, model shape: {}".format(
key, state_dict[key].shape, model_state_dict[key].shape
)
)
del state_dict[key] # Remove the mismatched key

new_ckpt_path = Path(save_path, "transfer.ckpt")
LOGGER.info("Saved modified checkpoint at", new_ckpt_path)
torch.save(checkpoint, new_ckpt_path)

return new_ckpt_path

0 comments on commit 57f9026

Please sign in to comment.