Skip to content

Commit

Permalink
Completely remove auxiliary methods in models and separate them to th…
Browse files Browse the repository at this point in the history
…eir own functions (facebookresearch#481)

Summary:
Pull Request resolved: facebookresearch#481

Methods like get_min, inv_query, and dim_grid have been completely removed from models and made into utility functions. Strategy used to call these now call the utility functions. The functions are better encapsulated such that they work exactly the same way whether they're being called separately or from within a strategy.

Reviewed By: crasanders

Differential Revision: D67118446
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 14, 2024
1 parent e4be2e6 commit 728269c
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 263 deletions.
16 changes: 0 additions & 16 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import gpytorch
import torch
from aepsych.utils import dim_grid
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll, fit_gpytorch_mll_scipy
from botorch.models.gpytorch import GPyTorchModel
Expand Down Expand Up @@ -114,21 +113,6 @@ class AEPsychMixin(GPyTorchModel):
train_inputs: Optional[Tuple[torch.Tensor]]
train_targets: Optional[torch.Tensor]

# Only used for PairwiseProbitModel, as it is the only one that still uses lower and upper bounds
# TODO remove this method and move the logic to PairwiseProbitModel or find a way to update PairwiseProbitModel to work with lb and ub.
def dim_grid(
self: ModelProtocol,
gridsize: int = 30,
slice_dims: Optional[Mapping[int, float]] = None,
) -> torch.Tensor:
"""Generate a grid based on lower, upper, and dim.
Args:
gridsize (int): Number of points in each dimension. Defaults to 30.
slice_dims (Mapping[int, float], optional): Dimensions to fix at a certain value. Defaults to None.
"""
return dim_grid(self.lb, self.ub, gridsize, slice_dims)

def set_train_data(
self,
inputs: Optional[torch.Tensor] = None,
Expand Down
91 changes: 84 additions & 7 deletions aepsych/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import warnings
from typing import Mapping, Optional, Tuple, Union
from typing import List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -162,8 +162,77 @@ def get_extremum(
return best_val, best_point.squeeze(0)


def get_min(
model: ModelProtocol,
bounds: torch.Tensor,
locked_dims: Optional[Mapping[int, float]] = None,
probability_space: bool = False,
n_samples: int = 1000,
max_time: Optional[float] = None,
) -> Tuple[float, torch.Tensor]:
"""Return the minimum of the modeled function, subject to constraints
Args:
model (ModelProtocol): AEPsychModel to get the minimum of.
bounds (torch.Tensor): Bounds of the space to find the minimum.
locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the
inverse is along a slice of the full surface.
probability_space (bool): Is y (and therefore the returned nearest_y) in
probability space instead of latent function space? Defaults to False.
n_samples (int): number of coarse grid points to sample for optimization estimate.
max_time (float, optional): Maximum time to spend optimizing. Defaults to None.
Returns:
Tuple[float, torch.Tensor]: Tuple containing the min and its location (argmin).
"""
_, _arg = get_extremum(
model, "min", bounds, locked_dims, n_samples, max_time=max_time
)
arg = torch.tensor(_arg.reshape(1, bounds.shape[1]))
if probability_space:
val, _ = model.predict_probability(arg)
else:
val, _ = model.predict(arg)

return float(val.item()), arg


def get_max(
model: ModelProtocol,
bounds: torch.Tensor,
locked_dims: Optional[Mapping[int, float]] = None,
probability_space: bool = False,
n_samples: int = 1000,
max_time: Optional[float] = None,
) -> Tuple[float, torch.Tensor]:
"""Return the maximum of the modeled function, subject to constraints
Args:
model (ModelProtocol): AEPsychModel to get the maximum of.
bounds (torch.Tensor): Bounds of the space to find the maximum.
locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the
inverse is along a slice of the full surface. Defaults to None.
probability_space (bool): Is y (and therefore the returned nearest_y) in
probability space instead of latent function space? Defaults to False.
n_samples (int): number of coarse grid points to sample for optimization estimate.
max_time (float, optional): Maximum time to spend optimizing. Defaults to None.
Returns:
Tuple[float, torch.Tensor]: Tuple containing the max and its location (argmax).
"""
_, _arg = get_extremum(
model, "max", bounds, locked_dims, n_samples, max_time=max_time
)
arg = torch.tensor(_arg.reshape(1, bounds.shape[1]))
if probability_space:
val, _ = model.predict_probability(arg)
else:
val, _ = model.predict(arg)

return float(val.item()), arg


def inv_query(
model: Model,
model: ModelProtocol,
y: Union[float, torch.Tensor],
bounds: torch.Tensor,
locked_dims: Optional[Mapping[int, float]] = None,
Expand All @@ -176,9 +245,10 @@ def inv_query(
Return nearest x such that f(x) = queried y, and also return the
value of f at that point.
Args:
model (ModelProtocol): AEPsychModel to get the find the inverse from y.
y (Union[float, torch.Tensor]): Points at which to find the inverse.
bounds (torch.Tensor): Lower and upper bounds of the search space.
locked_dims (Mapping[int, List[float]], optional): Dimensions to fix, so that the
locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the
inverse is along a slice of the full surface. Defaults to None.
probability_space (bool): Is y (and therefore the
returned nearest_y) in probability space instead of latent
Expand All @@ -191,17 +261,17 @@ def inv_query(
nearest to queried y and the x position of this value.
"""
locked_dims = locked_dims or {}
if model.num_outputs > 1:
if model._num_outputs > 1:
if weights is None:
weights = torch.Tensor([1] * model.num_outputs)
weights = torch.Tensor([1] * model._num_outputs)
if probability_space:
warnings.warn(
"Inverse querying with probability_space=True assumes that the model uses Probit-Bernoulli likelihood!"
)
posterior_transform = TargetProbabilityDistancePosteriorTransform(y, weights)
else:
posterior_transform = TargetDistancePosteriorTransform(y, weights)
val, arg = get_extremum(
_, _arg = get_extremum(
model,
"min",
bounds,
Expand All @@ -211,7 +281,14 @@ def inv_query(
max_time,
weights,
)
return val, arg

arg = torch.tensor(_arg.reshape(1, bounds.shape[1]))
if probability_space:
val, _ = model.predict_probability(arg)
else:
val, _ = model.predict(arg)

return float(val.item()), arg


def get_jnd(
Expand Down
45 changes: 12 additions & 33 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.utils import get_extremum, get_jnd, inv_query
from aepsych.models.utils import get_extremum, get_jnd, get_max, get_min, inv_query
from aepsych.transforms import (
ParameterTransformedGenerator,
ParameterTransformedModel,
Expand Down Expand Up @@ -338,28 +338,20 @@ def get_max(
Returns:
Tuple[float, torch.Tensor]: Tuple containing the max and its location (argmax).
"""
constraints = constraints or {}
assert (
self.model is not None
), "model is None! Cannot get the max without a model!"
self.model.to(self.model_device)

locked_dims = constraints or {}
_, _arg = get_extremum(
val, arg = get_max(
self.model,
"max",
self.bounds,
locked_dims,
locked_dims=constraints,
probability_space=probability_space,
max_time=max_time,
n_samples=1000,
)
arg = torch.tensor(_arg.reshape(1, self.dim))
if probability_space:
val, _ = self.model.predict_probability(arg)
else:
val, _ = self.model.predict(arg)

return float(val.item()), arg
return val, arg

@ensure_model_is_fresh
def get_min(
Expand All @@ -375,28 +367,20 @@ def get_min(
probability_space (bool): Whether to return the min in probability space. Defaults to False.
max_time (float, optional): Maximum time to run the optimization. Defaults to None.
"""
constraints = constraints or {}
assert (
self.model is not None
), "model is None! Cannot get the min without a model!"
self.model.to(self.model_device)

locked_dims = constraints or {}
_, _arg = get_extremum(
val, arg = get_min(
self.model,
"min",
self.bounds,
locked_dims,
locked_dims=constraints,
probability_space=probability_space,
max_time=max_time,
n_samples=1000,
)
arg = torch.tensor(_arg.reshape(1, self.dim))
if probability_space:
val, _ = self.model.predict_probability(arg)
else:
val, _ = self.model.predict(arg)

return float(val.item()), arg
return val, arg

@ensure_model_is_fresh
def inv_query(
Expand All @@ -417,26 +401,21 @@ def inv_query(
Returns:
Tuple[float, torch.Tensor]: The input that corresponds to the given output value and the corresponding output.
"""
constraints = constraints or {}
assert (
self.model is not None
), "model is None! Cannot get the inv_query without a model!"
self.model.to(self.model_device)

_, _arg = inv_query(
val, arg = inv_query(
model=self.model,
y=y,
bounds=self.bounds,
locked_dims=constraints,
probability_space=probability_space,
max_time=max_time,
)
arg = torch.tensor(_arg.reshape(1, self.dim))
if probability_space:
val, _ = self.model.predict_probability(arg.reshape(1, self.dim))
else:
val, _ = self.model.predict(arg.reshape(1, self.dim))
return float(val.item()), arg

return val, arg

@ensure_model_is_fresh
def predict(self, x: torch.Tensor, probability_space: bool = False) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 728269c

Please sign in to comment.