From 0dd53487e3607c08a63ff8da70845a44f1fd8eb2 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Mon, 18 Nov 2024 10:40:05 -0800 Subject: [PATCH] derivativeGP gpu support (#444) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/444 Add gpu support for derivative GP. I noticed that this model isn’t actually like a normal model that can show up in a live experiment with a config, but we should still make it work for GPU. I did most of that but it did require some pretty arcane shenanigans with overriding GPyTorch’s underlying handling of train_inputs. This in turn made me do some arcane mypy stuff. Differential Revision: D65515631 --- aepsych/kernels/pairwisekernel.py | 12 ++++---- aepsych/kernels/rbf_partial_grad.py | 6 ++-- aepsych/models/base.py | 12 +++++--- aepsych/models/derivative_gp.py | 4 ++- tests_gpu/models/test_derivative_gp.py | 39 ++++++++++++++++++++++++++ tests_gpu/test_strategy.py | 1 - 6 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 tests_gpu/models/test_derivative_gp.py diff --git a/aepsych/kernels/pairwisekernel.py b/aepsych/kernels/pairwisekernel.py index c30bbf512..f85a6bc3d 100644 --- a/aepsych/kernels/pairwisekernel.py +++ b/aepsych/kernels/pairwisekernel.py @@ -16,12 +16,12 @@ class PairwiseKernel(Kernel): """ def __init__( - self, latent_kernel: Kernel, is_partial_obs: bool=False, **kwargs + self, latent_kernel: Kernel, is_partial_obs: bool = False, **kwargs ) -> None: """ - Args: - latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP. - is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False. + Args: + latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP. + is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False. """ super(PairwiseKernel, self).__init__(**kwargs) @@ -40,11 +40,11 @@ def forward( x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. diag (bool): Should the Kernel compute the whole covariance matrix or just the diagonal? Defaults to False. - + Returns: torch.Tensor (or :class:`gpytorch.lazy.LazyTensor`) : A `b x n x m` or `n x m` tensor representing - the covariance matrix between `x1` and `x2`. + the covariance matrix between `x1` and `x2`. The exact size depends on the kernel's evaluation mode: * `full_covar`: `n x m` or `b x n x m` * `diag`: `n` or `b x n` diff --git a/aepsych/kernels/rbf_partial_grad.py b/aepsych/kernels/rbf_partial_grad.py index 09a204130..43c574d75 100644 --- a/aepsych/kernels/rbf_partial_grad.py +++ b/aepsych/kernels/rbf_partial_grad.py @@ -31,14 +31,14 @@ def forward( self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params: Any ) -> torch.Tensor: """Computes the covariance matrix between x1 and x2 based on the RBF - + Args: x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space. diag (bool): Should the Kernel compute the whole covariance matrix (False) or just the diagonal (True)? Defaults to False. - - + + Returns: torch.Tensor: A `b x n x m` or `n x m` tensor representing the covariance matrix between `x1` and `x2`. The exact size depends on the kernel's evaluation mode: diff --git a/aepsych/models/base.py b/aepsych/models/base.py index 7fbedf9b1..d5d564253 100644 --- a/aepsych/models/base.py +++ b/aepsych/models/base.py @@ -116,7 +116,7 @@ class AEPsychMixin(GPyTorchModel): extremum_solver = "Nelder-Mead" outcome_types: List[str] = [] - train_inputs: Optional[Tuple[torch.Tensor]] + train_inputs: Optional[Tuple[torch.Tensor, ...]] train_targets: Optional[torch.Tensor] @property @@ -393,7 +393,7 @@ def p_below_threshold( class AEPsychModelDeviceMixin(AEPsychMixin): - _train_inputs: Optional[Tuple[torch.Tensor]] + _train_inputs: Optional[Tuple[torch.Tensor, ...]] _train_targets: Optional[torch.Tensor] def set_train_data(self, inputs=None, targets=None, strict=False): @@ -423,13 +423,17 @@ def device(self) -> torch.device: return torch.device("cpu") @property - def train_inputs(self) -> Optional[Tuple[torch.Tensor]]: + def train_inputs(self) -> Optional[Tuple[torch.Tensor, ...]]: if self._train_inputs is None: return None # makes sure the tensors are on the right device, move in place + _train_inputs = [] for input in self._train_inputs: - input.to(self.device) + _train_inputs.append(input.to(self.device)) + + _tuple_inputs: Tuple[torch.Tensor, ...] = tuple(_train_inputs) + self._train_inputs = _tuple_inputs return self._train_inputs diff --git a/aepsych/models/derivative_gp.py b/aepsych/models/derivative_gp.py index 3e4250e87..eea6768da 100644 --- a/aepsych/models/derivative_gp.py +++ b/aepsych/models/derivative_gp.py @@ -23,7 +23,9 @@ from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy -class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel): +class MixedDerivativeVariationalGP( + gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel +): """A variational GP with mixed derivative observations. For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010. diff --git a/tests_gpu/models/test_derivative_gp.py b/tests_gpu/models/test_derivative_gp.py new file mode 100644 index 000000000..200ef62eb --- /dev/null +++ b/tests_gpu/models/test_derivative_gp.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from aepsych import Config, SequentialStrategy +from aepsych.models.derivative_gp import MixedDerivativeVariationalGP +from botorch.fit import fit_gpytorch_mll +from botorch.utils.testing import BotorchTestCase +from gpytorch.likelihoods import BernoulliLikelihood +from gpytorch.mlls.variational_elbo import VariationalELBO + + +class TestDerivativeGP(BotorchTestCase): + def test_MixedDerivativeVariationalGP_gpu(self): + train_x = torch.cat( + (torch.tensor([1.0, 2.0, 3.0, 4.0]).unsqueeze(1), torch.zeros(4, 1)), dim=1 + ) + train_y = torch.tensor([1.0, 2.0, 3.0, 4.0]) + m = MixedDerivativeVariationalGP( + train_x=train_x, + train_y=train_y, + inducing_points=train_x, + fixed_prior_mean=0.5, + ).cuda() + + self.assertEqual(m.mean_module.constant.item(), 0.5) + self.assertEqual( + m.covar_module.base_kernel.raw_lengthscale.shape, torch.Size([1, 1]) + ) + mll = VariationalELBO( + likelihood=BernoulliLikelihood(), model=m, num_data=train_y.numel() + ).cuda() + mll = fit_gpytorch_mll(mll) + test_x = torch.tensor([[1.0, 0], [3.0, 1.0]]).cuda() + m(test_x) diff --git a/tests_gpu/test_strategy.py b/tests_gpu/test_strategy.py index 8d1107d56..ec3dc91ca 100644 --- a/tests_gpu/test_strategy.py +++ b/tests_gpu/test_strategy.py @@ -27,6 +27,5 @@ def test_gpu_no_model_generator_warn(self): ) - if __name__ == "__main__": unittest.main()