From b591b5f13474abf759431ed66f5a177436b3f806 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 20 Nov 2024 09:25:21 -0800 Subject: [PATCH 1/3] Add support for discrete parameters (#445) Summary: Discrete parameter support added via rounding transform. Reviewed By: crasanders Differential Revision: D65699942 --- aepsych/config.py | 16 ++++ aepsych/transforms/parameters.py | 159 ++++++++++++++++++++++++++++++- docs/parameters.md | 14 +++ tests/test_transforms.py | 88 +++++++++++++++++ 4 files changed, 272 insertions(+), 5 deletions(-) diff --git a/aepsych/config.py b/aepsych/config.py index 4753d60ba..e5b0ca18d 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -260,6 +260,22 @@ def _check_param_settings(self, param_name: str) -> None: raise ValueError( f"Parameter {param_name} is missing the upper_bound setting." ) + elif param_block["par_type"] == "integer": + # Check if bounds exist and actaully integers + if "lower_bound" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the lower_bound setting." + ) + if "upper_bound" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the upper_bound setting." + ) + + if not ( + self.getint(param_name, "lower_bound") % 1 == 0 + and self.getint(param_name, "upper_bound") % 1 == 0 + ): + raise ValueError(f"Parameter {param_name} has non-integer bounds.") else: raise ValueError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 066e90e3c..8d1fd75a5 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -9,7 +9,18 @@ from abc import ABC from configparser import NoOptionError from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) import numpy as np import torch @@ -17,7 +28,12 @@ from aepsych.generators.base import AEPsychGenerator from aepsych.models.base import AEPsychMixin, ModelProtocol from botorch.acquisition import AcquisitionFunction -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize +from botorch.models.transforms.input import ( + ChainedInputTransform, + Log10, + Normalize, + ReversibleInputTransform, +) from botorch.models.transforms.utils import subset_transform from botorch.posteriors import Posterior from torch import Tensor @@ -42,7 +58,7 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin): def _temporary_reshape(func: Callable) -> Callable: # Decorator to reshape tensors to the expected 2D shape, even if the input was # 1D or 3D and after the transform reshape it back to the original. - def wrapper(self, X: Tensor) -> Tensor: + def wrapper(self, X: Tensor, **kwargs) -> Tensor: squeeze = False if len(X.shape) == 1: # For 1D inputs, primarily for transforming arguments X = X.unsqueeze(0) @@ -54,7 +70,7 @@ def wrapper(self, X: Tensor) -> Tensor: X = X.swapaxes(-2, -1).reshape(-1, dim) reshape = True - X = func(self, X) + X = func(self, X, **kwargs) if reshape: X = X.reshape(batch, stim, -1).swapaxes(-1, -2) @@ -80,6 +96,38 @@ def transform(self, X: Tensor) -> Tensor: """ return super().transform(X) + @_temporary_reshape + def transform_bounds( + self, X: Tensor, bound: Optional[Literal["lb", "ub"]] = None + ) -> Tensor: + r"""Transform bounds of a parameter. + + Individual transforms are applied in sequence. Then an adjustment is applied to + ensure the bounds are correct. + + Args: + X: A tensor of inputs. Either `[dim]`, `[batch, dim]`, or `[batch, dim, stimuli]`. + + Returns: + A tensor of transformed inputs with the same shape as the input. + """ + for tf in self.values(): + # This is the entire reason this method exists to help handle the + # continuous relaxation necessary for discrete parameters. But this is + # super awkward. + if isinstance(tf, Round): + if bound == "lb": + X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices)) + elif bound == "ub": + X[0, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices)) + else: # Both bounds + X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices)) + X[1, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices)) + else: + X = tf.forward(X) + + return X + @_temporary_reshape def untransform(self, X: Tensor) -> Tensor: r"""Un-transform the inputs to a model. @@ -132,6 +180,22 @@ def get_config_options( for par in parnames: # This is the order that transforms are potentially applied, order matters + try: + par_type = config[par]["par_type"] + except KeyError: # Probably because par doesn't have its own section + par_type = "continuous" + + # Integer variable + if par_type == "integer": + round = Round.from_config( + config=config, name=par, options=transform_options + ) + + # Nudge bounds + transform_options["bounds"][0, round.indices] -= 0.5 + transform_options["bounds"][1, round.indices] += 0.5 - 1e-6 + transform_dict[f"{par}_Round"] = round + # Log scale if config.getboolean(par, "log_scale", fallback=False): log10 = Log10Plus.from_config( @@ -914,6 +978,88 @@ def get_config_options( return options +class Round(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): + def __init__( + self, + indices: list[int], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a round transform. This operation rounds the inputs at the indices + in both direction. + + Args: + indices: The indices of the inputs to round. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: Currently will not do anything, here to conform to + API. + reverse: Whether to round in forward or backward passes. + **kwargs: Accepted to conform to API. + """ + super().__init__() + self.register_buffer("indices", torch.tensor(indices, dtype=torch.long)) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + + @subset_transform + def _transform(self, X: torch.Tensor) -> torch.Tensor: + r"""Round the inputs to a model to be discrete. This rounding is the same both + in the forward and the backward pass. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + torch.Tensor: The input tensor with values rounded. + """ + return X.round() + + @subset_transform + def _untransform(self, X: Tensor) -> Tensor: + r"""Round the inputs to a model to be discrete. This rounding is the same both + in the forward and the backward pass. + + Args: + X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + torch.Tensor: The input tensor with values rounded. + """ + return X.round() + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Return a dictionary of the relevant options to initialize the Round transform + from the config for the named transform. + + Args: + config (Config): Config to look for options in. + name (str, optional): The parameter to find options for. + options (Dict[str, Any], optional): Options to override from the config, + defaults to None. + + Return: + Dict[str, Any]: A dictionary of options to initialize this class. + """ + options = _get_parameter_options(config, name, options) + + return options + + def transform_options( config: Config, transforms: Optional[ChainedInputTransform] = None ) -> Config: @@ -938,7 +1084,10 @@ def transform_options( value = np.array(value, dtype=float) value = torch.tensor(value).to(torch.float64) - value = transforms.transform(value) + if option in ["ub", "lb"]: + value = transforms.transform_bounds(value, bound=option) + else: + value = transforms.transform(value) def _arr_to_list(iter): if hasattr(iter, "__iter__"): diff --git a/docs/parameters.md b/docs/parameters.md index ef7dfaf14..f93526b66 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -27,6 +27,19 @@ parameters can have any non-infinite ranges. This means that continuous paramete include negative values (e.g., lower bound = -1, upper bound = 1) or have very large ranges (e.g., lower bound = 0, upper bound = 1,000,000). +

Integer

+```ini +[parameter] +par_type = integer +lower_bound = -5 +upper_bound = 5 +``` + +Integer parameters are similar to continuous parameters insofar as its possible range +and necessity of bounds. However, integer parameters will use continuous relaxation to +allow the models and generators to handle integer input/outputs. For example, this could +represent the number of lights are on for a detection threshold experiment. +

Parameter Transformations

Currently, we only support a log scale transformation to parameters. More parameter transformations to come! In general, you can define your parameters in the raw @@ -83,5 +96,6 @@ of operation, regardless of how the options were set in the config file. Each pa is transformed entirely separately. Currently, the order is as follows: +* Rounding for integer parameters (rounding is applied in both directions) * Log scale * Normalize scale \ No newline at end of file diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e40c5e489..599df2595 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -310,3 +310,91 @@ def test_normalize_scale(self): self.assertTrue(torch.allclose(transformed, expected)) self.assertTrue(torch.allclose(transforms.untransform(transformed), values)) + + +class TransformInteger(unittest.TestCase): + def test_integer_bounds(self): + config_str = """ + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat] + + [signal1] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [signal2] + par_type = integer + lower_bound = 1 + upper_bound = 5 + + [init_strat] + generator = SobolGenerator + min_asks = 1 + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + points = strat.gen()[0] + + self.assertTrue((points[0] % 1).item() != 0.0) + self.assertTrue((points[1] % 1).item() == 0.0) + self.assertTrue(torch.all(strat._strat.generator.lb == 0)) + self.assertTrue(torch.all(strat._strat.generator.ub == 1)) + + def test_integer_model(self): + np.random.seed(1) + torch.manual_seed(1) + + lower_bound = 1 + upper_bound = 100 + target = 0.75 + + config_str = f""" + [common] + parnames = [signal1] + stimuli_per_trial = 1 + outcome_types = [binary] + target = {target} + strategy_names = [init_strat, opt_strat] + + [signal1] + par_type = integer + lower_bound = {lower_bound} + upper_bound = {upper_bound} + + [init_strat] + generator = SobolGenerator + min_total_tells = 50 + + [SobolGenerator] + seed = 1 + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_total_tells = 1 + """ + + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + + while not strat.finished: + next_x = strat.gen() + self.assertTrue((next_x % 1).item() == 0.0) + response = int(np.random.rand() < (next_x / 100)) + strat.add_data(next_x, [response]) + + x = torch.linspace(lower_bound, upper_bound, 100) + + zhat, _ = strat.predict(x) + est_max = x[np.argmin((zhat - target) ** 2)] + diff = np.abs(est_max / 100 - target) + self.assertTrue(diff < 0.15, f"Diff = {diff}") From 399bd641cf3f8961ff013be4d7fd41c913f58841 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 20 Nov 2024 09:25:21 -0800 Subject: [PATCH 2/3] Make a transform base class for consistent API (#452) Summary: Transforming bounds requires additional logic that used to be part of ParameterTransforms, we move these to the parameters itself and have the ParameterTransforms look for these special methods when transforming bounds. We add a new ABC for our transforms as going forward it is likely that all of our transforms will have unique capabilities over the BoTorch base. This includes how we handle some finding options from configs. Reviewed By: crasanders Differential Revision: D65897908 --- aepsych/transforms/parameters.py | 187 +++++++++++++++++++------------ 1 file changed, 115 insertions(+), 72 deletions(-) diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 8d1fd75a5..45e5029c5 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -112,19 +112,7 @@ def transform_bounds( A tensor of transformed inputs with the same shape as the input. """ for tf in self.values(): - # This is the entire reason this method exists to help handle the - # continuous relaxation necessary for discrete parameters. But this is - # super awkward. - if isinstance(tf, Round): - if bound == "lb": - X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices)) - elif bound == "ub": - X[0, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices)) - else: # Both bounds - X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices)) - X[1, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices)) - else: - X = tf.forward(X) + X = tf.transform_bounds(X, bound=bound) return X @@ -191,9 +179,10 @@ def get_config_options( config=config, name=par, options=transform_options ) - # Nudge bounds - transform_options["bounds"][0, round.indices] -= 0.5 - transform_options["bounds"][1, round.indices] += 0.5 - 1e-6 + # Transform bounds + transform_options["bounds"] = round.transform_bounds( + transform_options["bounds"] + ) transform_dict[f"{par}_Round"] = round # Log scale @@ -784,7 +773,67 @@ def get_config_options( return options -class Log10Plus(Log10, ConfigurableMixin): +class Transform(ReversibleInputTransform, ConfigurableMixin, ABC): + """Base class for individual transforms. These transforms are intended to be stacked + together using the ParameterTransforms class. + """ + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): Which bound this is to transform, if + None, it's the `[2, dim]` form with both bounds stacked. + **kwargs: Keyword arguments for specific transforms, they should have + default values. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + return self.transform(X) + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Return a dictionary of the relevant options to initialize a Log10Plus + transform for the named parameter within the config. + + Args: + config (Config): Config to look for options in. + name (str): Parameter to find options for. + options (Dict[str, Any]): Options to override from the config. + + Returns: + Dict[str, Any]: A diciontary of options to initialize this class with, + including the transformed bounds. + """ + if name is None: + raise ValueError(f"{name} must be set to initialize a transform.") + + if options is None: + options = {} + else: + options = deepcopy(options) + + # Figure out the index of this parameter + parnames = config.getlist("common", "parnames", element_type=str) + idx = parnames.index(name) + + if "indices" not in options: + options["indices"] = [idx] + + return options + + +class Log10Plus(Log10, Transform): """Base-10 log transform that we add a constant to the values""" def __init__( @@ -867,7 +916,7 @@ def get_config_options( Dict[str, Any]: A diciontary of options to initialize this class with, including the transformed bounds. """ - options = _get_parameter_options(config, name, options) + options = super().get_config_options(config=config, name=name, options=options) # Make sure we have bounds ready if "bounds" not in options: @@ -887,7 +936,7 @@ def get_config_options( return options -class NormalizeScale(Normalize, ConfigurableMixin): +class NormalizeScale(Normalize, Transform): def __init__( self, d: int, @@ -965,20 +1014,19 @@ def get_config_options( Dict[str, Any]: A diciontary of options to initialize this class with, including the transformed bounds. """ - options = _get_parameter_options(config, name, options) + options = super().get_config_options(config=config, name=name, options=options) # Make sure we have bounds ready if "bounds" not in options: options["bounds"] = get_bounds(config) if "d" not in options: - parnames = config.getlist("common", "parnames", element_type=str) - options["d"] = len(parnames) + options["d"] = options["bounds"].shape[1] return options -class Round(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): +class Round(Transform, torch.nn.Module): def __init__( self, indices: list[int], @@ -1035,29 +1083,56 @@ def _untransform(self, X: Tensor) -> Tensor: """ return X.round() - @classmethod - def get_config_options( - cls, - config: Config, - name: Optional[str] = None, - options: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + **kwargs: passed to _transform_bounds + epsilon: will modify the offset for the rounding to ensure each discrete + value has equal space in the parameter space. + + Returns: + torch.Tensor: A transformed set of parameter bounds. """ - Return a dictionary of the relevant options to initialize the Round transform - from the config for the named transform. + epsilon = kwargs.get("epsilon", 1e-6) + return self._transform_bounds(X, bound=bound, epsilon=epsilon) + + def _transform_bounds( + self, + X: torch.Tensor, + bound: Optional[Literal["lb", "ub"]] = None, + epsilon: float = 1e-6, + ) -> torch.Tensor: + r"""Return the bounds X transformed. Args: - config (Config): Config to look for options in. - name (str, optional): The parameter to find options for. - options (Dict[str, Any], optional): Options to override from the config, - defaults to None. + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + epsilon: + **kwargs: other kwargs - Return: - Dict[str, Any]: A dictionary of options to initialize this class. + Returns: + torch.Tensor: A transformed set of parameter bounds. """ - options = _get_parameter_options(config, name, options) + X = X.clone() - return options + if bound == "lb": + X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) + elif bound == "ub": + X[0, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices)) + else: # Both bounds + X[0, self.indices] -= torch.tensor([0.5] * len(self.indices)) + X[1, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices)) + + return X def transform_options( @@ -1131,35 +1206,3 @@ def get_bounds(config: Config) -> torch.Tensor: bounds = torch.stack((_lb, _ub)) return bounds - - -def _get_parameter_options( - config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - """Return options for a parameter in a config. - - Args: - config (Config): Config to search for parameter. - name (str): Name of parameter. - options (Dict[str, Any], optional): dictionary of options to overwrite config - options, defaults to an empty dictionary. - - Returns: - Dict[str, Any]: Dictionary of options to initialize a transform from config. - """ - if name is None: - raise ValueError(f"{name} must be set to initialize a transform.") - - if options is None: - options = {} - else: - options = deepcopy(options) - - # Figure out the index of this parameter - parnames = config.getlist("common", "parnames", element_type=str) - idx = parnames.index(name) - - if "indices" not in options: - options["indices"] = [idx] - - return options From c24bbd918818d4dd4102aec4336fe073cf7fc8b6 Mon Sep 17 00:00:00 2001 From: Jason Chow Date: Wed, 20 Nov 2024 09:25:21 -0800 Subject: [PATCH 3/3] add support for categorical parameters (#449) Summary: Added support for categorical parameters. This required changing some server functions to be able to take and give strings as a part of the responses. Externally, categorical parameters are strings. Internally, categorical parameters are represented as index parameters (0, nChoices -1) while they are passed to models as one_hot vectored parameters. The index intermediate could be useful in the future for alternative ways to model categorical parameters. Differential Revision: D65779384 --- aepsych/config.py | 19 ++- aepsych/server/server.py | 21 ++- aepsych/transforms/parameters.py | 241 ++++++++++++++++++++++++++++++- tests/test_transforms.py | 202 +++++++++++++++++++++++++- 4 files changed, 468 insertions(+), 15 deletions(-) diff --git a/aepsych/config.py b/aepsych/config.py index e5b0ca18d..1464311b7 100644 --- a/aepsych/config.py +++ b/aepsych/config.py @@ -176,14 +176,19 @@ def update( par_names = self.getlist( "common", "parnames", element_type=str, fallback=[] ) - lb = [None] * len(par_names) - ub = [None] * len(par_names) + lb = [] + ub = [] for i, par_name in enumerate(par_names): # Validate the parameter-specific block self._check_param_settings(par_name) - lb[i] = self[par_name]["lower_bound"] - ub[i] = self[par_name]["upper_bound"] + if self[par_name]["par_type"] == "categorical": + choices = self.getlist(par_name, "choices", element_type=str) + lb.append("0") + ub.append(str(len(choices) - 1)) + else: + lb.append(self[par_name]["lower_bound"]) + ub.append(self[par_name]["upper_bound"]) self["common"]["lb"] = f"[{', '.join(lb)}]" self["common"]["ub"] = f"[{', '.join(ub)}]" @@ -276,6 +281,12 @@ def _check_param_settings(self, param_name: str) -> None: and self.getint(param_name, "upper_bound") % 1 == 0 ): raise ValueError(f"Parameter {param_name} has non-integer bounds.") + elif param_block["par_type"] == "categorical": + # Need a choices array + if "choices" not in param_block: + raise ValueError( + f"Parameter {param_name} is missing the choices setting." + ) else: raise ValueError( f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}." diff --git a/aepsych/server/server.py b/aepsych/server/server.py index fd2d33e3e..df1dfdfce 100644 --- a/aepsych/server/server.py +++ b/aepsych/server/server.py @@ -276,22 +276,31 @@ def can_pregen_ask(self): return self.strat is not None and self.enable_pregen def _tensor_to_config(self, next_x): + next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0] config = {} for name, val in zip(self.parnames, next_x): - if val.dim() == 0: + if isinstance(val, str): + config[name] = [val] + elif isinstance(val, (int, float)): config[name] = [float(val)] + elif isinstance(val[0], str): + config[name] = val else: - config[name] = np.array(val) + config[name] = np.array(val, dtype="float64") return config def _config_to_tensor(self, config): unpacked = [config[name] for name in self.parnames] - - # handle config elements being either scalars or length-1 lists if isinstance(unpacked[0], list): - x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1) + x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1) else: - x = torch.tensor(np.stack(unpacked)) + x = np.stack(unpacked, dtype="O") + + # Unsqueeze batch dimension + x = np.expand_dims(x, 0) + + x = self.strat.transforms.str_to_indices(x)[0] + return x def __getstate__(self): diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index 45e5029c5..66798b617 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -33,6 +33,7 @@ Log10, Normalize, ReversibleInputTransform, + InputTransform, ) from botorch.models.transforms.utils import subset_transform from botorch.posteriors import Posterior @@ -55,6 +56,30 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin): space back into raw space. """ + def __init__( + self, + **transforms: InputTransform, + ) -> None: + self.cat_map_raw = {} + transform_keys = list(transforms.keys()) + for key in transform_keys: + if isinstance(transforms[key], Categorical): + categorical = transforms.pop(key) + self.cat_map_raw.update(categorical.cat_map_raw) + + if len(self.cat_map_raw) > 0: + # Remake the categorical and put it at the end + transforms["_CombinedCategorical"] = Categorical( + indices=list(self.cat_map_raw.keys()), categorical_map=self.cat_map_raw + ) + self.cat_map_transformed = transforms[ + "_CombinedCategorical" + ].cat_map_transformed + else: + self.cat_map_transformed = {} + + super().__init__(**transforms) + def _temporary_reshape(func: Callable) -> Callable: # Decorator to reshape tensors to the expected 2D shape, even if the input was # 1D or 3D and after the transform reshape it back to the original. @@ -102,8 +127,9 @@ def transform_bounds( ) -> Tensor: r"""Transform bounds of a parameter. - Individual transforms are applied in sequence. Then an adjustment is applied to - ensure the bounds are correct. + Individual transforms are applied in sequence. Looks for a specific + transform_bounds method in each transform to apply that, otherwise uses the + normal transform. Args: X: A tensor of inputs. Either `[dim]`, `[batch, dim]`, or `[batch, dim, stimuli]`. @@ -130,6 +156,45 @@ def untransform(self, X: Tensor) -> Tensor: """ return super().untransform(X) + @_temporary_reshape + def indices_to_str(self, X: Tensor) -> np.ndarray: + r"""Return a NumPy array of objects where the categorical parameters will be + strings. + + Args: + X (Tensor): A tensor shaped `[batch, dim]` to turn into a mixed type NumPy + array. + + Returns: + np.ndarray: An array with the objet type where the categorical parameters + are strings. + """ + obj_arr = X.cpu().numpy().astype("O") + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats[int(i)] for i in obj_arr[:, idx]] + + return obj_arr + + @_temporary_reshape + def str_to_indices(self, obj_arr: np.ndarray) -> Tensor: + r"""Return a Tensor where the categorical parameters are converted from strings + to indices. + + Args: + obj_arr (np.ndarray): A NumPy array `[batch, dim]` where the categorical + parameters are strings. + + Returns: + Tensor: A tensor with the categorical parameters converted to indices. + """ + obj_arr = obj_arr[:] + + for idx, cats in self.cat_map_raw.items(): + obj_arr[:, idx] = [cats.index(cat) for cat in obj_arr[:, idx]] + + return torch.tensor(obj_arr.astype("float64"), dtype=torch.float64) + @classmethod def get_config_options( cls, @@ -185,6 +250,14 @@ def get_config_options( ) transform_dict[f"{par}_Round"] = round + # Categorical variable + elif par_type == "categorical": + categorical = Categorical.from_config( + config=config, name=par, options=transform_options + ) + + transform_dict[f"{par}_Categorical"] = categorical + # Log scale if config.getboolean(par, "log_scale", fallback=False): log10 = Log10Plus.from_config( @@ -197,8 +270,10 @@ def get_config_options( ) transform_dict[f"{par}_Log10Plus"] = log10 - # Normalize scale (defaults true) - if config.getboolean(par, "normalize_scale", fallback=True): + # Normalize scale (defaults true), don't do this for categoricals + if par_type != "categorical" and config.getboolean( + par, "normalize_scale", fallback=True + ): normalize = NormalizeScale.from_config( config=config, name=par, options=transform_options ) @@ -1135,6 +1210,164 @@ def _transform_bounds( return X +class Categorical(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin): + is_one_to_many = True + + def __init__( + self, + indices: list[int], + categorical_map: Dict[int, List[str]], + transform_on_train: bool = True, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + reverse: bool = False, + **kwargs, + ) -> None: + """Initialize a Categorical transform. Takes the integer at the indices and + converts it to one_hot starting from that indices (and therefore pushing) + forward other indices. + + Args: + indices: The indices of the inputs to turn into categoricals. + categorical_map: A dictionary where the key is the index of the categorical + variable and the values are the possible categories. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: Currently will not do anything, here to conform to + API. + reverse: A boolean indicating whether the forward pass should untransform + the parameters. + **kwargs: Accepted to conform to API. + """ + # indices needs to be sorted + indices = sorted(indices) + + # Multiple categoricals need to shift indices + categorical_offset = 0 + new_indices = [] + cat_map_transformed = {} + for idx in indices: + category_values = categorical_map[idx] + num_classes = len(categorical_map[idx]) + new_idx = idx + categorical_offset + categorical_offset += num_classes - 1 + + new_indices.append(new_idx) + cat_map_transformed[new_idx] = category_values + + super().__init__() + self.register_buffer("indices", torch.tensor(new_indices, dtype=torch.long)) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + self.reverse = reverse + self.cat_map_raw = categorical_map + self.cat_map_transformed = cat_map_transformed + + def _transform(self, X: torch.Tensor) -> torch.Tensor: + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + def _untransform(self, X: torch.Tensor) -> torch.Tensor: + for idx in reversed(self.indices): + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Chop up X around the one_hot + pre_categorical = X[:, :idx] + one_hot = X[:, idx : idx + num_classes] + post_categorical = X[:, idx + num_classes :] + + # Turn one_hot back into indices + idxs = torch.argmax(one_hot, dim=1).unsqueeze(-1).to(X) + + X = torch.cat((pre_categorical, idxs, post_categorical), dim=1) + + return X + + def transform_bounds( + self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None + ) -> torch.Tensor: + r"""Return the bounds X transformed. + + Args: + X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter + bounds. + bound (Literal["lb", "ub"], optional): The bound that this is, if None, we + will assume the input is both bounds with a `[2, dim]` X. + + Returns: + torch.Tensor: A transformed set of parameter bounds. + """ + for idx in self.indices: + num_classes = len(self.cat_map_transformed[idx.item()]) + + # Turns indices into one hot + idxs = X[:, idx].to(torch.long) + one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes) + one_hot = one_hot.view(X.shape[0], num_classes) + + if bound == "lb": + one_hot[:] = 0.0 + elif bound == "ub": + one_hot[:] = 1.0 + else: # Both bounds + one_hot[0, :] = 0.0 + one_hot[1, :] = 1.0 + + # Chop up X and stick one_hot in + pre_categorical = X[:, :idx] + post_categorical = X[:, idx + 1 :] + X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1) + + return X + + @classmethod + def get_config_options( + cls, + config: Config, + name: Optional[str] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Return a dictionary of the relevant options to initialize the categorical + transform from the config for the named transform. + + Args: + config (Config): Config to look for options in. + name (str, optional): The parameter to find options for. + options (Dict[str, Any], optional): Options to override from the config, + defaults to None. + + Return: + Dict[str, Any]: A dictionary of options to initialize this class. + """ + options = _get_parameter_options(config, name, options) + + if "categorical_map" not in options: + if name is None: + raise ValueError("name argument must be set to initialize from config.") + + options["categorical_map"] = { + options["indices"][0]: config.getlist(name, "choices", element_type=str) + } + + return options + + def transform_options( config: Config, transforms: Optional[ChainedInputTransform] = None ) -> Config: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 599df2595..4128f9e75 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,20 +4,26 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest +import uuid import numpy as np import torch +from aepsych import server, utils_logging from aepsych.config import Config from aepsych.generators import SobolGenerator from aepsych.models import GPClassificationModel +from aepsych.server.message_handlers.handle_ask import ask +from aepsych.server.message_handlers.handle_setup import configure +from aepsych.server.message_handlers.handle_tell import tell from aepsych.strategy import SequentialStrategy from aepsych.transforms import ( ParameterTransformedGenerator, ParameterTransformedModel, ParameterTransforms, ) -from aepsych.transforms.parameters import Log10Plus, NormalizeScale +from aepsych.transforms.parameters import Categorical, Log10Plus, NormalizeScale class TransformsConfigTest(unittest.TestCase): @@ -398,3 +404,197 @@ def test_integer_model(self): est_max = x[np.argmin((zhat - target) ** 2)] diff = np.abs(est_max / 100 - target) self.assertTrue(diff < 0.15, f"Diff = {diff}") + + +class TransformCategorical(unittest.TestCase): + def test_categorical_model(self): + np.random.seed(1) + torch.manual_seed(1) + + n_init = 50 + n_opt = 1 + target = 0.75 + config_str = f""" + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + target = {target} + + [signal1] + par_type = categorical + choices = [red, green, blue] + + [signal2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = {n_init} + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = {n_opt} + """ + config = Config() + config.update(config_str=config_str) + + strat = SequentialStrategy.from_config(config) + transforms = strat.transforms + while not strat.finished: + points = strat.gen() + points = transforms.indices_to_str(points) + + if points[0][0] == "blue": + response = int(np.random.rand() < points[0][1]) + else: + response = 0 + + strat.add_data(transforms.str_to_indices(points), response) + + _, loc = strat.model.get_max() + loc = transforms.indices_to_str(loc)[0] + + self.assertTrue(loc[0] == "blue") + self.assertTrue(loc[1] - target < 0.15) + + def test_standalone_transform(self): + categorical_map = {1: ["red", "green", "blue"], 3: ["big", "small"]} + input = torch.tensor([[0.2, 2, 4, 0, 1], [0.5, 0, 3, 0, 1], [0.9, 1, 0, 1, 0]]) + input_cats = np.array( + [ + [0.2, "blue", 4, "big", "right"], + [0.5, "red", 3, "big", "right"], + [0.9, "green", 0, "small", "left"], + ], + dtype="O", + ) + + transforms = ParameterTransforms( + categorical1=Categorical([1, 3], categorical_map=categorical_map), + categorical2=Categorical([4], categorical_map={4: ["left", "right"]}), + ) + + self.assertTrue("_CombinedCategorical" in list(transforms.keys())) + self.assertTrue("categorical1" not in list(transforms.keys())) + + transformed = transforms.transform(input) + untransformed = transforms.untransform(transformed) + + self.assertTrue(torch.equal(input, untransformed)) + + strings = transforms.indices_to_str(input) + self.assertTrue(np.all(input_cats == strings)) + + indices = transforms.str_to_indices(input_cats) + self.assertTrue(torch.all(indices == input)) + + +class TransformServer(unittest.TestCase): + def setUp(self): + # setup logger + server.logger = utils_logging.getLogger(logging.DEBUG, "logs") + # random datebase path name without dashes + database_path = "./{}.db".format(str(uuid.uuid4().hex)) + self.s = server.AEPsychServer(database_path=database_path) + + def tearDown(self): + self.s.cleanup() + + # cleanup the db + if self.s.db is not None: + self.s.db.delete_db() + + def test_categorical_smoketest(self): + server = self.s + config_str = f""" + [common] + parnames = [signal1, signal2] + stimuli_per_trial = 1 + outcome_types = [binary] + strategy_names = [init_strat, opt_strat] + target = 0.75 + + [signal1] + par_type = categorical + choices = [red, green, blue] + + [signal2] + par_type = continuous + lower_bound = 0 + upper_bound = 1 + + [init_strat] + generator = SobolGenerator + min_asks = 1 + + [opt_strat] + generator = OptimizeAcqfGenerator + acqf = MCLevelSetEstimation + model = GPClassificationModel + min_asks = 1 + """ + configure( + server, + config_str=config_str, + ) + + for _ in range(2): + next_config = ask(server) + + self.assertTrue(isinstance(next_config["signal1"][0], str)) + + tell(server, config=next_config, outcome=0) + + def test_pairwise_categorical(self): + server = self.s + config_str = """ + [common] + stimuli_per_trial=2 + outcome_types=[binary] + parnames = [x, y, z] + strategy_names = [init_strat, opt_strat] + + [x] + par_type = continuous + lower_bound = 1 + upper_bound = 4 + normalize_scale = False + + [y] + par_type = categorical + choices = [red, green, blue] + + [z] + par_type = discrete + lower_bound = 1 + upper_bound = 1000 + log_scale = True + + [init_strat] + min_asks = 1 + generator = SobolGenerator + + [opt_strat] + model = PairwiseProbitModel + min_asks = 1 + generator = OptimizeAcqfGenerator + acqf = PairwiseMCPosteriorVariance + + [PairwiseProbitModel] + mean_covar_factory = default_mean_covar_factory + + [PairwiseMCPosteriorVariance] + objective = ProbitObjective + """ + configure(server, config_str=config_str) + + for _ in range(2): + next_config = ask(server) + self.assertTrue(all([isinstance(val, str) for val in next_config["y"]])) + tell(server, config=next_config, outcome=0)