Skip to content

Commit

Permalink
Merge pull request #16 from jmaack24/main
Browse files Browse the repository at this point in the history
Add user defined matrix multiplication function
  • Loading branch information
andrewning authored Jun 12, 2024
2 parents 62acd97 + fad0e0d commit 3550cb3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
25 changes: 13 additions & 12 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ This version is for linear equations Ay = b
- `A::matrix`, `b::vector`: components of linear system ``A y = b``
- `lsolve::function`: lsolve(A, b). Function to solve the linear system, default is backslash operator.
- `Af::factorization`: An optional factorization of A, useful to override default factorize, or if multiple linear solves will be performed with same A matrix.
- `mmul`: Function to compute ``A*y`` for a vector ``y``. Defaults to the julia multipy operator.
"""
implicit_linear(A, b; lsolve=linear_solve, Af=nothing) = _implicit_linear(A, b, lsolve, Af)
implicit_linear(A, b; lsolve=linear_solve, Af=nothing, mmul=*) = _implicit_linear(A, b, lsolve, mmul, Af)


# If no AD, just solve normally.
_implicit_linear(A, b, lsolve, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b)
_implicit_linear(A, b, lsolve, mmul, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b)
# function implicit_linear(A, b, lsolve, Af, cache)
# if isnothing(cache)
# if isnothing(Af)
Expand All @@ -90,9 +91,9 @@ _implicit_linear(A, b, lsolve, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b
# end

# catch three cases where one or both contain duals
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T)
_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T)
# implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
# implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
# implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache)
Expand All @@ -110,7 +111,7 @@ _implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {
# implicit_linear!(ydot, A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {T} = linear_dual!(ydot, A, b, lsolve, Af, T)

# Both A and b contain duals
function linear_dual(A, b, lsolve, Af, T)
function linear_dual(A, b, lsolve, mmul, Af, T)

# unpack dual numbers (if not dual numbers, since only one might be, just returns itself)
bv = fd_value(b)
Expand All @@ -122,7 +123,7 @@ function linear_dual(A, b, lsolve, Af, T)
yv = lsolve(Afact, bv)

# extract Partials of b - A * y i.e., bdot - Adot * y (since y does not contain duals)
rhs = fd_partials(b - A*yv)
rhs = fd_partials(b - mmul(A,yv))

# solve for new derivatives
ydot = lsolve(Afact, rhs)
Expand Down Expand Up @@ -182,7 +183,7 @@ end


# Provide a ChainRule rule for reverse mode
function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af)
function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, mmul, Af)

# save factorization
Afact = isnothing(Af) ? factorize(ReverseDiff.value(A)) : Af
Expand All @@ -192,16 +193,16 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af)

function implicit_pullback(ybar)
u = lsolve(Afact', ybar)
return NoTangent(), -u*y', u, NoTangent(), NoTangent()
return NoTangent(), -u*y', u, NoTangent(), NoTangent(), NoTangent()
end

return y, implicit_pullback
end

# register above rule for ReverseDiff
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, mmul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af)
ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af)


# function implicit_linear_inplace(A, b, y, Af)
Expand Down
65 changes: 64 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,69 @@ end
@test all(isapprox.(J1, J2, atol=3e-12))
end

@testset "linear_user_mmul" begin

count = 0

function my_multiply(A, x)
(m, n) = size(A)
T = promote_type(eltype(A), eltype(x))
y = zeros(T, m)
for j in 1:n
for i in 1:m
y[i] += A[i, j] * x[j]
end
end
# Provide a way to make sure this function was called
count += 1
return y
end

function test(a)
A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7]
b = 2.0 * a[2:4]
x = implicit_linear(A, b; mmul=my_multiply)
z = 2 * x
return z
end

function test2(a)
A = [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7]
b = 2.0 * a[2:4]
x = implicit_linear(A, b; mmul=my_multiply)
z = 2 * x
return z
end

function test3(a)
A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7]
b = 2.0 * ones(3)
x = implicit_linear(A, b; mmul=my_multiply)
z = 2 * x
return z
end

a = [1.2, 2.3, 3.1, 4.3]
J1 = ForwardDiff.jacobian(test, a)
J2 = ReverseDiff.jacobian(test, a)

@test count == 1
@test all(isapprox.(J1, J2, atol=3e-12))

J1 = ForwardDiff.jacobian(test2, a)
J2 = ReverseDiff.jacobian(test2, a)

@test count == 2
@test all(isapprox.(J1, J2, atol=3e-12))

J1 = ForwardDiff.jacobian(test3, a)
J2 = ReverseDiff.jacobian(test3, a)

@test count == 3
@test all(isapprox.(J1, J2, atol=3e-12))

end

@testset "1d (also parameters)" begin

residual(y, x, p) = y/x[1] + x[2]*cos(y)
Expand Down Expand Up @@ -991,4 +1054,4 @@ end
@test all(isapprox.(J1, J2, atol=1e-15))
@test all(isapprox.(J1, Jfd, atol=1e-9))

end
end

0 comments on commit 3550cb3

Please sign in to comment.