diff --git a/aepsych/plotting.py b/aepsych/plotting.py index 726d1fdc4..bca71f1cb 100644 --- a/aepsych/plotting.py +++ b/aepsych/plotting.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt import numpy as np from aepsych.strategy import Strategy -from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol +from aepsych.utils import dim_grid, get_lse_contour, get_lse_interval, make_scaled_sobol from matplotlib.axes import Axes from matplotlib.image import AxesImage from scipy.stats import norm @@ -174,7 +174,7 @@ def _plot_strat_1d( assert x is not None and y is not None, "No data to plot!" if strat.model is not None: - grid = strat.model.dim_grid(gridsize=gridsize).cpu() + grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu() samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach()) phimean = samps.mean(0) else: @@ -196,9 +196,6 @@ def _plot_strat_1d( if target_level is not None: from aepsych.utils import interpolate_monotonic - lb = strat.transforms.untransform(strat.lb)[0] - ub = strat.transforms.untransform(strat.ub)[0] - threshold_samps = [ interpolate_monotonic(grid, s, target_level, strat.lb[0], strat.ub[0]) for s in samps @@ -300,7 +297,7 @@ def _plot_strat_2d( if strat.model is not None: strat.model.fit(train_x=x, train_y=y, max_fit_time=None) - grid = strat.model.dim_grid(gridsize=gridsize) + grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu() fmean, _ = strat.model.predict(grid) phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T else: @@ -341,6 +338,8 @@ def _plot_strat_2d( model=strat.model, mono_grid=mono_grid, target_level=target_level, + grid_lb=strat.lb, + grid_ub=strat.ub, cred_level=cred_level, mono_dim=1, lb=mono_grid.min(), @@ -513,7 +512,12 @@ def plot_slice( """ extent = np.c_[strat.lb, strat.ub].reshape(-1) if strat.model is not None: - x = strat.model.dim_grid(gridsize=gridsize, slice_dims={slice_dim: slice_val}) + x = dim_grid( + lower=strat.lb, + upper=strat.ub, + gridsize=gridsize, + slice_dims={slice_dim: slice_val}, + ).cpu() else: raise RuntimeError("Cannot plot without a model!") if lse: diff --git a/aepsych/utils.py b/aepsych/utils.py index 7383a8a7a..044e8c878 100644 --- a/aepsych/utils.py +++ b/aepsych/utils.py @@ -168,13 +168,19 @@ def interpolate_monotonic( y1 = y[idx] x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0) - return x_star.cpu().item() + + if isinstance(x_star, torch.Tensor): + return x_star.cpu().item() + else: + return x_star def get_lse_interval( model: GPyTorchModel, mono_grid: Union[torch.Tensor, np.ndarray], target_level: float, + grid_lb: torch.Tensor, + grid_ub: torch.Tensor, cred_level: Optional[float] = None, mono_dim: int = -1, n_samps: int = 500, @@ -189,11 +195,17 @@ def get_lse_interval( model (GPyTorchModel): Model to use for sampling. mono_grid (Union[torch.Tensor, np.ndarray]): Monotonic grid. target_level (float): Target level. + grid_lb (torch.Tensor): The lower bound of the grid to sample from to calculate + LSE. + grid_ub (torch.Tensor): The upper bound of the grid to sample from to calculate + LSE. cred_level (float, optional): Credibility level. Defaults to None. mono_dim (int): Monotonic dimension. Defaults to -1. n_samps (int): Number of samples. Defaults to 500. - lb (float): Lower bound. Defaults to -float("inf"). - ub (float): Upper bound. Defaults to float("inf"). + lb (float): Theoreticaly true lower bound for the parameter. Defaults to + -float("inf"). + ub (float): Theoretical true uppper bound for the parameters. Defaults to + float("inf"). gridsize (int): Grid size. Defaults to 30. Returns: @@ -203,7 +215,7 @@ def get_lse_interval( xgrid = torch.stack( torch.meshgrid( [ - torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize) + torch.linspace(grid_lb[i].item(), grid_ub[i].item(), gridsize) for i in range(model.dim) ] ),