Skip to content

Commit

Permalink
faster inv and div for Complex{Union{Float16, Float32}} (#44111)
Browse files Browse the repository at this point in the history
* faster inv and div for Complex{Union{Float16, Float32}}
* fix float64 division bug
  • Loading branch information
oscardssmith committed Feb 13, 2022
1 parent befe38f commit 1e86463
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
50 changes: 34 additions & 16 deletions base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,30 +347,37 @@ muladd(z::Complex, w::Complex, x::Real) =

function /(a::Complex{T}, b::Complex{T}) where T<:Real
are = real(a); aim = imag(a); bre = real(b); bim = imag(b)
if abs(bre) <= abs(bim)
if isinf(bre) && isinf(bim)
r = sign(bre)/sign(bim)
else
r = bre / bim
if (isinf(bre) | isinf(bim))
if isfinite(a)
return complex(zero(T)*sign(are)*sign(bre), -zero(T)*sign(aim)*sign(bim))
end
return T(NaN)+T(NaN)*im
end
if abs(bre) <= abs(bim)
r = bre / bim
den = bim + r*bre
Complex((are*r + aim)/den, (aim*r - are)/den)
else
if isinf(bre) && isinf(bim)
r = sign(bim)/sign(bre)
else
r = bim / bre
end
r = bim / bre
den = bre + r*bim
Complex((are + aim*r)/den, (aim - are*r)/den)
end
end

inv(z::Complex{<:Union{Float16,Float32}}) =
oftype(z, inv(widen(z)))

/(z::Complex{T}, w::Complex{T}) where {T<:Union{Float16,Float32}} =
oftype(z, widen(z)*inv(widen(w)))
function /(z::Complex{T}, w::Complex{T}) where {T<:Union{Float16,Float32}}
c, d = reim(widen(w))
a, b = reim(widen(z))
if (isinf(c) | isinf(d))
if isfinite(z)
return complex(zero(T)*sign(real(z))*sign(real(w)), -zero(T)*sign(imag(z))*sign(imag(w)))
end
return T(NaN)+T(NaN)*im
end
mag = inv(muladd(c, c, d^2))
re_part = muladd(a, c, b*d)
im_part = muladd(b, c, -a*d)
return oftype(z, Complex(re_part*mag, im_part*mag))
end

# robust complex division for double precision
# variables are scaled & unscaled to avoid over/underflow, if necessary
Expand All @@ -382,7 +389,12 @@ function /(z::ComplexF64, w::ComplexF64)
a, b = reim(z); c, d = reim(w)
absa = abs(a); absb = abs(b); ab = absa >= absb ? absa : absb # equiv. to max(abs(a),abs(b)) but without NaN-handling (faster)
absc = abs(c); absd = abs(d); cd = absc >= absd ? absc : absd

if (isinf(c) | isinf(d))
if isfinite(z)
return complex(0.0*sign(a)*sign(c), -0.0*sign(b)*sign(d))
end
return NaN+NaN*im
end
halfov = 0.5*floatmax(Float64) # overflow threshold
twounϵ = floatmin(Float64)*2.0/eps(Float64) # underflow threshold

Expand Down Expand Up @@ -449,6 +461,12 @@ function robust_cdiv2(a::Float64, b::Float64, c::Float64, d::Float64, r::Float64
end
end

function inv(z::Complex{T}) where T<:Union{Float16,Float32}
c, d = reim(widen(z))
(isinf(c) | isinf(d)) && return complex(copysign(zero(T), c), flipsign(-zero(T), d))
mag = inv(muladd(c, c, d^2))
return oftype(z, Complex(c*mag, -d*mag))
end
function inv(w::ComplexF64)
c, d = reim(w)
(isinf(c) | isinf(d)) && return complex(copysign(0.0, c), flipsign(-0.0, d))
Expand Down
10 changes: 6 additions & 4 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ end
@testset "corner cases of division, issue #22983" begin
# These results abide by ISO/IEC 10967-3:2006(E) and
# mathematical definition of division of complex numbers.
for T in (Float32, Float64, BigFloat)
for T in (Float16, Float32, Float64, BigFloat)
@test isequal(one(T) / zero(Complex{T}), one(Complex{T}) / zero(Complex{T}))
@test isequal(one(T) / zero(Complex{T}), Complex{T}(NaN, NaN))
@test isequal(one(Complex{T}) / zero(T), Complex{T}(Inf, NaN))
Expand All @@ -1050,7 +1050,7 @@ end
end

@testset "division by Inf, issue#23134" begin
@testset "$T" for T in (Float32, Float64, BigFloat)
@testset "$T" for T in (Float16, Float32, Float64, BigFloat)
@test isequal(one(T) / complex(T(Inf)), complex(zero(T), -zero(T)))
@test isequal(one(T) / complex(T(Inf), one(T)), complex(zero(T), -zero(T)))
@test isequal(one(T) / complex(T(Inf), T(NaN)), complex(zero(T), -zero(T)))
Expand Down Expand Up @@ -1088,8 +1088,10 @@ end
@test isequal(one(T) / complex(T(-NaN), T(-Inf)), complex(-zero(T), zero(T)))

# divide complex by complex Inf
@test isequal(complex(one(T)) / complex(T(Inf), T(-Inf)), complex(zero(T), zero(T))) broken=(T==Float64)
@test isequal(complex(one(T)) / complex(T(-Inf), T(Inf)), complex(-zero(T), -zero(T))) broken=(T in (Float32, Float64))
@test isequal(complex(one(T)) / complex(T(Inf), T(-Inf)), complex(zero(T), zero(T)))
@test isequal(complex(one(T)) / complex(T(-Inf), T(Inf)), complex(-zero(T), -zero(T)))
@test isequal(complex(T(Inf)) / complex(T(Inf), T(-Inf)), complex(T(NaN), T(NaN)))
@test isequal(complex(T(NaN)) / complex(T(-Inf), T(Inf)), complex(T(NaN), T(NaN)))
end
end

Expand Down

0 comments on commit 1e86463

Please sign in to comment.