diff --git a/4b40a34b0cfb4203acd556e92c7fbe8f.db b/4b40a34b0cfb4203acd556e92c7fbe8f.db new file mode 100644 index 000000000..31afabe48 Binary files /dev/null and b/4b40a34b0cfb4203acd556e92c7fbe8f.db differ diff --git a/521277e99e1945a59ed884b9d2a0014f.db b/521277e99e1945a59ed884b9d2a0014f.db new file mode 100644 index 000000000..f333e6ece Binary files /dev/null and b/521277e99e1945a59ed884b9d2a0014f.db differ diff --git a/aepsych/generators/optimize_acqf_generator.py b/aepsych/generators/optimize_acqf_generator.py index 77ab1a42d..d78c16f6e 100644 --- a/aepsych/generators/optimize_acqf_generator.py +++ b/aepsych/generators/optimize_acqf_generator.py @@ -108,7 +108,7 @@ def _gen( new_candidate, _ = optimize_acqf( acq_function=acqf, - bounds=torch.tensor(np.c_[model.lb, model.ub]).T.to(train_x), + bounds=torch.stack([model.lb, model.ub]).to(train_x), q=num_points, num_restarts=self.restarts, raw_samples=self.samps, diff --git a/aepsych/generators/sobol_generator.py b/aepsych/generators/sobol_generator.py index ce54150f3..98ca38f73 100644 --- a/aepsych/generators/sobol_generator.py +++ b/aepsych/generators/sobol_generator.py @@ -58,7 +58,12 @@ def gen( np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. """ grid = self.engine.draw(num_points) + + # PyTorch's sobol engine always returns tenors on CPU. + # Thus, we have to move the tensor to the correct device manually. + grid = grid.to(self.lb.device) grid = self.lb + (self.ub - self.lb) * grid + if self.stimuli_per_trial == 1: return grid diff --git a/aepsych/models/gp_classification.py b/aepsych/models/gp_classification.py index 7f8785536..55645fdd7 100644 --- a/aepsych/models/gp_classification.py +++ b/aepsych/models/gp_classification.py @@ -195,6 +195,7 @@ def _reset_variational_strategy(self): variational_distribution, learn_inducing_locations=False, ) + self.variational_strategy.to(self.train_targets.device) def fit( self, @@ -268,8 +269,9 @@ def predict( a_star = fmean / torch.sqrt(1 + fvar) pmean = Normal(0, 1).cdf(a_star) t_term = torch.tensor( - owens_t(a_star.numpy(), 1 / np.sqrt(1 + 2 * fvar.numpy())), + owens_t(a_star.cpu().numpy(), 1 / np.sqrt(1 + 2 * fvar.cpu().numpy())), dtype=a_star.dtype, + device=pmean.device, ) pvar = pmean - 2 * t_term - pmean.square() return promote_0d(pmean), promote_0d(pvar) diff --git a/aepsych/models/utils.py b/aepsych/models/utils.py index 874845f49..b62532066 100644 --- a/aepsych/models/utils.py +++ b/aepsych/models/utils.py @@ -96,7 +96,7 @@ def select_inducing_points( elif method == "kmeans++": # initialize using kmeans inducing_points = torch.tensor( - kmeans2(unique_X.numpy(), inducing_size, minit="++")[0], + kmeans2(unique_X.cpu().numpy(), inducing_size, minit="++")[0], dtype=X.dtype, ) return inducing_points diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 704dd09fd..55b57009b 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -19,6 +19,7 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.generators.sobol_generator import SobolGenerator from aepsych.models.base import ModelProtocol +from aepsych.models import GPClassificationModel from aepsych.utils import ( _process_bounds, make_scaled_sobol, @@ -178,30 +179,25 @@ def normalize_inputs(self, x, y): y (np.ndarray): training outputs Returns: - x (np.ndarray): training inputs, normalized - y (np.ndarray): training outputs, normalized + x (tensor): training inputs, normalized + y (tensor): training outputs, normalized n (int): number of observations """ assert ( x.shape == self.event_shape or x.shape[1:] == self.event_shape ), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}" - if x.shape == self.event_shape: - x = x[None, :] - - if self.x is None: - x = np.r_[x] - else: - x = np.r_[self.x, x] + if isinstance(y, np.ndarray) or isinstance(y, list) or isinstance(y, int) or isinstance(y, float): + x = torch.tensor(x, dtype=torch.float64) + y = torch.tensor(y, dtype=torch.float64).view(-1) - if self.y is None: - y = np.r_[y] - else: - y = np.r_[self.y, y] + if x.shape == self.event_shape: + x = x.unsqueeze(0) - n = y.shape[0] + x = torch.cat([self.x, x], dim=0) if self.x is not None else x + y = torch.cat([self.y, y], dim=0) if self.y is not None else y - return torch.Tensor(x), torch.Tensor(y), n + return x, y, y.size(0) # TODO: allow user to pass in generator options @ensure_model_is_fresh @@ -310,6 +306,22 @@ def add_data(self, x, y): self.x, self.y, self.n = self.normalize_inputs(x, y) self._model_is_fresh = False + if self.x.size(0) >= 100: + # TODO: Support more models beyond GPClassificationModel + if ( + isinstance(self.model, GPClassificationModel) and + self.model.variational_strategy.inducing_points.size(0) >= 100 + ): + # move the model and data to GPUs if the number of training points is at least 100 and + # the number of inducing points is at least 100 + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(device) + self.model.lb = self.model.lb.to(device) + self.model.ub = self.model.ub.to(device) + + self.x = self.x.to(device) + self.y = self.y.to(device) + def fit(self): if self.can_fit: if self.keep_most_recent is not None: diff --git a/tests/models/test_gp_classification.py b/tests/models/test_gp_classification.py index 6bd334b4a..3356f6a6d 100644 --- a/tests/models/test_gp_classification.py +++ b/tests/models/test_gp_classification.py @@ -117,35 +117,43 @@ def test_1d_classification(self): """ Just see if we memorize the training set """ - X, y = self.X, self.y - model = GPClassificationModel( - torch.Tensor([-3]), torch.Tensor([3]), inducing_size=10 - ) + if torch.cuda.is_available(): + lst_devices = ["cuda:0", "cpu"] + else: + lst_devices = ["cpu"] - model.fit(X[:50], y[:50]) + for device in lst_devices: + X, y = self.X, self.y + X, y = X.to(device), y.to(device) - # pspace - pm, _ = model.predict_probability(X[:50]) - pred = (pm > 0.5).numpy() - npt.assert_allclose(pred, y[:50]) + model = GPClassificationModel( + torch.Tensor([-3]).to(device), torch.Tensor([3]).to(device), inducing_size=10 + ).to(device) - # fspace - pm, _ = model.predict(X[:50], probability_space=False) - pred = (pm > 0).numpy() - npt.assert_allclose(pred, y[:50]) + model.fit(X[:50], y[:50]) - # smoke test update - model.update(X, y) + # pspace + pm, _ = model.predict_probability(X[:50]) + pred = (pm > 0.5).cpu().numpy() + npt.assert_allclose(pred, y[:50].cpu().numpy()) - # pspace - pm, _ = model.predict_probability(X) - pred = (pm > 0.5).numpy() - npt.assert_allclose(pred, y) + # fspace + pm, _ = model.predict(X[:50], probability_space=False) + pred = (pm > 0).cpu().numpy() + npt.assert_allclose(pred, y[:50].cpu().numpy()) - # fspace - pm, _ = model.predict(X, probability_space=False) - pred = (pm > 0).numpy() - npt.assert_allclose(pred, y) + # smoke test update + model.update(X, y) + + # pspace + pm, _ = model.predict_probability(X) + pred = (pm > 0.5).cpu().numpy() + npt.assert_allclose(pred, y.cpu().numpy()) + + # fspace + pm, _ = model.predict(X, probability_space=False) + pred = (pm > 0).cpu().numpy() + npt.assert_allclose(pred, y.cpu().numpy()) def test_1d_classification_pytorchopt(self): """ @@ -646,7 +654,7 @@ def obj(x): for _i in range(n_init + n_opt): next_x = strat.gen() - strat.add_data(next_x, [bernoulli.rvs(norm.cdf(next_x / 1.5))]) + strat.add_data(next_x, Normal(0, 1).cdf(next_x / 1.5).bernoulli().view(-1)) x = torch.linspace(-4, 4, 100)