Skip to content

Commit

Permalink
Merge pull request #28 from schrimpf/fixStaticArrays
Browse files Browse the repository at this point in the history
corrected rand! for mutable static arrays and moved them to an extension
  • Loading branch information
chriselrod authored Oct 4, 2023
2 parents c34cf44 + bfe3c8f commit 33880b1
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 91 deletions.
18 changes: 13 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
name = "VectorizedRNG"
uuid = "33b4df10-0173-11e9-2a0c-851a7edac40e"
authors = ["Chris Elrod <[email protected]>"]
version = "0.2.24"
version = "0.2.25"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

[weakdeps]
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
VectorizedRNGStaticArraysExt = ["StaticArraysCore"]

[compat]
Requires = "1"
SLEEFPirates = "0.6.29"
StaticArraysCore = "1"
StrideArraysCore = "0.3, 0.4"
UnPack = "1"
VectorizationBase = "0.19.38, 0.20.1, 0.21"
julia = "1.6"

[extras]
RNGTest = "97cc5700-e6cb-5ca1-8fb2-7f6b45264ecd"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "RNGTest"]
test = ["Test", "RNGTest", "StaticArrays"]
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"

[compat]
Documenter = "1"
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ using Documenter
makedocs(;
modules = [VectorizedRNG],
authors = "Chris Elrod",
repo = "https://github.com/JuliaSIMD/VectorizedRNG.jl/blob/{commit}{path}#L{line}",
repo = Remotes.GitHub("JuliaSIMD","VectorizedRNG.jl"),
sitename = "VectorizedRNG.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://JuliaSIMD.github.io/VectorizedRNG.jl",
assets = String[]
),
pages = ["Home" => "index.md"],
strict = false
warnonly = true
)

deploydocs(; repo = "github.com/JuliaSIMD/VectorizedRNG.jl")
54 changes: 54 additions & 0 deletions ext/VectorizedRNGStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module VectorizedRNGStaticArraysExt

using VectorizedRNG: samplevector!, random_uniform, random_normal, AbstractVRNG, random_unsigned
if isdefined(Base, :get_extension)
using StaticArraysCore
else
using ..StaticArraysCore
end
using VectorizationBase: StaticInt
import Random

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T},
α::Number = StaticInt{0}(),
β = StaticInt{0}(),
γ = StaticInt{1}()
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_uniform, rng, x, α, β, γ, identity)
end
return x
end

function Random.randn!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T},
α::Number = StaticInt{0}(),
β = StaticInt{0}(),
γ = StaticInt{1}()
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_normal, rng, x, α, β, γ, identity)
end
return x
end

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,UInt64}
)
samplevector!(
random_unsigned,
rng,
x,
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end


end
9 changes: 9 additions & 0 deletions src/VectorizedRNG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ using SLEEFPirates

using Distributed: myid

if !isdefined(Base, :get_extension)
using Requires
end

export local_rng, rand!, randn!#, randexp, randexp!

abstract type AbstractVRNG{N} <: Random.AbstractRNG end
Expand Down Expand Up @@ -119,6 +123,11 @@ function __init()
end
end
function __init__()
@static if !isdefined(Base, :get_extension)
@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin
include("../ext/VectorizedRNGStaticArraysExt.jl")
end
end
ccall(:jl_generating_output, Cint, ()) == 1 && return
__init()
end
Expand Down
83 changes: 2 additions & 81 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,77 +375,16 @@ function Random.randn!(
) where {T<:Union{Float32,Float64}}
samplevector!(random_normal, rng, x, α, β, γ, identity)
end

@inline function random_unsigned(
state::AbstractState,
::Val{N},
::Type{T}
) where {N,T}
nextstate(state, Val{N}())
end
function Random.rand!(rng::AbstractVRNG, x::AbstractArray{UInt64})
samplevector!(
random_unsigned,
rng,
x,
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end

using StaticArraysCore, StrideArraysCore
function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T}
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_uniform, rng, PtrArray(x), α, β, γ, identity)
end
return x
end
function Random.rand!(
rng::AbstractVRNG,
x::SA
) where {
S<:Tuple,
T<:Union{Float32,Float64},
SA<:StaticArraysCore.StaticArray{S,T}
}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(random_uniform, rng, PtrArray(a), α, β, γ, identity)
end
x .= a
end
function Random.randn!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,T}
) where {T<:Union{Float32,Float64}}
GC.@preserve x begin
samplevector!(random_normal, rng, PtrArray(x), α, β, γ, identity)
end
return x
end
function Random.randn!(
rng::AbstractVRNG,
x::SA
) where {
S<:Tuple,
T<:Union{Float32,Float64},
SA<:StaticArraysCore.StaticArray{S,T}
}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(random_normal, rng, PtrArray(a), α, β, γ, identity)
end
x .= a
end

function Random.rand!(
rng::AbstractVRNG,
x::StaticArraysCore.MArray{<:Tuple,UInt64}
)
function Random.rand!(rng::AbstractVRNG, x::AbstractArray{UInt64})
samplevector!(
random_unsigned,
rng,
Expand All @@ -456,24 +395,6 @@ function Random.rand!(
identity
)
end
function Random.rand!(
rng::AbstractVRNG,
x::SA
) where {S<:Tuple,SA<:StaticArraysCore.StaticArray{S,UInt64}}
a = MArray{S,UInt64}(undef)
GC.@preserve a begin
samplevector!(
random_unsigned,
rng,
PtrArray(a),
StaticInt{0}(),
StaticInt{0}(),
StaticInt{1}(),
identity
)
end
x .= a
end

Random.rand(rng::AbstractVRNG, d1::Integer, dims::Vararg{Integer,N}) where {N} =
rand!(rng, Array{Float64}(undef, d1, dims...))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
RNGTest = "97cc5700-e6cb-5ca1-8fb2-7f6b45264ecd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31 changes: 28 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test
using InteractiveUtils: versioninfo
versioninfo(; verbose = true)

using RNGTest, Random, SpecialFunctions, Aqua, Distributions
using RNGTest, Random, SpecialFunctions, Aqua, Distributions, StaticArrays

const α = 1e-4

Expand Down Expand Up @@ -171,11 +171,36 @@ end
vrng = local_rng()
σ = 0.5
for i = 1:N
randn!(vrng, x, VectorizedRNG.static(0), VectorizedRNG.static(0), σ)
randn!(vrng, x, VectorizedRNG.StaticInt(0), VectorizedRNG.StaticInt(0), σ)
s += std(x)
end
s /= N
@test s σ rtol = 1e-1
end
end
end

@testset "StaticArrays" begin
seed = 1234
rng = local_rng()
for T in (Float32, Float64, UInt64, Int)
for dim in ((10),(10,10), (10,10,10))
A = zeros(T, dim)
mA = MArray{Tuple{dim...}}(A)
VectorizedRNG.seed!(seed)
rand!(rng, A)
VectorizedRNG.seed!(seed)
rand!(rng, mA)
@test all(A .== mA)

if T <: AbstractFloat
VectorizedRNG.seed!(seed)
randn!(rng, A)
VectorizedRNG.seed!(seed)
randn!(rng, mA)
@test all(A .== mA)
end
end
end
end

end

2 comments on commit 33880b1

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92788

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.25 -m "<description of version>" 33880b1497227d3bae27c5b6bc64085cf0ed02c8
git push origin v0.2.25

Please sign in to comment.