Skip to content

Commit

Permalink
Revert initialize acqf method to not add bounds to calls (#490)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #490

Acqf_kwargs already gets bounds from config if it needs it so now we work with those instead of the generator's bounds.

Reviewed By: crasanders

Differential Revision: D67532416

fbshipit-source-id: 7fc57d320213dfedf874249da4f225371be40742
  • Loading branch information
JasonKChow committed Dec 20, 2024
1 parent afa6705 commit 35257c4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
35 changes: 14 additions & 21 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,33 +82,26 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti
Returns:
AcquisitionFunction: Configured acquisition function.
"""
if (
"lb" in inspect.signature(self.acqf).parameters
and "ub" in inspect.signature(self.acqf).parameters
):
if self.acqf == AnalyticExpectedUtilityOfBestOption:
return self.acqf(pref_model=model, lb=self.lb, ub=self.ub)
if self.acqf == AnalyticExpectedUtilityOfBestOption:
return self.acqf(pref_model=model)

self.lb = self.lb.to(model.device)
self.ub = self.ub.to(model.device)
if self.acqf in self.baseline_requiring_acqfs:
return self.acqf(
model,
model.train_inputs[0],
lb=self.lb,
ub=self.ub,
**self.acqf_kwargs,
)
if hasattr(model, "device"):
if "lb" in self.acqf_kwargs:
if not isinstance(self.acqf_kwargs["lb"], torch.Tensor):
self.acqf_kwargs["lb"] = torch.tensor(self.acqf_kwargs["lb"])

return self.acqf(model=model, lb=self.lb, ub=self.ub, **self.acqf_kwargs)
self.acqf_kwargs["lb"] = self.acqf_kwargs["lb"].to(model.device)

if self.acqf == AnalyticExpectedUtilityOfBestOption:
return self.acqf(pref_model=model)
if "ub" in self.acqf_kwargs:
if not isinstance(self.acqf_kwargs["ub"], torch.Tensor):
self.acqf_kwargs["ub"] = torch.tensor(self.acqf_kwargs["ub"])

self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device)

if self.acqf in self.baseline_requiring_acqfs:
return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs)

return self.acqf(model=model, **self.acqf_kwargs)
else:
return self.acqf(model=model, **self.acqf_kwargs)

def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def test_analytic_lookahead_generation(self):
"target": 0.75,
"query_set_size": 100,
"Xq": make_scaled_sobol(self.lb, self.ub, 100),
"lb": self.lb,
"ub": self.ub,
},
max_gen_time=0.2,
lb=self.lb,
Expand Down
8 changes: 7 additions & 1 deletion tests_gpu/generators/test_optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from inspect import signature

import torch
from aepsych.acquisition import (
Expand Down Expand Up @@ -51,9 +52,14 @@ class TestOptimizeAcqfGenerator(unittest.TestCase):
def test_gpu_smoketest(self, acqf, acqf_kwargs):
lb = torch.tensor([0.0])
ub = torch.tensor([1.0])
bounds = torch.stack([lb, ub])
inducing_size = 10

acqf_args_expected = list(signature(acqf).parameters.keys())
if "lb" in acqf_args_expected:
acqf_kwargs = acqf_kwargs.copy()
acqf_kwargs["lb"] = lb
acqf_kwargs["ub"] = ub

model = GPClassificationModel(
dim=1,
inducing_size=inducing_size,
Expand Down

0 comments on commit 35257c4

Please sign in to comment.