From fad0e0d0da4da8233f466d76f3ce10755155a742 Mon Sep 17 00:00:00 2001 From: Jonathan Maack Date: Thu, 11 Apr 2024 09:13:33 -0600 Subject: [PATCH] Add test for user defined matrix multiplication --- test/runtests.jl | 65 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 30d1457..0aad926 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -292,6 +292,69 @@ end @test all(isapprox.(J1, J2, atol=3e-12)) end +@testset "linear_user_mmul" begin + + count = 0 + + function my_multiply(A, x) + (m, n) = size(A) + T = promote_type(eltype(A), eltype(x)) + y = zeros(T, m) + for j in 1:n + for i in 1:m + y[i] += A[i, j] * x[j] + end + end + # Provide a way to make sure this function was called + count += 1 + return y + end + + function test(a) + A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * a[2:4] + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + function test2(a) + A = [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * a[2:4] + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + function test3(a) + A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * ones(3) + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + a = [1.2, 2.3, 3.1, 4.3] + J1 = ForwardDiff.jacobian(test, a) + J2 = ReverseDiff.jacobian(test, a) + + @test count == 1 + @test all(isapprox.(J1, J2, atol=3e-12)) + + J1 = ForwardDiff.jacobian(test2, a) + J2 = ReverseDiff.jacobian(test2, a) + + @test count == 2 + @test all(isapprox.(J1, J2, atol=3e-12)) + + J1 = ForwardDiff.jacobian(test3, a) + J2 = ReverseDiff.jacobian(test3, a) + + @test count == 3 + @test all(isapprox.(J1, J2, atol=3e-12)) + +end + @testset "1d (also parameters)" begin residual(y, x, p) = y/x[1] + x[2]*cos(y) @@ -991,4 +1054,4 @@ end @test all(isapprox.(J1, J2, atol=1e-15)) @test all(isapprox.(J1, Jfd, atol=1e-9)) -end \ No newline at end of file +end