You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My model allocates a bit more space for some examples than other. When I run BatchSizeFinder, the batch size it discovers works for a while, but eventually runs into a larger example and I run out of memory.
Pitch
I would like to add an argument to BatchSizeFinder that decreases the final batch size by a multiplicative factor. Possibilities are:
safety_margin of 0.1 would mean multiplying final batch size * (1-0.1)
scale of 0.9 would mean multiplying final batch size by 0.9
This would be turned off by default.
Alternatives
I implemented a pretty hacky version of this using inheritance:
# Adds a safety margin. For example, `safety_margin` of 0.1 indicates that# the final batch_size will be reduced by 10%classSafeBatchSizeFinder(BatchSizeFinder):
def__init__(self, safety_margin=0.1, *args, **kwargs):
super().__init__(*args, **kwargs)
assertsafety_margin>=0andsafety_margin<=1.0self.safety_margin=safety_margindefscale_batch_size(self, trainer, *args, **kwargs):
super().scale_batch_size(trainer, *args, **kwargs)
original_batch_size=self.optimal_batch_sizenew_batch_size=int(self.optimal_batch_size* (1.0-self.safety_margin))
print(
f"Found optimal batch size of {original_batch_size}, but with a safety margin of {self.safety_margin}, reducing it to {new_batch_size}"
)
self.optimal_batch_size=new_batch_size# This adjusts the data module batch_size.pl.tuner.batch_size_scaling._adjust_batch_size(trainer, value=new_batch_size)
pl.tuner.batch_size_scaling._reset_dataloaders(trainer)
trainer._active_loop.reset()
Additional context
I am willing to implement this. I don't think it would be hard.
Description & Motivation
My model allocates a bit more space for some examples than other. When I run
BatchSizeFinder
, the batch size it discovers works for a while, but eventually runs into a larger example and I run out of memory.Pitch
I would like to add an argument to
BatchSizeFinder
that decreases the final batch size by a multiplicative factor. Possibilities are:safety_margin
of 0.1 would mean multiplying final batch size * (1-0.1)scale
of 0.9 would mean multiplying final batch size by 0.9This would be turned off by default.
Alternatives
I implemented a pretty hacky version of this using inheritance:
Additional context
I am willing to implement this. I don't think it would be hard.
cc @Borda
The text was updated successfully, but these errors were encountered: