Skip to content

Commit

Permalink
Fix plotting to stop using methods/attributes removed from models (#491)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #491

Dim grid no longer exists in models, plotting code updated to use standalone dim_grid.

Bounds are gone from models so get_lse_interval needs to separately define the bounds of the grid to sample from.

Reviewed By: crasanders

Differential Revision: D67535775

fbshipit-source-id: 4ffdbfff7f1911c8c5c4293499d11da02653b4c6
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Dec 20, 2024
1 parent 7874b44 commit 2f6ee28
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
18 changes: 11 additions & 7 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
from aepsych.strategy import Strategy
from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol
from aepsych.utils import dim_grid, get_lse_contour, get_lse_interval, make_scaled_sobol
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from scipy.stats import norm
Expand Down Expand Up @@ -174,7 +174,7 @@ def _plot_strat_1d(
assert x is not None and y is not None, "No data to plot!"

if strat.model is not None:
grid = strat.model.dim_grid(gridsize=gridsize).cpu()
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
else:
Expand All @@ -196,9 +196,6 @@ def _plot_strat_1d(
if target_level is not None:
from aepsych.utils import interpolate_monotonic

lb = strat.transforms.untransform(strat.lb)[0]
ub = strat.transforms.untransform(strat.ub)[0]

threshold_samps = [
interpolate_monotonic(grid, s, target_level, strat.lb[0], strat.ub[0])
for s in samps
Expand Down Expand Up @@ -300,7 +297,7 @@ def _plot_strat_2d(
if strat.model is not None:
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)

grid = strat.model.dim_grid(gridsize=gridsize)
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
else:
Expand Down Expand Up @@ -341,6 +338,8 @@ def _plot_strat_2d(
model=strat.model,
mono_grid=mono_grid,
target_level=target_level,
grid_lb=strat.lb,
grid_ub=strat.ub,
cred_level=cred_level,
mono_dim=1,
lb=mono_grid.min(),
Expand Down Expand Up @@ -513,7 +512,12 @@ def plot_slice(
"""
extent = np.c_[strat.lb, strat.ub].reshape(-1)
if strat.model is not None:
x = strat.model.dim_grid(gridsize=gridsize, slice_dims={slice_dim: slice_val})
x = dim_grid(
lower=strat.lb,
upper=strat.ub,
gridsize=gridsize,
slice_dims={slice_dim: slice_val},
).cpu()
else:
raise RuntimeError("Cannot plot without a model!")
if lse:
Expand Down
20 changes: 16 additions & 4 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,19 @@ def interpolate_monotonic(
y1 = y[idx]

x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0)
return x_star.cpu().item()

if isinstance(x_star, torch.Tensor):
return x_star.cpu().item()
else:
return x_star


def get_lse_interval(
model: GPyTorchModel,
mono_grid: Union[torch.Tensor, np.ndarray],
target_level: float,
grid_lb: torch.Tensor,
grid_ub: torch.Tensor,
cred_level: Optional[float] = None,
mono_dim: int = -1,
n_samps: int = 500,
Expand All @@ -189,11 +195,17 @@ def get_lse_interval(
model (GPyTorchModel): Model to use for sampling.
mono_grid (Union[torch.Tensor, np.ndarray]): Monotonic grid.
target_level (float): Target level.
grid_lb (torch.Tensor): The lower bound of the grid to sample from to calculate
LSE.
grid_ub (torch.Tensor): The upper bound of the grid to sample from to calculate
LSE.
cred_level (float, optional): Credibility level. Defaults to None.
mono_dim (int): Monotonic dimension. Defaults to -1.
n_samps (int): Number of samples. Defaults to 500.
lb (float): Lower bound. Defaults to -float("inf").
ub (float): Upper bound. Defaults to float("inf").
lb (float): Theoreticaly true lower bound for the parameter. Defaults to
-float("inf").
ub (float): Theoretical true uppper bound for the parameters. Defaults to
float("inf").
gridsize (int): Grid size. Defaults to 30.
Returns:
Expand All @@ -203,7 +215,7 @@ def get_lse_interval(
xgrid = torch.stack(
torch.meshgrid(
[
torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize)
torch.linspace(grid_lb[i].item(), grid_ub[i].item(), gridsize)
for i in range(model.dim)
]
),
Expand Down

0 comments on commit 2f6ee28

Please sign in to comment.