Skip to content

Commit

Permalink
add support for categorical parameters (#449)
Browse files Browse the repository at this point in the history
Summary:

Added support for categorical parameters. This required changing some server functions to be able to take and give strings as a part of the responses.

Externally, categorical parameters are strings. Internally, categorical parameters are represented as index parameters (0, nChoices -1) while they are passed to models as one_hot vectored parameters. The index intermediate could be useful in the future for alternative ways to model categorical parameters.

Differential Revision: D65779384
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 20, 2024
1 parent 399bd64 commit c24bbd9
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 15 deletions.
19 changes: 15 additions & 4 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,19 @@ def update(
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
lb = []
ub = []
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]
if self[par_name]["par_type"] == "categorical":
choices = self.getlist(par_name, "choices", element_type=str)
lb.append("0")
ub.append(str(len(choices) - 1))
else:
lb.append(self[par_name]["lower_bound"])
ub.append(self[par_name]["upper_bound"])

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
Expand Down Expand Up @@ -276,6 +281,12 @@ def _check_param_settings(self, param_name: str) -> None:
and self.getint(param_name, "upper_bound") % 1 == 0
):
raise ValueError(f"Parameter {param_name} has non-integer bounds.")
elif param_block["par_type"] == "categorical":
# Need a choices array
if "choices" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the choices setting."
)
else:
raise ValueError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
21 changes: 15 additions & 6 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,22 +276,31 @@ def can_pregen_ask(self):
return self.strat is not None and self.enable_pregen

def _tensor_to_config(self, next_x):
next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0]
config = {}
for name, val in zip(self.parnames, next_x):
if val.dim() == 0:
if isinstance(val, str):
config[name] = [val]
elif isinstance(val, (int, float)):
config[name] = [float(val)]
elif isinstance(val[0], str):
config[name] = val
else:
config[name] = np.array(val)
config[name] = np.array(val, dtype="float64")
return config

def _config_to_tensor(self, config):
unpacked = [config[name] for name in self.parnames]

# handle config elements being either scalars or length-1 lists
if isinstance(unpacked[0], list):
x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1)
x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1)
else:
x = torch.tensor(np.stack(unpacked))
x = np.stack(unpacked, dtype="O")

# Unsqueeze batch dimension
x = np.expand_dims(x, 0)

x = self.strat.transforms.str_to_indices(x)[0]

return x

def __getstate__(self):
Expand Down
241 changes: 237 additions & 4 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Log10,
Normalize,
ReversibleInputTransform,
InputTransform,
)
from botorch.models.transforms.utils import subset_transform
from botorch.posteriors import Posterior
Expand All @@ -55,6 +56,30 @@ class ParameterTransforms(ChainedInputTransform, ConfigurableMixin):
space back into raw space.
"""

def __init__(
self,
**transforms: InputTransform,
) -> None:
self.cat_map_raw = {}
transform_keys = list(transforms.keys())
for key in transform_keys:
if isinstance(transforms[key], Categorical):
categorical = transforms.pop(key)
self.cat_map_raw.update(categorical.cat_map_raw)

if len(self.cat_map_raw) > 0:
# Remake the categorical and put it at the end
transforms["_CombinedCategorical"] = Categorical(
indices=list(self.cat_map_raw.keys()), categorical_map=self.cat_map_raw
)
self.cat_map_transformed = transforms[
"_CombinedCategorical"
].cat_map_transformed
else:
self.cat_map_transformed = {}

super().__init__(**transforms)

def _temporary_reshape(func: Callable) -> Callable:
# Decorator to reshape tensors to the expected 2D shape, even if the input was
# 1D or 3D and after the transform reshape it back to the original.
Expand Down Expand Up @@ -102,8 +127,9 @@ def transform_bounds(
) -> Tensor:
r"""Transform bounds of a parameter.
Individual transforms are applied in sequence. Then an adjustment is applied to
ensure the bounds are correct.
Individual transforms are applied in sequence. Looks for a specific
transform_bounds method in each transform to apply that, otherwise uses the
normal transform.
Args:
X: A tensor of inputs. Either `[dim]`, `[batch, dim]`, or `[batch, dim, stimuli]`.
Expand All @@ -130,6 +156,45 @@ def untransform(self, X: Tensor) -> Tensor:
"""
return super().untransform(X)

@_temporary_reshape
def indices_to_str(self, X: Tensor) -> np.ndarray:
r"""Return a NumPy array of objects where the categorical parameters will be
strings.
Args:
X (Tensor): A tensor shaped `[batch, dim]` to turn into a mixed type NumPy
array.
Returns:
np.ndarray: An array with the objet type where the categorical parameters
are strings.
"""
obj_arr = X.cpu().numpy().astype("O")

for idx, cats in self.cat_map_raw.items():
obj_arr[:, idx] = [cats[int(i)] for i in obj_arr[:, idx]]

return obj_arr

@_temporary_reshape
def str_to_indices(self, obj_arr: np.ndarray) -> Tensor:
r"""Return a Tensor where the categorical parameters are converted from strings
to indices.
Args:
obj_arr (np.ndarray): A NumPy array `[batch, dim]` where the categorical
parameters are strings.
Returns:
Tensor: A tensor with the categorical parameters converted to indices.
"""
obj_arr = obj_arr[:]

for idx, cats in self.cat_map_raw.items():
obj_arr[:, idx] = [cats.index(cat) for cat in obj_arr[:, idx]]

return torch.tensor(obj_arr.astype("float64"), dtype=torch.float64)

@classmethod
def get_config_options(
cls,
Expand Down Expand Up @@ -185,6 +250,14 @@ def get_config_options(
)
transform_dict[f"{par}_Round"] = round

# Categorical variable
elif par_type == "categorical":
categorical = Categorical.from_config(
config=config, name=par, options=transform_options
)

transform_dict[f"{par}_Categorical"] = categorical

# Log scale
if config.getboolean(par, "log_scale", fallback=False):
log10 = Log10Plus.from_config(
Expand All @@ -197,8 +270,10 @@ def get_config_options(
)
transform_dict[f"{par}_Log10Plus"] = log10

# Normalize scale (defaults true)
if config.getboolean(par, "normalize_scale", fallback=True):
# Normalize scale (defaults true), don't do this for categoricals
if par_type != "categorical" and config.getboolean(
par, "normalize_scale", fallback=True
):
normalize = NormalizeScale.from_config(
config=config, name=par, options=transform_options
)
Expand Down Expand Up @@ -1135,6 +1210,164 @@ def _transform_bounds(
return X


class Categorical(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin):
is_one_to_many = True

def __init__(
self,
indices: list[int],
categorical_map: Dict[int, List[str]],
transform_on_train: bool = True,
transform_on_eval: bool = True,
transform_on_fantasize: bool = True,
reverse: bool = False,
**kwargs,
) -> None:
"""Initialize a Categorical transform. Takes the integer at the indices and
converts it to one_hot starting from that indices (and therefore pushing)
forward other indices.
Args:
indices: The indices of the inputs to turn into categoricals.
categorical_map: A dictionary where the key is the index of the categorical
variable and the values are the possible categories.
transform_on_train: A boolean indicating whether to apply the
transforms in train() mode. Default: True.
transform_on_eval: A boolean indicating whether to apply the
transform in eval() mode. Default: True.
transform_on_fantasize: Currently will not do anything, here to conform to
API.
reverse: A boolean indicating whether the forward pass should untransform
the parameters.
**kwargs: Accepted to conform to API.
"""
# indices needs to be sorted
indices = sorted(indices)

# Multiple categoricals need to shift indices
categorical_offset = 0
new_indices = []
cat_map_transformed = {}
for idx in indices:
category_values = categorical_map[idx]
num_classes = len(categorical_map[idx])
new_idx = idx + categorical_offset
categorical_offset += num_classes - 1

new_indices.append(new_idx)
cat_map_transformed[new_idx] = category_values

super().__init__()
self.register_buffer("indices", torch.tensor(new_indices, dtype=torch.long))
self.transform_on_train = transform_on_train
self.transform_on_eval = transform_on_eval
self.transform_on_fantasize = transform_on_fantasize
self.reverse = reverse
self.cat_map_raw = categorical_map
self.cat_map_transformed = cat_map_transformed

def _transform(self, X: torch.Tensor) -> torch.Tensor:
for idx in self.indices:
num_classes = len(self.cat_map_transformed[idx.item()])

# Turns indices into one hot
idxs = X[:, idx].to(torch.long)
one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes)
one_hot = one_hot.view(X.shape[0], num_classes)

# Chop up X and stick one_hot in
pre_categorical = X[:, :idx]
post_categorical = X[:, idx + 1 :]
X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1)

return X

def _untransform(self, X: torch.Tensor) -> torch.Tensor:
for idx in reversed(self.indices):
num_classes = len(self.cat_map_transformed[idx.item()])

# Chop up X around the one_hot
pre_categorical = X[:, :idx]
one_hot = X[:, idx : idx + num_classes]
post_categorical = X[:, idx + num_classes :]

# Turn one_hot back into indices
idxs = torch.argmax(one_hot, dim=1).unsqueeze(-1).to(X)

X = torch.cat((pre_categorical, idxs, post_categorical), dim=1)

return X

def transform_bounds(
self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None
) -> torch.Tensor:
r"""Return the bounds X transformed.
Args:
X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter
bounds.
bound (Literal["lb", "ub"], optional): The bound that this is, if None, we
will assume the input is both bounds with a `[2, dim]` X.
Returns:
torch.Tensor: A transformed set of parameter bounds.
"""
for idx in self.indices:
num_classes = len(self.cat_map_transformed[idx.item()])

# Turns indices into one hot
idxs = X[:, idx].to(torch.long)
one_hot = torch.nn.functional.one_hot(idxs, num_classes=num_classes)
one_hot = one_hot.view(X.shape[0], num_classes)

if bound == "lb":
one_hot[:] = 0.0
elif bound == "ub":
one_hot[:] = 1.0
else: # Both bounds
one_hot[0, :] = 0.0
one_hot[1, :] = 1.0

# Chop up X and stick one_hot in
pre_categorical = X[:, :idx]
post_categorical = X[:, idx + 1 :]
X = torch.cat((pre_categorical, one_hot, post_categorical), dim=1)

return X

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Return a dictionary of the relevant options to initialize the categorical
transform from the config for the named transform.
Args:
config (Config): Config to look for options in.
name (str, optional): The parameter to find options for.
options (Dict[str, Any], optional): Options to override from the config,
defaults to None.
Return:
Dict[str, Any]: A dictionary of options to initialize this class.
"""
options = _get_parameter_options(config, name, options)

if "categorical_map" not in options:
if name is None:
raise ValueError("name argument must be set to initialize from config.")

options["categorical_map"] = {
options["indices"][0]: config.getlist(name, "choices", element_type=str)
}

return options


def transform_options(
config: Config, transforms: Optional[ChainedInputTransform] = None
) -> Config:
Expand Down
Loading

0 comments on commit c24bbd9

Please sign in to comment.