Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rand sampling from an array accepts an optional rng #9065

Merged
merged 1 commit into from
Nov 19, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ rand(T::Type, dims::Dims) = rand(GLOBAL_RNG, T, dims)
rand(T::Type, d1::Int, dims::Int...) = rand(T, tuple(d1, dims...))
rand!(A::AbstractArray) = rand!(GLOBAL_RNG, A)

rand(r::AbstractArray) = rand(GLOBAL_RNG, r)
rand!(r::Range, A::AbstractArray) = rand!(GLOBAL_RNG, r, A)

rand(r::AbstractArray, dims::Dims) = rand(GLOBAL_RNG, r, dims)
rand(r::AbstractArray, dims::Int...) = rand(GLOBAL_RNG, r, dims)

## random floating point values

@inline rand(r::AbstractRNG) = rand(r, CloseOpen)
Expand Down Expand Up @@ -344,58 +350,58 @@ end

# this function uses 32 bit entropy for small ranges of length <= typemax(UInt32) + 1
# RandIntGen is responsible for providing the right value of k
function rand{T<:Union(UInt64, Int64)}(g::RandIntGen{T,UInt64})
function rand{T<:Union(UInt64, Int64)}(mt::MersenneTwister, g::RandIntGen{T,UInt64})
local x::UInt64
if (g.k - 1) >> 32 == 0
x = rand(UInt32)
x = rand(mt, UInt32)
while x > g.u
x = rand(UInt32)
x = rand(mt, UInt32)
end
else
x = rand(UInt64)
x = rand(mt, UInt64)
while x > g.u
x = rand(UInt64)
x = rand(mt, UInt64)
end
end
return reinterpret(T, reinterpret(UInt64, g.a) + rem_knuth(x, g.k))
end

function rand{T<:Integer, U<:Unsigned}(g::RandIntGen{T,U})
x = rand(U)
function rand{T<:Integer, U<:Unsigned}(mt::MersenneTwister, g::RandIntGen{T,U})
x = rand(mt, U)
while x > g.u
x = rand(U)
x = rand(mt, U)
end
(unsigned(g.a) + rem_knuth(x, g.k)) % T
end

rand{T<:Union(Signed,Unsigned,Bool,Char)}(r::UnitRange{T}) = rand(RandIntGen(r))
rand{T<:Union(Signed,Unsigned,Bool,Char)}(mt::MersenneTwister, r::UnitRange{T}) = rand(mt, RandIntGen(r))

# Randomly draw a sample from an AbstractArray r
# (e.g. r is a range 0:2:8 or a vector [2, 3, 5, 7])
rand(r::AbstractArray) = @inbounds return r[rand(1:length(r))]
rand(mt::MersenneTwister, r::AbstractArray) = @inbounds return r[rand(mt, 1:length(r))]

function rand!(g::RandIntGen, A::AbstractArray)
function rand!(mt::MersenneTwister, g::RandIntGen, A::AbstractArray)
for i = 1 : length(A)
@inbounds A[i] = rand(g)
@inbounds A[i] = rand(mt, g)
end
return A
end

rand!{T<:Union(Signed,Unsigned,Bool,Char)}(r::UnitRange{T}, A::AbstractArray) = rand!(RandIntGen(r), A)
rand!{T<:Union(Signed,Unsigned,Bool,Char)}(mt::MersenneTwister, r::UnitRange{T}, A::AbstractArray) = rand!(mt, RandIntGen(r), A)

rand!(r::Range, A::AbstractArray) = _rand!(r, A)
rand!(mt::MersenneTwister, r::Range, A::AbstractArray) = _rand!(mt, r, A)

# TODO: this more general version is "disabled" until #8246 is resolved
function _rand!(r::AbstractArray, A::AbstractArray)
function _rand!(mt::MersenneTwister, r::AbstractArray, A::AbstractArray)
g = RandIntGen(1:(length(r)))
for i = 1 : length(A)
@inbounds A[i] = r[rand(g)]
@inbounds A[i] = r[rand(mt, g)]
end
return A
end

rand{T}(r::AbstractArray{T}, dims::Dims) = _rand!(r, Array(T, dims))
rand(r::AbstractArray, dims::Int...) = rand(r, dims)
rand{T}(mt::MersenneTwister, r::AbstractArray{T}, dims::Dims) = _rand!(mt, r, Array(T, dims))
rand(mt::MersenneTwister, r::AbstractArray, dims::Int...) = rand(mt, r, dims)

## random BitArrays (AbstractRNG)

Expand Down
16 changes: 9 additions & 7 deletions doc/stdlib/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4071,19 +4071,21 @@ A ``MersenneTwister`` RNG can generate random numbers of the following types: ``

Create a ``MersenneTwister`` RNG object. Different RNG objects can have their own seeds, which may be useful for generating different streams of random numbers.

.. function:: rand([rng], [t::Type], [dims...])
.. function:: rand([rng], [S], [dims...])

Generate a random value or an array of random values of the given type, ``t``, which defaults to ``Float64``.
Pick a random element or array of random elements from the set of values specified by ``S``; ``S`` can be

.. function:: rand!([rng], A)
* an indexable collection (for example ``1:n`` or ``['x','y','z']``), or

Populate the array A with random values.
* a type: the set of values to pick from is then equivalent to ``typemin(S):typemax(S)`` for integers, and to [0,1) for floating point numbers;

.. function:: rand(coll, [dims...])
``S`` defaults to ``Float64``.

Pick a random element or array of random elements from the indexable collection ``coll`` (for example, ``1:n`` or ``['x','y','z']``).
.. function:: rand!([rng], A)

Populate the array A with random values.

.. function:: rand!(r, A)
.. function:: rand!([rng], r, A)

Populate the array A with random values drawn uniformly from the range ``r``.

Expand Down
18 changes: 14 additions & 4 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,20 @@ rand!(MersenneTwister(0), A)
8690327730555225005 8435109092665372532]

# rand from AbstractArray
@test rand(0:3:1000) in 0:3:1000
coll = Any[2, UInt128(128), big(619), "string", 'c']
@test rand(coll) in coll
@test issubset(rand(coll, 2, 3), coll)
let mt = MersenneTwister()
srand(mt)
@test rand(mt, 0:3:1000) in 0:3:1000
@test issubset(rand!(mt, 0:3:1000, Array(Int, 100)), 0:3:1000)
coll = Any[2, UInt128(128), big(619), "string", 'c']
@test rand(mt, coll) in coll
@test issubset(rand(mt, coll, 2, 3), coll)

# check API with default RNG:
rand(0:3:1000)
rand!(0:3:1000, Array(Int, 100))
rand(coll)
rand(coll, 2, 3)
end

# randn
@test randn(MersenneTwister(42)) == -0.5560268761463861
Expand Down