Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix plotting to stop using methods/attributes removed from models #491

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
from aepsych.strategy import Strategy
from aepsych.utils import get_lse_contour, get_lse_interval, make_scaled_sobol
from aepsych.utils import dim_grid, get_lse_contour, get_lse_interval, make_scaled_sobol
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from scipy.stats import norm
Expand Down Expand Up @@ -174,7 +174,7 @@ def _plot_strat_1d(
assert x is not None and y is not None, "No data to plot!"

if strat.model is not None:
grid = strat.model.dim_grid(gridsize=gridsize).cpu()
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
else:
Expand All @@ -196,9 +196,6 @@ def _plot_strat_1d(
if target_level is not None:
from aepsych.utils import interpolate_monotonic

lb = strat.transforms.untransform(strat.lb)[0]
ub = strat.transforms.untransform(strat.ub)[0]

threshold_samps = [
interpolate_monotonic(grid, s, target_level, strat.lb[0], strat.ub[0])
for s in samps
Expand Down Expand Up @@ -300,7 +297,7 @@ def _plot_strat_2d(
if strat.model is not None:
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)

grid = strat.model.dim_grid(gridsize=gridsize)
grid = dim_grid(lower=strat.lb, upper=strat.ub, gridsize=gridsize).cpu()
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
else:
Expand Down Expand Up @@ -341,6 +338,8 @@ def _plot_strat_2d(
model=strat.model,
mono_grid=mono_grid,
target_level=target_level,
grid_lb=strat.lb,
grid_ub=strat.ub,
cred_level=cred_level,
mono_dim=1,
lb=mono_grid.min(),
Expand Down Expand Up @@ -513,7 +512,12 @@ def plot_slice(
"""
extent = np.c_[strat.lb, strat.ub].reshape(-1)
if strat.model is not None:
x = strat.model.dim_grid(gridsize=gridsize, slice_dims={slice_dim: slice_val})
x = dim_grid(
lower=strat.lb,
upper=strat.ub,
gridsize=gridsize,
slice_dims={slice_dim: slice_val},
).cpu()
else:
raise RuntimeError("Cannot plot without a model!")
if lse:
Expand Down
3 changes: 3 additions & 0 deletions aepsych/transforms/ops/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def get_config_options(
if "values" not in options:
value = config[name].get("value")

if value is None:
raise ValueError(f"Value option not found in {name} section.")

try:
options["values"] = [float(value)]
except ValueError:
Expand Down
22 changes: 17 additions & 5 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def dim_grid(
if i in slice_dims.keys():
mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1))
else:
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j))
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j)) # type: ignore

return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T)

Expand Down Expand Up @@ -168,13 +168,19 @@ def interpolate_monotonic(
y1 = y[idx]

x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0)
return x_star.cpu().item()

if isinstance(x_star, torch.Tensor):
return x_star.cpu().item()
else:
return x_star


def get_lse_interval(
model: GPyTorchModel,
mono_grid: Union[torch.Tensor, np.ndarray],
target_level: float,
grid_lb: torch.Tensor,
grid_ub: torch.Tensor,
cred_level: Optional[float] = None,
mono_dim: int = -1,
n_samps: int = 500,
Expand All @@ -189,11 +195,17 @@ def get_lse_interval(
model (GPyTorchModel): Model to use for sampling.
mono_grid (Union[torch.Tensor, np.ndarray]): Monotonic grid.
target_level (float): Target level.
grid_lb (torch.Tensor): The lower bound of the grid to sample from to calculate
LSE.
grid_ub (torch.Tensor): The upper bound of the grid to sample from to calculate
LSE.
cred_level (float, optional): Credibility level. Defaults to None.
mono_dim (int): Monotonic dimension. Defaults to -1.
n_samps (int): Number of samples. Defaults to 500.
lb (float): Lower bound. Defaults to -float("inf").
ub (float): Upper bound. Defaults to float("inf").
lb (float): Theoreticaly true lower bound for the parameter. Defaults to
-float("inf").
ub (float): Theoretical true uppper bound for the parameters. Defaults to
float("inf").
gridsize (int): Grid size. Defaults to 30.

Returns:
Expand All @@ -203,7 +215,7 @@ def get_lse_interval(
xgrid = torch.stack(
torch.meshgrid(
[
torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize)
torch.linspace(grid_lb[i].item(), grid_ub[i].item(), gridsize)
for i in range(model.dim)
]
),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"flake8",
"black",
"sqlalchemy-stubs", # for mypy stubs
"mypy",
"mypy==1.14.0",
"parameterized",
"scikit-learn", # used in unit tests
]
Expand Down