Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Support Custom Learning Rate Scheduler #1637

Merged
merged 44 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
2ae4506
enable re-training
Constantin343 Jun 24, 2024
900c8d5
update scheduler
Constantin343 Jun 29, 2024
f1355eb
change scheduler for continued training
Constantin343 Jun 30, 2024
da3a6d5
add test
Constantin343 Jul 1, 2024
492dee9
merge main
Constantin343 Jul 1, 2024
f996928
fix metrics logging
Constantin343 Jul 1, 2024
f9a77f8
include feedback
Constantin343 Jul 5, 2024
7ad761d
get correct optimizer states
Constantin343 Jul 5, 2024
b14d20b
fix tests
Constantin343 Jul 5, 2024
9fe3401
enable setting the scheduler
Constantin343 Jul 8, 2024
00f2e25
update for onecyclelr
Constantin343 Jul 8, 2024
5f103d8
add tests and adapt docstring
Constantin343 Jul 9, 2024
e043201
fix array mismatch
Constantin343 Jul 9, 2024
df74dc3
Merge branch 'main' into dynamic-weight-saving-for-retraining
ourownstory Aug 23, 2024
63c935c
robustify scheduler config
ourownstory Aug 24, 2024
6a74680
clean up train config setup
ourownstory Aug 24, 2024
420f8a6
restructure train model config
ourownstory Aug 24, 2024
1982089
remove continue train
ourownstory Aug 27, 2024
99e0355
fix regularization
ourownstory Aug 28, 2024
9575de1
fix regularization of holidays test
ourownstory Aug 28, 2024
6d76cb0
address events reg test
ourownstory Aug 28, 2024
19d6497
fixed reg tests
ourownstory Aug 28, 2024
9187f7f
fix save
ourownstory Aug 28, 2024
7a86edf
move to debug folder
ourownstory Aug 28, 2024
ee9e0e4
debugging
ourownstory Aug 28, 2024
c3e5ba2
Merge branch 'main' into custom-lr-scheduler
ourownstory Aug 28, 2024
c3f3c3c
fix custom lr
ourownstory Aug 28, 2024
db09100
set finding lr arg
ourownstory Aug 28, 2024
b8bf9b8
add logging of progress and lr
ourownstory Aug 28, 2024
a87651c
update lr schedulers to use epochs
ourownstory Aug 28, 2024
c83e7cc
Merge branch 'main' into custom-lr-scheduler
ourownstory Aug 28, 2024
b79b7e1
fix lr-finder
ourownstory Aug 29, 2024
9f03ed2
Merge branch 'main' into custom-lr-scheduler
ourownstory Aug 29, 2024
bc64a52
improve num_training calculation for lr-finder and remove loss-min fo…
ourownstory Aug 29, 2024
c7c6313
large changeset - isolate lr-finder
ourownstory Aug 30, 2024
ee28441
fix progressbar
ourownstory Aug 30, 2024
be3b6cf
remove dataloader from model
ourownstory Aug 30, 2024
462f80f
fix callbacks ProgressBar
ourownstory Aug 30, 2024
628b4ad
fixing tuner
ourownstory Aug 30, 2024
4b9b305
fix tuner
ourownstory Aug 30, 2024
bc27cfc
readd prep_or_copy
ourownstory Aug 30, 2024
56215b7
undo copy of model, loader, trainer
ourownstory Aug 30, 2024
fda417a
add comment about separate lr finder copies
ourownstory Aug 30, 2024
7020244
improve lr finder comment
ourownstory Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 106 additions & 64 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
@dataclass
class Model:
lagged_reg_layers: Optional[List[int]]
quantiles: Optional[List[float]] = None

def setup_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be provided as list."
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)


@dataclass
Expand Down Expand Up @@ -92,30 +110,31 @@
batch_size: Optional[int]
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable]
optimizer: Union[str, Type[torch.optim.Optimizer]]
quantiles: List[float] = field(default_factory=list)
# quantiles: List[float] = field(default_factory=list)
optimizer_args: dict = field(default_factory=dict)
scheduler: Optional[Type[torch.optim.lr_scheduler.OneCycleLR]] = None
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None
scheduler_args: dict = field(default_factory=dict)
early_stopping: Optional[bool] = False
newer_samples_weight: float = 1.0
newer_samples_start: float = 0.0
reg_delay_pct: float = 0.5
reg_lambda_trend: Optional[float] = None
trend_reg_threshold: Optional[Union[bool, float]] = None
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
lr_finder_args: dict = field(default_factory=dict)
pl_trainer_config: dict = field(default_factory=dict)

def __post_init__(self):
# assert the uncertainty estimation params and then finalize the quantiles
self.set_quantiles()
assert self.newer_samples_weight >= 1.0
assert self.newer_samples_start >= 0.0
assert self.newer_samples_start < 1.0
self.set_loss_func()
self.set_optimizer()
self.set_scheduler()
# self.set_loss_func(self.quantiles)

def set_loss_func(self):
# called in TimeNet configure_optimizers:
# self.set_optimizer()
# self.set_scheduler()

def set_loss_func(self, quantiles: List[float]):
if isinstance(self.loss_func, str):
if self.loss_func.lower() in ["smoothl1", "smoothl1loss", "huber"]:
# keeping 'huber' for backwards compatiblility, though not identical
Expand All @@ -135,25 +154,8 @@
self.loss_func_name = type(self.loss_func).__name__
else:
raise NotImplementedError(f"Loss function {self.loss_func} not found")
if len(self.quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles)

def set_quantiles(self):
# convert quantiles to empty list [] if None
if self.quantiles is None:
self.quantiles = []
# assert quantiles is a list type
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
# check if quantiles are float values in (0, 1)
assert all(
0 < quantile < 1 for quantile in self.quantiles
), "The quantiles specified need to be floats in-between (0, 1)."
# sort the quantiles
self.quantiles.sort()
# 0 is the median quantile index
self.quantiles.insert(0, 0.5)
if len(quantiles) > 1:
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=quantiles)

def set_auto_batch_epoch(
self,
Expand Down Expand Up @@ -182,51 +184,88 @@
"""
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding
torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet.

Parameters
----------
optimizer_name : int
Object provided to NeuralProphet as optimizer.
optimizer_args : dict
Arguments for the optimizer.

"""
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config(
self.optimizer, self.optimizer_args
)
if isinstance(self.optimizer, str):
if self.optimizer.lower() == "adamw":
# Tends to overfit, but reliable
self.optimizer = torch.optim.AdamW
self.optimizer_args["weight_decay"] = 1e-3
elif self.optimizer.lower() == "sgd":
# better validation performance, but diverges sometimes
self.optimizer = torch.optim.SGD
self.optimizer_args["momentum"] = 0.9
self.optimizer_args["weight_decay"] = 1e-4
else:
raise ValueError(
f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class."
)
elif not issubclass(self.optimizer, torch.optim.Optimizer):
raise ValueError("The provided optimizer is not supported.")

def set_scheduler(self):
"""
Set the scheduler and scheduler args.
Set the scheduler and scheduler arg depending on the user selection.
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
"""
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
self.scheduler_args.update(
{
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
)

def set_lr_finder_args(self, dataset_size, num_batches):
"""
Set the lr_finder_args.
This is the range of learning rates to test.
"""
num_training = 100 + int(np.log10(dataset_size) * 20)
if num_batches < num_training:
log.warning(
f"Learning rate finder: The number of batches ({num_batches}) is too small than the required number \
for the learning rate finder ({num_training}). The results might not be optimal."
)
# num_training = num_batches
self.lr_finder_args.update(
{
"min_lr": 1e-7,
"max_lr": 10,
"num_training": num_training,
"early_stop_threshold": None,
}
)
if self.scheduler is None:
log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.")
self.scheduler = "exponentiallr"

if isinstance(self.scheduler, str):
if self.scheduler.lower() in ["onecycle", "onecyclelr"]:
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
defaults = {
"pct_start": 0.3,
"anneal_strategy": "cos",
"div_factor": 10.0,
"final_div_factor": 10.0,
"three_phase": True,
}
elif self.scheduler.lower() == "steplr":
self.scheduler = torch.optim.lr_scheduler.StepLR
defaults = {
"step_size": 10,
"gamma": 0.1,
}
elif self.scheduler.lower() == "exponentiallr":
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
defaults = {
"gamma": 0.9,
}
elif self.scheduler.lower() == "cosineannealinglr":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
defaults = {
"T_max": 50,
}
elif self.scheduler.lower() == "cosineannealingwarmrestarts":
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
defaults = {
"T_0": 5,
"T_mult": 2,
}
else:
raise NotImplementedError(
f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class."
)
if self.scheduler_args is not None:
defaults.update(self.scheduler_args)
self.scheduler_args = defaults
else:
assert issubclass(
self.scheduler, torch.optim.lr_scheduler.LRScheduler
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler"

def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0):
def get_reg_delay_weight(self, progress, reg_start_pct: float = 0.66, reg_full_pct: float = 1.0):
# Ignore type warning of epochs possibly being None (does not work with dataclasses)
progress = (e + iter_progress) / float(self.epochs) # type: ignore
if reg_start_pct == reg_full_pct:
reg_progress = float(progress > reg_start_pct)
else:
Expand All @@ -239,6 +278,9 @@
delay_weight = 1
return delay_weight

def set_batches_per_epoch(self, batches_per_epoch: int):
self.batches_per_epoch = batches_per_epoch


@dataclass
class Trend:
Expand Down Expand Up @@ -304,7 +346,7 @@
log.error("Invalid growth for global_local mode '{}'. Set to 'global'".format(self.trend_global_local))
self.trend_global_local = "global"

if self.trend_local_reg < 0:

Check failure on line 349 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator "<" not supported for "None" (reportOptionalOperand)
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg))
self.trend_local_reg = False

Expand Down Expand Up @@ -353,13 +395,13 @@
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local))
self.global_local = "global"

self.periods = OrderedDict(

Check failure on line 398 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

No overloads for "__init__" match the provided arguments (reportCallIssue)
{

Check failure on line 399 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "dict[str, Season]" cannot be assigned to parameter "iterable" of type "Iterable[list[bytes]]" in function "__init__" (reportArgumentType)
"yearly": Season(
resolution=6,
period=365.25,
arg=self.yearly_arg,
global_local=(

Check failure on line 404 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.yearly_global_local
if self.yearly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -370,7 +412,7 @@
resolution=3,
period=7,
arg=self.weekly_arg,
global_local=(

Check failure on line 415 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.weekly_global_local
if self.weekly_global_local in ["global", "local"]
else self.global_local
Expand All @@ -381,7 +423,7 @@
resolution=6,
period=1,
arg=self.daily_arg,
global_local=(

Check failure on line 426 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "SeasonGlobalLocalMode | Literal['auto']" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__" (reportArgumentType)
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local
),
condition_name=None,
Expand All @@ -389,7 +431,7 @@
}
)

assert self.seasonality_local_reg >= 0, "Invalid seasonality_local_reg '{}'.".format(self.seasonality_local_reg)

Check failure on line 434 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Operator ">=" not supported for "None" (reportOptionalOperand)

if self.seasonality_local_reg is True:
log.warning("seasonality_local_reg = True. Default seasonality_local_reg value set to 1")
Expand All @@ -407,7 +449,7 @@
resolution=resolution,
period=period,
arg=arg,
global_local=global_local if global_local in ["global", "local"] else self.global_local,

Check failure on line 452 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Argument of type "str" cannot be assigned to parameter "global_local" of type "SeasonGlobalLocalMode" in function "__init__"   Type "str" is not assignable to type "SeasonGlobalLocalMode"     "str" is not assignable to type "Literal['global']"     "str" is not assignable to type "Literal['local']"     "str" is not assignable to type "Literal['glocal']" (reportArgumentType)
condition_name=condition_name,
)

Expand Down Expand Up @@ -483,7 +525,7 @@
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects

def __post_init__(self):
self.regressors = None

Check failure on line 528 in neuralprophet/configure.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot assign to attribute "regressors" for class "ConfigFutureRegressors*"   "None" is not assignable to "OrderedDict[Unknown, Unknown]" (reportAttributeAccessIssue)


@dataclass
Expand Down
Loading
Loading