Skip to content

Commit

Permalink
Ensure that inducing point allocators can handle cases where the inpu…
Browse files Browse the repository at this point in the history
…ts don't match their dim (#484)

Summary:
Pull Request resolved: #484

It's not guaranteed that the inputs will match the dimensionality of the inducing point allocator. For allocators that directly act on the inputs, this needs to be handled. This usually happens when the inputs are augmented with extra indices (e.g., MonotonicRejectionGP).

The KMeans allocator needs to slice off the extra parts of the dimension.

GreedyVarianceAllocator needs to keep them for evaluating with the kernel, then slice them off from the results to make inducing points the right shape.

Reviewed By: crasanders

Differential Revision: D67158938
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 14, 2024
1 parent 4cf6735 commit 25dbfb3
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 3 deletions.
10 changes: 9 additions & 1 deletion aepsych/models/inducing_points/greedy_variance_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@ def allocate_inducing_points(
return self._allocate_dummy_points(num_inducing=num_inducing)
else:
self.last_allocator_used = self.__class__
return BaseGreedyVarianceReduction.allocate_inducing_points(

points = BaseGreedyVarianceReduction.allocate_inducing_points(
self,
inputs=inputs,
covar_module=covar_module,
num_inducing=num_inducing,
input_batch_shape=input_batch_shape,
)

if points.shape[1] != self.dim:
# We assume if the shape doesn't match the dim, it's because the points
# were augmented by adding it to be end of the shape
points = points[:, : self.dim, ...]

return points
4 changes: 4 additions & 0 deletions aepsych/models/inducing_points/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def allocate_inducing_points(
if inputs is None: # Dummy points
return self._allocate_dummy_points(num_inducing=num_inducing)

if inputs.shape[1] != self.dim:
# The inputs were augmented somehow, assuming it was added to the end of dims
inputs = inputs[:, : self.dim, ...]

self.last_allocator_used = self.__class__

# Ensure inputs are unique to avoid duplication issues with k-means++
Expand Down
2 changes: 1 addition & 1 deletion aepsych/models/monotonic_rejection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None:
allocator=self.inducing_point_method,
inducing_size=self.inducing_size,
covar_module=self.covar_module,
X=self.train_inputs[0],
X=self._augment_with_deriv_index(self.train_inputs[0], 0),
)
self._set_model(train_x, train_y)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_reset_variational_strategy(self, model_constructor):
model = model_constructor(
stim_dim=stim_dim,
floor=0,
inducing_point_method=AutoAllocator(dim=1),
inducing_point_method=AutoAllocator(dim=2),
)
link = FloorLogitObjective(floor=0)
y = torch.bernoulli(link(self.f))
Expand Down
62 changes: 62 additions & 0 deletions tests/test_points_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unittest

import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.kernels import RBFKernelPartialObsGrad
from aepsych.models.gp_classification import GPClassificationModel
from aepsych.models.inducing_points import (
AutoAllocator,
Expand Down Expand Up @@ -136,6 +139,38 @@ def test_kmeans_allocator_from_model_config(self):
strat = Strategy.from_config(config, "init_strat")
self.assertTrue(isinstance(strat.model.inducing_point_method, KMeansAllocator))

def test_kmeans_shape_handling(self):
allocator = KMeansAllocator(dim=1)

inputs = torch.tensor([[1], [2], [3]])

inputs_aug = torch.hstack([inputs, torch.zeros(size=[3, 1])])

points = allocator.allocate_inducing_points(inputs=inputs_aug, num_inducing=2)
self.assertTrue(points.shape == (2, 1))

points = allocator.allocate_inducing_points(inputs=inputs_aug, num_inducing=100)
self.assertTrue(torch.equal(points, inputs))

def test_auto_allocator_allocate_inducing_points(self):
# Mock data for testing
train_X = torch.randint(low=0, high=100, size=(100, 2), dtype=torch.float64)
train_Y = torch.rand(100, 1)
model = GPClassificationModel(
inducing_point_method=AutoAllocator(dim=2),
inducing_size=10,
dim=2,
)

[KMeansAllocator]
"""
config = Config()
config.update(config_str=config_str)
allocator = AutoAllocator.from_config(config)
self.assertTrue(isinstance(allocator, AutoAllocator))
self.assertTrue(allocator.dim == 1)
def test_auto_allocator_allocate_inducing_points(self):
# Mock data for testing
train_X = torch.randint(low=0, high=100, size=(100, 2), dtype=torch.float64)
Expand Down Expand Up @@ -258,6 +293,33 @@ def test_greedy_variance_from_config(self):
isinstance(strat.model.inducing_point_method, GreedyVarianceReduction)
)
def test_greedy_variance_shape_handling(self):
allocator = GreedyVarianceReduction(dim=1)
inputs = torch.tensor([[1], [2], [3]])
inputs_aug = torch.hstack([inputs, torch.zeros(size=[3, 1])])
ls_prior = gpytorch.priors.GammaPrior(
concentration=4.6, rate=1.0, transform=lambda x: 1 / x
)
ls_prior_mode = ls_prior.rate / (ls_prior.concentration + 1)
ls_constraint = gpytorch.constraints.GreaterThan(
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
)
covar_module = gpytorch.kernels.ScaleKernel(
RBFKernelPartialObsGrad(
lengthscale_prior=ls_prior,
lengthscale_constraint=ls_constraint,
ard_num_dims=1,
),
outputscale_prior=gpytorch.priors.SmoothedBoxPrior(a=1, b=4),
)
points = allocator.allocate_inducing_points(
inputs=inputs_aug, covar_module=covar_module, num_inducing=2
)
self.assertTrue(points.shape[1] == 1)
def test_fixed_allocator_allocate_inducing_points(self):
config_str = """
[common]
Expand Down

0 comments on commit 25dbfb3

Please sign in to comment.