diff --git a/aepsych/models/utils.py b/aepsych/models/utils.py index 8677221c8..d6eb6ac1d 100644 --- a/aepsych/models/utils.py +++ b/aepsych/models/utils.py @@ -123,7 +123,7 @@ def get_extremum( # Transform locked dims tmp = {} for key, value in locked_dims.items(): - tensor = torch.zeros(len(bounds)) + tensor = torch.zeros(model.dim) tensor[key] = value tensor = model.transforms.transform(tensor) tmp[key] = tensor[key].item() diff --git a/tests/server/message_handlers/test_query_handlers.py b/tests/server/message_handlers/test_query_handlers.py index 038731181..d03c52472 100644 --- a/tests/server/message_handlers/test_query_handlers.py +++ b/tests/server/message_handlers/test_query_handlers.py @@ -19,7 +19,7 @@ def test_strat_query(self): [common] stimuli_per_trial=2 outcome_types=[binary] - parnames = [par1, par2] + parnames = [par1, par2, par3] strategy_names = [opt_strat] acqf = PairwiseMCPosteriorVariance @@ -33,6 +33,11 @@ def test_strat_query(self): lower_bound = -1 upper_bound = 1 + [par3] + par_type = continuous + lower_bound = 10 + upper_bound = 100 + [opt_strat] min_asks = 1 model = PairwiseProbitModel @@ -57,9 +62,26 @@ def test_strat_query(self): tell_request = { "type": "tell", "message": [ - {"config": {"par1": [0.5, 0.5], "par2": [-0.5, -0.5]}, "outcome": 1}, - {"config": {"par1": [0.0, 0.75], "par2": [0.0, -1]}, "outcome": 0}, - {"config": {"par1": [1, -1], "par2": [0, 0.0]}, "outcome": 0}, + { + "config": { + "par1": [0.5, 0.5], + "par2": [-0.5, -0.5], + "par3": [40, 50], + }, + "outcome": 1, + }, + { + "config": { + "par1": [0.0, 0.75], + "par2": [0.0, -1], + "par3": [11, 99], + }, + "outcome": 0, + }, + { + "config": {"par1": [1, -1], "par2": [0, 0.0], "par3": [40, 12]}, + "outcome": 0, + }, ], } @@ -84,7 +106,7 @@ def test_strat_query(self): "type": "query", "message": { "query_type": "prediction", - "x": {"par1": [0.0], "par2": [-0.5]}, + "x": {"par1": [0.0], "par2": [-0.5], "par3": [45]}, }, } query_inv_req = { @@ -105,7 +127,7 @@ def test_strat_query(self): } query_inv_const = { "type": "query", - "message": {"query_type": "inverse", "y": 5.0, "constraints": {1: 0}}, + "message": {"query_type": "inverse", "y": 5.0, "constraints": {2: 20}}, } response = self.s.handle_request(query_min_req) @@ -131,7 +153,7 @@ def test_strat_query(self): self.assertTrue(response["x"]["par1"][0] == 0.25) response = self.s.handle_request(query_inv_const) - self.assertTrue(response["x"]["par2"][0] == 0) + self.assertTrue(response["x"]["par3"][0] == 20) def test_grad_model_smoketest(self): # Some models return values with gradients that need to be handled