Skip to content

Commit

Permalink
Make a transform base class for consistent API (#452)
Browse files Browse the repository at this point in the history
Summary:

Transforming bounds requires additional logic that used to be part of ParameterTransforms, we move these to the parameters itself and have the ParameterTransforms look for these special methods when transforming bounds.

We add a new ABC for our transforms as going forward it is likely that all of our transforms will have unique capabilities over the BoTorch base. This includes how we handle some finding options from configs.

Reviewed By: crasanders

Differential Revision: D65897908
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Nov 20, 2024
1 parent b591b5f commit 399bd64
Showing 1 changed file with 115 additions and 72 deletions.
187 changes: 115 additions & 72 deletions aepsych/transforms/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,7 @@ def transform_bounds(
A tensor of transformed inputs with the same shape as the input.
"""
for tf in self.values():
# This is the entire reason this method exists to help handle the
# continuous relaxation necessary for discrete parameters. But this is
# super awkward.
if isinstance(tf, Round):
if bound == "lb":
X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices))
elif bound == "ub":
X[0, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices))
else: # Both bounds
X[0, tf.indices] -= torch.tensor([0.5] * len(tf.indices))
X[1, tf.indices] += torch.tensor([0.5 - 1e-6] * len(tf.indices))
else:
X = tf.forward(X)
X = tf.transform_bounds(X, bound=bound)

return X

Expand Down Expand Up @@ -191,9 +179,10 @@ def get_config_options(
config=config, name=par, options=transform_options
)

# Nudge bounds
transform_options["bounds"][0, round.indices] -= 0.5
transform_options["bounds"][1, round.indices] += 0.5 - 1e-6
# Transform bounds
transform_options["bounds"] = round.transform_bounds(
transform_options["bounds"]
)
transform_dict[f"{par}_Round"] = round

# Log scale
Expand Down Expand Up @@ -784,7 +773,67 @@ def get_config_options(
return options


class Log10Plus(Log10, ConfigurableMixin):
class Transform(ReversibleInputTransform, ConfigurableMixin, ABC):
"""Base class for individual transforms. These transforms are intended to be stacked
together using the ParameterTransforms class.
"""

def transform_bounds(
self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs
) -> torch.Tensor:
r"""Return the bounds X transformed.
Args:
X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter
bounds.
bound (Literal["lb", "ub"], optional): Which bound this is to transform, if
None, it's the `[2, dim]` form with both bounds stacked.
**kwargs: Keyword arguments for specific transforms, they should have
default values.
Returns:
torch.Tensor: A transformed set of parameter bounds.
"""
return self.transform(X)

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Return a dictionary of the relevant options to initialize a Log10Plus
transform for the named parameter within the config.
Args:
config (Config): Config to look for options in.
name (str): Parameter to find options for.
options (Dict[str, Any]): Options to override from the config.
Returns:
Dict[str, Any]: A diciontary of options to initialize this class with,
including the transformed bounds.
"""
if name is None:
raise ValueError(f"{name} must be set to initialize a transform.")

if options is None:
options = {}
else:
options = deepcopy(options)

# Figure out the index of this parameter
parnames = config.getlist("common", "parnames", element_type=str)
idx = parnames.index(name)

if "indices" not in options:
options["indices"] = [idx]

return options


class Log10Plus(Log10, Transform):
"""Base-10 log transform that we add a constant to the values"""

def __init__(
Expand Down Expand Up @@ -867,7 +916,7 @@ def get_config_options(
Dict[str, Any]: A diciontary of options to initialize this class with,
including the transformed bounds.
"""
options = _get_parameter_options(config, name, options)
options = super().get_config_options(config=config, name=name, options=options)

# Make sure we have bounds ready
if "bounds" not in options:
Expand All @@ -887,7 +936,7 @@ def get_config_options(
return options


class NormalizeScale(Normalize, ConfigurableMixin):
class NormalizeScale(Normalize, Transform):
def __init__(
self,
d: int,
Expand Down Expand Up @@ -965,20 +1014,19 @@ def get_config_options(
Dict[str, Any]: A diciontary of options to initialize this class with,
including the transformed bounds.
"""
options = _get_parameter_options(config, name, options)
options = super().get_config_options(config=config, name=name, options=options)

# Make sure we have bounds ready
if "bounds" not in options:
options["bounds"] = get_bounds(config)

if "d" not in options:
parnames = config.getlist("common", "parnames", element_type=str)
options["d"] = len(parnames)
options["d"] = options["bounds"].shape[1]

return options


class Round(ReversibleInputTransform, torch.nn.Module, ConfigurableMixin):
class Round(Transform, torch.nn.Module):
def __init__(
self,
indices: list[int],
Expand Down Expand Up @@ -1035,29 +1083,56 @@ def _untransform(self, X: Tensor) -> Tensor:
"""
return X.round()

@classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
def transform_bounds(
self, X: torch.Tensor, bound: Optional[Literal["lb", "ub"]] = None, **kwargs
) -> torch.Tensor:
r"""Return the bounds X transformed.
Args:
X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter
bounds.
bound (Literal["lb", "ub"], optional): The bound that this is, if None, we
will assume the input is both bounds with a `[2, dim]` X.
**kwargs: passed to _transform_bounds
epsilon: will modify the offset for the rounding to ensure each discrete
value has equal space in the parameter space.
Returns:
torch.Tensor: A transformed set of parameter bounds.
"""
Return a dictionary of the relevant options to initialize the Round transform
from the config for the named transform.
epsilon = kwargs.get("epsilon", 1e-6)
return self._transform_bounds(X, bound=bound, epsilon=epsilon)

def _transform_bounds(
self,
X: torch.Tensor,
bound: Optional[Literal["lb", "ub"]] = None,
epsilon: float = 1e-6,
) -> torch.Tensor:
r"""Return the bounds X transformed.
Args:
config (Config): Config to look for options in.
name (str, optional): The parameter to find options for.
options (Dict[str, Any], optional): Options to override from the config,
defaults to None.
X (torch.Tensor): Either a `[1, dim]` or `[2, dim]` tensor of parameter
bounds.
bound (Literal["lb", "ub"], optional): The bound that this is, if None, we
will assume the input is both bounds with a `[2, dim]` X.
epsilon:
**kwargs: other kwargs
Return:
Dict[str, Any]: A dictionary of options to initialize this class.
Returns:
torch.Tensor: A transformed set of parameter bounds.
"""
options = _get_parameter_options(config, name, options)
X = X.clone()

return options
if bound == "lb":
X[0, self.indices] -= torch.tensor([0.5] * len(self.indices))
elif bound == "ub":
X[0, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices))
else: # Both bounds
X[0, self.indices] -= torch.tensor([0.5] * len(self.indices))
X[1, self.indices] += torch.tensor([0.5 - epsilon] * len(self.indices))

return X


def transform_options(
Expand Down Expand Up @@ -1131,35 +1206,3 @@ def get_bounds(config: Config) -> torch.Tensor:
bounds = torch.stack((_lb, _ub))

return bounds


def _get_parameter_options(
config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Return options for a parameter in a config.
Args:
config (Config): Config to search for parameter.
name (str): Name of parameter.
options (Dict[str, Any], optional): dictionary of options to overwrite config
options, defaults to an empty dictionary.
Returns:
Dict[str, Any]: Dictionary of options to initialize a transform from config.
"""
if name is None:
raise ValueError(f"{name} must be set to initialize a transform.")

if options is None:
options = {}
else:
options = deepcopy(options)

# Figure out the index of this parameter
parnames = config.getlist("common", "parnames", element_type=str)
idx = parnames.index(name)

if "indices" not in options:
options["indices"] = [idx]

return options

0 comments on commit 399bd64

Please sign in to comment.