Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Error in configure_optimizers #20462

Open
LukasSalchow opened this issue Dec 3, 2024 · 0 comments
Open

Type Error in configure_optimizers #20462

LukasSalchow opened this issue Dec 3, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@LukasSalchow
Copy link

LukasSalchow commented Dec 3, 2024

Bug description

As suggested in the docs my configure_optimizers looks ruffly like this:

    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
        ...
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

Note that if I want to specify a return type I have to choose OptimizerLRSchedulerConfig (or OptimizerLRScheduler) since that is the only sub type of OptimizerLRScheduler (the return type of configure_optimizers) that is a dict.

When I run that I get

lightning_fabric.utilities.exceptions.MisconfigurationException: `configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used. For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}

But I cannot add 'monitor': 'metric_to_track' to my returned dict since then mypy complains

error: Extra key "monitor" for TypedDict "OptimizerLRSchedulerConfig"  [typeddict-unknown-key]

Suggested fix

I suggest to replace

class OptimizerLRSchedulerConfig(TypedDict):
    optimizer: Optimizer
    lr_scheduler: NotRequired[Union[LRSchedulerTypeUnion, LRSchedulerConfigType]]

from utilities/types.py with

class OptimizerConfigDict(TypedDict):
    optimizer: Optimizer


class OptimizerLRSchedulerConfigDict(TypedDict):
    optimizer: Optimizer
    lr_scheduler: Union[LRSchedulerTypeUnion, LRSchedulerConfigType]
    monitor: str

and

OptimizerLRScheduler = Optional[
    Union[
        Optimizer,
        Sequence[Optimizer],
        Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]],
        OptimizerLRSchedulerConfig,
        Sequence[OptimizerLRSchedulerConfig],
    ]
]

with

OptimizerLRScheduler = Optional[
    Union[
        Optimizer,
        Sequence[Optimizer],
        Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]],
        OptimizerConfigDict,
        OptimizerLRSchedulerConfigDict,
        Sequence[OptimizerConfigDict],
        Sequence[OptimizerLRSchedulerConfigDict],
    ]
]

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

@LukasSalchow LukasSalchow added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant