Skip to content

Commit

Permalink
PT engine, cleanup types
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 4, 2023
1 parent 3d8c8ef commit e9a9920
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions returnn/torch/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import typing
from typing import Set
from typing import Any, Set, Dict

from returnn.log import log
from returnn.util.basic import RefIdEq
Expand Down Expand Up @@ -35,7 +35,8 @@ def _init_optimizer_classes_dict():

def get_optimizer_class(class_name):
"""
:param str|function|type[torch.optim.Optimizer] class_name: Optimizer data, e.g. "adam", torch.optim.Adam...
:param str|()->torch.optim.Optimizer|type[torch.optim.Optimizer] class_name:
Optimizer data, e.g. "adam", torch.optim.Adam...
:return: Optimizer class
:rtype: type[torch.optim.Optimizer]
"""
Expand Down Expand Up @@ -156,7 +157,7 @@ def _create_optimizer(self, optimizer_opts):
if isinstance(optimizer_opts, torch.optim.Optimizer):
return optimizer_opts
elif callable(optimizer_opts):
optimizer_opts = {"class": optimizer_opts}
optimizer_opts: Dict[str, Any] = {"class": optimizer_opts}
else:
if not isinstance(optimizer_opts, dict):
raise ValueError("'optimizer' must of type dict, callable or torch.optim.Optimizer instance.")
Expand Down

0 comments on commit e9a9920

Please sign in to comment.