Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix method ambiguity for qr (#931) and lingering ambiguities for lu #932

Merged
merged 3 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix inferrence issues due to using @invoke for lu keyword arguments
  • Loading branch information
thchr committed Jul 2, 2021
commit 538dd13c4875c94b5b2b538a39a79c4969d872a5
31 changes: 12 additions & 19 deletions src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,21 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
end

# LU decomposition
function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true); check = true)
L, U, p = _lu(A, pivot, check)
LU(L, U, p)
end

# For the square version, return explicit lower and upper triangular matrices.
# We would do this for the rectangular case too, but Base doesn't support that.
function lu(A::StaticMatrix{N,N}, pivot::Union{Val{false},Val{true}}=Val(true);
check = true) where {N}
L, U, p = _lu(A, pivot, check)
LU(LowerTriangular(L), UpperTriangular(U), p)
end
for pv in (:true, :false)
# ... define each `pivot::Val{true/false}` method individually to avoid ambiguties
@eval function lu(A::StaticMatrix, pivot::Val{$pv}; check = true)
L, U, p = _lu(A, pivot, check)
LU(L, U, p)
end

@static if VERSION >= v"1.7-DEV"
# disambiguation
for p in (:true, :false)
@eval function lu(A::StaticMatrix{N,N}, pivot::Val{$p}; check = true) where {N}
Base.@invoke lu(A::StaticMatrix{N,N} where N,
pivot::Union{Val{false},Val{true}}; check)
end
# For the square version, return explicit lower and upper triangular matrices.
# We would do this for the rectangular case too, but Base doesn't support that.
@eval function lu(A::StaticMatrix{N,N}, pivot::Val{$pv}; check = true) where {N}
L, U, p = _lu(A, pivot, check)
LU(LowerTriangular(L), UpperTriangular(U), p)
end
end
lu(A::StaticMatrix; check = true) = lu(A, Val(true); check=check)

# location of the first zero on the diagonal, 0 when not found
function _first_zero_on_diagonal(A::StaticMatrix{M,N,T}) where {M,N,T}
Expand Down
42 changes: 19 additions & 23 deletions src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,26 @@ Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:p))
Base.iterate(S::QR, ::Val{:p}) = (S.p, Val(:done))
Base.iterate(S::QR, ::Val{:done}) = nothing

for pv in (:true, :false)
@eval begin
@inline function qr(A::StaticMatrix, pivot::Val{$pv})
QRp = _qr(Size(A), A, pivot)
if length(QRp) === 2
# create an identity permutation since that is cheap,
# and much safer since, in the case of isbits types, we can't
# safely leave the field undefined.
p = identity_perm(QRp[2])
return QR(QRp[1], QRp[2], p)
else # length(QRp) === 3
return QR(QRp[1], QRp[2], QRp[3])
end
end
end
end
"""
qr(A::StaticMatrix, pivot=Val(false))
qr(A::StaticMatrix, pivot::Union{Val{true}, Val{false}} = Val(false))

Compute the QR factorization of `A`. The factors can be obtain by iteration:
Compute the QR factorization of `A`. The factors can be obtained by iteration:

```julia
julia> A = @SMatrix rand(3,4);
Expand All @@ -34,27 +50,7 @@ julia> F.Q * F.R ≈ A
true
```
"""
@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false))
QRp = _qr(Size(A), A, pivot)
if length(QRp) === 2
# create an identity permutation since that is cheap,
# and much safer since, in the case of isbits types, we can't
# safely leave the field undefined.
p = identity_perm(QRp[2])
return QR(QRp[1], QRp[2], p)
else # length(QRp) === 3
return QR(QRp[1], QRp[2], QRp[3])
end
end

@static if VERSION >= v"1.7-DEV"
# disambiguation
for p in (:true, :false)
@eval function qr(A::StaticMatrix, pivot::Val{$p})
Base.@invoke qr(A::StaticMatrix, pivot::Union{Val{false},Val{true}})
end
end
end
qr(A::StaticMatrix) = qr(A, Val(false))

function identity_perm(R::StaticMatrix{N,M,T}) where {N,M,T}
return similar_type(R, Int, Size((M,)))(ntuple(x -> x, Val{M}()))
Expand Down
13 changes: 7 additions & 6 deletions test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ end

@testset "LU method ambiguity" begin
# Issue #920; just test that methods do not throw an ambiguity error when called
A = @SMatrix [1.0 2.0; 3.0 4.0]
@test isa(lu(A), StaticArrays.LU)
@test isa(lu(A, Val(true)), StaticArrays.LU)
@test isa(lu(A, Val(false)), StaticArrays.LU)
@test isa(lu(A; check=false), StaticArrays.LU)
@test isa(lu(A; check=true), StaticArrays.LU)
for A in ((@SMatrix [1.0 2.0; 3.0 4.0]), (@SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0]))
@test isa(lu(A), StaticArrays.LU)
@test isa(lu(A, Val(true)), StaticArrays.LU)
@test isa(lu(A, Val(false)), StaticArrays.LU)
@test isa(lu(A; check=false), StaticArrays.LU)
@test isa(lu(A; check=true), StaticArrays.LU)
end
end