Skip to content

Commit

Permalink
make sure rand(sampler, dims) works (JuliaLang#51643)
Browse files Browse the repository at this point in the history
For any object `x` from which one can sample, the `Random` API also
requires that `rand(rng, Sampler(typeof(rng), x), [dims])` works. So add
tests for that, and fix `rand(Tuple{...})` accordingly, which was not
using `SamplerTag` fully correctly.

More precisely, the `Sampler` constructor for tuple types was returning
a `SamplerTag` object whose `gentype` was returning the wrong type,
leading to the wrong eltype of the output array for a call like
`rand(rng, Sampler(rng, Tuple{...}), dims)`, so filling this array with
random values was failing.
  • Loading branch information
rfourquet committed Oct 10, 2023
1 parent ecaf457 commit a857a86
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 35 deletions.
9 changes: 5 additions & 4 deletions stdlib/Random/src/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,19 @@ end

function Sampler(::Type{RNG}, ::Type{T}, n::Repetition) where {T<:Tuple, RNG<:AbstractRNG}
tail_sp_ = Sampler(RNG, Tuple{Base.tail(fieldtypes(T))...}, n)
SamplerTag{T}((Sampler(RNG, fieldtype(T, 1), n), tail_sp_.data...))
SamplerTag{Ref{T}}((Sampler(RNG, fieldtype(T, 1), n), tail_sp_.data...))
# Ref so that the gentype is `T` in SamplerTag's constructor
end

function Sampler(::Type{RNG}, ::Type{Tuple{Vararg{T, N}}}, n::Repetition) where {T, N, RNG<:AbstractRNG}
if N > 0
SamplerTag{Tuple{Vararg{T, N}}}((Sampler(RNG, T, n),))
SamplerTag{Ref{Tuple{Vararg{T, N}}}}((Sampler(RNG, T, n),))
else
SamplerTag{Tuple{}}(())
SamplerTag{Ref{Tuple{}}}(())
end
end

function rand(rng::AbstractRNG, sp::SamplerTag{T}) where T<:Tuple
function rand(rng::AbstractRNG, sp::SamplerTag{Ref{T}}) where T<:Tuple
ntuple(i -> rand(rng, sp.data[min(i, length(sp.data))]), Val{fieldcount(T)}())::T
end

Expand Down
68 changes: 37 additions & 31 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ end

# test all rand APIs
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
realrng = rng == [] ? default_rng() : only(rng)
ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat]
cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...]
types = [Bool, Char, BigFloat, Tuple{Bool, Tuple{Int, Char}}, Pair{Int8, UInt32},
Expand Down Expand Up @@ -321,42 +322,47 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
@test size(f2) == (5,)
@test size(f3) == size(f4) == (2, 3)
for T in functypes[f]
a0 = f(rng..., T) ::T
a1 = f(rng..., T, 5) ::Vector{T}
a2 = f(rng..., T, 2, 3) ::Array{T, 2}
a3 = f(rng..., T, b2, u3) ::Array{T, 2}
a4 = f(rng..., T, (2, 3)) ::Array{T, 2}
if T <: Number
@test size(a0) == ()
end
@test size(a1) == (5,)
@test size(a2) == size(a3) == size(a4) == (2, 3)
if T <: AbstractFloat && f === rand
for a in T[a0, a1..., a2..., a3..., a4...]
@test 0.0 <= a < 1.0
tts = f == rand ? (T, Sampler(realrng, T, Val(1)), Sampler(realrng, T, Val(Inf))) : (T,)
for tt in tts
a0 = f(rng..., tt) ::T
a1 = f(rng..., tt, 5) ::Vector{T}
a2 = f(rng..., tt, 2, 3) ::Array{T, 2}
a3 = f(rng..., tt, b2, u3) ::Array{T, 2}
a4 = f(rng..., tt, (2, 3)) ::Array{T, 2}
if T <: Number
@test size(a0) == ()
end
@test size(a1) == (5,)
@test size(a2) == size(a3) == size(a4) == (2, 3)
if T <: AbstractFloat && f === rand
for a in T[a0, a1..., a2..., a3..., a4...]
@test 0.0 <= a < 1.0
end
end
end
end
end
for (C, T) in collections
a0 = rand(rng..., C) ::T
a1 = rand(rng..., C, 5) ::Vector{T}
a2 = rand(rng..., C, 2, 3) ::Array{T, 2}
a3 = rand(rng..., C, (2, 3)) ::Array{T, 2}
a4 = rand(rng..., C, b2, u3) ::Array{T, 2}
a5 = rand!(rng..., Array{T}(undef, 5), C) ::Vector{T}
a6 = rand!(rng..., Array{T}(undef, 2, 3), C) ::Array{T, 2}
a7 = rand!(rng..., GenericArray{T}(undef, 5), C) ::GenericArray{T, 1}
a8 = rand!(rng..., GenericArray{T}(undef, 2, 3), C) ::GenericArray{T, 2}
a9 = rand!(rng..., OffsetArray(Array{T}(undef, 5), 9), C) ::OffsetArray{T, 1}
a10 = rand!(rng..., OffsetArray(Array{T}(undef, 2, 3), (-2, 4)), C) ::OffsetArray{T, 2}
@test size(a1) == (5,)
@test size(a2) == size(a3) == (2, 3)
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10...]
if C isa Type
@test a isa C
else
@test a in C
for cc = (C, Sampler(realrng, C, Val(1)), Sampler(realrng, C, Val(Inf)))
a0 = rand(rng..., cc) ::T
a1 = rand(rng..., cc, 5) ::Vector{T}
a2 = rand(rng..., cc, 2, 3) ::Array{T, 2}
a3 = rand(rng..., cc, (2, 3)) ::Array{T, 2}
a4 = rand(rng..., cc, b2, u3) ::Array{T, 2}
a5 = rand!(rng..., Array{T}(undef, 5), cc) ::Vector{T}
a6 = rand!(rng..., Array{T}(undef, 2, 3), cc) ::Array{T, 2}
a7 = rand!(rng..., GenericArray{T}(undef, 5), cc) ::GenericArray{T, 1}
a8 = rand!(rng..., GenericArray{T}(undef, 2, 3), cc) ::GenericArray{T, 2}
a9 = rand!(rng..., OffsetArray(Array{T}(undef, 5), 9), cc) ::OffsetArray{T, 1}
a10 = rand!(rng..., OffsetArray(Array{T}(undef, 2, 3), (-2, 4)), cc) ::OffsetArray{T, 2}
@test size(a1) == (5,)
@test size(a2) == size(a3) == (2, 3)
for a in [a0, a1..., a2..., a3..., a4..., a5..., a6..., a7..., a8..., a9..., a10...]
if C isa Type
@test a isa C
else
@test a in C
end
end
end
end
Expand Down

0 comments on commit a857a86

Please sign in to comment.