Skip to content

Commit

Permalink
Fix a few minor problems for Triangular arithmetic. Fixes #16458 (#16562
Browse files Browse the repository at this point in the history
)
  • Loading branch information
andreasnoack authored and tkelman committed May 25, 2016
1 parent 5b9d8d6 commit 3a346a6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 34 deletions.
95 changes: 61 additions & 34 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1293,22 +1293,35 @@ end

for f in (:A_mul_B!, :A_ldiv_B!)
@eval begin
# Upper
$f(A::UpperTriangular, B::UpperTriangular) =
UpperTriangular($f(A, triu!(B.data)))
$f(A::UnitUpperTriangular, B::UpperTriangular) =
UpperTriangular($f(A, triu!(B.data)))
$f(A::UpperTriangular, B::UnitUpperTriangular) =
UpperTriangular($f(A, triu!(B.data)))
$f(A::UnitUpperTriangular, B::UnitUpperTriangular) =
UnitUpperTriangular($f(A, triu!(B.data)))
UpperTriangular($f(triu!(A.data), B))
function $f(A::UnitUpperTriangular, B::UnitUpperTriangular)
BB = triu!(B.data)
for i = 1:size(BB, 1)
BB[i,i] = 1
end
return UnitUpperTriangular($f(A, BB))
end

# Lower
$f(A::LowerTriangular, B::LowerTriangular) =
LowerTriangular($f(A, tril!(B.data)))
$f(A::UnitLowerTriangular, B::LowerTriangular) =
LowerTriangular($f(A, tril!(B.data)))
$f(A::LowerTriangular, B::UnitLowerTriangular) =
LowerTriangular($f(A, tril!(B.data)))
$f(A::UnitLowerTriangular, B::UnitLowerTriangular) =
LowerTriangular($f(A, tril!(B.data)))
LowerTriangular($f(tril!(A), B))
function $f(A::UnitLowerTriangular, B::UnitLowerTriangular)
BB = tril!(B.data)
for i = 1:size(BB, 1)
BB[i,i] = 1
end
UnitLowerTriangular($f(A, BB))
end
end
end

Expand Down Expand Up @@ -1345,25 +1358,29 @@ for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriang
end
end

for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B))
for (f1, f2) in ((:*, :A_mul_B!), (:\, :A_ldiv_B!))
@eval begin
function $f1(A::LowerTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::LowerTriangular, B::LowerTriangular)
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::UnitLowerTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function $(f1)(A::UnitLowerTriangular, B::LowerTriangular)
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::UpperTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::UpperTriangular, B::UpperTriangular)
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::UnitUpperTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::UnitUpperTriangular, B::UpperTriangular)
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end
end
Expand All @@ -1372,67 +1389,77 @@ end
for (f1, f2) in ((:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!),
(:Ac_ldiv_B, Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
@eval begin
function $f1(A::UpperTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::UpperTriangular, B::LowerTriangular)
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::UnitUpperTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::UnitUpperTriangular, B::LowerTriangular)
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::LowerTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::LowerTriangular, B::UpperTriangular)
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end

function $f1(A::UnitLowerTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
function ($f1)(A::UnitLowerTriangular, B::UpperTriangular)
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(convert(AbstractMatrix{TAB}, A), copy_oftype(B, TAB)))
end
end
end

function (/)(A::LowerTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/
one(eltype(A)))
TAB = typeof((/)(zero(eltype(A)), zero(eltype(B))) +
(/)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end
function (/)(A::LowerTriangular, B::UnitLowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end
function (/)(A::UpperTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B))/
one(eltype(A)))
TAB = typeof((/)(zero(eltype(A)), zero(eltype(B))) +
(/)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end
function (/)(A::UpperTriangular, B::UnitUpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular(A_rdiv_B!(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end

for (f1, f2) in ((:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!),
(:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
@eval begin
function $f1(A::LowerTriangular, B::UpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end

function $f1(A::LowerTriangular, B::UnitUpperTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return LowerTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end

function $f1(A::UpperTriangular, B::LowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof(($f1)(zero(eltype(A)), zero(eltype(B))) +
($f1)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end

function $f1(A::UpperTriangular, B::UnitLowerTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
TAB = typeof((*)(zero(eltype(A)), zero(eltype(B))) +
(*)(zero(eltype(A)), zero(eltype(B))))
return UpperTriangular($f2(copy_oftype(A, TAB), convert(AbstractMatrix{TAB}, B)))
end
end
Expand Down Expand Up @@ -1510,7 +1537,7 @@ end
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
@eval begin
function ($f)(A::$mat, B::Tuple{UnitUpperTriangular, UnitLowerTriangular})
function ($f)(A::$mat, B::Union{UnitUpperTriangular, UnitLowerTriangular})
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
AA = similar(A, TAB, size(A))
copy!(AA, A)
Expand Down
1 change: 1 addition & 0 deletions test/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ for elty1 in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloa
@test_approx_eq full(A1.'A2.') full(A1).'full(A2).'
@test_approx_eq full(A1'A2') full(A1)'full(A2)'
@test_approx_eq full(A1/A2) full(A1)/full(A2)
@test_approx_eq full(A1\A2) full(A1)\full(A2)
@test_throws DimensionMismatch eye(n+1)/A2
@test_throws DimensionMismatch eye(n+1)/A2.'
@test_throws DimensionMismatch eye(n+1)/A2'
Expand Down

0 comments on commit 3a346a6

Please sign in to comment.