Skip to content

Commit

Permalink
Merge pull request #20 from utkarsh530/solverstats
Browse files Browse the repository at this point in the history
MATLAB solver stats
  • Loading branch information
ChrisRackauckas authored May 16, 2020
2 parents 43579fd + 6e63c4b commit 899cec0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "MATLABDiffEq"
uuid = "e2752cbe-bcf4-5895-8727-84ebc14a76bd"
version = "0.3.1"
version = "0.3.2"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
23 changes: 20 additions & 3 deletions src/MATLABDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ function DiffEqBase.__solve(

eval_string("options = odeset('RelTol',reltol,'AbsTol',abstol);")
algstr = string(typeof(alg).name.name)
#algstr = replace(string(typeof(alg)),"MATLABDiffEq.","")
eval_string("[t,u] = $(algstr)(diffeqf,tspan,u0,options);")
eval_string("mxsol = $(algstr)(diffeqf,tspan,u0,options);")
eval_string("mxsolstats = struct(mxsol.stats);")
solstats = get_variable(:mxsolstats)
eval_string("t = mxsol.x;")
ts = jvector(get_mvariable(:t))
eval_string("u = mxsol.y';")
timeseries_tmp = jarray(get_mvariable(:u))

# Reshape the result if needed
Expand All @@ -77,8 +80,22 @@ function DiffEqBase.__solve(
timeseries = timeseries_tmp
end

destats = buildDEStats(solstats)

DiffEqBase.build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
timeseries_errors = timeseries_errors,destats = destats)
end

function buildDEStats(solverstats::Dict)

destats = DiffEqBase.DEStats(0)
destats.nf = if (haskey(solverstats, "nfevals")) solverstats["nfevals"] else 0 end
destats.nreject = if (haskey(solverstats, "nfailed")) solverstats["nfailed"] else 0 end
destats.naccept = if (haskey(solverstats, "nsteps")) solverstats["nsteps"] else 0 end
destats.nsolve = if (haskey(solverstats, "nsolves")) solverstats["nsolves"] else 0 end
destats.njacs = if (haskey(solverstats, "npds")) solverstats["npds"] else 0 end
destats.nw = if (haskey(solverstats, "ndecomps")) solverstats["ndecomps"] else 0 end
destats
end

end # module

0 comments on commit 899cec0

Please sign in to comment.