diff --git a/aepsych/generators/acqf_thompson_sampler_generator.py b/aepsych/generators/acqf_thompson_sampler_generator.py index 9a744f56d..3e5ea7b68 100644 --- a/aepsych/generators/acqf_thompson_sampler_generator.py +++ b/aepsych/generators/acqf_thompson_sampler_generator.py @@ -76,7 +76,8 @@ def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Ten num_points (int, optional): Number of points to query. model (ModelProtocol): Fitted model of the data. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim] + or shape [num_points, dim, 2] if pairing is applied. """ if self.stimuli_per_trial == 2: diff --git a/aepsych/generators/manual_generator.py b/aepsych/generators/manual_generator.py index f794db2ab..079d3b938 100644 --- a/aepsych/generators/manual_generator.py +++ b/aepsych/generators/manual_generator.py @@ -57,7 +57,7 @@ def gen( Args: num_points (int): Number of points to query. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim]. """ if num_points > (len(self.points) - self._idx): warnings.warn( diff --git a/aepsych/generators/monotonic_rejection_generator.py b/aepsych/generators/monotonic_rejection_generator.py index 6676feaf9..d103d1845 100644 --- a/aepsych/generators/monotonic_rejection_generator.py +++ b/aepsych/generators/monotonic_rejection_generator.py @@ -80,7 +80,7 @@ def gen( num_points (int, optional): Number of points to query. model (AEPsychMixin): Fitted model of the data. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim]. """ options = self.model_gen_options or {} diff --git a/aepsych/generators/monotonic_thompson_sampler_generator.py b/aepsych/generators/monotonic_thompson_sampler_generator.py index b08d50bbb..62a0bce6e 100644 --- a/aepsych/generators/monotonic_thompson_sampler_generator.py +++ b/aepsych/generators/monotonic_thompson_sampler_generator.py @@ -60,7 +60,7 @@ def gen( num_points (int, optional): Number of points to query. model (AEPsychMixin): Fitted model of the data. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: The next set of point(s) to evaluate, with shape [num_points, dim]. """ # Generate the points at which to sample diff --git a/aepsych/generators/optimize_acqf_generator.py b/aepsych/generators/optimize_acqf_generator.py index 77ab1a42d..a787c097c 100644 --- a/aepsych/generators/optimize_acqf_generator.py +++ b/aepsych/generators/optimize_acqf_generator.py @@ -80,7 +80,8 @@ def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Ten num_points (int, optional): Number of points to query. model (ModelProtocol): Fitted model of the data. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim] + or shape [num_points, dim, 2] if pairing is applied. """ if self.stimuli_per_trial == 2: diff --git a/aepsych/generators/random_generator.py b/aepsych/generators/random_generator.py index 41acc8546..f4f1cf849 100644 --- a/aepsych/generators/random_generator.py +++ b/aepsych/generators/random_generator.py @@ -45,7 +45,7 @@ def gen( Args: num_points (int, optional): Number of points to query. Currently, only 1 point can be queried at a time. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim]. """ X = self.bounds_[0] + torch.rand((num_points, self.bounds_.shape[1])) * ( self.bounds_[1] - self.bounds_[0] diff --git a/aepsych/generators/sobol_generator.py b/aepsych/generators/sobol_generator.py index ce54150f3..549b6308f 100644 --- a/aepsych/generators/sobol_generator.py +++ b/aepsych/generators/sobol_generator.py @@ -55,7 +55,8 @@ def gen( Args: num_points (int, optional): Number of points to query. Returns: - np.ndarray: Next set of point(s) to evaluate, [num_points x dim]. + torch.Tensor: Next set of point(s) to evaluate, with shape [num_points, dim] + or shape [num_points, dim , stimuli_per_trial] if `stimuli_per_trial` is greater than 1. """ grid = self.engine.draw(num_points) grid = self.lb + (self.ub - self.lb) * grid