From c24bbd918818d4dd4102aec4336fe073cf7fc8b6 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 20 Nov 2024 09:25:21 -0800 Subject: [PATCH] add support for categorical parameters (#449) Summary: Added support for categorical parameters. This required changing some server functions to be able to take and give strings as a part of the responses. Externally, categorical parameters are strings. Internally, categorical parameters are represented as index parameters (0, nChoices -1) while they are passed to models as one_hot vectored parameters. The index intermediate could be useful in the future for alternative ways to model categorical parameters. Differential Revision: D65779384 --- aepsych/config.py | 19 ++- aepsych/server/server.py | 21 ++- aepsych/transforms/parameters.py | 241 ++++++++++++++++++++++++++++++- tests/test_transforms.py | 202 +++++++++++++++++++++++++- 4 files changed, 468 insertions(+), 15 deletions(-) diff --git a/aepsych/config.py b/aepsych/config.py index e5b0ca18d..1464311b7 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -176,14 +176,19 @@ def update( par_names = self.getlist( "common", "parnames", element_type=str, fallback=[] ) - lb = [None] * len(par_names) - ub = [None] * len(par_names) + lb = [] + ub = [] for i, par_name in enumerate(par_names): # Validate the parameter-specific block self._check_param_settings(par_name) - lb[i] = self[par_name]["lower_bound"] - ub[i] = self[par_name]["upper_bound"] + if self[par_name]["par_type"] == "categorical": + choices = self.getlist(par_name, "choices", element_type=str) + lb.append("0") + ub.append(str(len(choices) - 1)) + else: + lb.append(self[par_name]["lower_bound"]) + ub.append(self[par_name]["upper_bound"]) self["common"]["lb"] = f"[{', '.join(lb)}]" self["common"]["ub"] = f"[{', '.join(ub)}]" @@ -276,6 +281,12 @@ def _check_param_settings(self, param_name: str) -> None: and self.getint(param_name, "upper_bound") % 1 == 0 ): raise ValueError(f"Parameter {param_name} has non-integer bounds.") + elif param_block["par_type"] == "categorical": + # Need a choices array + if "choices" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the choices setting." + ) else: raise ValueError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/server/server.py b/aepsych/server/server.py index fd2d33e3e..df1dfdfce 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -276,22 +276,31 @@ def can_pregen_ask(self): return self.strat is not None and self.enable_pregen def _tensor_to_config(self, next_x): + next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0] config = {} for name, val in zip(self.parnames, next_x): - if val.dim() == 0: + if isinstance(val, str): + config[name] = [val] + elif isinstance(val, (int, float)): config[name] = [float(val)] + elif isinstance(val[0], str): + config[name] = val else: - config[name] = np.array(val) + config[name] = np.array(val, dtype="float64") return config def _config_to_tensor(self, config): unpacked = [config[name] for name in self.parnames] - - # handle config elements being either scalars or length-1 lists if isinstance(unpacked[0], list): - x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1) + x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1) else: - x = torch.tensor(np.stack(unpacked)) + x = np.stack(unpacked, dtype="O") + + # Unsqueeze batch dimension + x = np.expand_dims(x, 0) + + x = self.strat.transforms.str_to_indices(x)[0] + return x def __getstate__(self): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 45e5029c5..66798b617 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -33,6 +33,7 @@ Log10, Normalize, ReversibleInputTransform, + InputTransform, ) from botorch.models.transforms.utils import subset_transform from botorch.posteriors import Posterior @@ -55,6 +56,30 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin): space back into raw space. """ + def __init__( + self, + **transforms: InputTransform, + ) -> None: + self.cat_map_raw = {} + transform_keys = list(transforms.keys()) + for key in transform_keys: + if isinstance(transforms[key], Categorical): + categorical = transforms.pop(key) + self.cat_map_raw.update(categorical.cat_map_raw) + + if len(self.cat_map_raw) > 0: + # Remake the categorical and put it at the end + transforms["_CombinedCategorical"] = Categorical( + indices=list(self.cat_map_raw.keys()), categorical_map=self.cat_map_raw + ) + self.cat_map_transformed = transforms[ + "_CombinedCategorical" + ].cat_map_transformed + else: + self.cat_map_transformed = {} + + super().__init__(**transforms) + def _temporary_reshape(func: Callable) -> Callable: # Decorator to reshape tensors to the expected 2D shape, even if the input was # 1D or 3D and after the transform reshape it back to the original. @@ -102,8 +127,9 @@ def transform_bounds( ) -> Tensor: r"""Transform bounds of a parameter. - Individual transforms are applied in sequence. Then an adjustment is applied to - ensure the bounds are correct. + Individual transforms are applied in sequence. Looks for a specific + transform_bounds method in each transform to apply that, otherwise uses the + normal transform. Args: X: A tensor of inputs. Either `[dim]`, `[batch, dim]`, or `[batch, dim, stimuli]`. @@ -130,6 +156,45 @@ def untransform(self, X: Tensor) -> Tensor: """ return super().untransform(X) + @_temporary_reshape + def indices_to_str(self, X: Tensor) -> np.ndarray: + r"""Return a NumPy array of objects where the categorical parameters will be + strings. + + Args: + X (Tensor): A tensor shaped `[batch, dim]` to turn into a mixed type NumPy + array. + + Returns: + np.ndarray: An array with the objet type where the categorical parameters + are strings. + """ + obj_arr = X.cpu().numpy().astype("O") + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats[int(i)] for i in obj_arr[:, idx]] + + return obj_arr + + @_temporary_reshape + def str_to_indices(self, obj_arr: np.ndarray) -> Tensor: + r"""Return a Tensor where the categorical parameters are converted from strings + to indices. + + Args: + obj_arr (np.ndarray): A NumPy array `[batch, dim]` where the categorical + parameters are strings. + + Returns: + Tensor: A tensor with the categorical parameters converted to indices. + """ + obj_arr = obj_arr[:] + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats.index(cat) for cat in obj_arr[:, idx]] + + return torch.tensor(obj_arr.astype("float64"), dtype=torch.float64) + @classmethod def get_config_options( cls, @@ -185,6 +250,14 @@ def get_config_options( ) transform_dict[f"{par}_Round"] = round + # Categorical variable + elif par_type == "categorical": + categorical = Categorical.from_config( + config=config, name=par, options=transform_options + ) + + transform_dict[f"{par}_Categorical"] = categorical + # Log scale if config.getboolean(par, "log_scale", fallback=False): log10 = Log10Plus.from_config( @@ -197,8 +270,10 @@ def get_config_options( ) transform_dict[f"{par}_Log10Plus"] = log10 - # Normalize scale (defaults true) - if config.getboolean(par, "normalize_scale", fallback=True): + # Normalize scale (defaults true), don't do this for categoricals + if par_type != "categorical" and config.getboolean( + par, "normalize_scale", fallback=True + ): normalize = NormalizeScale.from_config( config=config, name=par, options=transform_options ) @@ -1135,6 +1210,164 @@ def _transform_bounds( return X +class Categorical(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): + is_one_to_many = True + + def __init__( + self, + indices: list[int], + categorical_map: Dict[int, List[str]], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a Categorical transform. Takes the integer at the indices and + converts it to one_hot starting from that indices (and therefore pushing) + forward other indices. + + Args: + indices: The indices of the inputs to turn into categoricals. + categorical_map: A dictionary where the key is the index of the categorical + variable and the values are the possible categories. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: Currently will not do anything, here to conform to + API. + reverse: A boolean indicating whether the forward pass should untransform + the parameters. + **kwargs: Accepted to conform to API. + """ + # indices needs to be sorted + indices = sorted(indices) + + # Multiple categoricals need to shift indices + categorical_offset = 0 + new_indices = [] + cat_map_transformed = {} + for idx in indices: + category_values = categorical_map[idx] + num_classes = len(categorical_map[idx]) + new_idx = idx + categorical_offset + categorical_offset += num_classes - 1 + + new_indices.append(new_idx) + cat_map_transformed[new_idx] = category_values + + super().__init__() + self.register_buffer("indices", torch.tensor(new_indices, dtype=torch.long)) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + self.cat_map_raw = categorical_map + self.cat_map_transformed = cat_map_transformed + + def _transform(self, X: torch.Tensor) -> torch.Tensor: + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + def _untransform(self, X: torch.Tensor) -> torch.Tensor: + for idx in reversed(self.indices): + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Chop up X around the one_hot + pre_categorical = X[:, :idx] + one_hot = X[:, idx : idx + num_classes] + post_categorical = X[:, idx + num_classes :] + + # Turn one_hot back into indices + idxs = torch.argmax(one_hot, dim=1).unsqueeze(-1).to(X) + + X = torch.cat((pre_categorical, idxs, post_categorical), dim=1) + + return X + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + if bound == "lb": + one_hot[:] = 0.0 + elif bound == "ub": + one_hot[:] = 1.0 + else: # Both bounds + one_hot[0, :] = 0.0 + one_hot[1, :] = 1.0 + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Return a dictionary of the relevant options to initialize the categorical + transform from the config for the named transform. + + Args: + config (Config): Config to look for options in. + name (str, optional): The parameter to find options for. + options (Dict[str, Any], optional): Options to override from the config, + defaults to None. + + Return: + Dict[str, Any]: A dictionary of options to initialize this class. + """ + options = _get_parameter_options(config, name, options) + + if "categorical_map" not in options: + if name is None: + raise ValueError("name argument must be set to initialize from config.") + + options["categorical_map"] = { + options["indices"][0]: config.getlist(name, "choices", element_type=str) + } + + return options + + def transform_options( config: Config, transforms: Optional[ChainedInputTransform] = None ) -> Config: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 599df2595..4128f9e75 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,20 +4,26 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest +import uuid import numpy as np import torch +from aepsych import server, utils_logging from aepsych.config import Config from aepsych.generators import SobolGenerator from aepsych.models import GPClassificationModel +from aepsych.server.message_handlers.handle_ask import ask +from aepsych.server.message_handlers.handle_setup import configure +from aepsych.server.message_handlers.handle_tell import tell from aepsych.strategy import SequentialStrategy from aepsych.transforms import ( ParameterTransformedGenerator, ParameterTransformedModel, ParameterTransforms, ) -from aepsych.transforms.parameters import Log10Plus, NormalizeScale +from aepsych.transforms.parameters import Categorical, Log10Plus, NormalizeScale class TransformsConfigTest(unittest.TestCase): @@ -398,3 +404,197 @@ def test_integer_model(self): est_max = x[np.argmin((zhat - target) ** 2)] diff = np.abs(est_max / 100 - target) self.assertTrue(diff < 0.15, f"Diff = {diff}") + + +class TransformCategorical(unittest.TestCase): + def test_categorical_model(self): + np.random.seed(1) + torch.manual_seed(1) + + n_init = 50 + n_opt = 1 + target = 0.75 + config_str = f""" + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + target = {target} + + [signal1] + par_type = categorical + choices = [red, green, blue] + + [signal2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = {n_init} + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = {n_opt} + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + transforms = strat.transforms + while not strat.finished: + points = strat.gen() + points = transforms.indices_to_str(points) + + if points[0][0] == "blue": + response = int(np.random.rand() < points[0][1]) + else: + response = 0 + + strat.add_data(transforms.str_to_indices(points), response) + + _, loc = strat.model.get_max() + loc = transforms.indices_to_str(loc)[0] + + self.assertTrue(loc[0] == "blue") + self.assertTrue(loc[1] - target < 0.15) + + def test_standalone_transform(self): + categorical_map = {1: ["red", "green", "blue"], 3: ["big", "small"]} + input = torch.tensor([[0.2, 2, 4, 0, 1], [0.5, 0, 3, 0, 1], [0.9, 1, 0, 1, 0]]) + input_cats = np.array( + [ + [0.2, "blue", 4, "big", "right"], + [0.5, "red", 3, "big", "right"], + [0.9, "green", 0, "small", "left"], + ], + dtype="O", + ) + + transforms = ParameterTransforms( + categorical1=Categorical([1, 3], categorical_map=categorical_map), + categorical2=Categorical([4], categorical_map={4: ["left", "right"]}), + ) + + self.assertTrue("_CombinedCategorical" in list(transforms.keys())) + self.assertTrue("categorical1" not in list(transforms.keys())) + + transformed = transforms.transform(input) + untransformed = transforms.untransform(transformed) + + self.assertTrue(torch.equal(input, untransformed)) + + strings = transforms.indices_to_str(input) + self.assertTrue(np.all(input_cats == strings)) + + indices = transforms.str_to_indices(input_cats) + self.assertTrue(torch.all(indices == input)) + + +class TransformServer(unittest.TestCase): + def setUp(self): + # setup logger + server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + # random datebase path name without dashes + database_path = "./{}.db".format(str(uuid.uuid4().hex)) + self.s = server.AEPsychServer(database_path=database_path) + + def tearDown(self): + self.s.cleanup() + + # cleanup the db + if self.s.db is not None: + self.s.db.delete_db() + + def test_categorical_smoketest(self): + server = self.s + config_str = f""" + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + target = 0.75 + + [signal1] + par_type = categorical + choices = [red, green, blue] + + [signal2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = 1 + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = 1 + """ + configure( + server, + config_str=config_str, + ) + + for _ in range(2): + next_config = ask(server) + + self.assertTrue(isinstance(next_config["signal1"][0], str)) + + tell(server, config=next_config, outcome=0) + + def test_pairwise_categorical(self): + server = self.s + config_str = """ + [common] + stimuli_per_trial=2 + outcome_types=[binary] + parnames = [x, y, z] + strategy_names = [init_strat, opt_strat] + + [x] + par_type = continuous + lower_bound = 1 + upper_bound = 4 + normalize_scale = False + + [y] + par_type = categorical + choices = [red, green, blue] + + [z] + par_type = discrete + lower_bound = 1 + upper_bound = 1000 + log_scale = True + + [init_strat] + min_asks = 1 + generator = SobolGenerator + + [opt_strat] + model = PairwiseProbitModel + min_asks = 1 + generator = OptimizeAcqfGenerator + acqf = PairwiseMCPosteriorVariance + + [PairwiseProbitModel] + mean_covar_factory = default_mean_covar_factory + + [PairwiseMCPosteriorVariance] + objective = ProbitObjective + """ + configure(server, config_str=config_str) + + for _ in range(2): + next_config = ask(server) + self.assertTrue(all([isinstance(val, str) for val in next_config["y"]])) + tell(server, config=next_config, outcome=0)