Skip to content

Commit

Permalink
Fix bug in multistart with least-squares optimizers (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg authored Nov 19, 2022
1 parent 737f4d7 commit 8ce1611
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
35 changes: 32 additions & 3 deletions src/estimagic/optimization/tiktak.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def run_multistart_optimization(
starts=starts,
results=batch_results,
convergence_criteria=convergence_criteria,
primary_key=primary_key,
)
opt_counter += len(batch)
scheduled_steps = scheduled_steps[len(batch) :]
Expand Down Expand Up @@ -427,7 +428,9 @@ def get_batched_optimization_sample(sorted_sample, n_optimizations, batch_size):
return batched


def update_convergence_state(current_state, starts, results, convergence_criteria):
def update_convergence_state(
current_state, starts, results, convergence_criteria, primary_key
):
"""Update the state of all quantities related to convergence.
Args:
Expand All @@ -442,6 +445,8 @@ def update_convergence_state(current_state, starts, results, convergence_criteri
starts (list): List of starting points for local optimizations.
results (list): List of results from local optimizations.
convergence_criteria (dict): Dict with the entries "xtol" and "max_discoveries"
primary_key: The primary criterion entry of the local optimizer. Needed to
interpret the output of the internal criterion function.
Returns:
Expand All @@ -456,17 +461,38 @@ def update_convergence_state(current_state, starts, results, convergence_criteri
best_y = current_state["best_y"]
best_res = current_state["best_res"]

# get indices of local optimizations that did not fail
valid_indices = [i for i, res in enumerate(results) if not isinstance(res, str)]

# If all local optimizations failed, return early so we don't have to worry about
# index errors later.
if not valid_indices:
return current_state, False

# ==================================================================================
# reduce eveything to valid optimizations
# ==================================================================================
valid_results = [results[i] for i in valid_indices]
valid_starts = [starts[i] for i in valid_indices]

valid_new_x = [res["solution_x"] for res in valid_results]
valid_new_y = [res["solution_criterion"] for res in valid_results]
valid_new_y = []

# make the criterion output scalar if a least squares optimizer returns an
# array as solution_criterion.
for res in valid_results:
if np.isscalar(res["solution_criterion"]):
valid_new_y.append(res["solution_criterion"])
else:
valid_new_y.append(
aggregate_func_output_to_value(
f_eval=res["solution_criterion"],
primary_key=primary_key,
)
)

# ==================================================================================
# accept new best point if we find a new lowest function value
# ==================================================================================
best_index = np.argmin(valid_new_y)
if valid_new_y[best_index] <= best_y:
best_x = valid_new_x[best_index]
Expand All @@ -478,6 +504,9 @@ def update_convergence_state(current_state, starts, results, convergence_criteri
elif best_res is None:
best_res = valid_results[best_index]

# ==================================================================================
# update history and state
# ==================================================================================
new_x_history = current_state["x_history"] + valid_new_x
all_x = np.array(new_x_history)
relative_diffs = (all_x - best_x) / np.clip(best_x, 0.1, np.inf)
Expand Down
14 changes: 14 additions & 0 deletions tests/optimization/test_multistart.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,17 @@ def ackley(x):
"convergence_max_discoveries": 10,
},
)


def test_multistart_with_least_squares_optimizers():
est = minimize(
criterion=sos_dict_criterion,
params=np.array([-1, 1.0]),
lower_bounds=np.full(2, -10.0),
upper_bounds=np.full(2, 10.0),
algorithm="scipy_ls_trf",
multistart=True,
multistart_options={"n_samples": 3, "share_optimizations": 1.0},
)

aaae(est.params, np.zeros(2))
2 changes: 2 additions & 0 deletions tests/optimization/test_tiktak.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_update_state_converged(current_state, starts, results):
starts=starts,
results=results,
convergence_criteria=criteria,
primary_key="value",
)

aaae(new_state["best_x"], np.arange(3))
Expand All @@ -176,6 +177,7 @@ def test_update_state_not_converged(current_state, starts, results):
starts=starts,
results=results,
convergence_criteria=criteria,
primary_key="value",
)

assert not is_converged

0 comments on commit 8ce1611

Please sign in to comment.