Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make save_hyperparameters consistent for CLI and hardcoded training for custom python objects #20432

Open
cgebbe opened this issue Nov 19, 2024 · 0 comments
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@cgebbe
Copy link

cgebbe commented Nov 19, 2024

Description & Motivation

Problem

Given the script below, when running running it with lightning CLI, then hparams.yaml becomes

optimizer:
  class_path: __main__.SGD
  init_args:
    params:
    - 1
    - 2
    - 3
    lr: 123.0
myds:
  x: hello
_instantiator: lightning.pytorch.cli.instantiate_module

When running the hardcoded training script instead, hparams.yaml becomes

myds: !!python/object:__main__.MyDataclass
  x: hello
optimizer: !!python/object:__main__.SGD {}

In other words, even though the hyperparameters are the same, hparams.yaml look different. Maybe an alternative question is what's the best practice to define more complex hyperparameters.

Script below

"""
# How to trigger hardcoded training

Comment out `main()` at the very end. Run `python script.py

# How to trigger CLI training

python script.py --config cfg.yaml

model:
  optimizer:
    class_path: SGD
    init_args:
      lr: 123
      params: [1, 2, 3]
"""

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
import torch
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import v2

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# custom objects
from dataclasses import dataclass
import abc
from typing import Iterable, Callable

from lightning.pytorch.core.mixins import HyperparametersMixin


class Optimizer(abc.ABC):
    def __init__(self, params: Iterable = [1, 2, 3]):
        pass


class SGD(Optimizer):
    def __init__(self, params: Iterable, lr: float):
        super().__init__()


@dataclass
class MyDataclass:
    x: str = "hello"


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(
        self,
        # optimizer: Callable[[], Optimizer],
        optimizer: Optimizer,
        myds: MyDataclass,
    ):
        print(type(optimizer))
        # print(type(optimizer()))

        super().__init__()
        self.save_hyperparameters()

        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        transform = v2.ToTensor()  # mainly to convert PIL to tensor
        self.mnist_test = MNIST(
            self.data_dir,
            train=False,
            download=True,
            transform=transform,
        )
        self.mnist_predict = self.mnist_test
        mnist_full = MNIST(
            self.data_dir,
            train=True,
            download=True,
            transform=transform,
        )
        self.mnist_train, self.mnist_val = random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)


def train_hardcode():
    autoencoder = LitAutoEncoder(
        optimizer=SGD(params=[1, 2, 3], lr=123),
        myds=MyDataclass(),
    )

    # setup data
    dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
    train_loader = utils.data.DataLoader(dataset)

    # train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
    trainer = L.Trainer(limit_train_batches=10, max_epochs=1)
    trainer.fit(model=autoencoder, train_dataloaders=train_loader)


def main():
    cli.LightningCLI(
        LitAutoEncoder,
        MNISTDataModule,
        trainer_defaults=dict(
            max_epochs=1,
            limit_train_batches=10,
            limit_val_batches=10,
        ),
    )


if __name__ == "__main__":
    main()
    # train_hardcode()

Pitch

No response

Alternatives

No response

Additional context

No response

cc @Borda

@cgebbe cgebbe added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant