diff --git a/base/dSFMT.jl b/base/dSFMT.jl index b78b1b1931c7b..b0eb6e7d41089 100644 --- a/base/dSFMT.jl +++ b/base/dSFMT.jl @@ -2,6 +2,8 @@ module dSFMT +import Base: copy, copy!, == + export DSFMT_state, dsfmt_get_min_array_size, dsfmt_get_idstring, dsfmt_init_gen_rand, dsfmt_init_by_array, dsfmt_gv_init_by_array, dsfmt_fill_array_close_open!, dsfmt_fill_array_close1_open2!, @@ -21,10 +23,16 @@ const JPOLY1e21 = "e172e20c5d2de26b567c0cace9e7c6cc4407bd5ffcd22ca59d37b73d54fd type DSFMT_state val::Vector{Int32} - DSFMT_state() = new(Array{Int32}(JN32)) - DSFMT_state(val::Vector{Int32}) = new(val) + + DSFMT_state(val::Vector{Int32} = zeros(Int32, JN32)) = + new(length(val) == JN32 ? val : throw(DomainError())) end +copy!(dst::DSFMT_state, src::DSFMT_state) = (copy!(dst.val, src.val); dst) +copy(src::DSFMT_state) = DSFMT_state(copy(src.val)) + +==(s1::DSFMT_state, s2::DSFMT_state) = s1.val == s2.val + function dsfmt_get_idstring() idstring = ccall((:dsfmt_get_idstring,:libdSFMT), Ptr{UInt8}, diff --git a/base/random.jl b/base/random.jl index c14fd9a96b389..b8e8c90d0a220 100644 --- a/base/random.jl +++ b/base/random.jl @@ -4,7 +4,7 @@ module Random using Base.dSFMT using Base.GMP: GMP_VERSION, Limb -import Base.copymutable +import Base: copymutable, copy, copy!, == export srand, rand, rand!, @@ -64,16 +64,37 @@ rand(rng::RandomDevice, ::Type{CloseOpen}) = rand(rng, Close1Open2) - 1.0 const MTCacheLength = dsfmt_get_min_array_size() type MersenneTwister <: AbstractRNG + seed::Vector{UInt32} state::DSFMT_state vals::Vector{Float64} idx::Int - seed::Vector{UInt32} - MersenneTwister(state::DSFMT_state, seed) = new(state, Array{Float64}(MTCacheLength), MTCacheLength, seed) - MersenneTwister(seed) = srand(new(DSFMT_state(), Array{Float64}(MTCacheLength)), seed) - MersenneTwister() = MersenneTwister(0) + function MersenneTwister(seed, state, vals, idx) + length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength || throw(DomainError()) + new(seed, state, vals, idx) + end +end + +MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) = + MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength) + +MersenneTwister(seed=0) = srand(MersenneTwister(Vector{UInt32}(), DSFMT_state()), seed) + +function copy!(dst::MersenneTwister, src::MersenneTwister) + copy!(resize!(dst.seed, length(src.seed)), src.seed) + copy!(dst.state, src.state) + copy!(dst.vals, src.vals) + dst.idx = src.idx + dst end +copy(src::MersenneTwister) = + MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx) + +==(r1::MersenneTwister, r2::MersenneTwister) = + r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) && r1.idx == r2.idx + + ## Low level API for MersenneTwister @inline mt_avail(r::MersenneTwister) = MTCacheLength - r.idx @@ -105,7 +126,7 @@ end @inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r) function srand(r::MersenneTwister, seed::Vector{UInt32}) - r.seed = seed + copy!(resize!(r.seed, length(seed)), seed) dsfmt_init_by_array(r.state, r.seed) mt_setempty!(r) return r @@ -117,7 +138,7 @@ function randjump(mt::MersenneTwister, jumps::Integer, jumppoly::AbstractString) push!(mts, mt) for i in 1:jumps-1 cmt = mts[end] - push!(mts, MersenneTwister(dSFMT.dsfmt_jump(cmt.state, jumppoly), cmt.seed)) + push!(mts, MersenneTwister(cmt.seed, dSFMT.dsfmt_jump(cmt.state, jumppoly))) end return mts end diff --git a/test/random.jl b/test/random.jl index 8ec79ef76d2f2..91358626f5601 100644 --- a/test/random.jl +++ b/test/random.jl @@ -414,3 +414,33 @@ end # test that the following is not an error (#16925) srand(typemax(UInt)) srand(typemax(UInt128)) + +# copy and == +let seed = rand(UInt32, 10) + r = MersenneTwister(seed) + @test r == MersenneTwister(seed) # r.vals should be all zeros + s = copy(r) + @test s == r && s !== r + skip, len = rand(0:2000, 2) + for j=1:skip + rand(r) + rand(s) + end + @test rand(r, len) == rand(s, len) + @test s == r +end + +# MersenneTwister initialization with invalid values +@test_throws DomainError Base.dSFMT.DSFMT_state(zeros(Int32, rand(0:Base.dSFMT.JN32-1))) +@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), + zeros(Float64, 10), 0) +@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), + zeros(Float64, Base.Random.MTCacheLength), -1) + +# seed is private to MersenneTwister +let seed = rand(UInt32, 10) + r = MersenneTwister(seed) + @test r.seed == seed && r.seed !== seed + resize!(seed, 4) + @test r.seed != seed +end