Skip to content

Commit

Permalink
Change keyword to mmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaack24 committed Apr 11, 2024
1 parent 437f93d commit 39ad951
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ This version is for linear equations Ay = b
- `A::matrix`, `b::vector`: components of linear system ``A y = b``
- `lsolve::function`: lsolve(A, b). Function to solve the linear system, default is backslash operator.
- `Af::factorization`: An optional factorization of A, useful to override default factorize, or if multiple linear solves will be performed with same A matrix.
- `matvec_mul`: Function to compute ``A*y`` for a vector ``y``. Defaults to the julia multipy operator.
- `mmul`: Function to compute ``A*y`` for a vector ``y``. Defaults to the julia multipy operator.
"""
implicit_linear(A, b; lsolve=linear_solve, Af=nothing, matvec_mul=*) = _implicit_linear(A, b, lsolve, matvec_mul, Af)
implicit_linear(A, b; lsolve=linear_solve, Af=nothing, mmul=*) = _implicit_linear(A, b, lsolve, mmul, Af)


# If no AD, just solve normally.
_implicit_linear(A, b, lsolve, matvec_mul, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b)
_implicit_linear(A, b, lsolve, mmul, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b)
# function implicit_linear(A, b, lsolve, Af, cache)
# if isnothing(cache)
# if isnothing(Af)
Expand All @@ -91,9 +91,9 @@ _implicit_linear(A, b, lsolve, matvec_mul, Af) = isnothing(Af) ? lsolve(A, b) :
# end

# catch three cases where one or both contain duals
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, matvec_mul, Af) where {T} = linear_dual(A, b, lsolve, matvec_mul, Af, T)
_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, matvec_mul, Af) where {T} = linear_dual(A, b, lsolve, matvec_mul, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, matvec_mul, Af) where {T} = linear_dual(A, b, lsolve, matvec_mul, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
# implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
# implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
# implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
Expand All @@ -111,7 +111,7 @@ _implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, matvec_mul,
# implicit_linear!(ydot, A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {T} = linear_dual!(ydot, A, b, lsolve, Af, T)

# Both A and b contain duals
function linear_dual(A, b, lsolve, matvec_mul, Af, T)
function linear_dual(A, b, lsolve, mmul, Af, T)

# unpack dual numbers (if not dual numbers, since only one might be, just returns itself)
bv = fd_value(b)
Expand All @@ -123,7 +123,7 @@ function linear_dual(A, b, lsolve, matvec_mul, Af, T)
yv = lsolve(Afact, bv)

# extract Partials of b - A * y i.e., bdot - Adot * y (since y does not contain duals)
rhs = fd_partials(b - matvec_mul(A,yv))
rhs = fd_partials(b - mmul(A,yv))

# solve for new derivatives
ydot = lsolve(Afact, rhs)
Expand Down Expand Up @@ -183,7 +183,7 @@ end


# Provide a ChainRule rule for reverse mode
function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, matvec_mul, Af)
function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, mmul, Af)

# save factorization
Afact = isnothing(Af) ? factorize(ReverseDiff.value(A)) : Af
Expand All @@ -200,9 +200,9 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, matvec_m
end

# register above rule for ReverseDiff
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, matvec_mul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, matvec_mul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, matvec_mul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, mmul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af)


# function implicit_linear_inplace(A, b, y, Af)
Expand Down

0 comments on commit 39ad951

Please sign in to comment.