Skip to content

Commit

Permalink
Merge pull request #2990 from Steven-Roberts/sundials_7.1
Browse files Browse the repository at this point in the history
Fix compilation warnings with SUNDIALS 7.1.0
  • Loading branch information
bendudson authored Nov 6, 2024
2 parents 9d8c036 + 9e2a7ce commit 36e3da5
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 88 deletions.
33 changes: 20 additions & 13 deletions include/bout/sundials_backports.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,26 @@
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sunnonlinsol/sunnonlinsol_newton.h>

#if SUNDIALS_VERSION_MAJOR >= 6
#ifndef SUNDIALS_VERSION
#error "Unable to determine SUNDIALS version"
#endif

// NOLINTBEGIN(cppcoreguidelines-macro-usage)
#define SUNDIALS_VERSION_AT_LEAST(major, minor, patch) \
((major) < SUNDIALS_VERSION_MAJOR \
|| ((major) == SUNDIALS_VERSION_MAJOR \
&& ((minor) < SUNDIALS_VERSION_MINOR \
|| ((minor) == SUNDIALS_VERSION_MINOR && (patch) <= SUNDIALS_VERSION_PATCH))))
#define SUNDIALS_VERSION_LESS_THAN(major, minor, patch) \
(!SUNDIALS_VERSION_AT_LEAST(major, minor, patch))
// NOLINTEND(cppcoreguidelines-macro-usage)

#if SUNDIALS_VERSION_AT_LEAST(6, 0, 0)
#include <sundials/sundials_context.hpp>
#endif
// IWYU pragma: end_exports

#if SUNDIALS_VERSION_MAJOR < 6
#if SUNDIALS_VERSION_LESS_THAN(6, 0, 0)
using sundials_real_type = realtype;
#else
using sundials_real_type = sunrealtype;
Expand All @@ -40,14 +54,7 @@ using sundials_real_type = sunrealtype;
static_assert(std::is_same_v<BoutReal, sundials_real_type>,
"BOUT++ and SUNDIALS real types do not match");

#define SUNDIALS_CONTROLLER_SUPPORT \
(SUNDIALS_VERSION_MAJOR > 6 \
|| SUNDIALS_VERSION_MAJOR == 6 && SUNDIALS_VERSION_MINOR >= 7)
#define SUNDIALS_TABLE_BY_NAME_SUPPORT \
(SUNDIALS_VERSION_MAJOR > 6 \
|| SUNDIALS_VERSION_MAJOR == 6 && SUNDIALS_VERSION_MINOR >= 4)

#if SUNDIALS_VERSION_MAJOR < 6
#if SUNDIALS_VERSION_LESS_THAN(6, 0, 0)
constexpr auto SUN_PREC_RIGHT = PREC_RIGHT;
constexpr auto SUN_PREC_LEFT = PREC_LEFT;
constexpr auto SUN_PREC_NONE = PREC_NONE;
Expand All @@ -58,9 +65,9 @@ using Context = std::nullptr_t;
#endif

inline sundials::Context createSUNContext([[maybe_unused]] MPI_Comm& comm) {
#if SUNDIALS_VERSION_MAJOR < 6
#if SUNDIALS_VERSION_LESS_THAN(6, 0, 0)
return nullptr;
#elif SUNDIALS_VERSION_MAJOR < 7
#elif SUNDIALS_VERSION_LESS_THAN(7, 0, 0)
// clang-tidy can see through `MPI_Comm` which might be a typedef to
// a pointer. We don't care, so tell it to be quiet
// NOLINTNEXTLINE(bugprone-multi-level-implicit-pointer-conversion)
Expand All @@ -73,7 +80,7 @@ inline sundials::Context createSUNContext([[maybe_unused]] MPI_Comm& comm) {
template <typename Func, typename... Args>
inline decltype(auto) callWithSUNContext(Func f, [[maybe_unused]] sundials::Context& ctx,
Args&&... args) {
#if SUNDIALS_VERSION_MAJOR < 6
#if SUNDIALS_VERSION_LESS_THAN(6, 0, 0)
return f(std::forward<Args>(args)...);
#else
return f(std::forward<Args>(args)..., ctx);
Expand Down
129 changes: 56 additions & 73 deletions src/solver/impls/arkode/arkode.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ArkodeSolver::ArkodeSolver(Options* opts)
"not recommended except for code comparison")
.withDefault(false)),
order((*options)["order"].doc("Order of internal step").withDefault(4)),
#if SUNDIALS_TABLE_BY_NAME_SUPPORT
#if ARKODE_TABLE_BY_NAME_SUPPORT
implicit_table((*options)["implicit_table"]
.doc("Name of the implicit Butcher table")
.withDefault("")),
Expand Down Expand Up @@ -140,8 +140,10 @@ ArkodeSolver::ArkodeSolver(Options* opts)
use_jacobian((*options)["use_jacobian"]
.doc("Use user-supplied Jacobian function")
.withDefault(false)),
#if ARKODE_OPTIMAL_PARAMS_SUPPORT
optimize(
(*options)["optimize"].doc("Use ARKode optimal parameters").withDefault(false)),
#endif
suncontext(createSUNContext(BoutComm::get())) {
has_constraints = false; // This solver doesn't have constraints

Expand All @@ -160,11 +162,11 @@ ArkodeSolver::ArkodeSolver(Options* opts)

ArkodeSolver::~ArkodeSolver() {
N_VDestroy(uvec);
ARKStepFree(&arkode_mem);
ARKodeFree(&arkode_mem);
SUNLinSolFree(sun_solver);
SUNNonlinSolFree(nonlinear_solver);

#if SUNDIALS_CONTROLLER_SUPPORT
#if ARKODE_CONTROLLER_SUPPORT
SUNAdaptController_Destroy(controller);
#endif
}
Expand Down Expand Up @@ -222,51 +224,29 @@ int ArkodeSolver::init() {
throw BoutException("ARKStepCreate failed\n");
}

switch (treatment) {
case Treatment::ImEx:
output_info.write("\tUsing ARKode ImEx solver \n");
if (ARKStepSetImEx(arkode_mem) != ARK_SUCCESS) {
throw BoutException("ARKStepSetImEx failed\n");
}
break;
case Treatment::Explicit:
output_info.write("\tUsing ARKStep Explicit solver \n");
if (ARKStepSetExplicit(arkode_mem) != ARK_SUCCESS) {
throw BoutException("ARKStepSetExplicit failed\n");
}
break;
case Treatment::Implicit:
output_info.write("\tUsing ARKStep Implicit solver \n");
if (ARKStepSetImplicit(arkode_mem) != ARK_SUCCESS) {
throw BoutException("ARKStepSetImplicit failed\n");
}
break;
default:
throw BoutException("Invalid treatment: {}\n", toString(treatment));
}

// For callbacks, need pointer to solver object
if (ARKStepSetUserData(arkode_mem, this) != ARK_SUCCESS) {
throw BoutException("ARKStepSetUserData failed\n");
if (ARKodeSetUserData(arkode_mem, this) != ARK_SUCCESS) {
throw BoutException("ARKodeSetUserData failed\n");
}

if (ARKStepSetLinear(arkode_mem, set_linear) != ARK_SUCCESS) {
throw BoutException("ARKStepSetLinear failed\n");
if (ARKodeSetLinear(arkode_mem, static_cast<int>(set_linear))
!= ARK_SUCCESS) {
throw BoutException("ARKodeSetLinear failed\n");
}

if (fixed_step) {
// If not given, default to adaptive timestepping
const auto fixed_timestep = (*options)["timestep"].withDefault(0.0);
if (ARKStepSetFixedStep(arkode_mem, fixed_timestep) != ARK_SUCCESS) {
throw BoutException("ARKStepSetFixedStep failed\n");
if (ARKodeSetFixedStep(arkode_mem, fixed_timestep) != ARK_SUCCESS) {
throw BoutException("ARKodeSetFixedStep failed\n");
}
}

if (ARKStepSetOrder(arkode_mem, order) != ARK_SUCCESS) {
throw BoutException("ARKStepSetOrder failed\n");
if (ARKodeSetOrder(arkode_mem, order) != ARK_SUCCESS) {
throw BoutException("ARKodeSetOrder failed\n");
}

#if SUNDIALS_TABLE_BY_NAME_SUPPORT
#if ARKODE_TABLE_BY_NAME_SUPPORT
if (!implicit_table.empty() || !explicit_table.empty()) {
if (ARKStepSetTableName(
arkode_mem,
Expand All @@ -278,11 +258,11 @@ int ArkodeSolver::init() {
}
#endif

if (ARKStepSetCFLFraction(arkode_mem, cfl_frac) != ARK_SUCCESS) {
throw BoutException("ARKStepSetCFLFraction failed\n");
if (ARKodeSetCFLFraction(arkode_mem, cfl_frac) != ARK_SUCCESS) {
throw BoutException("ARKodeSetCFLFraction failed\n");
}

#if SUNDIALS_CONTROLLER_SUPPORT
#if ARKODE_CONTROLLER_SUPPORT
switch (adap_method) {
case AdapMethod::PID:
controller = SUNAdaptController_PID(suncontext);
Expand All @@ -306,12 +286,12 @@ int ArkodeSolver::init() {
throw BoutException("Invalid adap_method\n");
}

if (ARKStepSetAdaptController(arkode_mem, controller) != ARK_SUCCESS) {
throw BoutException("ARKStepSetAdaptController failed\n");
if (ARKodeSetAdaptController(arkode_mem, controller) != ARK_SUCCESS) {
throw BoutException("ARKodeSetAdaptController failed\n");
}

if (ARKStepSetAdaptivityAdjustment(arkode_mem, 0) != ARK_SUCCESS) {
throw BoutException("ARKStepSetAdaptivityAdjustment failed\n");
if (ARKodeSetAdaptivityAdjustment(arkode_mem, 0) != ARK_SUCCESS) {
throw BoutException("ARKodeSetAdaptivityAdjustment failed\n");
}
#else
int adap_method_int;
Expand Down Expand Up @@ -374,36 +354,36 @@ int ArkodeSolver::init() {

set_abstol_values(N_VGetArrayPointer(abstolvec), f2dtols, f3dtols);

if (ARKStepSVtolerances(arkode_mem, reltol, abstolvec) != ARK_SUCCESS) {
throw BoutException("ARKStepSVtolerances failed\n");
if (ARKodeSVtolerances(arkode_mem, reltol, abstolvec) != ARK_SUCCESS) {
throw BoutException("ARKodeSVtolerances failed\n");
}

N_VDestroy(abstolvec);
} else {
if (ARKStepSStolerances(arkode_mem, reltol, abstol) != ARK_SUCCESS) {
throw BoutException("ARKStepSStolerances failed\n");
if (ARKodeSStolerances(arkode_mem, reltol, abstol) != ARK_SUCCESS) {
throw BoutException("ARKodeSStolerances failed\n");
}
}

if (ARKStepSetMaxNumSteps(arkode_mem, mxsteps) != ARK_SUCCESS) {
throw BoutException("ARKStepSetMaxNumSteps failed\n");
if (ARKodeSetMaxNumSteps(arkode_mem, mxsteps) != ARK_SUCCESS) {
throw BoutException("ARKodeSetMaxNumSteps failed\n");
}

if (max_timestep > 0.0) {
if (ARKStepSetMaxStep(arkode_mem, max_timestep) != ARK_SUCCESS) {
throw BoutException("ARKStepSetMaxStep failed\n");
if (ARKodeSetMaxStep(arkode_mem, max_timestep) != ARK_SUCCESS) {
throw BoutException("ARKodeSetMaxStep failed\n");
}
}

if (min_timestep > 0.0) {
if (ARKStepSetMinStep(arkode_mem, min_timestep) != ARK_SUCCESS) {
throw BoutException("ARKStepSetMinStep failed\n");
if (ARKodeSetMinStep(arkode_mem, min_timestep) != ARK_SUCCESS) {
throw BoutException("ARKodeSetMinStep failed\n");
}
}

if (start_timestep > 0.0) {
if (ARKStepSetInitStep(arkode_mem, start_timestep) != ARK_SUCCESS) {
throw BoutException("ARKStepSetInitStep failed");
if (ARKodeSetInitStep(arkode_mem, start_timestep) != ARK_SUCCESS) {
throw BoutException("ARKodeSetInitStep failed");
}
}

Expand All @@ -414,8 +394,8 @@ int ArkodeSolver::init() {
if (nonlinear_solver == nullptr) {
throw BoutException("Creating SUNDIALS fixed point nonlinear solver failed\n");
}
if (ARKStepSetNonlinearSolver(arkode_mem, nonlinear_solver) != ARK_SUCCESS) {
throw BoutException("ARKStepSetNonlinearSolver failed\n");
if (ARKodeSetNonlinearSolver(arkode_mem, nonlinear_solver) != ARK_SUCCESS) {
throw BoutException("ARKodeSetNonlinearSolver failed\n");
}
} else {
output.write("\tUsing Newton iteration\n");
Expand All @@ -426,18 +406,18 @@ int ArkodeSolver::init() {
if (sun_solver == nullptr) {
throw BoutException("Creating SUNDIALS linear solver failed\n");
}
if (ARKStepSetLinearSolver(arkode_mem, sun_solver, nullptr) != ARKLS_SUCCESS) {
throw BoutException("ARKStepSetLinearSolver failed\n");
if (ARKodeSetLinearSolver(arkode_mem, sun_solver, nullptr) != ARKLS_SUCCESS) {
throw BoutException("ARKodeSetLinearSolver failed\n");
}

/// Set Preconditioner
if (use_precon) {
if (hasPreconditioner()) {
output.write("\tUsing user-supplied preconditioner\n");

if (ARKStepSetPreconditioner(arkode_mem, nullptr, arkode_pre)
if (ARKodeSetPreconditioner(arkode_mem, nullptr, arkode_pre)
!= ARKLS_SUCCESS) {
throw BoutException("ARKStepSetPreconditioner failed\n");
throw BoutException("ARKodeSetPreconditioner failed\n");
}
} else {
output.write("\tUsing BBD preconditioner\n");
Expand Down Expand Up @@ -489,20 +469,23 @@ int ArkodeSolver::init() {
if (use_jacobian and hasJacobian()) {
output.write("\tUsing user-supplied Jacobian function\n");

if (ARKStepSetJacTimes(arkode_mem, nullptr, arkode_jac) != ARKLS_SUCCESS) {
throw BoutException("ARKStepSetJacTimes failed\n");
if (ARKodeSetJacTimes(arkode_mem, nullptr, arkode_jac) != ARKLS_SUCCESS) {
throw BoutException("ARKodeSetJacTimes failed\n");
}
} else {
output.write("\tUsing difference quotient approximation for Jacobian\n");
}
}

#if ARKODE_OPTIMAL_PARAMS_SUPPORT
if (optimize) {
output.write("\tUsing ARKode inbuilt optimization\n");
if (ARKStepSetOptimalParams(arkode_mem) != ARK_SUCCESS) {
throw BoutException("ARKStepSetOptimalParams failed");
}
}
#endif

return 0;
}

Expand Down Expand Up @@ -532,17 +515,17 @@ int ArkodeSolver::run() {

// Get additional diagnostics
long int temp_long_int, temp_long_int2;
ARKStepGetNumSteps(arkode_mem, &temp_long_int);
ARKodeGetNumSteps(arkode_mem, &temp_long_int);
nsteps = int(temp_long_int);
ARKStepGetNumRhsEvals(arkode_mem, &temp_long_int, &temp_long_int2);
nfe_evals = int(temp_long_int);
nfi_evals = int(temp_long_int2);
if (treatment == Treatment::ImEx or treatment == Treatment::Implicit) {
ARKStepGetNumNonlinSolvIters(arkode_mem, &temp_long_int);
ARKodeGetNumNonlinSolvIters(arkode_mem, &temp_long_int);
nniters = int(temp_long_int);
ARKStepGetNumPrecEvals(arkode_mem, &temp_long_int);
ARKodeGetNumPrecEvals(arkode_mem, &temp_long_int);
npevals = int(temp_long_int);
ARKStepGetNumLinIters(arkode_mem, &temp_long_int);
ARKodeGetNumLinIters(arkode_mem, &temp_long_int);
nliters = int(temp_long_int);
}

Expand Down Expand Up @@ -580,15 +563,15 @@ BoutReal ArkodeSolver::run(BoutReal tout) {
int flag;
if (!monitor_timestep) {
// Run in normal mode
flag = ARKStepEvolve(arkode_mem, tout, uvec, &simtime, ARK_NORMAL);
flag = ARKodeEvolve(arkode_mem, tout, uvec, &simtime, ARK_NORMAL);
} else {
// Run in single step mode, to call timestep monitors
BoutReal internal_time;
ARKStepGetCurrentTime(arkode_mem, &internal_time);
ARKodeGetCurrentTime(arkode_mem, &internal_time);
while (internal_time < tout) {
// Run another step
const BoutReal last_time = internal_time;
flag = ARKStepEvolve(arkode_mem, tout, uvec, &internal_time, ARK_ONE_STEP);
flag = ARKodeEvolve(arkode_mem, tout, uvec, &internal_time, ARK_ONE_STEP);

if (flag != ARK_SUCCESS) {
output_error.write("ERROR ARKODE solve failed at t = {:e}, flag = {:d}\n",
Expand All @@ -600,7 +583,7 @@ BoutReal ArkodeSolver::run(BoutReal tout) {
call_timestep_monitors(internal_time, internal_time - last_time);
}
// Get output at the desired time
flag = ARKStepGetDky(arkode_mem, tout, 0, uvec);
flag = ARKodeGetDky(arkode_mem, tout, 0, uvec);
simtime = tout;
}

Expand Down Expand Up @@ -630,7 +613,7 @@ void ArkodeSolver::rhs_e(BoutReal t, BoutReal* udata, BoutReal* dudata) {

// Get the current timestep
// Note: ARKodeGetCurrentStep updated too late in older versions
ARKStepGetLastStep(arkode_mem, &hcur);
ARKodeGetLastStep(arkode_mem, &hcur);

// Call RHS function
run_convective(t);
Expand All @@ -647,7 +630,7 @@ void ArkodeSolver::rhs_i(BoutReal t, BoutReal* udata, BoutReal* dudata) {
TRACE("Running RHS: ArkodeSolver::rhs_i({:e})", t);

load_vars(udata);
ARKStepGetLastStep(arkode_mem, &hcur);
ARKodeGetLastStep(arkode_mem, &hcur);
// Call Implicit RHS function
run_diffusive(t);
save_derivs(dudata);
Expand All @@ -660,7 +643,7 @@ void ArkodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) {
TRACE("Running RHS: ArkodeSolver::rhs({:e})", t);

load_vars(udata);
ARKStepGetLastStep(arkode_mem, &hcur);
ARKodeGetLastStep(arkode_mem, &hcur);
// Call Implicit RHS function
run_rhs(t);
save_derivs(dudata);
Expand Down
Loading

0 comments on commit 36e3da5

Please sign in to comment.