Skip to content

Commit

Permalink
Fix LU factorization in-place operations (JuliaLang#22774)
Browse files Browse the repository at this point in the history
* Add tests for in-place operations of A_ldiv_B!(y, LU, x) for LU factorizations

* Do Ax_ldiv_B! for LU factorizations in-place

* Remove incorrect condition in LU test. The transpose should be tested for complex numbers, and might be skipped for real numbers; not the other way around. For clarity just test both unconditionally.

* Get rid of allocations, apply the (inverse) permutation in-place
  • Loading branch information
haampie authored and andreasnoack committed Aug 24, 2017
1 parent 6378d7d commit 98c7a06
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
57 changes: 39 additions & 18 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,34 +241,55 @@ function show(io::IO, F::LU)
print(io, "\nsuccessful: $(issuccess(F))")
end

_apply_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, 1 : length(A.ipiv), B)
_apply_inverse_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, length(A.ipiv) : -1 : 1, B)

function _ipiv!(A::LU, order::OrdinalRange, B::StridedVecOrMat)
for i = order
if i != A.ipiv[i]
_swap_rows!(B, i, A.ipiv[i])
end
end
B
end

function _swap_rows!(B::StridedVector, i::Integer, j::Integer)
B[i], B[j] = B[j], B[i]
B
end

function _swap_rows!(B::StridedMatrix, i::Integer, j::Integer)
for col = 1 : size(B, 2)
B[i,col], B[j,col] = B[j,col], B[i,col]
end
B
end

A_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('N', A.factors, A.ipiv, B) A.info
A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) =
A_ldiv_B!(UpperTriangular(A.factors),
A_ldiv_B!(UnitLowerTriangular(A.factors), b[ipiv2perm(A.ipiv, length(b))]))
A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) =
A_ldiv_B!(UpperTriangular(A.factors),
A_ldiv_B!(UnitLowerTriangular(A.factors), B[ipiv2perm(A.ipiv, size(B, 1)),:]))

function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
_apply_ipiv!(A, B)
A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), B))
end

At_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, B) A.info
At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) =
At_ldiv_B!(UnitLowerTriangular(A.factors),
At_ldiv_B!(UpperTriangular(A.factors), b))[invperm(ipiv2perm(A.ipiv, length(b)))]
At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) =
At_ldiv_B!(UnitLowerTriangular(A.factors),
At_ldiv_B!(UpperTriangular(A.factors), B))[invperm(ipiv2perm(A.ipiv, size(B,1))),:]

function At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
At_ldiv_B!(UnitLowerTriangular(A.factors), At_ldiv_B!(UpperTriangular(A.factors), B))
_apply_inverse_ipiv!(A, B)
end

Ac_ldiv_B!(F::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:Real} =
At_ldiv_B!(F, B)
Ac_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasComplex} =
@assertnonsingular LAPACK.getrs!('C', A.factors, A.ipiv, B) A.info
Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector) =
Ac_ldiv_B!(UnitLowerTriangular(A.factors),
Ac_ldiv_B!(UpperTriangular(A.factors), b))[invperm(ipiv2perm(A.ipiv, length(b)))]
Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix) =
Ac_ldiv_B!(UnitLowerTriangular(A.factors),
Ac_ldiv_B!(UpperTriangular(A.factors), B))[invperm(ipiv2perm(A.ipiv, size(B,1))),:]

function Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
Ac_ldiv_B!(UnitLowerTriangular(A.factors), Ac_ldiv_B!(UpperTriangular(A.factors), B))
_apply_inverse_ipiv!(A, B)
end

At_ldiv_Bt(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, transpose(B)) A.info
Expand Down
27 changes: 23 additions & 4 deletions test/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,30 @@ dimg = randn(n)/2
@test norm(a*(lua\c) - c, 1) < ε*κ*n # c is a vector
@test norm(a'*(lua'\c) - c, 1) < ε*κ*n # c is a vector
@test AbstractArray(lua) a
if eltya <: Real && eltyb <: Real
@test norm(a.'*(lua.'\b) - b,1) < ε*κ*n*2 # Two because the right hand side has two columns
@test norm(a.'*(lua.'\c) - c,1) < ε*κ*n
end
@test norm(a.'*(lua.'\b) - b,1) < ε*κ*n*2 # Two because the right hand side has two columns
@test norm(a.'*(lua.'\c) - c,1) < ε*κ*n
end

# Test whether Ax_ldiv_B!(y, LU, x) indeed overwrites y
resultT = typeof(oneunit(eltyb) / oneunit(eltya))

b_dest = similar(b, resultT)
c_dest = similar(c, resultT)

A_ldiv_B!(b_dest, lua, b)
A_ldiv_B!(c_dest, lua, c)
@test norm(b_dest - lua \ b, 1) < ε*κ*2n
@test norm(c_dest - lua \ c, 1) < ε*κ*n

At_ldiv_B!(b_dest, lua, b)
At_ldiv_B!(c_dest, lua, c)
@test norm(b_dest - lua.' \ b, 1) < ε*κ*2n
@test norm(c_dest - lua.' \ c, 1) < ε*κ*n

Ac_ldiv_B!(b_dest, lua, b)
Ac_ldiv_B!(c_dest, lua, c)
@test norm(b_dest - lua' \ b, 1) < ε*κ*2n
@test norm(c_dest - lua' \ c, 1) < ε*κ*n
end
if eltya <: BlasFloat && eltyb <: BlasFloat
e = rand(eltyb,n,n)
Expand Down

0 comments on commit 98c7a06

Please sign in to comment.