Skip to content

Commit

Permalink
Monotonic rejection model and generator
Browse files Browse the repository at this point in the history
Summary: monotonic rejection model GPU support, since they're tied to the generator, we also ensure the generators are gpu ready as well.

Differential Revision: D65638150
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 15, 2024
1 parent d096c6a commit 940cf31
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 21 deletions.
2 changes: 1 addition & 1 deletion aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def gen(
)

# Augment bounds with deriv indicator
bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1)
# Fix deriv indicator to 0 during optimization
fixed_features = {(bounds.shape[1] - 1): 0.0}
# Fix explore features to random values
Expand Down
7 changes: 6 additions & 1 deletion aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,12 @@ def set_train_data(self, inputs=None, targets=None, strict=False):
def device(self) -> torch.device:
# We assume all models have some parameters and all models will only use one device
# notice that this has no setting, don't let users set device, use .to().
return next(self.parameters()).device
try:
return next(self.parameters()).device
except (
AttributeError
): # Fallback for cases where we need device before we have params
return torch.device("cpu")

@property
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
Expand Down
26 changes: 15 additions & 11 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aepsych.factory.monotonic import monotonic_mean_covar_factory
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from botorch.fit import fit_gpytorch_mll
Expand All @@ -32,7 +32,7 @@
from torch import Tensor


class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP):
"""A monotonic GP using rejection sampling.
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
Expand Down Expand Up @@ -83,15 +83,15 @@ def __init__(
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
lb, ub, self.dim = _process_bounds(lb, ub, dim)
if likelihood is None:
likelihood = BernoulliLikelihood()

self.inducing_size = num_induc
self.inducing_point_method = inducing_point_method
inducing_points = select_inducing_points(
inducing_size=self.inducing_size,
bounds=self.bounds,
bounds=torch.stack((lb, ub)),
method="sobol",
)

Expand Down Expand Up @@ -134,7 +134,9 @@ def __init__(

super().__init__(variational_strategy)

self.bounds_ = torch.stack([self.lb, self.ub])
self.register_buffer("lb", lb)
self.register_buffer("ub", ub)
self.register_buffer("bounds_", torch.stack([self.lb, self.ub]))
self.mean_module = mean_module
self.covar_module = covar_module
self.likelihood = likelihood
Expand All @@ -144,7 +146,8 @@ def __init__(
self.num_samples = num_samples
self.num_rejection_samples = num_rejection_samples
self.fixed_prior_mean = fixed_prior_mean
self.inducing_points = inducing_points
# self.inducing_points = inducing_points
self.register_buffer("inducing_points", inducing_points)

def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
"""Fit the model
Expand All @@ -161,7 +164,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
X=self.train_inputs[0],
bounds=self.bounds,
method=self.inducing_point_method,
)
).to(self.device)
self._set_model(train_x, train_y)

def _set_model(
Expand Down Expand Up @@ -284,13 +287,14 @@ def predict_probability(
return self.predict(x, probability_space=True)

def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor:
x = x.to(self.device)
return torch.cat(
(x, indx * torch.ones(x.shape[0], 1)),
(x, indx * torch.ones(x.shape[0], 1).to(self.device)),
dim=1,
)

def _get_deriv_constraint_points(self) -> Tensor:
deriv_cp = torch.tensor([])
deriv_cp = torch.tensor([]).to(self.device)
for i in self.monotonic_idxs:
induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1)
deriv_cp = torch.cat((deriv_cp, induc_i), dim=0)
Expand All @@ -299,8 +303,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
@classmethod
def from_config(cls, config: Config) -> MonotonicRejectionGP:
classname = cls.__name__
num_induc = config.gettensor(classname, "num_induc", fallback=25)
num_samples = config.gettensor(classname, "num_samples", fallback=250)
num_induc = config.getint(classname, "num_induc", fallback=25)
num_samples = config.getint(classname, "num_samples", fallback=250)
num_rejection_samples = config.getint(
classname, "num_rejection_samples", fallback=5000
)
Expand Down
11 changes: 3 additions & 8 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ class Strategy(object):

_n_eval_points: int = 1000

no_gpu_acqfs = (
MonotonicMCAcquisition,
MonotonicBernoulliMCMutualInformation,
MonotonicMCPosteriorVariance,
MonotonicMCLSE,
)

def __init__(
self,
generator: Union[AEPsychGenerator, ParameterTransformedGenerator],
Expand Down Expand Up @@ -182,7 +175,7 @@ def __init__(
)
self.generator_device = torch.device("cpu")
else:
if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs:
if hasattr(generator, "acqf"):
warnings.warn(
f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.",
UserWarning,
Expand Down Expand Up @@ -283,9 +276,11 @@ def normalize_inputs(
x = x[None, :]

if self.x is not None:
x = x.to(self.x)
x = torch.cat((self.x, x), dim=0)

if self.y is not None:
y = y.to(self.y)
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
Expand Down
6 changes: 6 additions & 0 deletions tests_gpu/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
51 changes: 51 additions & 0 deletions tests_gpu/acquisition/test_monotonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.acquisition.objective import ProbitObjective
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
from botorch.acquisition.objective import IdentityMCObjective
from botorch.utils.testing import BotorchTestCase


class TestMonotonicAcq(BotorchTestCase):
def test_monotonic_acq_gpu(self):
# Init
train_X_aug = torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]]
).cuda()
deriv_constraint_points = torch.tensor(
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]]
).cuda()
train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda()

m = MixedDerivativeVariationalGP(
train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug
).cuda()
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
)
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
objective=ProbitObjective(),
).cuda()
# forward
acq(train_X_aug)
Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0)
posterior = m.posterior(Xfull)
samples = acq.sampler(posterior)
self.assertEqual(samples.shape, torch.Size([5, 6, 1]))
188 changes: 188 additions & 0 deletions tests_gpu/models/test_monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch

# run on single threads to keep us from deadlocking weirdly in CI
if "CI" in os.environ or "SANDCASTLE" in os.environ:
torch.set_num_threads(1)

import numpy as np
from aepsych import Config
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.acquisition.objective import ProbitObjective
from aepsych.generators import MonotonicRejectionGenerator
from aepsych.models import MonotonicRejectionGP
from aepsych.strategy import SequentialStrategy, Strategy
from botorch.acquisition.objective import IdentityMCObjective
from botorch.utils.testing import BotorchTestCase
from gpytorch.likelihoods import BernoulliLikelihood, GaussianLikelihood
from scipy.stats import norm


class MonotonicRejectionGPLSETest(BotorchTestCase):
def test_regression_gpu(self):
# Init
target = 1.5
model_gen_options = {"num_restarts": 1, "raw_samples": 3, "epochs": 5}
lb = torch.tensor([0, 0])
ub = torch.tensor([4, 4])
m = MonotonicRejectionGP(
lb=lb,
ub=ub,
likelihood=GaussianLikelihood(),
fixed_prior_mean=target,
monotonic_idxs=[1],
num_induc=2,
num_samples=3,
num_rejection_samples=4,
).cuda()
strat = Strategy(
lb=lb,
ub=ub,
model=m,
generator=MonotonicRejectionGenerator(
MonotonicMCLSE,
acqf_kwargs={"target": target},
model_gen_options=model_gen_options,
),
min_asks=1,
stimuli_per_trial=1,
outcome_types=["binary"],
use_gpu_modeling=True,
)
# Fit
train_x = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
train_y = torch.tensor([[1.0], [2.0], [3.0]])
m.fit(train_x=train_x, train_y=train_y)
self.assertEqual(m.inducing_points.shape, torch.Size([2, 2]))
self.assertEqual(m.mean_module.constant.item(), 1.5)
# Predict
f, var = m.predict(train_x)
self.assertEqual(f.shape, torch.Size([3]))
self.assertEqual(var.shape, torch.Size([3]))
# Gen
strat.add_data(train_x, train_y)
Xopt = strat.gen()
self.assertEqual(Xopt.shape, torch.Size([1, 2]))
# Acquisition function
acq = strat.generator._instantiate_acquisition_fn(m)
self.assertEqual(acq.deriv_constraint_points.shape, torch.Size([2, 3]))
self.assertTrue(
torch.equal(acq.deriv_constraint_points[:, -1].cpu(), 2 * torch.ones(2))
)
self.assertEqual(acq.target, 1.5)
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))

def test_classification_gpu(self):
# Init
target = 0.75
model_gen_options = {"num_restarts": 1, "raw_samples": 3, "epochs": 5}
lb = torch.tensor([0, 0])
ub = torch.tensor([4, 4])
m = MonotonicRejectionGP(
lb=lb,
ub=ub,
likelihood=BernoulliLikelihood(),
fixed_prior_mean=target,
monotonic_idxs=[1],
num_induc=2,
num_samples=3,
num_rejection_samples=4,
).cuda()
strat = Strategy(
lb=lb,
ub=ub,
model=m,
generator=MonotonicRejectionGenerator(
MonotonicMCLSE,
acqf_kwargs={"target": target, "objective": ProbitObjective()},
model_gen_options=model_gen_options,
),
min_asks=1,
stimuli_per_trial=1,
outcome_types=["binary"],
use_gpu_modeling=True,
)
# Fit
train_x = torch.tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
train_y = torch.tensor([1.0, 1.0, 0.0])
m.fit(train_x=train_x, train_y=train_y)
self.assertEqual(m.inducing_points.shape, torch.Size([2, 2]))
self.assertAlmostEqual(m.mean_module.constant.item(), norm.ppf(0.75))
# Predict
f, var = m.predict(train_x)
self.assertEqual(f.shape, torch.Size([3]))
self.assertEqual(var.shape, torch.Size([3]))
# Gen
strat.add_data(train_x, train_y)
Xopt = strat.gen()
self.assertEqual(Xopt.shape, torch.Size([1, 2]))
# Acquisition function
acq = strat.generator._instantiate_acquisition_fn(m)
self.assertEqual(acq.deriv_constraint_points.shape, torch.Size([2, 3]))
self.assertTrue(
torch.equal(acq.deriv_constraint_points[:, -1].cpu(), 2 * torch.ones(2))
)
self.assertEqual(acq.target, 0.75)
self.assertTrue(isinstance(acq.objective, ProbitObjective))
# Update
m.update(train_x=train_x[:2, :2], train_y=train_y[:2], warmstart=True)
self.assertEqual(m.train_inputs[0].shape, torch.Size([2, 3]))

def test_classification_from_config_gpu(self):
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)

n_init = 15
n_opt = 1

config_str = f"""
[common]
parnames = [par1]
outcome_types = [binary]
stimuli_per_trial = 1
strategy_names = [init_strat, opt_strat]
[par1]
par_type = continuous
lower_bound = 0
upper_bound = 1
[init_strat]
generator = SobolGenerator
min_asks = {n_init}
[opt_strat]
generator = MonotonicRejectionGenerator
model = MonotonicRejectionGP
acqf = MonotonicMCLSE
min_asks = {n_opt}
[MonotonicRejectionGenerator]
use_gpu = True
[MonotonicRejectionGP]
num_induc = 2
num_samples = 3
num_rejection_samples = 4
monotonic_idxs = [0]
use_gpu = True
[MonotonicMCLSE]
target = 0.75
objective = ProbitObjective
"""
config = Config(config_str=config_str)
strat = SequentialStrategy.from_config(config)

for _i in range(n_init + n_opt):
next_x = strat.gen()
strat.add_data(next_x, int(np.random.rand() > next_x))

0 comments on commit 940cf31

Please sign in to comment.