Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchSizeFinder safety margin #20447

Open
edmcman opened this issue Nov 25, 2024 · 0 comments
Open

BatchSizeFinder safety margin #20447

edmcman opened this issue Nov 25, 2024 · 0 comments
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@edmcman
Copy link

edmcman commented Nov 25, 2024

Description & Motivation

My model allocates a bit more space for some examples than other. When I run BatchSizeFinder, the batch size it discovers works for a while, but eventually runs into a larger example and I run out of memory.

Pitch

I would like to add an argument to BatchSizeFinder that decreases the final batch size by a multiplicative factor. Possibilities are:

  1. safety_margin of 0.1 would mean multiplying final batch size * (1-0.1)
  2. scale of 0.9 would mean multiplying final batch size by 0.9

This would be turned off by default.

Alternatives

I implemented a pretty hacky version of this using inheritance:

    # Adds a safety margin.  For example, `safety_margin`  of 0.1 indicates that
    # the final batch_size will be reduced by 10%
    class SafeBatchSizeFinder(BatchSizeFinder):
        def __init__(self, safety_margin=0.1, *args, **kwargs):
            super().__init__(*args, **kwargs)
            assert safety_margin >= 0 and safety_margin <= 1.0
            self.safety_margin = safety_margin

        def scale_batch_size(self, trainer, *args, **kwargs):
            super().scale_batch_size(trainer, *args, **kwargs)
            original_batch_size = self.optimal_batch_size
            new_batch_size = int(self.optimal_batch_size * (1.0 - self.safety_margin))
            print(
                f"Found optimal batch size of {original_batch_size}, but with a safety margin of {self.safety_margin}, reducing it to {new_batch_size}"
            )
            self.optimal_batch_size = new_batch_size
            # This adjusts the data module batch_size.
            pl.tuner.batch_size_scaling._adjust_batch_size(trainer, value=new_batch_size)
            pl.tuner.batch_size_scaling._reset_dataloaders(trainer)
            trainer._active_loop.reset()

Additional context

I am willing to implement this. I don't think it would be hard.

cc @Borda

@edmcman edmcman added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant