From 4d91df5a63b5af81e482c7d25928896d516201eb Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 18 Nov 2024 11:49:06 -0800 Subject: [PATCH] Monotonic rejection model and generator (#458) Summary: monotonic rejection model GPU support, since they're tied to the generator, we also ensure the generators are gpu ready as well. Differential Revision: D65638150 --- .../monotonic_rejection_generator.py | 2 +- aepsych/kernels/pairwisekernel.py | 12 +- aepsych/kernels/rbf_partial_grad.py | 6 +- aepsych/means/constant_partial_grad.py | 2 +- aepsych/models/base.py | 7 +- aepsych/models/derivative_gp.py | 6 +- aepsych/models/monotonic_projection_gp.py | 16 +- aepsych/models/monotonic_rejection_gp.py | 25 ++- aepsych/strategy.py | 17 +- tests_gpu/acquisition/__init__.py | 6 + tests_gpu/acquisition/test_monotonic.py | 51 +++++ .../models/test_monotonic_rejection_gp.py | 188 ++++++++++++++++++ tests_gpu/test_strategy.py | 13 -- 13 files changed, 294 insertions(+), 57 deletions(-) create mode 100644 tests_gpu/acquisition/__init__.py create mode 100644 tests_gpu/acquisition/test_monotonic.py create mode 100644 tests_gpu/models/test_monotonic_rejection_gp.py diff --git a/aepsych/generators/monotonic_rejection_generator.py b/aepsych/generators/monotonic_rejection_generator.py index 3df99be2e..cc74649e1 100644 --- a/aepsych/generators/monotonic_rejection_generator.py +++ b/aepsych/generators/monotonic_rejection_generator.py @@ -101,7 +101,7 @@ def gen( ) # Augment bounds with deriv indicator - bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1) + bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1) # Fix deriv indicator to 0 during optimization fixed_features = {(bounds.shape[1] - 1): 0.0} # Fix explore features to random values diff --git a/aepsych/kernels/pairwisekernel.py b/aepsych/kernels/pairwisekernel.py index c30bbf512..f85a6bc3d 100644 --- a/aepsych/kernels/pairwisekernel.py +++ b/aepsych/kernels/pairwisekernel.py @@ -16,12 +16,12 @@ class PairwiseKernel(Kernel): """ def __init__( - self, latent_kernel: Kernel, is_partial_obs: bool=False, **kwargs + self, latent_kernel: Kernel, is_partial_obs: bool = False, **kwargs ) -> None: """ - Args: - latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP. - is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False. + Args: + latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP. + is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False. """ super(PairwiseKernel, self).__init__(**kwargs) @@ -40,11 +40,11 @@ def forward( x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. diag (bool): Should the Kernel compute the whole covariance matrix or just the diagonal? Defaults to False. - + Returns: torch.Tensor (or :class:`gpytorch.lazy.LazyTensor`) : A `b x n x m` or `n x m` tensor representing - the covariance matrix between `x1` and `x2`. + the covariance matrix between `x1` and `x2`. The exact size depends on the kernel's evaluation mode: * `full_covar`: `n x m` or `b x n x m` * `diag`: `n` or `b x n` diff --git a/aepsych/kernels/rbf_partial_grad.py b/aepsych/kernels/rbf_partial_grad.py index 09a204130..43c574d75 100644 --- a/aepsych/kernels/rbf_partial_grad.py +++ b/aepsych/kernels/rbf_partial_grad.py @@ -31,14 +31,14 @@ def forward( self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params: Any ) -> torch.Tensor: """Computes the covariance matrix between x1 and x2 based on the RBF - + Args: x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. diag (bool): Should the Kernel compute the whole covariance matrix (False) or just the diagonal (True)? Defaults to False. - - + + Returns: torch.Tensor: A `b x n x m` or `n x m` tensor representing the covariance matrix between `x1` and `x2`. The exact size depends on the kernel's evaluation mode: diff --git a/aepsych/means/constant_partial_grad.py b/aepsych/means/constant_partial_grad.py index ead7ee6ed..e0af2c29a 100644 --- a/aepsych/means/constant_partial_grad.py +++ b/aepsych/means/constant_partial_grad.py @@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: idx = input[..., -1].to(dtype=torch.long) > 0 mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :]) sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]]) - mean = torch.zeros(sz) + mean = torch.zeros(sz).to(input) mean[~idx] = mean_fit return mean diff --git a/aepsych/models/base.py b/aepsych/models/base.py index 4d554be81..7fbedf9b1 100644 --- a/aepsych/models/base.py +++ b/aepsych/models/base.py @@ -415,7 +415,12 @@ def set_train_data(self, inputs=None, targets=None, strict=False): def device(self) -> torch.device: # We assume all models have some parameters and all models will only use one device # notice that this has no setting, don't let users set device, use .to(). - return next(self.parameters()).device + try: + return next(self.parameters()).device + except ( + AttributeError + ): # Fallback for cases where we need device before we have params + return torch.device("cpu") @property def train_inputs(self) -> Optional[Tuple[torch.Tensor]]: diff --git a/aepsych/models/derivative_gp.py b/aepsych/models/derivative_gp.py index b338f5a7a..90d3042c3 100644 --- a/aepsych/models/derivative_gp.py +++ b/aepsych/models/derivative_gp.py @@ -13,6 +13,7 @@ import torch from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad +from aepsych.models.base import AEPsychModelDeviceMixin from botorch.models.gpytorch import GPyTorchModel from gpytorch.distributions import MultivariateNormal from gpytorch.kernels import Kernel @@ -22,7 +23,9 @@ from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy -class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel): +class MixedDerivativeVariationalGP( + gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel +): """A variational GP with mixed derivative observations. For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010. @@ -99,6 +102,7 @@ def __init__( self._num_outputs = 1 self.train_inputs = (train_x,) self.train_targets = train_y + self.to(self.device) # Needed to prep for below self(train_x) # Necessary for CholeskyVariationalDistribution def forward(self, x: torch.Tensor) -> MultivariateNormal: diff --git a/aepsych/models/monotonic_projection_gp.py b/aepsych/models/monotonic_projection_gp.py index 61bec1648..3de672c8e 100644 --- a/aepsych/models/monotonic_projection_gp.py +++ b/aepsych/models/monotonic_projection_gp.py @@ -136,14 +136,17 @@ def posterior( # using numpy because torch doesn't support vectorized linspace, # pytorch/issues/61292 grid: Union[np.ndarray, torch.Tensor] = np.linspace( - self.lb[dim], - X[:, dim].numpy(), + self.lb[dim].cpu().numpy(), + X[:, dim].cpu().numpy(), s + 1, ) # (s+1 x n) grid = torch.tensor(grid[:-1, :], dtype=X.dtype) # Drop x; (s x n) X_aug[(1 + i * s) : (1 + (i + 1) * s), :, dim] = grid # X_aug[0, :, :] is X, and then subsequent indices are points in the grids # Predict marginal distributions on X_aug + + X = X.to(self.device) + X_aug = X_aug.to(self.device) with torch.no_grad(): post_aug = super().posterior(X=X_aug) mu_aug = post_aug.mean.squeeze() # (m*s+1 x n) @@ -158,12 +161,13 @@ def posterior( # Adjust the whole covariance matrix to accomadate the projected marginals with torch.no_grad(): post = super().posterior(X=X) - R = cov2corr(post.distribution.covariance_matrix.squeeze().numpy()) - S_proj = torch.tensor(corr2cov(R, sigma_proj.numpy()), dtype=X.dtype) + R = cov2corr(post.distribution.covariance_matrix.squeeze().cpu().numpy()) + S_proj = torch.tensor(corr2cov(R, sigma_proj.cpu().numpy()), dtype=X.dtype) mvn_proj = gpytorch.distributions.MultivariateNormal( - mu_proj.unsqueeze(0), - S_proj.unsqueeze(0), + mu_proj.unsqueeze(0).to(self.device), + S_proj.unsqueeze(0).to(self.device), ) + return GPyTorchPosterior(mvn_proj) def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor: diff --git a/aepsych/models/monotonic_rejection_gp.py b/aepsych/models/monotonic_rejection_gp.py index 9bf761abe..2a075f374 100644 --- a/aepsych/models/monotonic_rejection_gp.py +++ b/aepsych/models/monotonic_rejection_gp.py @@ -18,7 +18,7 @@ from aepsych.factory.monotonic import monotonic_mean_covar_factory from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad -from aepsych.models.base import AEPsychMixin +from aepsych.models.base import AEPsychModelDeviceMixin from aepsych.models.utils import select_inducing_points from aepsych.utils import _process_bounds, promote_0d from botorch.fit import fit_gpytorch_mll @@ -32,7 +32,7 @@ from torch import Tensor -class MonotonicRejectionGP(AEPsychMixin, ApproximateGP): +class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP): """A monotonic GP using rejection sampling. This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP @@ -83,7 +83,7 @@ def __init__( objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli. extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None. """ - self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) + lb, ub, self.dim = _process_bounds(lb, ub, dim) if likelihood is None: likelihood = BernoulliLikelihood() @@ -91,7 +91,7 @@ def __init__( self.inducing_point_method = inducing_point_method inducing_points = select_inducing_points( inducing_size=self.inducing_size, - bounds=self.bounds, + bounds=torch.stack((lb, ub)), method="sobol", ) @@ -134,7 +134,9 @@ def __init__( super().__init__(variational_strategy) - self.bounds_ = torch.stack([self.lb, self.ub]) + self.register_buffer("lb", lb) + self.register_buffer("ub", ub) + self.register_buffer("bounds_", torch.stack([self.lb, self.ub])) self.mean_module = mean_module self.covar_module = covar_module self.likelihood = likelihood @@ -144,7 +146,7 @@ def __init__( self.num_samples = num_samples self.num_rejection_samples = num_rejection_samples self.fixed_prior_mean = fixed_prior_mean - self.inducing_points = inducing_points + self.register_buffer("inducing_points", inducing_points) def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None: """Fit the model @@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None: X=self.train_inputs[0], bounds=self.bounds, method=self.inducing_point_method, - ) + ).to(self.device) self._set_model(train_x, train_y) def _set_model( @@ -284,13 +286,14 @@ def predict_probability( return self.predict(x, probability_space=True) def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor: + x = x.to(self.device) return torch.cat( - (x, indx * torch.ones(x.shape[0], 1)), + (x, indx * torch.ones(x.shape[0], 1).to(self.device)), dim=1, ) def _get_deriv_constraint_points(self) -> Tensor: - deriv_cp = torch.tensor([]) + deriv_cp = torch.tensor([]).to(self.device) for i in self.monotonic_idxs: induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1) deriv_cp = torch.cat((deriv_cp, induc_i), dim=0) @@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor: @classmethod def from_config(cls, config: Config) -> MonotonicRejectionGP: classname = cls.__name__ - num_induc = config.gettensor(classname, "num_induc", fallback=25) - num_samples = config.gettensor(classname, "num_samples", fallback=250) + num_induc = config.getint(classname, "num_induc", fallback=25) + num_samples = config.getint(classname, "num_samples", fallback=250) num_rejection_samples = config.getint( classname, "num_rejection_samples", fallback=5000 ) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index aaa5da335..e75f62f3d 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -71,13 +71,6 @@ class Strategy(object): _n_eval_points: int = 1000 - no_gpu_acqfs = ( - MonotonicMCAcquisition, - MonotonicBernoulliMCMutualInformation, - MonotonicMCPosteriorVariance, - MonotonicMCLSE, - ) - def __init__( self, generator: Union[AEPsychGenerator, ParameterTransformedGenerator], @@ -182,13 +175,7 @@ def __init__( ) self.generator_device = torch.device("cpu") else: - if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs: - warnings.warn( - f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.", - UserWarning, - ) - self.generator_device = torch.device("cpu") - elif not torch.cuda.is_available(): + if not torch.cuda.is_available(): warnings.warn( f"GPU requested for generator {type(generator).__name__}, but no GPU found! Using CPU instead.", UserWarning, @@ -283,9 +270,11 @@ def normalize_inputs( x = x[None, :] if self.x is not None: + x = x.to(self.x) x = torch.cat((self.x, x), dim=0) if self.y is not None: + y = y.to(self.y) y = torch.cat((self.y, y), dim=0) # Ensure the correct dtype diff --git a/tests_gpu/acquisition/__init__.py b/tests_gpu/acquisition/__init__.py new file mode 100644 index 000000000..500a0829c --- /dev/null +++ b/tests_gpu/acquisition/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests_gpu/acquisition/test_monotonic.py b/tests_gpu/acquisition/test_monotonic.py new file mode 100644 index 000000000..fd04326c3 --- /dev/null +++ b/tests_gpu/acquisition/test_monotonic.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE +from aepsych.acquisition.objective import ProbitObjective +from aepsych.models.derivative_gp import MixedDerivativeVariationalGP +from botorch.acquisition.objective import IdentityMCObjective +from botorch.utils.testing import BotorchTestCase + + +class TestMonotonicAcq(BotorchTestCase): + def test_monotonic_acq_gpu(self): + # Init + train_X_aug = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]] + ).cuda() + deriv_constraint_points = torch.tensor( + [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]] + ).cuda() + train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda() + + m = MixedDerivativeVariationalGP( + train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug + ).cuda() + acq = MonotonicMCLSE( + model=m, + deriv_constraint_points=deriv_constraint_points, + num_samples=5, + num_rejection_samples=8, + target=1.9, + ) + self.assertTrue(isinstance(acq.objective, IdentityMCObjective)) + acq = MonotonicMCLSE( + model=m, + deriv_constraint_points=deriv_constraint_points, + num_samples=5, + num_rejection_samples=8, + target=1.9, + objective=ProbitObjective(), + ).cuda() + # forward + acq(train_X_aug) + Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0) + posterior = m.posterior(Xfull) + samples = acq.sampler(posterior) + self.assertEqual(samples.shape, torch.Size([5, 6, 1])) diff --git a/tests_gpu/models/test_monotonic_rejection_gp.py b/tests_gpu/models/test_monotonic_rejection_gp.py new file mode 100644 index 000000000..da1998051 --- /dev/null +++ b/tests_gpu/models/test_monotonic_rejection_gp.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch + +# run on single threads to keep us from deadlocking weirdly in CI +if "CI" in os.environ or "SANDCASTLE" in os.environ: + torch.set_num_threads(1) + +import numpy as np +from aepsych import Config +from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE +from aepsych.acquisition.objective import ProbitObjective +from aepsych.generators import MonotonicRejectionGenerator +from aepsych.models import MonotonicRejectionGP +from aepsych.strategy import SequentialStrategy, Strategy +from botorch.acquisition.objective import IdentityMCObjective +from botorch.utils.testing import BotorchTestCase +from gpytorch.likelihoods import BernoulliLikelihood, GaussianLikelihood +from scipy.stats import norm + + +class MonotonicRejectionGPLSETest(BotorchTestCase): + def test_regression_gpu(self): + # Init + target = 1.5 + model_gen_options = {"num_restarts": 1, "raw_samples": 3, "epochs": 5} + lb = torch.tensor([0, 0]) + ub = torch.tensor([4, 4]) + m = MonotonicRejectionGP( + lb=lb, + ub=ub, + likelihood=GaussianLikelihood(), + fixed_prior_mean=target, + monotonic_idxs=[1], + num_induc=2, + num_samples=3, + num_rejection_samples=4, + ).cuda() + strat = Strategy( + lb=lb, + ub=ub, + model=m, + generator=MonotonicRejectionGenerator( + MonotonicMCLSE, + acqf_kwargs={"target": target}, + model_gen_options=model_gen_options, + ), + min_asks=1, + stimuli_per_trial=1, + outcome_types=["binary"], + use_gpu_modeling=True, + ) + # Fit + train_x = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) + train_y = torch.tensor([[1.0], [2.0], [3.0]]) + m.fit(train_x=train_x, train_y=train_y) + self.assertEqual(m.inducing_points.shape, torch.Size([2, 2])) + self.assertEqual(m.mean_module.constant.item(), 1.5) + # Predict + f, var = m.predict(train_x) + self.assertEqual(f.shape, torch.Size([3])) + self.assertEqual(var.shape, torch.Size([3])) + # Gen + strat.add_data(train_x, train_y) + Xopt = strat.gen() + self.assertEqual(Xopt.shape, torch.Size([1, 2])) + # Acquisition function + acq = strat.generator._instantiate_acquisition_fn(m) + self.assertEqual(acq.deriv_constraint_points.shape, torch.Size([2, 3])) + self.assertTrue( + torch.equal(acq.deriv_constraint_points[:, -1].cpu(), 2 * torch.ones(2)) + ) + self.assertEqual(acq.target, 1.5) + self.assertTrue(isinstance(acq.objective, IdentityMCObjective)) + + def test_classification_gpu(self): + # Init + target = 0.75 + model_gen_options = {"num_restarts": 1, "raw_samples": 3, "epochs": 5} + lb = torch.tensor([0, 0]) + ub = torch.tensor([4, 4]) + m = MonotonicRejectionGP( + lb=lb, + ub=ub, + likelihood=BernoulliLikelihood(), + fixed_prior_mean=target, + monotonic_idxs=[1], + num_induc=2, + num_samples=3, + num_rejection_samples=4, + ).cuda() + strat = Strategy( + lb=lb, + ub=ub, + model=m, + generator=MonotonicRejectionGenerator( + MonotonicMCLSE, + acqf_kwargs={"target": target, "objective": ProbitObjective()}, + model_gen_options=model_gen_options, + ), + min_asks=1, + stimuli_per_trial=1, + outcome_types=["binary"], + use_gpu_modeling=True, + ) + # Fit + train_x = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]) + train_y = torch.tensor([1.0, 1.0, 0.0]) + m.fit(train_x=train_x, train_y=train_y) + self.assertEqual(m.inducing_points.shape, torch.Size([2, 2])) + self.assertAlmostEqual(m.mean_module.constant.item(), norm.ppf(0.75)) + # Predict + f, var = m.predict(train_x) + self.assertEqual(f.shape, torch.Size([3])) + self.assertEqual(var.shape, torch.Size([3])) + # Gen + strat.add_data(train_x, train_y) + Xopt = strat.gen() + self.assertEqual(Xopt.shape, torch.Size([1, 2])) + # Acquisition function + acq = strat.generator._instantiate_acquisition_fn(m) + self.assertEqual(acq.deriv_constraint_points.shape, torch.Size([2, 3])) + self.assertTrue( + torch.equal(acq.deriv_constraint_points[:, -1].cpu(), 2 * torch.ones(2)) + ) + self.assertEqual(acq.target, 0.75) + self.assertTrue(isinstance(acq.objective, ProbitObjective)) + # Update + m.update(train_x=train_x[:2, :2], train_y=train_y[:2], warmstart=True) + self.assertEqual(m.train_inputs[0].shape, torch.Size([2, 3])) + + def test_classification_from_config_gpu(self): + seed = 1 + torch.manual_seed(seed) + np.random.seed(seed) + + n_init = 15 + n_opt = 1 + + config_str = f""" + [common] + parnames = [par1] + outcome_types = [binary] + stimuli_per_trial = 1 + strategy_names = [init_strat, opt_strat] + + [par1] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = {n_init} + + [opt_strat] + generator = MonotonicRejectionGenerator + model = MonotonicRejectionGP + acqf = MonotonicMCLSE + min_asks = {n_opt} + + [MonotonicRejectionGenerator] + use_gpu = True + + [MonotonicRejectionGP] + num_induc = 2 + num_samples = 3 + num_rejection_samples = 4 + monotonic_idxs = [0] + use_gpu = True + + [MonotonicMCLSE] + target = 0.75 + objective = ProbitObjective + """ + config = Config(config_str=config_str) + strat = SequentialStrategy.from_config(config) + + for _i in range(n_init + n_opt): + next_x = strat.gen() + strat.add_data(next_x, int(np.random.rand() > next_x)) diff --git a/tests_gpu/test_strategy.py b/tests_gpu/test_strategy.py index b1e229eaf..ec3dc91ca 100644 --- a/tests_gpu/test_strategy.py +++ b/tests_gpu/test_strategy.py @@ -26,19 +26,6 @@ def test_gpu_no_model_generator_warn(self): use_gpu_generating=True, ) - def test_no_gpu_acqf(self): - with self.assertWarns(UserWarning): - Strategy( - lb=[0], - ub=[1], - stimuli_per_trial=1, - outcome_types=["binary"], - min_asks=1, - model=GPClassificationModel(lb=[0], ub=[1]), - generator=OptimizeAcqfGenerator(acqf=MonotonicMCLSE), - use_gpu_generating=True, - ) - if __name__ == "__main__": unittest.main()