Skip to content

Commit

Permalink
Add nansem
Browse files Browse the repository at this point in the history
  • Loading branch information
brenhinkeller committed Jan 29, 2024
1 parent b331c4f commit e7aff7d
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 1 deletion.
272 changes: 272 additions & 0 deletions src/ArrayStats/nansem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
"""
```julia
nansem(A; dims=:, mean=nothing, corrected=true)
```
Compute the standard error of the mean of all non-`NaN` elements in `A`, optionally
over dimensions specified by `dims`.
A precomputed `mean` may optionally be provided, which results in a somewhat faster
calculation. If `corrected` is `true`, then _Bessel's correction_ is applied, such
that the sum is divided by `n-1` rather than `n`.
As an alternative to `dims`, `nansem` also supports the `dim` keyword, which
behaves identically to `dims`, but also drops any singleton dimensions that have
been reduced over (as is the convention in some other languages).
## Examples
```julia
julia> using NaNStatistics
julia> A = [1 2; 3 4]
2×2 Matrix{Int64}:
1 2
3 4
julia> nansem(A, dims=1)
1×2 Matrix{Float64}:
2.0 2.0
julia> nansem(A, dims=2)
2×1 Matrix{Float64}:
0.5
0.5
```
"""
nansem(A; dims=:, dim=:, mean=nothing, corrected=true) = __nansem(mean, corrected, A, dims, dim)
__nansem(mean, corrected, A, ::Colon, ::Colon) = _nansem(mean, corrected, A, :)
__nansem(mean, corrected, A, region, ::Colon) = _nansem(mean, corrected, A, region)
__nansem(mean, corrected, A, ::Colon, region) = reducedims(_nansem(mean, corrected, A, region), region)
export nansem

# If dims is an integer, wrap it in a tuple
_nansem(μ, corrected::Bool, A, dims::Int) = _nansem(μ, corrected, A, (dims,))

# If the mean isn't known, compute it
_nansem(::Nothing, corrected::Bool, A, dims::Tuple) = _nansem!(_nanmean(A, dims), corrected, A, dims)
# Reduce all the dims!
function _nansem(::Nothing, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
Tₒ = Base.promote_op(/, T, Int)
n = 0
Σ == zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Aᵢ = A[i]
notnan = Aᵢ==Aᵢ
n += notnan
Σ += ifelse(notnan, Aᵢ, ∅)
end
μ = Σ / n
σ² == zero(typeof(μ))
@turbo check_empty=true for i eachindex(A)
δ = A[i] - μ
notnan = δ==δ
σ² += ifelse(notnan, δ * δ, ∅)
end
return sqrt(σ² / max(n-corrected,0) / n)
end
function _nansem(::Nothing, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
Tₒ = Base.promote_op(/, T, Int)
n = length(A)
Σ = zero(Tₒ)
@turbo check_empty=true for i eachindex(A)
Σ += A[i]
end
μ = Σ / n
σ² = zero(typeof(μ))
@turbo check_empty=true for i eachindex(A)
δ = A[i] - μ
σ² += δ * δ
end
return sqrt(σ² / max(n-corrected,0) / n)
end
# Fallback method for non-StridedArrays
function _nansem(::Nothing, corrected::Bool, A, ::Colon)
Tₒ = Base.promote_op(/, eltype(A), Int)
n = 0
Σ == zero(Tₒ)
@inbounds for i eachindex(A)
Aᵢ = A[i]
notnan = Aᵢ==Aᵢ
n += notnan
Σ += ifelse(notnan, Aᵢ, ∅)
end
μ = Σ / n
σ² == zero(typeof(μ))
@inbounds for i eachindex(A)
δ = A[i] - μ
notnan = δ==δ
σ² += ifelse(notnan, δ * δ, ∅)
end
return sqrt(σ² / max(n-corrected,0) / n)
end


# If the mean is known, pass it on in the appropriate form
_nansem(μ, corrected::Bool, A, dims::Tuple) = _nansem!(collect(μ), corrected, A, dims)
_nansem::Array, corrected::Bool, A, dims::Tuple) = _nansem!(copy(μ), corrected, A, dims)
_nansem::Number, corrected::Bool, A, dims::Tuple) = _nansem!([μ], corrected, A, dims)
# Reduce all the dims!
function _nansem::Number, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveFloat
n = 0
σ² == zero(typeof(μ))
@turbo check_empty=true for i eachindex(A)
δ = A[i] - μ
notnan = δ==δ
n += notnan
σ² += ifelse(notnan, δ * δ, ∅)
end
return sqrt(σ² / max(n-corrected, 0) / n)
end
function _nansem::Number, corrected::Bool, A::StridedArray{T}, ::Colon) where T<:PrimitiveInteger
σ² = zero(typeof(μ))
if μ==μ
@turbo check_empty=true for i eachindex(A)
δ = A[i] - μ
σ² += δ * δ
end
n = length(A)
else
n = 0
end
return sqrt(σ² / max(n-corrected, 0) / n)
end
# Fallback method for non-strided-arrays
function _nansem::Number, corrected::Bool, A, ::Colon)
n = 0
σ² == zero(typeof(μ))
@inbounds for i eachindex(A)
δ = A[i] - μ
notnan = δ==δ
n += notnan
σ² += ifelse(notnan, δ * δ, ∅)
end
return sqrt(σ² / max(n-corrected, 0) / n)
end

# # Fallback method for overly-complex reductions
# function _nansem_fallback!(B::AbstractArray, corrected::Bool, A::AbstractArray,region)
# mask = nanmask(A)
# N = sum(mask, dims=region)
# Σ = sum(A.*mask, dims=region)./N
# δ = A .- Σ # Subtract mean, using broadcasting
# @turbo check_empty=true for i ∈ eachindex(δ)
# δᵢ = δ[i]
# δ[i] = ifelse(mask[i], δᵢ * δᵢ, 0)
# end
# B .= sum(δ, dims=region)
# @turbo check_empty=true for i ∈ eachindex(B)
# B[i] = B[i] / max(N[i] - corrected, 0)
# end
# return B
# end



function staticdim_nansem_quote(static_dims::Vector{Int}, N::Int)
M = length(static_dims)
# `static_dims` now contains every dim we're taking the var over.
Bᵥ = Expr(:call, :view, :B)
reduct_inds = Int[]
nonreduct_inds = Int[]
# Firstly, build our expressions for indexing each array
Aind = :(A[])
Bind = :(Bᵥ[])
inds = Vector{Symbol}(undef, N)
for n 1:N
ind = Symbol(:i_,n)
inds[n] = ind
push!(Aind.args, ind)
if n static_dims
push!(reduct_inds, n)
push!(Bᵥ.args, :(firstindex(B,$n)))
else
push!(nonreduct_inds, n)
push!(Bᵥ.args, :)
push!(Bind.args, ind)
end
end
firstn = first(nonreduct_inds)
# Secondly, build up our set of loops
block = Expr(:block)
loops = Expr(:for, :($(inds[firstn]) = indices((A,B),$firstn)), block)
if length(nonreduct_inds) > 1
for n @view(nonreduct_inds[2:end])
newblock = Expr(:block)
push!(block.args, Expr(:for, :($(inds[n]) = indices((A,B),$n)), newblock))
block = newblock
end
end
rblock = block
# Push more things here if you want them at the beginning of the reduction loop
push!(rblock.args, :(μ = $Bind))
push!(rblock.args, :(n = 0))
push!(rblock.args, :(σ² = ∅))
# Build the reduction loop
for n reduct_inds
newblock = Expr(:block)
push!(block.args, Expr(:for, :($(inds[n]) = axes(A,$n)), newblock))
block = newblock
end
# Push more things here if you want them in the innermost loop
push!(block.args, :(δ = $Aind - μ))
push!(block.args, :(notnan = δ==δ))
push!(block.args, :(n += notnan))
push!(block.args, :(σ² += ifelse(notnan, δ * δ, ∅)))
# Push more things here if you want them at the end of the reduction loop
push!(rblock.args, :($Bind = sqrt(σ² * inv(max(n-corrected,0)) * inv(n))))

# Put it all together
quote
= zero(eltype(B))
Bᵥ = $Bᵥ
@inbounds $loops
return B
end
end

function branches_nansem_quote(N::Int, M::Int, D)
static_dims = Int[]
for m 1:M
param = D.parameters[m]
if param <: StaticInt
new_dim = _dim(param)::Int
@assert new_dim static_dims
push!(static_dims, new_dim)
else
t = Expr(:tuple)
for n static_dims
push!(t.args, :(StaticInt{$n}()))
end
q = Expr(:block, :(dimm = dims[$m]))
qold = q
ifsym = :if
for n 1:N
n static_dims && continue
tc = copy(t)
push!(tc.args, :(StaticInt{$n}()))
qnew = Expr(ifsym, :(dimm == $n), :(return _nansem!(B, corrected, A, $tc)))
for r m+1:M
push!(tc.args, :(dims[$r]))
end
push!(qold.args, qnew)
qold = qnew
ifsym = :elseif
end
push!(qold.args, Expr(:block, :(throw("Dimension `$dimm` not found."))))
return q
end
end
staticdim_nansem_quote(static_dims, N)
end

# Efficient @generated in-place var
@generated function _nansem!(B::AbstractArray{Tₒ,N}, corrected::Bool, A::AbstractArray{T,N}, dims::D) where {Tₒ,T,N,M,D<:Tuple{Vararg{IntOrStaticInt,M}}}
N == M && return :(B[1] = _nansem(B[1], corrected, A, :); B)
# total_combinations = binomial(N,M)
# if total_combinations > 6
# # Fallback, for overly-complex reductions
# return :(_nansem_fallback!(B, corrected, A, dims))
# else
branches_nansem_quote(N, M, D)
# end
end
1 change: 1 addition & 0 deletions src/NaNStatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module NaNStatistics
include("ArrayStats/nancumsum.jl")
include("ArrayStats/nanvar.jl")
include("ArrayStats/nanstd.jl")
include("ArrayStats/nansem.jl")
include("ArrayStats/nancov.jl")
include("Sorting/quicksort.jl")
include("Sorting/nanmedian.jl")
Expand Down
25 changes: 24 additions & 1 deletion test/testArrayStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
@test nanextrema(A) === (1.0, 10.0)
@test nanvar([1,2,3,NaN]) === 1.0
@test nanstd([1,2,3,NaN]) === 1.0
@test nansem([1,2,3,NaN]) 1/sqrt(3)
@test nanstd([1,2,3,NaN], ones(4)) === 1.0 # weighted
@test nanmad([1,2,3,NaN]) === 1.0
@test nanaad([1,2,3,NaN]) 2/3
Expand All @@ -62,6 +63,7 @@
@test isnan(nanvar(A))
@test isnan(nanvar(A, mean=NaN))
@test isnan(nanstd(A))
@test isnan(nansem(A))
@test isnan(nanstd(A, ones(10))) # weighted
@test isnan(nanaad(A))
@test isnan(nanmad(A))
Expand All @@ -81,6 +83,7 @@
@test isnan(nanvar(A))
@test isnan(nanvar(A, mean=0))
@test isnan(nanstd(A))
@test isnan(nansem(A))
@test isnan(nanstd(A, copy(A))) # weighted
@test isnan(nanaad(A))
@test isnan(nanmad(A))
Expand All @@ -107,6 +110,7 @@
@test nanvar([1,2,3]) === 1.0
@test nanvar([1,2,3], mean=2) === 1.0
@test nanstd([1,2,3]) === 1.0
@test nansem([1,2,3]) 1/sqrt(3)
@test nanstd([1,2,3], ones(3)) === 1.0 # weighted
@test nanmad([1,2,3]) === 1.0
@test nanaad([1,2,3]) 2/3
Expand All @@ -131,6 +135,7 @@
@test nanvar(1:3) === 1.0
@test nanvar(1:3, mean=2) === 1.0
@test nanstd(1:3) === 1.0
@test nansem(1:3) 1/sqrt(3)
@test nanstd(1:3, ones(3)) === 1.0 # weighted
@test nanmad(1:3) === 1.0
@test nanaad(1:3) 2/3
Expand All @@ -151,6 +156,7 @@
@test nanvar(1:3.) === 1.0
@test nanvar(1:3., mean=2.0) === 1.0
@test nanstd(1:3.) === 1.0
@test nansem(1:3.) 1/sqrt(3)
@test nanstd(1:3., ones(3)) === 1.0 # weighted
@test nanmad(1:3.) === 1.0
@test nanaad(1:3.) 2/3
Expand Down Expand Up @@ -179,6 +185,10 @@
@test nanstd(A, dims=2) std(A, dims=2)
@test nanstd(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)
@test nanstd(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)
@test nansem(A, dims=1) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2) std(A, dims=2)./sqrt(size(A,2))
@test nansem(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)./sqrt(size(A,2))
@test nanstd(A, ones(size(A)), dims=1) std(A, dims=1) # weighted
@test nanstd(A, ones(size(A)), dims=2) std(A, dims=2) # weighted
@test nanmad(A, dims=1) == [25.0 25.0 25.0]
Expand Down Expand Up @@ -213,6 +223,10 @@
@test nanstd(A, dims=2) std(A, dims=2)
@test nanstd(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)
@test nanstd(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)
@test nansem(A, dims=1) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2) std(A, dims=2)./sqrt(size(A,2))
@test nansem(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)./sqrt(size(A,2))
@test nanstd(A, ones(size(A)), dims=1) std(A, dims=1) # weighted
@test nanstd(A, ones(size(A)), dims=2) std(A, dims=2) # weighted
@test nanmedian(A, dims=1) == median(A, dims=1)
Expand Down Expand Up @@ -258,6 +272,8 @@
@test nanvar(A, dim=2, mean=nanmean(A,dims=2)) vec(var(A, dims=2))
@test nanstd(A, dim=1) vec(std(A, dims=1))
@test nanstd(A, dim=2) vec(std(A, dims=2))
@test nansem(A, dim=1) vec(std(A, dims=1)./sqrt(size(A, 1)))
@test nansem(A, dim=2) vec(std(A, dims=2)./sqrt(size(A, 2)))
@test nanstd(A, ones(size(A)), dim=1) vec(std(A, dims=1)) # weighted
@test nanstd(A, ones(size(A)), dim=2) vec(std(A, dims=2)) # weighted
@test nanmedian(A, dim=1) == vec(median(A, dims=1))
Expand Down Expand Up @@ -305,6 +321,10 @@
@test nanstd(A, dims=2) std(A, dims=2)
@test nanstd(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)
@test nanstd(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)
@test nansem(A, dims=1) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2) std(A, dims=2)./sqrt(size(A,2))
@test nansem(A, dims=1, mean=nanmean(A,dims=1)) std(A, dims=1)./sqrt(size(A,1))
@test nansem(A, dims=2, mean=nanmean(A,dims=2)) std(A, dims=2)./sqrt(size(A,2))
@test nanstd(A, ones(size(A)), dims=1) std(A, dims=1) # weighted
@test nanstd(A, ones(size(A)), dims=2) std(A, dims=2) # weighted
@test nanmedian(A, dims=1) == median(A, dims=1)
Expand Down Expand Up @@ -337,6 +357,8 @@
@test nanmean(A, dims=(4,5,6)) mean(A, dims=(4,5,6))
@test nanstd(A, dims=(4,5,6)) std(A, dims=(4,5,6))
@test nanstd(A, dims=(4,5,6)) nanstd(A, dims=(4,5,6), mean=nanmean(A, dims=(4,5,6)))
@test nansem(A, dims=(4,5,6)) std(A, dims=(4,5,6))./sqrt(size(A,4)*size(A,5)*size(A,6))
@test nansem(A, dims=(4,5,6)) nansem(A, dims=(4,5,6), mean=nanmean(A, dims=(4,5,6)))

## --- Test in-place vs. out-of-place versions

Expand Down Expand Up @@ -374,7 +396,8 @@
@test nanmean((1,2,3,4,5)) === 3.0
@test nanstd((1,2,3)) === 1.0
@test nanstd((1,2,3), mean=2.0) === 1.0

@test nansem((1,2,3)) 1/sqrt(3)
@test nansem((1,2,3), mean=2.0) 1/sqrt(3)

## --- Standardization

Expand Down

0 comments on commit e7aff7d

Please sign in to comment.