From 2027761acd4a945d6caade42f36fc2168691fa40 Mon Sep 17 00:00:00 2001 From: Craig Sanders Date: Fri, 18 Oct 2024 11:37:35 -0700 Subject: [PATCH] add finished property to manual generators (#409) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/409 This adds a finished property to the manual generator classes so that they can keep track of whether they have generated all their points. Once this is hooked into the strategy's finishing logic, it should make writing configs simpler. Reviewed By: JasonKChow Differential Revision: D64600239 fbshipit-source-id: f8d2c18f516592a09cd545e76b52a8878b22b66b --- aepsych/generators/manual_generator.py | 13 +++++++++---- tests/generators/test_manual_generator.py | 6 ++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/aepsych/generators/manual_generator.py b/aepsych/generators/manual_generator.py index f794db2ab..a1c0ea778 100644 --- a/aepsych/generators/manual_generator.py +++ b/aepsych/generators/manual_generator.py @@ -6,15 +6,16 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union import numpy as np import torch +from torch.quasirandom import SobolEngine + from aepsych.config import Config from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin from aepsych.utils import _process_bounds -from torch.quasirandom import SobolEngine class ManualGenerator(AEPsychGenerator): @@ -95,6 +96,10 @@ def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: return options + @property + def finished(self): + return self._idx >= len(self.points) + class SampleAroundPointsGenerator(ManualGenerator): """Generator that samples in a window around reference points in a predefined list.""" @@ -131,9 +136,9 @@ def __init__( grid = self.engine.draw(samples_per_point) grid = p_lb + (p_ub - p_lb) * grid generated.append(grid) - generated = torch.Tensor(np.vstack(generated)) #type: ignore + generated = torch.Tensor(np.vstack(generated)) # type: ignore - super().__init__(lb, ub, generated, dim, shuffle, seed) #type: ignore + super().__init__(lb, ub, generated, dim, shuffle, seed) # type: ignore @classmethod def get_config_options(cls, config: Config, name: Optional[str] = None) -> Dict: diff --git a/tests/generators/test_manual_generator.py b/tests/generators/test_manual_generator.py index ee467a20d..06ed01143 100644 --- a/tests/generators/test_manual_generator.py +++ b/tests/generators/test_manual_generator.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt + from aepsych.config import Config from aepsych.generators import ManualGenerator, SampleAroundPointsGenerator @@ -50,6 +51,7 @@ def test_manual_generator(self): gen = ManualGenerator.from_config(config) npt.assert_equal(gen.lb.numpy(), np.array([0, 0])) npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) + self.assertFalse(gen.finished) p1 = list(gen.gen()[0]) p2 = list(gen.gen()[0]) @@ -60,6 +62,7 @@ def test_manual_generator(self): self.assertEqual(sorted([p1, p2, p3, p4]), points) self.assertEqual(gen.max_asks, len(points)) self.assertEqual(gen.seed, 123) + self.assertTrue(gen.finished) class TestSampleAroundPointsGenerator(unittest.TestCase): @@ -86,6 +89,7 @@ def test_sample_around_points_generator(self): npt.assert_equal(gen.ub.numpy(), np.array([1, 1])) self.assertEqual(gen.max_asks, len(points * samples_per_point)) self.assertEqual(gen.seed, 123) + self.assertFalse(gen.finished) points = gen.gen(gen.max_asks) for i in range(len(window)): @@ -93,6 +97,8 @@ def test_sample_around_points_generator(self): npt.assert_array_less(np.array([0] * len(points)), points[:, i]) npt.assert_array_less(points[:, i], np.array([1] * len(points))) + self.assertTrue(gen.finished) + if __name__ == "__main__": unittest.main()