Skip to content

Commit

Permalink
define show(::MersenneTwister)
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Oct 4, 2020
1 parent 55a6dab commit 8ef335d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 25 deletions.
7 changes: 5 additions & 2 deletions stdlib/Future/src/Future.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ One such step corresponds to the generation of two `Float64` numbers.
For each different value of `steps`, a large polynomial has to be generated internally.
One is already pre-computed for `steps=big(10)^20`.
"""
randjump(r::MersenneTwister, steps::Integer) =
Random._randjump(r, Random.DSFMT.calc_jump(steps))
function randjump(r::MersenneTwister, steps::Integer)
j = Random._randjump(r, Random.DSFMT.calc_jump(steps))
j.adv_jump += 2*big(steps) # convert to BigInt to prevent overflow
j
end

end # module Future
102 changes: 84 additions & 18 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,32 @@ mutable struct MersenneTwister <: AbstractRNG
idxF::Int
idxI::Int

function MersenneTwister(seed, state, vals, ints, idxF, idxI)
# counters for show
adv::Int64 # state of advance at the DSFMT_state level
adv_jump::BigInt # number of skipped Float64 values via randjump
adv_vals::Int64 # state of advance when vals is filled-up
adv_ints::Int64 # state of advance when ints is filled-up
adv_vals_pre::Int64 # state of advance when vals is filled-up before ints
adv_idxF_pre::Int # value of idxF before ints is filled-up

function MersenneTwister(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre)
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
throw(DomainError((length(vals), idxF),
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
throw(DomainError((length(ints), idxI),
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
new(seed, state, vals, ints, idxF, idxI)
new(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state,
Vector{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
MT_CACHE_F, 0)
MT_CACHE_F, 0, 0, 0, -1, -1, -1, -1)

"""
MersenneTwister(seed)
Expand Down Expand Up @@ -147,12 +157,19 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
copyto!(dst.ints, src.ints)
dst.idxF = src.idxF
dst.idxI = src.idxI
dst.adv = src.adv
dst.adv_jump = src.adv_jump
dst.adv_vals = src.adv_vals
dst.adv_ints = src.adv_ints
dst.adv_vals_pre = src.adv_vals_pre
dst.adv_idxF_pre = src.adv_idxF_pre
dst
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI)
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints,
src.adv_vals_pre, src.adv_idxF_pre)


==(r1::MersenneTwister, r2::MersenneTwister) =
Expand All @@ -164,17 +181,47 @@ copy(src::MersenneTwister) =
hash(r::MersenneTwister, h::UInt) =
foldr(hash, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI); init=h)

function fillcache_zeros!(r::MersenneTwister)
# the use of this function is not strictly necessary, but it makes
# comparing two MersenneTwister RNGs easier
function show(io::IO, rng::MersenneTwister)
# seed
seed = from_seed(rng.seed)
seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM
if rng.adv_jump == 0 && rng.adv == 0
return print(io, "MersenneTwister($seed_str)")
end
print(io, "MersenneTwister($seed_str, (")
# state
adv = Integer[rng.adv_jump, rng.adv]
if rng.adv_vals != -1
push!(adv, rng.adv_vals, rng.idxF)
end
if rng.adv_ints != -1 # then rng.adv_vals is always != -1
idxI = (length(rng.ints)*16 - rng.idxI) / 8 # 8 represents one Int64
idxI = Int(idxI) # idxI should always be an integer when using public APIs
push!(adv,
rng.adv_ints,
rng.adv_vals_pre == -1 ? 0 : rng.adv_vals_pre,
rng.adv_vals_pre == -1 ? 0 : rng.adv_idxF_pre,
idxI)
end
join(io, adv, ", ")
print(io, "))")
end

### low level API

function reset_caches!(r::MersenneTwister)
# zeroing the caches makes comparing two MersenneTwister RNGs easier
fill!(r.vals, 0.0)
fill!(r.ints, zero(UInt128))
mt_setempty!(r)
mt_setempty!(r, UInt128)
r.adv_vals = -1
r.adv_ints = -1
r.adv_vals_pre = -1
r.adv_idxF_pre = -1
r
end


### low level API

#### floats

mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF
Expand All @@ -184,7 +231,8 @@ mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1]

function gen_rand(r::MersenneTwister)
GC.@preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
r.adv_vals = r.adv
GC.@preserve r fill_array!(r, pointer(r.vals), length(r.vals), CloseOpen12())
mt_setfull!(r)
end

Expand Down Expand Up @@ -212,6 +260,9 @@ mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} =
r.idxI >> logsizeof(T)

function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger})
r.adv_ints = r.adv
r.adv_vals_pre = r.adv_vals
r.adv_idxF_pre = r.idxF
rand!(r, r.ints)
r.idxI = MT_CACHE_I
end
Expand Down Expand Up @@ -275,14 +326,18 @@ function make_seed(n::Integer)
end
end

# inverse of make_seed(::Integer)
from_seed(a::Vector{UInt32})::BigInt = sum(a[i] * big(2)^(32*(i-1)) for i in 1:length(a))


#### seed!()

function seed!(r::MersenneTwister, seed::Vector{UInt32})
copyto!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
mt_setempty!(r)
mt_setempty!(r, UInt128)
fillcache_zeros!(r)
reset_caches!(r)
r.adv = 0
r.adv_jump = 0
return r
end

Expand Down Expand Up @@ -464,6 +519,10 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
A
end

function fill_array!(rng::MersenneTwister, A::Ptr{Float64}, n::Int, I)
rng.adv += n
fill_array!(rng.state, A, n, I)
end

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen01_64) =
dsfmt_fill_array_close_open!(s, A, n)
Expand All @@ -488,10 +547,10 @@ function rand!(r::MersenneTwister, A::UnsafeView{Float64},
align = Csize_t(pA) % 16
if align > 0
pA2 = pA + 16 - align
fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted
fill_array!(r, pA2, n2, I[]) # generate the data in-place, but shifted
unsafe_copyto!(pA, pA2, n2) # move the data to the beginning of the array
else
fill_array!(r.state, pA, n2, I[])
fill_array!(r, pA, n2, I[])
end
for i=n2+1:n
A[i] = rand(r, I[])
Expand Down Expand Up @@ -653,5 +712,12 @@ end

# Old randjump methods are deprecated, the scalar version is in the Future module.

_randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X) =
fillcache_zeros!(MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly)))
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
adv = r.adv
adv_jump = r.adv_jump
s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))
reset_caches!(s)
s.adv = adv
s.adv_jump = adv_jump
s
end
2 changes: 1 addition & 1 deletion stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Base.GMP: Limb
using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing

import Base: copymutable, copy, copy!, ==, hash, convert,
rand, randn
rand, randn, show

export rand!, randn!,
randexp, randexp!,
Expand Down
8 changes: 4 additions & 4 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,16 +592,16 @@ end
@test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1)))

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0)
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1, -1, -1)

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)
Expand Down

0 comments on commit 8ef335d

Please sign in to comment.