Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LU factorization in-place operations #22774

Merged
merged 4 commits into from
Aug 24, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Get rid of allocations, apply the (inverse) permutation in-place
  • Loading branch information
haampie committed Aug 23, 2017
commit 944f02af900bd3b0d722d999564497de78e9b7ce
59 changes: 31 additions & 28 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,51 +241,54 @@ function show(io::IO, F::LU)
print(io, "\nsuccessful: $(issuccess(F))")
end

A_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('N', A.factors, A.ipiv, B) A.info
_apply_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, 1 : length(A.ipiv), B)
_apply_inverse_ipiv!(A::LU, B::StridedVecOrMat) = _ipiv!(A, length(A.ipiv) : -1 : 1, B)

function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector)
b_permuted = b[ipiv2perm(A.ipiv, length(b))]
A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), b_permuted))
copy!(b, b_permuted)
function _ipiv!(A::LU, order::OrdinalRange, B::StridedVecOrMat)
for i = order
if i != A.ipiv[i]
_swap_rows!(B, i, A.ipiv[i])
end
end
B
end

function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix)
B_permuted = B[ipiv2perm(A.ipiv, size(B, 1)), :]
A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), B_permuted))
copy!(B, B_permuted)
function _swap_rows!(B::StridedVector, i::Integer, j::Integer)
B[i], B[j] = B[j], B[i]
B
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary in this pull request, but DRYing out these and the similar method definitions below might be nice. For example, as a first step perhaps something along the lines of

function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
    B_permuted = _ipivpermute(B, A.ipiv)
    A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), B_permuted))
    copy!(B, B_permuted)
end
_ipivpermute(b::StridedVector, ipiv) = b[ipiv2perm(ipiv, length(b))]
_ipivpermute(B::StridedMatrix, ipiv) = B[ipiv2perm(ipiv, size(B, 1)), :]

would do the trick? :)


At_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, B) A.info
function _swap_rows!(B::StridedMatrix, i::Integer, j::Integer)
for col = 1 : size(B, 2)
B[i,col], B[j,col] = B[j,col], B[i,col]
end
B
end

function At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector)
At_ldiv_B!(UnitLowerTriangular(A.factors), At_ldiv_B!(UpperTriangular(A.factors), b))
b_permuted = b[invperm(ipiv2perm(A.ipiv, length(b)))]
copy!(b, b_permuted)
A_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('N', A.factors, A.ipiv, B) A.info

function A_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
_apply_ipiv!(A, B)
A_ldiv_B!(UpperTriangular(A.factors), A_ldiv_B!(UnitLowerTriangular(A.factors), B))
end

function At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix)
At_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
@assertnonsingular LAPACK.getrs!('T', A.factors, A.ipiv, B) A.info

function At_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
At_ldiv_B!(UnitLowerTriangular(A.factors), At_ldiv_B!(UpperTriangular(A.factors), B))
B_permuted = B[invperm(ipiv2perm(A.ipiv, size(B,1))), :]
copy!(B, B_permuted)
_apply_inverse_ipiv!(A, B)
end

Ac_ldiv_B!(F::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:Real} =
At_ldiv_B!(F, B)
Ac_ldiv_B!(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasComplex} =
@assertnonsingular LAPACK.getrs!('C', A.factors, A.ipiv, B) A.info

function Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, b::StridedVector)
Ac_ldiv_B!(UnitLowerTriangular(A.factors), Ac_ldiv_B!(UpperTriangular(A.factors), b))
b_permuted = b[invperm(ipiv2perm(A.ipiv, length(b)))]
copy!(b, b_permuted)
end

function Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedMatrix)
function Ac_ldiv_B!(A::LU{<:Any,<:StridedMatrix}, B::StridedVecOrMat)
Ac_ldiv_B!(UnitLowerTriangular(A.factors), Ac_ldiv_B!(UpperTriangular(A.factors), B))
B_permuted = B[invperm(ipiv2perm(A.ipiv, size(B,1))), :]
copy!(B, B_permuted)
_apply_inverse_ipiv!(A, B)
end

At_ldiv_Bt(A::LU{T,<:StridedMatrix}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
Expand Down