Skip to content

Commit

Permalink
Delete select_inducing_point utility function (#480)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #480

Unnecessary utility function replaced by just calling the Allocator class method.

This diff removes using strings to select allocators.

Differential Revision: D67068021
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 14, 2024
1 parent 4c090f5 commit 890d6f6
Showing 1 changed file with 69 additions and 102 deletions.
171 changes: 69 additions & 102 deletions tests/test_points_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,6 @@


class TestInducingPointAllocators(unittest.TestCase):
def test_sobol_allocator_from_config(self):
config_str = """
[common]
parnames = [par1]
[par1]
par_type = continuous
lower_bound = 0.0
upper_bound = 1.0
log_scale = true
"""
config = Config()
config.update(config_str=config_str)
allocator = SobolAllocator.from_config(config)

# Check if bounds are correctly loaded
expected_bounds = torch.tensor([[0.0], [1.0]])
self.assertTrue(torch.equal(allocator.bounds, expected_bounds))

def test_sobol_allocator_allocate_inducing_points(self):
bounds = torch.tensor([[0.0], [1.0]])
allocator = SobolAllocator(bounds=bounds, dim=1)
Expand Down Expand Up @@ -92,43 +72,39 @@ def test_sobol_allocator_from_model_config(self):
)
)

def test_kmeans_allocator_from_config(self):
config_str = """
[common]
parnames = [par1]
[par1]
par_type = continuous
lower_bound = 0.0
upper_bound = 1.0
log_scale = true
[KMeansAllocator]
"""
config = Config()
config.update(config_str=config_str)
allocator = KMeansAllocator.from_config(config)

self.assertTrue(isinstance(allocator, KMeansAllocator))
self.assertTrue(allocator.dim == 1)

def test_kmeans_allocator_allocate_inducing_points(self):
inputs = torch.rand(100, 2) # 100 points in 2D
allocator = KMeansAllocator(dim=2)
# Mock data for testing
train_X = torch.randint(low=0, high=100, size=(100, 2), dtype=torch.float64)
train_Y = torch.rand(100, 1)
model = GPClassificationModel(
inducing_point_method=KMeansAllocator(dim=2),
inducing_size=10,
dim=2,
)

# Test dummy
inducing_points = allocator.allocate_inducing_points(num_inducing=10)
# Check if model has dummy points
self.assertIsNone(model.inducing_point_method.last_allocator_used)
self.assertTrue(torch.all(model.variational_strategy.inducing_points == 0))
self.assertTrue(model.variational_strategy.inducing_points.shape == (10, 2))

self.assertEqual(inducing_points.shape, (10, 2))
self.assertIsNone(allocator.last_allocator_used)
# Fit with small data leess than inducing_size
model.fit(train_X[:9], train_Y[:9])

# Test real inducing points
inducing_points = allocator.allocate_inducing_points(
inputs=inputs, num_inducing=10
)
self.assertIs(model.inducing_point_method.last_allocator_used, KMeansAllocator)
inducing_points = model.variational_strategy.inducing_points
self.assertTrue(inducing_points.shape == (9, 2))
# We made ints, so mod 1 should be 0s, so we know these were the original inputs
self.assertTrue(torch.all(inducing_points % 1 == 0))

# Then fit the model and check that the inducing points are updated
model.fit(train_X, train_Y)

self.assertEqual(inducing_points.shape, (10, 2))
self.assertIs(allocator.last_allocator_used, KMeansAllocator)
self.assertIs(model.inducing_point_method.last_allocator_used, KMeansAllocator)
inducing_points = model.variational_strategy.inducing_points
self.assertTrue(inducing_points.shape == (10, 2))
# It's highly unlikely clustered will all be integers, so check against extents too
self.assertFalse(torch.all(inducing_points % 1 == 0))
self.assertTrue(torch.all((inducing_points >= 0) & (inducing_points <= 100)))

def test_kmeans_allocator_from_model_config(self):
config_str = """
Expand Down Expand Up @@ -160,43 +136,40 @@ def test_kmeans_allocator_from_model_config(self):
strat = Strategy.from_config(config, "init_strat")
self.assertTrue(isinstance(strat.model.inducing_point_method, KMeansAllocator))

def test_auto_allocator_from_config(self):
config_str = """
[common]
parnames = [par1]
[par1]
par_type = continuous
lower_bound = 0.0
upper_bound = 1.0
log_scale = true
[KMeansAllocator]
"""
config = Config()
config.update(config_str=config_str)
allocator = AutoAllocator.from_config(config)

self.assertTrue(isinstance(allocator, AutoAllocator))
self.assertTrue(allocator.dim == 1)

def test_auto_allocator_allocate_inducing_points(self):
inputs = torch.rand(100, 2) # 100 points in 2D
allocator = AutoAllocator(dim=2)
# Mock data for testing
train_X = torch.randint(low=0, high=100, size=(100, 2), dtype=torch.float64)
train_Y = torch.rand(100, 1)
model = GPClassificationModel(
inducing_point_method=AutoAllocator(dim=2),
inducing_size=10,
dim=2,
)

# Test dummy
inducing_points = allocator.allocate_inducing_points(num_inducing=10)
# Check if model has dummy points
self.assertIsNone(model.inducing_point_method.last_allocator_used)
self.assertTrue(torch.all(model.variational_strategy.inducing_points == 0))
self.assertTrue(model.variational_strategy.inducing_points.shape == (10, 2))

self.assertEqual(inducing_points.shape, (10, 2))
self.assertIsNone(allocator.last_allocator_used)
# Fit with small data leess than inducing_size
model.fit(train_X[:9], train_Y[:9])

# Test real inducing points
inducing_points = allocator.allocate_inducing_points(
inputs=inputs, num_inducing=10
)
# We still check for the base allocator
self.assertIs(model.inducing_point_method.last_allocator_used, KMeansAllocator)
inducing_points = model.variational_strategy.inducing_points
self.assertTrue(inducing_points.shape == (9, 2))
# We made ints, so mod 1 should be 0s, so we know these were the original inputs
self.assertTrue(torch.all(inducing_points % 1 == 0))

self.assertEqual(inducing_points.shape, (10, 2))
self.assertIs(allocator.last_allocator_used, KMeansAllocator)
# Then fit the model and check that the inducing points are updated
model.fit(train_X, train_Y)

self.assertIs(model.inducing_point_method.last_allocator_used, KMeansAllocator)
inducing_points = model.variational_strategy.inducing_points
self.assertTrue(inducing_points.shape == (10, 2))
# It's highly unlikely clustered will all be integers, so check against extents too
self.assertFalse(torch.all(inducing_points % 1 == 0))
self.assertTrue(torch.all((inducing_points >= 0) & (inducing_points <= 100)))

def test_auto_allocator_from_model_config(self):
config_str = """
Expand Down Expand Up @@ -230,34 +203,28 @@ def test_auto_allocator_from_model_config(self):

def test_greedy_variance_reduction_allocate_inducing_points(self):
# Mock data for testing
train_X = torch.rand(100, 1)
train_X = torch.randint(low=0, high=100, size=(100, 2), dtype=torch.float64)
train_Y = torch.rand(100, 1)
lb = torch.tensor([0])
ub = torch.tensor([1])
bounds = torch.stack([lb, ub])
model = GPClassificationModel(
inducing_point_method=GreedyVarianceReduction(dim=1),
inducing_point_method=GreedyVarianceReduction(dim=2),
inducing_size=10,
dim=1,
dim=2,
)

# Instantiate GreedyVarianceReduction allocator
allocator = GreedyVarianceReduction(dim=1)

# Allocate inducing points and verify output shape
inducing_points = allocator.allocate_inducing_points(
inputs=train_X,
covar_module=model.covar_module,
num_inducing=10,
input_batch_shape=torch.Size([]),
)
# Check if model has dummy points
self.assertIsNone(model.inducing_point_method.last_allocator_used)
self.assertTrue(torch.all(model.variational_strategy.inducing_points == 0))
self.assertTrue(model.variational_strategy.inducing_points.shape == (10, 2))

# Then fit the model and check that the inducing points are updated
model.fit(train_X, train_Y)

self.assertTrue(
torch.allclose(inducing_points, model.variational_strategy.inducing_points)
self.assertIs(
model.inducing_point_method.last_allocator_used, GreedyVarianceReduction
)
inducing_points = model.variational_strategy.inducing_points
self.assertTrue(inducing_points.shape == (10, 2))
self.assertTrue(torch.all((inducing_points >= 0) & (inducing_points <= 100)))

def test_greedy_variance_from_config(self):
config_str = """
Expand Down

0 comments on commit 890d6f6

Please sign in to comment.