Skip to content

Commit

Permalink
Avoid allocation in ldiv! with QR (#38389)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
dlfivefifty and dkarrasch committed Dec 9, 2020
1 parent ce795bc commit b1a2847
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@ end
@inline function reflector!(x::AbstractVector)
require_one_based_indexing(x)
n = length(x)
n == 0 && return zero(eltype(x))
@inbounds begin
ξ1 = x[1]
normu = abs2(ξ1)
Expand Down Expand Up @@ -1514,6 +1515,7 @@ end
if length(x) != m
throw(DimensionMismatch("reflector has length $(length(x)), which must match the first dimension of matrix A, $m"))
end
m == 0 && return A
@inbounds begin
for j = 1:n
# dot
Expand Down
37 changes: 30 additions & 7 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -768,10 +768,16 @@ mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, Q::AbstractQ{T}) where {T} =
mul!(C::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}, B::StridedVecOrMat{T}) where {T} = lmul!(adjQ, copyto!(C, B))
mul!(C::StridedVecOrMat{T}, A::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:AbstractQ{T}}) where {T} = rmul!(copyto!(C, A), adjQ)

ldiv!(A::QRCompactWY{T}, b::StridedVector{T}) where {T<:BlasFloat} =
(ldiv!(UpperTriangular(A.R), view(lmul!(adjoint(A.Q), b), 1:size(A, 2))); b)
ldiv!(A::QRCompactWY{T}, B::StridedMatrix{T}) where {T<:BlasFloat} =
(ldiv!(UpperTriangular(A.R), view(lmul!(adjoint(A.Q), B), 1:size(A, 2), 1:size(B, 2))); B)
function ldiv!(A::QRCompactWY{T}, b::StridedVector{T}) where {T<:BlasFloat}
m,n = size(A)
ldiv!(UpperTriangular(view(A.factors, 1:min(m,n), 1:n)), view(lmul!(adjoint(A.Q), b), 1:size(A, 2)))
return b
end
function ldiv!(A::QRCompactWY{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
m,n = size(A)
ldiv!(UpperTriangular(view(A.factors, 1:min(m,n), 1:n)), view(lmul!(adjoint(A.Q), B), 1:size(A, 2), 1:size(B, 2)))
return B
end

# Julia implementation similar to xgelsy
function ldiv!(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Real) where T<:BlasFloat
Expand Down Expand Up @@ -813,12 +819,12 @@ ldiv!(A::QRPivoted{T}, B::StridedVector{T}) where {T<:BlasFloat} =
vec(ldiv!(A,reshape(B,length(B),1)))
ldiv!(A::QRPivoted{T}, B::StridedVecOrMat{T}) where {T<:BlasFloat} =
ldiv!(A, B, min(size(A)...)*eps(real(float(one(eltype(B))))))[1]
function ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
function _wide_qr_ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
m, n = size(A)
minmn = min(m,n)
mB, nB = size(B)
lmul!(adjoint(A.Q), view(B, 1:m, :))
R = A.R
R = A.R # makes a copy, used as a buffer below
@inbounds begin
if n > m # minimum norm solution
τ = zeros(T,m)
Expand All @@ -839,7 +845,7 @@ function ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
end
end
end
LinearAlgebra.ldiv!(UpperTriangular(view(R, :, 1:minmn)), view(B, 1:minmn, :))
ldiv!(UpperTriangular(view(R, :, 1:minmn)), view(B, 1:minmn, :))
if n > m # Apply elementary transformation to solution
B[m + 1:mB,1:nB] .= zero(T)
for j = 1:nB
Expand All @@ -859,6 +865,23 @@ function ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
end
return B
end


function ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
m, n = size(A)
m < n && return _wide_qr_ldiv!(A, B)

mB, nB = size(B)
lmul!(adjoint(A.Q), view(B, 1:m, :))
R = A.factors
ldiv!(UpperTriangular(view(R,1:n,:)), view(B, 1:n, :))
return B
end
function ldiv!(A::QR, B::StridedVector)
ldiv!(A, reshape(B, length(B), 1))
return B
end

function ldiv!(A::QR, B::StridedVector)
ldiv!(A, reshape(B, length(B), 1))
B
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,17 @@ end
@test c0 == c
end

@testset "Issue reflector of zero-length vector" begin
a = [2.0]
x = view(a,1:0)
τ = LinearAlgebra.reflector!(view(x,1:0))
@test τ == 0.0

b = reshape([3.0],1,1)
@test isempty(LinearAlgebra.reflectorApply!(x, τ, view(b,1:0,:)))
@test b[1] == 3.0
end

@testset "det(Q::Union{QRCompactWYQ, QRPackedQ})" begin
# 40 is the number larger than the default block size 36 of QRCompactWY
@testset for n in [1:3; 40], m in [1:3; 40], pivot in [false, true]
Expand Down

0 comments on commit b1a2847

Please sign in to comment.