Skip to content

Commit

Permalink
Sensible error if nan in objective evaluation from trust-region step
Browse files Browse the repository at this point in the history
  • Loading branch information
lindon authored and lindon committed Jan 25, 2024
1 parent 34e709b commit 2273895
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions dfols/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

__all__ = ['Controller', 'ExitInformation', 'EXIT_SLOW_WARNING', 'EXIT_MAXFUN_WARNING', 'EXIT_SUCCESS',
'EXIT_INPUT_ERROR', 'EXIT_TR_INCREASE_ERROR', 'EXIT_LINALG_ERROR', 'EXIT_FALSE_SUCCESS_WARNING',
'EXIT_AUTO_DETECT_RESTART_WARNING']
'EXIT_AUTO_DETECT_RESTART_WARNING', 'EXIT_EVAL_ERROR']

module_logger = logging.getLogger(__name__)

Expand All @@ -54,6 +54,7 @@
EXIT_INPUT_ERROR = -1 # error, bad inputs
EXIT_TR_INCREASE_ERROR = -2 # error, trust region step increased model value
EXIT_LINALG_ERROR = -3 # error, linalg error (singular matrix encountered)
EXIT_EVAL_ERROR = -4 # error, objective evaluation error (e.g. nan result received)


class ExitInformation(object):
Expand Down Expand Up @@ -83,11 +84,13 @@ def message(self, with_stem=True):
return "Error (linear algebra): " + self.msg
elif self.flag == EXIT_FALSE_SUCCESS_WARNING:
return "Warning (max false good steps): " + self.msg
elif self.flag == EXIT_EVAL_ERROR:
return "Error (function evaluation): " + self.msg
else:
return "Unknown exit flag: " + self.msg

def able_to_do_restart(self):
if self.flag in [EXIT_TR_INCREASE_ERROR, EXIT_TR_INCREASE_WARNING, EXIT_LINALG_ERROR, EXIT_SLOW_WARNING, EXIT_AUTO_DETECT_RESTART_WARNING]:
if self.flag in [EXIT_TR_INCREASE_ERROR, EXIT_TR_INCREASE_WARNING, EXIT_LINALG_ERROR, EXIT_SLOW_WARNING, EXIT_AUTO_DETECT_RESTART_WARNING, EXIT_EVAL_ERROR]:
return True
elif self.flag in [EXIT_MAXFUN_WARNING, EXIT_INPUT_ERROR]:
return False
Expand Down
11 changes: 11 additions & 0 deletions dfols/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,17 @@ def solve_main(objfun, x0, args, xl, xu, projections, npt, rhobeg, rhoend, maxfu
x = control.model.as_absolute_coordinates(xnew)
number_of_samples = max(nsamples(control.delta, control.rho, current_iter, nruns_so_far), 1)
rvec_list, f_list, num_samples_run, exit_info = control.evaluate_objective(x, number_of_samples, params)
if np.any(np.isnan(rvec_list)):
# Just exit without saving the current point
# We should be able to do a hard restart though, because it's unlikely
# that we will get the same trust-region step after expanding the radius/re-initialising
module_logger.warning("NaN encountered in evaluation of trust-region step")
if params("interpolation.throw_error_on_nans"):
raise np.linalg.LinAlgError("NaN encountered in objective evaluations")

exit_info = ExitInformation(EXIT_EVAL_ERROR, "NaN received from objective function evaluation")
nruns_so_far += 1
break # quit
if exit_info is not None:
if num_samples_run > 0:
control.model.save_point(x, np.mean(rvec_list[:num_samples_run, :], axis=0), num_samples_run,
Expand Down

0 comments on commit 2273895

Please sign in to comment.