Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

derivativeGP gpu support #444

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aepsych/generators/monotonic_rejection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def gen(
)

# Augment bounds with deriv indicator
bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1)
# Fix deriv indicator to 0 during optimization
fixed_features = {(bounds.shape[1] - 1): 0.0}
# Fix explore features to random values
Expand Down
12 changes: 6 additions & 6 deletions aepsych/kernels/pairwisekernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class PairwiseKernel(Kernel):
"""

def __init__(
self, latent_kernel: Kernel, is_partial_obs: bool=False, **kwargs
self, latent_kernel: Kernel, is_partial_obs: bool = False, **kwargs
) -> None:
"""
Args:
latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP.
is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False.
Args:
latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP.
is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False.
"""
super(PairwiseKernel, self).__init__(**kwargs)

Expand All @@ -40,11 +40,11 @@ def forward(
x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
diag (bool): Should the Kernel compute the whole covariance matrix or just the diagonal? Defaults to False.


Returns:
torch.Tensor (or :class:`gpytorch.lazy.LazyTensor`) : A `b x n x m` or `n x m` tensor representing
the covariance matrix between `x1` and `x2`.
the covariance matrix between `x1` and `x2`.
The exact size depends on the kernel's evaluation mode:
* `full_covar`: `n x m` or `b x n x m`
* `diag`: `n` or `b x n`
Expand Down
6 changes: 3 additions & 3 deletions aepsych/kernels/rbf_partial_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def forward(
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params: Any
) -> torch.Tensor:
"""Computes the covariance matrix between x1 and x2 based on the RBF

Args:
x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
diag (bool): Should the Kernel compute the whole covariance matrix (False) or just the diagonal (True)? Defaults to False.



Returns:
torch.Tensor: A `b x n x m` or `n x m` tensor representing the covariance matrix between `x1` and `x2`.
The exact size depends on the kernel's evaluation mode:
Expand Down
2 changes: 1 addition & 1 deletion aepsych/means/constant_partial_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
idx = input[..., -1].to(dtype=torch.long) > 0
mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :])
sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]])
mean = torch.zeros(sz)
mean = torch.zeros(sz).to(input)
mean[~idx] = mean_fit
return mean
19 changes: 14 additions & 5 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class AEPsychMixin(GPyTorchModel):

extremum_solver = "Nelder-Mead"
outcome_types: List[str] = []
train_inputs: Optional[Tuple[torch.Tensor]]
train_inputs: Optional[Tuple[torch.Tensor, ...]]
train_targets: Optional[torch.Tensor]

@property
Expand Down Expand Up @@ -393,7 +393,7 @@ def p_below_threshold(


class AEPsychModelDeviceMixin(AEPsychMixin):
_train_inputs: Optional[Tuple[torch.Tensor]]
_train_inputs: Optional[Tuple[torch.Tensor, ...]]
_train_targets: Optional[torch.Tensor]

def set_train_data(self, inputs=None, targets=None, strict=False):
Expand All @@ -415,16 +415,25 @@ def set_train_data(self, inputs=None, targets=None, strict=False):
def device(self) -> torch.device:
# We assume all models have some parameters and all models will only use one device
# notice that this has no setting, don't let users set device, use .to().
return next(self.parameters()).device
try:
return next(self.parameters()).device
except (
AttributeError
): # Fallback for cases where we need device before we have params
return torch.device("cpu")

@property
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
def train_inputs(self) -> Optional[Tuple[torch.Tensor, ...]]:
if self._train_inputs is None:
return None

# makes sure the tensors are on the right device, move in place
_train_inputs = []
for input in self._train_inputs:
input.to(self.device)
_train_inputs.append(input.to(self.device))

_tuple_inputs: Tuple[torch.Tensor, ...] = tuple(_train_inputs)
self._train_inputs = _tuple_inputs

return self._train_inputs

Expand Down
6 changes: 5 additions & 1 deletion aepsych/models/derivative_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychModelDeviceMixin
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel
Expand All @@ -22,7 +23,9 @@
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy


class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel):
class MixedDerivativeVariationalGP(
gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel
):
"""A variational GP with mixed derivative observations.

For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
Expand Down Expand Up @@ -99,6 +102,7 @@ def __init__(
self._num_outputs = 1
self.train_inputs = (train_x,)
self.train_targets = train_y
self.to(self.device) # Needed to prep for below
self(train_x) # Necessary for CholeskyVariationalDistribution

def forward(self, x: torch.Tensor) -> MultivariateNormal:
Expand Down
16 changes: 10 additions & 6 deletions aepsych/models/monotonic_projection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,17 @@ def posterior(
# using numpy because torch doesn't support vectorized linspace,
# pytorch/issues/61292
grid: Union[np.ndarray, torch.Tensor] = np.linspace(
self.lb[dim],
X[:, dim].numpy(),
self.lb[dim].cpu().numpy(),
X[:, dim].cpu().numpy(),
s + 1,
) # (s+1 x n)
grid = torch.tensor(grid[:-1, :], dtype=X.dtype) # Drop x; (s x n)
X_aug[(1 + i * s) : (1 + (i + 1) * s), :, dim] = grid
# X_aug[0, :, :] is X, and then subsequent indices are points in the grids
# Predict marginal distributions on X_aug

X = X.to(self.device)
X_aug = X_aug.to(self.device)
with torch.no_grad():
post_aug = super().posterior(X=X_aug)
mu_aug = post_aug.mean.squeeze() # (m*s+1 x n)
Expand All @@ -158,12 +161,13 @@ def posterior(
# Adjust the whole covariance matrix to accomadate the projected marginals
with torch.no_grad():
post = super().posterior(X=X)
R = cov2corr(post.distribution.covariance_matrix.squeeze().numpy())
S_proj = torch.tensor(corr2cov(R, sigma_proj.numpy()), dtype=X.dtype)
R = cov2corr(post.distribution.covariance_matrix.squeeze().cpu().numpy())
S_proj = torch.tensor(corr2cov(R, sigma_proj.cpu().numpy()), dtype=X.dtype)
mvn_proj = gpytorch.distributions.MultivariateNormal(
mu_proj.unsqueeze(0),
S_proj.unsqueeze(0),
mu_proj.unsqueeze(0).to(self.device),
S_proj.unsqueeze(0).to(self.device),
)

return GPyTorchPosterior(mvn_proj)

def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
Expand Down
25 changes: 14 additions & 11 deletions aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aepsych.factory.monotonic import monotonic_mean_covar_factory
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from botorch.fit import fit_gpytorch_mll
Expand All @@ -32,7 +32,7 @@
from torch import Tensor


class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP):
"""A monotonic GP using rejection sampling.
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
Expand Down Expand Up @@ -83,15 +83,15 @@ def __init__(
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
"""
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
lb, ub, self.dim = _process_bounds(lb, ub, dim)
if likelihood is None:
likelihood = BernoulliLikelihood()

self.inducing_size = num_induc
self.inducing_point_method = inducing_point_method
inducing_points = select_inducing_points(
inducing_size=self.inducing_size,
bounds=self.bounds,
bounds=torch.stack((lb, ub)),
method="sobol",
)

Expand Down Expand Up @@ -134,7 +134,9 @@ def __init__(

super().__init__(variational_strategy)

self.bounds_ = torch.stack([self.lb, self.ub])
self.register_buffer("lb", lb)
self.register_buffer("ub", ub)
self.register_buffer("bounds_", torch.stack([self.lb, self.ub]))
self.mean_module = mean_module
self.covar_module = covar_module
self.likelihood = likelihood
Expand All @@ -144,7 +146,7 @@ def __init__(
self.num_samples = num_samples
self.num_rejection_samples = num_rejection_samples
self.fixed_prior_mean = fixed_prior_mean
self.inducing_points = inducing_points
self.register_buffer("inducing_points", inducing_points)

def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
"""Fit the model
Expand All @@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
X=self.train_inputs[0],
bounds=self.bounds,
method=self.inducing_point_method,
)
).to(self.device)
self._set_model(train_x, train_y)

def _set_model(
Expand Down Expand Up @@ -284,13 +286,14 @@ def predict_probability(
return self.predict(x, probability_space=True)

def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor:
x = x.to(self.device)
return torch.cat(
(x, indx * torch.ones(x.shape[0], 1)),
(x, indx * torch.ones(x.shape[0], 1).to(self.device)),
dim=1,
)

def _get_deriv_constraint_points(self) -> Tensor:
deriv_cp = torch.tensor([])
deriv_cp = torch.tensor([]).to(self.device)
for i in self.monotonic_idxs:
induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1)
deriv_cp = torch.cat((deriv_cp, induc_i), dim=0)
Expand All @@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
@classmethod
def from_config(cls, config: Config) -> MonotonicRejectionGP:
classname = cls.__name__
num_induc = config.gettensor(classname, "num_induc", fallback=25)
num_samples = config.gettensor(classname, "num_samples", fallback=250)
num_induc = config.getint(classname, "num_induc", fallback=25)
num_samples = config.getint(classname, "num_samples", fallback=250)
num_rejection_samples = config.getint(
classname, "num_rejection_samples", fallback=5000
)
Expand Down
17 changes: 3 additions & 14 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ class Strategy(object):

_n_eval_points: int = 1000

no_gpu_acqfs = (
MonotonicMCAcquisition,
MonotonicBernoulliMCMutualInformation,
MonotonicMCPosteriorVariance,
MonotonicMCLSE,
)

def __init__(
self,
generator: Union[AEPsychGenerator, ParameterTransformedGenerator],
Expand Down Expand Up @@ -182,13 +175,7 @@ def __init__(
)
self.generator_device = torch.device("cpu")
else:
if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs:
warnings.warn(
f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.",
UserWarning,
)
self.generator_device = torch.device("cpu")
elif not torch.cuda.is_available():
if not torch.cuda.is_available():
warnings.warn(
f"GPU requested for generator {type(generator).__name__}, but no GPU found! Using CPU instead.",
UserWarning,
Expand Down Expand Up @@ -283,9 +270,11 @@ def normalize_inputs(
x = x[None, :]

if self.x is not None:
x = x.to(self.x)
x = torch.cat((self.x, x), dim=0)

if self.y is not None:
y = y.to(self.y)
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
Expand Down
6 changes: 6 additions & 0 deletions tests_gpu/acquisition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
51 changes: 51 additions & 0 deletions tests_gpu/acquisition/test_monotonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.acquisition.objective import ProbitObjective
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
from botorch.acquisition.objective import IdentityMCObjective
from botorch.utils.testing import BotorchTestCase


class TestMonotonicAcq(BotorchTestCase):
def test_monotonic_acq_gpu(self):
# Init
train_X_aug = torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]]
).cuda()
deriv_constraint_points = torch.tensor(
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]]
).cuda()
train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda()

m = MixedDerivativeVariationalGP(
train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug
).cuda()
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
)
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))
acq = MonotonicMCLSE(
model=m,
deriv_constraint_points=deriv_constraint_points,
num_samples=5,
num_rejection_samples=8,
target=1.9,
objective=ProbitObjective(),
).cuda()
# forward
acq(train_X_aug)
Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0)
posterior = m.posterior(Xfull)
samples = acq.sampler(posterior)
self.assertEqual(samples.shape, torch.Size([5, 6, 1]))
Loading
Loading