Skip to content

Commit

Permalink
Merge pull request #9132 from JuliaLang/rf/randn-splitbranch
Browse files Browse the repository at this point in the history
faster randn by separating out unlikely branch in a function
  • Loading branch information
Viral B. Shah committed Nov 25, 2014
2 parents f06b4a4 + b99ea92 commit a02fdc7
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type Close1Open2 <: FloatInterval end

@inline rand_ui52_raw_inbounds(r::MersenneTwister) = reinterpret(UInt64, rand_inbounds(r, Close1Open2))
@inline rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r))
@inline rand_ui52(r::MersenneTwister) = rand_ui52_raw(r) & 0x000fffffffffffff
@inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r)

function srand(r::MersenneTwister, seed::Vector{UInt32})
Expand Down Expand Up @@ -943,36 +944,37 @@ const ziggurat_nor_r = 3.6541528853610087963519472518
const ziggurat_nor_inv_r = inv(ziggurat_nor_r)
const ziggurat_exp_r = 7.6971174701310497140446280481

@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand(rng, Close1Open2)) & 0x000fffffffffffff

function randmtzig_randn(rng::MersenneTwister=GLOBAL_RNG)
@inbounds begin
r = rand_ui52(rng)
rabs = int64(r>>1) # One bit for the sign
idx = rabs & 0xFF
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
rabs < ki[idx+1] && return x # 99.3% of the time we return here 1st try
return randmtzig_randn_unlikely(rng, idx, rabs, x)
end
end

# this unlikely branch is put in a separate function for better efficiency
function randmtzig_randn_unlikely(rng, idx, rabs, x)
@inbounds if idx == 0
while true
r = randi(rng)
rabs = int64(r>>1) # One bit for the sign
idx = rabs & 0xFF
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
if rabs < ki[idx+1]
return x # 99.3% of the time we return here 1st try
elseif idx == 0
while true
xx = -ziggurat_nor_inv_r*log(rand(rng))
yy = -log(rand(rng))
if yy+yy > xx*xx
return (rabs & 0x100) != 0x000000000 ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
end
end
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
return x # return from the triangular area
end
xx = -ziggurat_nor_inv_r*log(rand(rng))
yy = -log(rand(rng))
yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
end
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
return x # return from the triangular area
else
return randmtzig_randn(rng)
end
end

function randmtzig_exprnd(rng::MersenneTwister=GLOBAL_RNG)
@inbounds begin
while true
ri = randi(rng)
ri = rand_ui52(rng)
idx = ri & 0xFF
x = ri*we[idx+1]
if ri < ke[idx+1]
Expand Down

0 comments on commit a02fdc7

Please sign in to comment.