diff --git a/Project.toml b/Project.toml index 23fbf0801..1072d0294 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["JuliaStats"] version = "0.25.107" [deps] +AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -30,6 +31,7 @@ DistributionsDensityInterfaceExt = "DensityInterface" DistributionsTestExt = "Test" [compat] +AliasTables = "1" Aqua = "0.8" Calculus = "0.5" ChainRulesCore = "1" diff --git a/src/Distributions.jl b/src/Distributions.jl index cf3ec3288..8d344c415 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -27,6 +27,8 @@ import PDMats: dim, PDMat, invquad using SpecialFunctions using Base.MathConstants: eulergamma +import AliasTables + export # re-export Statistics mean, median, quantile, std, var, cov, cor, diff --git a/src/samplers/aliastable.jl b/src/samplers/aliastable.jl index 56042930d..1f633cf71 100644 --- a/src/samplers/aliastable.jl +++ b/src/samplers/aliastable.jl @@ -1,22 +1,7 @@ struct AliasTable <: Sampleable{Univariate,Discrete} - accept::Vector{Float64} - alias::Vector{Int} + at::AliasTables.AliasTable{UInt64, Int} + AliasTable(probs::AbstractVector{<:Real}) = new(AliasTables.AliasTable(probs)) end -ncategories(s::AliasTable) = length(s.alias) - -function AliasTable(probs::AbstractVector) - n = length(probs) - n > 0 || throw(ArgumentError("The input probability vector is empty.")) - accp = Vector{Float64}(undef, n) - alias = Vector{Int}(undef, n) - StatsBase.make_alias_table!(probs, 1.0, accp, alias) - AliasTable(accp, alias) -end - -function rand(rng::AbstractRNG, s::AliasTable) - i = rand(rng, 1:length(s.alias)) % Int - # using `ifelse` improves performance here: github.com/JuliaStats/Distributions.jl/pull/1831/ - ifelse(rand(rng) < s.accept[i], i, s.alias[i]) -end - +ncategories(s::AliasTable) = length(s.at) +rand(rng::AbstractRNG, s::AliasTable) = rand(rng, s.at) show(io::IO, s::AliasTable) = @printf(io, "AliasTable with %d entries", ncategories(s)) diff --git a/test/samplers.jl b/test/samplers.jl index 2744ae9ac..749b45d0d 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -31,9 +31,11 @@ import Distributions: @testset "p=$p" for p in Any[[1.0], [0.3, 0.7], [0.2, 0.3, 0.4, 0.1]] test_samples(S(p), Categorical(p), n_tsamples) test_samples(S(p), Categorical(p), n_tsamples, rng=rng) + @test ncategories(S(p)) == length(p) end end + @test string(AliasTable(Float16[1,2,3])) == "AliasTable with 3 entries" ## Binomial samplers diff --git a/test/univariate/discrete/categorical.jl b/test/univariate/discrete/categorical.jl index a835f7c13..6d87d4dc8 100644 --- a/test/univariate/discrete/categorical.jl +++ b/test/univariate/discrete/categorical.jl @@ -93,9 +93,9 @@ end end @testset "reproducibility across julia versions" begin - d= Categorical([0.1, 0.2, 0.7]) + d = Categorical([0.1, 0.2, 0.7]) rng = StableRNGs.StableRNG(600) - @test rand(rng, d, 10) == [2, 1, 3, 3, 2, 3, 3, 3, 3, 3] + @test rand(rng, d, 10) == [3, 1, 1, 2, 3, 2, 3, 3, 2, 3] end @testset "comparisons" begin @@ -124,4 +124,17 @@ end @test Categorical([0.5, 0.5]) ≈ Categorical([0.5f0, 0.5f0]) end +@testset "issue #832" begin + priorities = collect(Float64, 1:1000) + priorities[1:50] .= 1e8 + + at = Distributions.AliasTable(priorities) + iat = rand(at, 16) + + # failure rate of a single sample is sum(51:1000)/50e8 = 9.9845e-5 + # failure rate of 4 out of 16 samples is 1-cdf(Binomial(16, 9.9845e-5), 3) = 1.8074430840897548e-13 + # this test should randomly fail with a probability of 1.8074430840897548e-13 + @test count(==(1e8), priorities[iat]) >= 13 +end + end