Skip to content

Commit

Permalink
fix query constraints to use dims to make dummies (#488)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #488

Fix to make queries respect transforms in constraints was not creating bounds correctly. This uses the new dims in the models to solve that.

Reviewed By: crasanders

Differential Revision: D67497154

fbshipit-source-id: b5965a764f04e7476b8945efb9ad259be7ad08d1
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 20, 2024
1 parent 669738f commit 115dddb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
2 changes: 1 addition & 1 deletion aepsych/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_extremum(
# Transform locked dims
tmp = {}
for key, value in locked_dims.items():
tensor = torch.zeros(len(bounds))
tensor = torch.zeros(model.dim)
tensor[key] = value
tensor = model.transforms.transform(tensor)
tmp[key] = tensor[key].item()
Expand Down
36 changes: 29 additions & 7 deletions tests/server/message_handlers/test_query_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_strat_query(self):
[common]
stimuli_per_trial=2
outcome_types=[binary]
parnames = [par1, par2]
parnames = [par1, par2, par3]
strategy_names = [opt_strat]
acqf = PairwiseMCPosteriorVariance
Expand All @@ -33,6 +33,11 @@ def test_strat_query(self):
lower_bound = -1
upper_bound = 1
[par3]
par_type = continuous
lower_bound = 10
upper_bound = 100
[opt_strat]
min_asks = 1
model = PairwiseProbitModel
Expand All @@ -57,9 +62,26 @@ def test_strat_query(self):
tell_request = {
"type": "tell",
"message": [
{"config": {"par1": [0.5, 0.5], "par2": [-0.5, -0.5]}, "outcome": 1},
{"config": {"par1": [0.0, 0.75], "par2": [0.0, -1]}, "outcome": 0},
{"config": {"par1": [1, -1], "par2": [0, 0.0]}, "outcome": 0},
{
"config": {
"par1": [0.5, 0.5],
"par2": [-0.5, -0.5],
"par3": [40, 50],
},
"outcome": 1,
},
{
"config": {
"par1": [0.0, 0.75],
"par2": [0.0, -1],
"par3": [11, 99],
},
"outcome": 0,
},
{
"config": {"par1": [1, -1], "par2": [0, 0.0], "par3": [40, 12]},
"outcome": 0,
},
],
}

Expand All @@ -84,7 +106,7 @@ def test_strat_query(self):
"type": "query",
"message": {
"query_type": "prediction",
"x": {"par1": [0.0], "par2": [-0.5]},
"x": {"par1": [0.0], "par2": [-0.5], "par3": [45]},
},
}
query_inv_req = {
Expand All @@ -105,7 +127,7 @@ def test_strat_query(self):
}
query_inv_const = {
"type": "query",
"message": {"query_type": "inverse", "y": 5.0, "constraints": {1: 0}},
"message": {"query_type": "inverse", "y": 5.0, "constraints": {2: 20}},
}

response = self.s.handle_request(query_min_req)
Expand All @@ -131,7 +153,7 @@ def test_strat_query(self):
self.assertTrue(response["x"]["par1"][0] == 0.25)

response = self.s.handle_request(query_inv_const)
self.assertTrue(response["x"]["par2"][0] == 0)
self.assertTrue(response["x"]["par3"][0] == 20)

def test_grad_model_smoketest(self):
# Some models return values with gradients that need to be handled
Expand Down

0 comments on commit 115dddb

Please sign in to comment.