Skip to content

Commit

Permalink
add support for categorical parameters
Browse files Browse the repository at this point in the history
Summary:
Added support for categorical parameters. This required changing some server functions to be able to take and give strings as a part of the responses.

Externally, categorical parameters are strings. Internally, categorical parameters are represented as index parameters (0, nChoices -1) while they are passed to models as one_hot vectored parameters. The index intermediate could be useful in the future for alternative ways to model categorical parameters.

Differential Revision: D65779384
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 13, 2024
1 parent 7975523 commit 6cad20c
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 31 deletions.
19 changes: 15 additions & 4 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,19 @@ def update(
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
lb = []
ub = []
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)

lb[i] = self[par_name]["lower_bound"]
ub[i] = self[par_name]["upper_bound"]
if self[par_name]["par_type"] == "categorical":
choices = self.getlist(par_name, "choices", element_type=str)
lb.append("0")
ub.append(str(len(choices) - 1))
else:
lb.append(self[par_name]["lower_bound"])
ub.append(self[par_name]["upper_bound"])

self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
Expand Down Expand Up @@ -279,6 +284,12 @@ def _check_param_settings(self, param_name: str) -> None:
and self.getint(param_name, "upper_bound") % 1 == 0
):
raise ValueError(f"Parameter {param_name} has non-discrete bounds.")
elif param_block["par_type"] == "categorical":
# Need a choices array
if "choices" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the choices setting."
)
else:
raise ValueError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
Expand Down
21 changes: 15 additions & 6 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,22 +276,31 @@ def can_pregen_ask(self):
return self.strat is not None and self.enable_pregen

def _tensor_to_config(self, next_x):
next_x = self.strat.transforms.indices_to_str(next_x.unsqueeze(0))[0]
config = {}
for name, val in zip(self.parnames, next_x):
if val.dim() == 0:
if isinstance(val, str):
config[name] = [val]
elif isinstance(val, (int, float)):
config[name] = [float(val)]
elif isinstance(val[0], str):
config[name] = val
else:
config[name] = np.array(val)
config[name] = np.array(val, dtype="float64")
return config

def _config_to_tensor(self, config):
unpacked = [config[name] for name in self.parnames]

# handle config elements being either scalars or length-1 lists
if isinstance(unpacked[0], list):
x = torch.tensor(np.stack(unpacked, axis=0)).squeeze(-1)
x = np.stack(unpacked, axis=0, dtype="O").squeeze(-1)
else:
x = torch.tensor(np.stack(unpacked))
x = np.stack(unpacked, dtype="O")

# Unsqueeze batch dimension
x = np.expand_dims(x, 0)

x = self.strat.transforms.str_to_indices(x)[0]

return x

def __getstate__(self):
Expand Down
Loading

0 comments on commit 6cad20c

Please sign in to comment.