Skip to content

Commit

Permalink
Add Experiment.with_config_keys()
Browse files Browse the repository at this point in the history
`Experiment.with_config_keys(new_config_keys)` returns a new Experiment
object with the same hypotheses but with then another list of config
keys; useful for reordering the hyperparameters.

See #14.
  • Loading branch information
wookayin committed Dec 18, 2022
1 parent 680e9a0 commit c89509c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 8 deletions.
55 changes: 54 additions & 1 deletion expt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@

T = TypeVar('T')

if hasattr(types, 'EllipsisType'):
EllipsisType = types.EllipsisType
else:
EllipsisType = type(...)

#########################################################################
# Data Classes
#########################################################################
Expand Down Expand Up @@ -803,6 +808,17 @@ def __init__(
for h in hypotheses:
self.add_hypothesis(h, extend_if_conflict=False)

def _replace(self, **kwargs) -> Experiment:
ex = Experiment(
name=kwargs.pop('name', self._name),
hypotheses=list(self._hypotheses.values()),
config_keys=kwargs.pop('config_keys', self._config_keys),
summary_columns=kwargs.pop('summary_columns', self._summary_columns),
)
if kwargs:
raise ValueError("Unknown fields: {}".format(list(kwargs.keys())))
return ex

@property
def _df(self) -> pd.DataFrame:
df = pd.DataFrame({
Expand All @@ -826,7 +842,9 @@ def _df(self) -> pd.DataFrame:
}),
], axis=1) # yapf: disable

df = df.set_index([*self._config_keys, 'name'])
# Need to sort index w.r.t the multi-index level hierarchy, because
# the order of hypotheses being added is not guaranteed
df = df.set_index([*self._config_keys, 'name']).sort_index()
return df

@classmethod
Expand Down Expand Up @@ -1272,6 +1290,41 @@ def aggregate_h(series):
f"The columns of summary DataFrame must be unique. Found: {df.columns}")
return df

@typechecked
def with_config_keys(
self,
new_config_keys: Sequence[str | EllipsisType], # type: ignore
) -> Experiment:
"""Create a new Experiment with the same set of Hypotheses, but a different
config keys in the multi-index (usually reordering).
Note that the underlying hypothesis objects in the new Experiment object
won't change, e.g., their name, config, etc. would remain the same.
Args:
new_config_keys: The new list of config keys. This can contain `...`
(Ellipsis) as the last element, which refers to all the other keys
in the current Experiment that was not included in the list.
"""

if new_config_keys[-1] is ...:
keys_requested = [x for x in new_config_keys if x is not ...]
keys_appended = [x for x in self._config_keys if x not in keys_requested]
new_config_keys = keys_requested + keys_appended

for key in new_config_keys:
if not isinstance(key, str):
raise TypeError(f"Invalid config key: {type(key)}")
for h in self._hypotheses.values():
if h.config is None:
raise ValueError(f"{h} does not have a config.")
if key not in h.config.keys():
raise ValueError(f"'{key}' not found in the config of {h}. "
"Close matches: " +
str(difflib.get_close_matches(key, h.config.keys())))

return self._replace(config_keys=new_config_keys)

def interpolate(self,
x_column: Optional[str] = None,
*,
Expand Down
63 changes: 56 additions & 7 deletions expt/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def runs_gridsearch() -> RunList:
return RunList(runs)


def _runs_gridsearch_config_fn(r: Run):
config = {}
config['algo'], config['env_id'], config['seed'] = r.name.split('-')
config['common_hparam'] = 1
return config


class TestRun(_TestBase):

def test_creation(self):
Expand Down Expand Up @@ -609,16 +616,10 @@ def config_fn(r: Run):
def test_create_from_runs(self, runs_gridsearch: RunList):
"""Tests Experiment.from_runs with the minimal defaults."""

def config_fn(r: Run):
config = {}
config['algo'], config['env_id'], config['seed'] = r.name.split('-')
config['common_hparam'] = 1
return config

# Uses default for config_keys: see varied_config_keys.
ex = Experiment.from_runs(
runs_gridsearch,
config_fn=config_fn,
config_fn=_runs_gridsearch_config_fn,
name="ex_from_runs",
)
assert ex.name == "ex_from_runs"
Expand Down Expand Up @@ -715,6 +716,54 @@ def test_indexing(self):
r = V(ex[['hyp1', 'hyp0'], 'a'])
# pylint: enable=unsubscriptable-object

def test_with_config_keys(self, runs_gridsearch: RunList):
ex_base = Experiment.from_runs(
runs_gridsearch,
config_fn=_runs_gridsearch_config_fn,
config_keys=['algo', 'env_id', 'common_hparam'],
)

assert ex_base._config_keys == ['algo', 'env_id', 'common_hparam']
assert ex_base._df.index.get_level_values( \
'algo').tolist() == ['ppo', 'ppo', 'ppo', 'sac', 'sac', 'sac']
assert ex_base._df.index.get_level_values( \
'env_id').tolist() == ['halfcheetah', 'hopper', 'humanoid'] * 2

def validate_env_id_algo(ex):
assert ex._summary_columns == ex_base._summary_columns
assert set(ex._hypotheses.values()) == set(ex_base._hypotheses.values())

# rows should be sorted according to the new multiindex
assert ex._df.index.get_level_values('env_id').tolist() == [
'halfcheetah', 'halfcheetah', \
'hopper', 'hopper', \
'humanoid', 'humanoid',
]
assert ex._df.index.get_level_values('common_hparam').tolist() == [1] * 6
assert ex._df.index.get_level_values('algo').tolist() == [
'ppo', 'sac', 'ppo', 'sac', 'ppo', 'sac'
]

# (1) full reordering
ex1 = ex_base.with_config_keys(['env_id', 'common_hparam', 'algo'])
assert ex1._config_keys == ['env_id', 'common_hparam', 'algo']
validate_env_id_algo(ex1)

# (2) partial with ellipsis
ex2 = ex_base.with_config_keys(['env_id', ...])
assert ex2._config_keys == ['env_id', 'algo', 'common_hparam']
validate_env_id_algo(ex2)

# (3) partial subset. TODO: Things to decide:
# - To reduce or not to reduce?
# - Hypothesis objects should remain the same or changes in
# name, config, etc.?

# (4) not existing keys: error
with pytest.raises(ValueError, \
match="'foo' not found in the config of") as e:
ex_base.with_config_keys(['env_id', 'foo', 'algo'])

def test_select_top(self):
# yapf: disable
hypos = [
Expand Down

0 comments on commit c89509c

Please sign in to comment.