From 5f40a29916904f0724f3c009329b8f318bb8f7be Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Mon, 23 Sep 2019 03:05:19 -0400 Subject: [PATCH] bug fixes in matrix log (#32327) * bug fixes in matrix log * patches to matrix log (#33245) * patches to matrix log Avoid integer overflow if `s > 63`. Correct logic for `s == 0`. Only use fancy divided difference formulae if eigenvalues are close - avoids dangerous roundoff error if they are in opposite sectors. * add tests (cherry picked from commit 318affa294e8d493b1b3edcc3e06df26c97eb4bb) --- stdlib/LinearAlgebra/src/triangular.jl | 50 +++++++++++--------------- stdlib/LinearAlgebra/test/dense.jl | 15 ++++++++ 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index fdf135a4e0fa8..e3df91cacaa0c 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -2138,32 +2138,14 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat end # Compute accurate superdiagonal of T - p = 1 / 2^s - for k = 1:n-1 - Ak = A0[k,k] - Akp1 = A0[k+1,k+1] - Akp = Ak^p - Akp1p = Akp1^p - A[k,k] = Akp - A[k+1,k+1] = Akp1p - if Ak == Akp1 - A[k,k+1] = p * A0[k,k+1] * Ak^(p-1) - elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) - A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak) - else - logAk = log(Ak) - logAkp1 = log(Akp1) - w = atanh((Akp1 - Ak)/(Akp1 + Ak)) + im*pi*ceil((imag(logAkp1-logAk)-pi)/(2*pi)) - dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak) - A[k,k+1] = A0[k,k+1] * dd - end - end + blockpower!(A, A0, 0.5^s) # Compute accurate diagonal of T for i = 1:n a = A0[i,i] if s == 0 - r = a - 1 + A[i,i] = a - 1 + continue end s0 = s if angle(a) >= pi / 2 @@ -2200,7 +2182,7 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat end # Scale back - lmul!(2^s, Y) + lmul!(2.0^s, Y) # Compute accurate diagonal and superdiagonal of log(T) for k = 1:n-1 @@ -2212,11 +2194,16 @@ function log(A0::UpperTriangular{T}) where T<:BlasFloat Y[k+1,k+1] = logAkp1 if Ak == Akp1 Y[k,k+1] = A0[k,k+1] / Ak - elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) + elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) || iszero(Akp1 + Ak) Y[k,k+1] = A0[k,k+1] * (logAkp1 - logAk) / (Akp1 - Ak) else - w = atanh((Akp1 - Ak)/(Akp1 + Ak) + im*pi*(ceil((imag(logAkp1-logAk) - pi)/(2*pi)))) - Y[k,k+1] = 2 * A0[k,k+1] * w / (Akp1 - Ak) + z = (Akp1 - Ak)/(Akp1 + Ak) + if abs(z) > 1 + Y[k,k+1] = A0[k,k+1] * (logAkp1 - logAk) / (Akp1 - Ak) + else + w = atanh(z) + im * pi * (unw(logAkp1-logAk) - unw(log1p(z)-log1p(-z))) + Y[k,k+1] = 2 * A0[k,k+1] * w / (Akp1 - Ak) + end end end @@ -2363,14 +2350,19 @@ function blockpower!(A::UpperTriangular, A0::UpperTriangular, p) if Ak == Akp1 A[k,k+1] = p * A0[k,k+1] * Ak^(p-1) - elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) + elseif 2 * abs(Ak) < abs(Akp1) || 2 * abs(Akp1) < abs(Ak) || iszero(Akp1 + Ak) A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak) else logAk = log(Ak) logAkp1 = log(Akp1) - w = atanh((Akp1 - Ak)/(Akp1 + Ak)) + im * pi * unw(logAkp1-logAk) - dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak); - A[k,k+1] = A0[k,k+1] * dd + z = (Akp1 - Ak)/(Akp1 + Ak) + if abs(z) > 1 + A[k,k+1] = A0[k,k+1] * (Akp1p - Akp) / (Akp1 - Ak) + else + w = atanh(z) + im * pi * (unw(logAkp1-logAk) - unw(log1p(z)-log1p(-z))) + dd = 2 * exp(p*(logAk+logAkp1)/2) * sinh(p*w) / (Akp1 - Ak); + A[k,k+1] = A0[k,k+1] * dd + end end end end diff --git a/stdlib/LinearAlgebra/test/dense.jl b/stdlib/LinearAlgebra/test/dense.jl index 4d3f1a9b470b4..ee4f7cd84f09c 100644 --- a/stdlib/LinearAlgebra/test/dense.jl +++ b/stdlib/LinearAlgebra/test/dense.jl @@ -877,4 +877,19 @@ end @test_broken inv(transpose(B))*transpose(B) ≈ I end +@testset "Matrix log issue #32313" begin + for A in ([30 20; -50 -30], [10.0im 0; 0 -10.0im], randn(6,6)) + @test exp(log(A)) ≈ A + end +end + +@testset "Matrix log PR #33245" begin + # edge case for divided difference + A1 = triu(ones(3,3),1) + diagm([1.0, -2eps()-1im, -eps()+0.75im]) + @test exp(log(A1)) ≈ A1 + # case where no sqrt is needed (s=0) + A2 = [1.01 0.01 0.01; 0 1.01 0.01; 0 0 1.01] + @test exp(log(A2)) ≈ A2 +end + end # module TestDense