Skip to content

Commit

Permalink
Use a faster implementation of AliasTables (#1848)
Browse files Browse the repository at this point in the history
* switch to AliasTables.jl

* retune heuristic

* add test for #832

* add more tests

* move alias table import and tighten from using to import

* Back out multinomial heuristic adjustment at @adienes's request

* Update test/univariate/discrete/categorical.jl (style)

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
LilithHafner and devmotion authored Apr 20, 2024
1 parent f33af97 commit b670fee
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["JuliaStats"]
version = "0.25.107"

[deps]
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -30,6 +31,7 @@ DistributionsDensityInterfaceExt = "DensityInterface"
DistributionsTestExt = "Test"

[compat]
AliasTables = "1"
Aqua = "0.8"
Calculus = "0.5"
ChainRulesCore = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import PDMats: dim, PDMat, invquad
using SpecialFunctions
using Base.MathConstants: eulergamma

import AliasTables

export
# re-export Statistics
mean, median, quantile, std, var, cov, cor,
Expand Down
23 changes: 4 additions & 19 deletions src/samplers/aliastable.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
struct AliasTable <: Sampleable{Univariate,Discrete}
accept::Vector{Float64}
alias::Vector{Int}
at::AliasTables.AliasTable{UInt64, Int}
AliasTable(probs::AbstractVector{<:Real}) = new(AliasTables.AliasTable(probs))
end
ncategories(s::AliasTable) = length(s.alias)

function AliasTable(probs::AbstractVector)
n = length(probs)
n > 0 || throw(ArgumentError("The input probability vector is empty."))
accp = Vector{Float64}(undef, n)
alias = Vector{Int}(undef, n)
StatsBase.make_alias_table!(probs, 1.0, accp, alias)
AliasTable(accp, alias)
end

function rand(rng::AbstractRNG, s::AliasTable)
i = rand(rng, 1:length(s.alias)) % Int
# using `ifelse` improves performance here: github.com/JuliaStats/Distributions.jl/pull/1831/
ifelse(rand(rng) < s.accept[i], i, s.alias[i])
end

ncategories(s::AliasTable) = length(s.at)
rand(rng::AbstractRNG, s::AliasTable) = rand(rng, s.at)
show(io::IO, s::AliasTable) = @printf(io, "AliasTable with %d entries", ncategories(s))
2 changes: 2 additions & 0 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ import Distributions:
@testset "p=$p" for p in Any[[1.0], [0.3, 0.7], [0.2, 0.3, 0.4, 0.1]]
test_samples(S(p), Categorical(p), n_tsamples)
test_samples(S(p), Categorical(p), n_tsamples, rng=rng)
@test ncategories(S(p)) == length(p)
end
end

@test string(AliasTable(Float16[1,2,3])) == "AliasTable with 3 entries"

## Binomial samplers

Expand Down
17 changes: 15 additions & 2 deletions test/univariate/discrete/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ end
end

@testset "reproducibility across julia versions" begin
d= Categorical([0.1, 0.2, 0.7])
d = Categorical([0.1, 0.2, 0.7])
rng = StableRNGs.StableRNG(600)
@test rand(rng, d, 10) == [2, 1, 3, 3, 2, 3, 3, 3, 3, 3]
@test rand(rng, d, 10) == [3, 1, 1, 2, 3, 2, 3, 3, 2, 3]
end

@testset "comparisons" begin
Expand Down Expand Up @@ -124,4 +124,17 @@ end
@test Categorical([0.5, 0.5]) Categorical([0.5f0, 0.5f0])
end

@testset "issue #832" begin
priorities = collect(Float64, 1:1000)
priorities[1:50] .= 1e8

at = Distributions.AliasTable(priorities)
iat = rand(at, 16)

# failure rate of a single sample is sum(51:1000)/50e8 = 9.9845e-5
# failure rate of 4 out of 16 samples is 1-cdf(Binomial(16, 9.9845e-5), 3) = 1.8074430840897548e-13
# this test should randomly fail with a probability of 1.8074430840897548e-13
@test count(==(1e8), priorities[iat]) >= 13
end

end

0 comments on commit b670fee

Please sign in to comment.