diff --git a/aepsych/generators/optimize_acqf_generator.py b/aepsych/generators/optimize_acqf_generator.py index 82630dfa7..af97b1237 100644 --- a/aepsych/generators/optimize_acqf_generator.py +++ b/aepsych/generators/optimize_acqf_generator.py @@ -82,33 +82,26 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti Returns: AcquisitionFunction: Configured acquisition function. """ - if ( - "lb" in inspect.signature(self.acqf).parameters - and "ub" in inspect.signature(self.acqf).parameters - ): - if self.acqf == AnalyticExpectedUtilityOfBestOption: - return self.acqf(pref_model=model, lb=self.lb, ub=self.ub) + if self.acqf == AnalyticExpectedUtilityOfBestOption: + return self.acqf(pref_model=model) - self.lb = self.lb.to(model.device) - self.ub = self.ub.to(model.device) - if self.acqf in self.baseline_requiring_acqfs: - return self.acqf( - model, - model.train_inputs[0], - lb=self.lb, - ub=self.ub, - **self.acqf_kwargs, - ) + if hasattr(model, "device"): + if "lb" in self.acqf_kwargs: + if not isinstance(self.acqf_kwargs["lb"], torch.Tensor): + self.acqf_kwargs["lb"] = torch.tensor(self.acqf_kwargs["lb"]) - return self.acqf(model=model, lb=self.lb, ub=self.ub, **self.acqf_kwargs) + self.acqf_kwargs["lb"] = self.acqf_kwargs["lb"].to(model.device) - if self.acqf == AnalyticExpectedUtilityOfBestOption: - return self.acqf(pref_model=model) + if "ub" in self.acqf_kwargs: + if not isinstance(self.acqf_kwargs["ub"], torch.Tensor): + self.acqf_kwargs["ub"] = torch.tensor(self.acqf_kwargs["ub"]) + + self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device) if self.acqf in self.baseline_requiring_acqfs: return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs) - - return self.acqf(model=model, **self.acqf_kwargs) + else: + return self.acqf(model=model, **self.acqf_kwargs) def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Tensor: """Query next point(s) to run by optimizing the acquisition function. diff --git a/tests/models/test_semi_p.py b/tests/models/test_semi_p.py index 1f2caf12f..aaf8f0f20 100644 --- a/tests/models/test_semi_p.py +++ b/tests/models/test_semi_p.py @@ -141,6 +141,8 @@ def test_analytic_lookahead_generation(self): "target": 0.75, "query_set_size": 100, "Xq": make_scaled_sobol(self.lb, self.ub, 100), + "lb": self.lb, + "ub": self.ub, }, max_gen_time=0.2, lb=self.lb, diff --git a/tests_gpu/generators/test_optimize_acqf_generator.py b/tests_gpu/generators/test_optimize_acqf_generator.py index 9caef06e4..6fc2de55e 100644 --- a/tests_gpu/generators/test_optimize_acqf_generator.py +++ b/tests_gpu/generators/test_optimize_acqf_generator.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from inspect import signature import torch from aepsych.acquisition import ( @@ -51,9 +52,14 @@ class TestOptimizeAcqfGenerator(unittest.TestCase): def test_gpu_smoketest(self, acqf, acqf_kwargs): lb = torch.tensor([0.0]) ub = torch.tensor([1.0]) - bounds = torch.stack([lb, ub]) inducing_size = 10 + acqf_args_expected = list(signature(acqf).parameters.keys()) + if "lb" in acqf_args_expected: + acqf_kwargs = acqf_kwargs.copy() + acqf_kwargs["lb"] = lb + acqf_kwargs["ub"] = ub + model = GPClassificationModel( dim=1, inducing_size=inducing_size,