Skip to content

Commit

Permalink
Support boxdot with n neighboring indices
Browse files Browse the repository at this point in the history
  • Loading branch information
KeitaNakamura committed Oct 25, 2024
1 parent ca2e4ca commit ed9d83f
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 40 deletions.
99 changes: 59 additions & 40 deletions src/TensorCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
return dest
end

export boxdot, , boxdot!
export boxdot, , ₂, boxdot!

"""
boxdot(A,B) = A ⊡ B # \\boxdot
Expand Down Expand Up @@ -177,40 +177,54 @@ Float64
```
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
"""
function boxdot(A::AbstractArray, B::AbstractArray)
Amat = _squash_left(A)
Bmat = _squash_right(B)
function boxdot(A::AbstractArray, B::AbstractArray, nth::Val)
_check_boxdot_axes(A, B, nth)
Amat = _squash_left(A, nth)
Bmat = _squash_right(B, nth)

axA, axB = axes(Amat,2), axes(Bmat,1)
axA == axB || _throw_dmm(axA, axB)

return _boxdot_reshape(Amat * Bmat, A, B)
return _boxdot_reshape(Amat * Bmat, A, B, nth)
end

boxdot(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(1))
boxdot2(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(2))

const = boxdot
const = boxdot2

@noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB"))
@noinline _throw_boxdot_nth(n) = throw(ArgumentError("boxdot order should be ≥ 1, got $n"))

function _check_boxdot_axes(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,M}, ::Val{K}) where {N,M,K}
K::Int
(K >= 1) || _throw_boxdot_nth(K)
for i in 1:K
axA, axB = axes(A)[N-K+i], axes(B)[i]
axA == axB || _throw_dmm(axA, axB)
end
end

_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A)))
_squash_left(A::AbstractMatrix) = A
_squash_left(A::AbstractArray, ::Val{N}) where {N} = reshape(A, prod(size(A)[1:end-N]),:)
_squash_left(A::AbstractMatrix, ::Val{1}) = A

_squash_right(B::AbstractArray) = reshape(B, size(B,1),:)
_squash_right(B::AbstractVecOrMat) = B
_squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[1+N:end]))
_squash_right(B::AbstractVecOrMat, ::Val{1}) = B

function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M}
ax = ntuple(i -> i<N ? axes(A, i) : axes(B, i-N+2), Val(N+M-2))
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}, ::Val{K}) where {T,N,S,M,K}
ax = ntuple(i -> iN-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K))
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
end

# These can skip final reshape:
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val) = AB

# These produce scalar output:
function boxdot(A::AbstractVector, B::AbstractVector)
axA, axB = axes(A,1), axes(B,1)
axA == axB || _throw_dmm(axA, axB)
function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N}
_check_boxdot_axes(A, B, Val(N))
if eltype(A) <: Number
return transpose(A)*B
return transpose(vec(A))*vec(B)
else
return sum(a*b for (a,b) in zip(A,B))
end
Expand All @@ -224,30 +238,30 @@ boxdot(a::Number, b::Number) = a*b
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

# Adjont and Transpose, vectors or almost (returning a scalar)
boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B
boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B
boxdot(A::AdjointAbsVec, B::AbstractVector, ::Val{1}) = A * B
boxdot(A::TransposeAbsVec, B::AbstractVector, ::Val{1}) = A * B

boxdot(A::AbstractVector, B::AdjointAbsVec) = A vec(B)
boxdot(A::AbstractVector, B::TransposeAbsVec) = A vec(B)
boxdot(A::AbstractVector, B::AdjointAbsVec, ::Val{1}) = A vec(B)
boxdot(A::AbstractVector, B::TransposeAbsVec, ::Val{1}) = A vec(B)

boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) adjoint(A))
boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) vec(B)
boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) vec(B)
boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
boxdot(A::AdjointAbsVec, B::AdjointAbsVec, ::Val{1}) = adjoint(adjoint(B) adjoint(A))
boxdot(A::AdjointAbsVec, B::TransposeAbsVec, ::Val{1}) = vec(A) vec(B)
boxdot(A::TransposeAbsVec, B::AdjointAbsVec, ::Val{1}) = vec(A) vec(B)
boxdot(A::TransposeAbsVec, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))

# ... with a matrix (returning another such)
boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B
boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B
boxdot(A::AdjointAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
boxdot(A::TransposeAbsVec, B::AbstractMatrix, ::Val{1}) = A * B

boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' A')'
boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
boxdot(A::AbstractMatrix, B::AdjointAbsVec, ::Val{1}) = (B' A')'
boxdot(A::AbstractMatrix, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))

# ... and with higher-dim (returning a plain array)
boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) B
boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) B
boxdot(A::AdjointAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B
boxdot(A::TransposeAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B

boxdot(A::AbstractArray, B::AdjointAbsVec) = A vec(B)
boxdot(A::AbstractArray, B::TransposeAbsVec) = A vec(B)
boxdot(A::AbstractArray, B::AdjointAbsVec, ::Val{1}) = A vec(B)
boxdot(A::AbstractArray, B::TransposeAbsVec, ::Val{1}) = A vec(B)


"""
Expand All @@ -260,25 +274,30 @@ function boxdot! end

if VERSION < v"1.3" # Then 5-arg mul! isn't defined

function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B))
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}) where {N}
_check_boxdot_axes(A, B, Val(N))
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)))
Y
end

boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B))
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) = boxdot!(Y, A, B, Val(1))
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B), Val(1))

else

function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β)
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}, α::Number=true, β::Number=false) where {N}
_check_boxdot_axes(A, B, Val(N))
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)), α, β)
Y
end

boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) = boxdot!(Y, A, B, Val(1), α, β)

# For boxdot!, only where mul! behaves differently:
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec,
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), α, β)
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), Val(1), α, β)

end

Expand Down
69 changes: 69 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,75 @@ end
@test boxdot!(similar(c,1), c', d) == [dot(c, d)]
end

@testset "higher-order boxdot" begin

# Arrays
A = [1 2+im; 3 4im; 5 6+im]
B = [5im 6; 7+im 8; 9im 10]
E3 = cat(A, B, conj(A .+ 1), dims=3)
F4 = cat(E3, conj(E3 .+ 1), dims=4)
E3adjoint = conj(permutedims(E3, (3,2,1)))
F4adjoint = conj(permutedims(F4, (4,3,2,1)))
E3lazy = PermutedDimsArray(permutedims(E3, (3,2,1)), (3,2,1))
F4lazy = PermutedDimsArray(permutedims(F4, (4,3,2,1)), (4,3,2,1))
@test E3lazy == E3
@test F4lazy == F4

@test A ₂ A isa Complex
@test boxdot(E3, E3, Val(3)) isa Complex
@test boxdot(F4, F4, Val(4)) isa Complex
@test A ₂ A == sum(A .* A)
@test boxdot(E3, E3, Val(3)) == sum(E3 .* E3)
@test boxdot(F4, F4, Val(4)) == sum(F4 .* F4)

@test size(A ₂ E3) == (3,)
@test A ₂ E3 == vec(reshape(A, 1,:) * reshape(E3, :,3))
@test A ₂ E3lazy == A ₂ E3
@test E3 ₂ A' == vec((A ₂ E3adjoint)')
@test E3 transpose(A) == A conj(E3adjoint)

@test size(A ₂ F4) == (3,2)
@test A ₂ F4 == reshape(reshape(A, 1,:) * reshape(F4, :,6), 3,2)
@test F4 ₂ A == (A' ₂ F4adjoint)'
@test A ₂ F4lazy == A ₂ F4
@test F4lazy ₂ A == F4 ₂ A

@test size(F4 ₂ E3) == (3,2,3)
@test F4 ₂ E3 == reshape(reshape(F4, 6,:) * reshape(E3, :,3), 3,2,3)
@test F4 ₂ E3adjoint == conj(permutedims(E3 ₂ F4adjoint, (3,2,1)))
@test F4 ₂ E3 == F4lazy ₂ E3lazy

# In-place
c = A ₂ E3
@test boxdot!(similar(c), A, E3, Val(2)) == A ₂ E3
if VERSION >= v"1.3"
@test boxdot!(similar(c), A, E3, Val(2), 100) == A ₂ E3 * 100
@test boxdot!(copy(c), B, E3, Val(2), 100, -5) == B ₂ E3 * 100 .- 5 .* c
end

@test boxdot!(similar(c,1), A, A, Val(2)) == [A ₂ A]
@test boxdot!(similar(c,3,2), A, F4, Val(2)) == A ₂ F4
@test boxdot!(similar(c,3,2,3), F4, E3, Val(2)) == F4 ₂ E3

# Errors
@test_throws DimensionMismatch ones(2,2) ones(3,2)
@test_throws DimensionMismatch ones(2,2) ones(2,3)
@test_throws DimensionMismatch ones(2,2,2) ones(2,3,2)
@test_throws BoundsError ones(2,2) ones(2)
@test_throws BoundsError ones(2) ones(2,2)
@test_throws ArgumentError boxdot(ones(2), ones(2), Val(-1))
@test_throws TypeError boxdot(ones(2), ones(2), Val(UInt(1)))

@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(3,2), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(2,3), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,2,2), ones(2,2,2), ones(2,3,2), Val(2))
@test_throws BoundsError boxdot!(similar(c,1), ones(2,2), ones(2), Val(2))
@test_throws BoundsError boxdot!(similar(c,1), ones(2), ones(2,2), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,2,3), ones(2,2,3), ones(2,3,2), Val(2))
@test_throws ArgumentError boxdot!(similar(c,1), ones(2), ones(2), Val(-1))
@test_throws TypeError boxdot!(similar(c,1), ones(2), ones(2), Val(UInt(1)))
end

@testset "_adjoint" begin
A = [1 2+im; 3 4im]
E3 = cat(A, -A, dims=3)
Expand Down

0 comments on commit ed9d83f

Please sign in to comment.