Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move models and data to GPUs automatically depending on the dataset size #374

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added 4b40a34b0cfb4203acd556e92c7fbe8f.db
Binary file not shown.
Binary file added 521277e99e1945a59ed884b9d2a0014f.db
Binary file not shown.
2 changes: 1 addition & 1 deletion aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _gen(

new_candidate, _ = optimize_acqf(
acq_function=acqf,
bounds=torch.tensor(np.c_[model.lb, model.ub]).T.to(train_x),
bounds=torch.stack([model.lb, model.ub]).to(train_x),
q=num_points,
num_restarts=self.restarts,
raw_samples=self.samps,
Expand Down
5 changes: 5 additions & 0 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def gen(
np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
"""
grid = self.engine.draw(num_points)

# PyTorch's sobol engine always returns tenors on CPU.
# Thus, we have to move the tensor to the correct device manually.
grid = grid.to(self.lb.device)
grid = self.lb + (self.ub - self.lb) * grid

if self.stimuli_per_trial == 1:
return grid

Expand Down
4 changes: 3 additions & 1 deletion aepsych/models/gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _reset_variational_strategy(self):
variational_distribution,
learn_inducing_locations=False,
)
self.variational_strategy.to(self.train_targets.device)

def fit(
self,
Expand Down Expand Up @@ -268,8 +269,9 @@ def predict(
a_star = fmean / torch.sqrt(1 + fvar)
pmean = Normal(0, 1).cdf(a_star)
t_term = torch.tensor(
owens_t(a_star.numpy(), 1 / np.sqrt(1 + 2 * fvar.numpy())),
owens_t(a_star.cpu().numpy(), 1 / np.sqrt(1 + 2 * fvar.cpu().numpy())),
dtype=a_star.dtype,
device=pmean.device,
)
pvar = pmean - 2 * t_term - pmean.square()
return promote_0d(pmean), promote_0d(pvar)
Expand Down
2 changes: 1 addition & 1 deletion aepsych/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def select_inducing_points(
elif method == "kmeans++":
# initialize using kmeans
inducing_points = torch.tensor(
kmeans2(unique_X.numpy(), inducing_size, minit="++")[0],
kmeans2(unique_X.cpu().numpy(), inducing_size, minit="++")[0],
dtype=X.dtype,
)
return inducing_points
Expand Down
42 changes: 27 additions & 15 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aepsych.generators.base import AEPsychGenerator
from aepsych.generators.sobol_generator import SobolGenerator
from aepsych.models.base import ModelProtocol
from aepsych.models import GPClassificationModel
from aepsych.utils import (
_process_bounds,
make_scaled_sobol,
Expand Down Expand Up @@ -178,30 +179,25 @@ def normalize_inputs(self, x, y):
y (np.ndarray): training outputs

Returns:
x (np.ndarray): training inputs, normalized
y (np.ndarray): training outputs, normalized
x (tensor): training inputs, normalized
y (tensor): training outputs, normalized
n (int): number of observations
"""
assert (
x.shape == self.event_shape or x.shape[1:] == self.event_shape
), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"

if x.shape == self.event_shape:
x = x[None, :]

if self.x is None:
x = np.r_[x]
else:
x = np.r_[self.x, x]
if isinstance(y, np.ndarray) or isinstance(y, list) or isinstance(y, int) or isinstance(y, float):
x = torch.tensor(x, dtype=torch.float64)
y = torch.tensor(y, dtype=torch.float64).view(-1)

if self.y is None:
y = np.r_[y]
else:
y = np.r_[self.y, y]
if x.shape == self.event_shape:
x = x.unsqueeze(0)

n = y.shape[0]
x = torch.cat([self.x, x], dim=0) if self.x is not None else x
y = torch.cat([self.y, y], dim=0) if self.y is not None else y

return torch.Tensor(x), torch.Tensor(y), n
return x, y, y.size(0)

# TODO: allow user to pass in generator options
@ensure_model_is_fresh
Expand Down Expand Up @@ -310,6 +306,22 @@ def add_data(self, x, y):
self.x, self.y, self.n = self.normalize_inputs(x, y)
self._model_is_fresh = False

if self.x.size(0) >= 100:
# TODO: Support more models beyond GPClassificationModel
if (
isinstance(self.model, GPClassificationModel) and
self.model.variational_strategy.inducing_points.size(0) >= 100
):
# move the model and data to GPUs if the number of training points is at least 100 and
# the number of inducing points is at least 100
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(device)
self.model.lb = self.model.lb.to(device)
self.model.ub = self.model.ub.to(device)

self.x = self.x.to(device)
self.y = self.y.to(device)

def fit(self):
if self.can_fit:
if self.keep_most_recent is not None:
Expand Down
56 changes: 32 additions & 24 deletions tests/models/test_gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,35 +117,43 @@ def test_1d_classification(self):
"""
Just see if we memorize the training set
"""
X, y = self.X, self.y
model = GPClassificationModel(
torch.Tensor([-3]), torch.Tensor([3]), inducing_size=10
)
if torch.cuda.is_available():
lst_devices = ["cuda:0", "cpu"]
else:
lst_devices = ["cpu"]

model.fit(X[:50], y[:50])
for device in lst_devices:
X, y = self.X, self.y
X, y = X.to(device), y.to(device)

# pspace
pm, _ = model.predict_probability(X[:50])
pred = (pm > 0.5).numpy()
npt.assert_allclose(pred, y[:50])
model = GPClassificationModel(
torch.Tensor([-3]).to(device), torch.Tensor([3]).to(device), inducing_size=10
).to(device)

# fspace
pm, _ = model.predict(X[:50], probability_space=False)
pred = (pm > 0).numpy()
npt.assert_allclose(pred, y[:50])
model.fit(X[:50], y[:50])

# smoke test update
model.update(X, y)
# pspace
pm, _ = model.predict_probability(X[:50])
pred = (pm > 0.5).cpu().numpy()
npt.assert_allclose(pred, y[:50].cpu().numpy())

# pspace
pm, _ = model.predict_probability(X)
pred = (pm > 0.5).numpy()
npt.assert_allclose(pred, y)
# fspace
pm, _ = model.predict(X[:50], probability_space=False)
pred = (pm > 0).cpu().numpy()
npt.assert_allclose(pred, y[:50].cpu().numpy())

# fspace
pm, _ = model.predict(X, probability_space=False)
pred = (pm > 0).numpy()
npt.assert_allclose(pred, y)
# smoke test update
model.update(X, y)

# pspace
pm, _ = model.predict_probability(X)
pred = (pm > 0.5).cpu().numpy()
npt.assert_allclose(pred, y.cpu().numpy())

# fspace
pm, _ = model.predict(X, probability_space=False)
pred = (pm > 0).cpu().numpy()
npt.assert_allclose(pred, y.cpu().numpy())

def test_1d_classification_pytorchopt(self):
"""
Expand Down Expand Up @@ -646,7 +654,7 @@ def obj(x):

for _i in range(n_init + n_opt):
next_x = strat.gen()
strat.add_data(next_x, [bernoulli.rvs(norm.cdf(next_x / 1.5))])
strat.add_data(next_x, Normal(0, 1).cdf(next_x / 1.5).bernoulli().view(-1))

x = torch.linspace(-4, 4, 100)

Expand Down