Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearAlgebra: Type-stability in broadcasting numbers over Bidiagonal #54067

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
LinearAlgebra: Type-stability in broadcasting numbers over Bidiagonal
  • Loading branch information
jishnub committed Apr 12, 2024
commit 2588d66c3354660426c313e268d0994b731367f9
23 changes: 13 additions & 10 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,22 @@ structured_broadcast_alloc(bc, ::Type{Diagonal}, ::Type{ElType}, n) where {ElTyp
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
# system will return Tridiagonal when there's more than one Bidiagonal, but when
# there's only one, we need to make figure out upper or lower
merge_uplos(::Nothing, ::Nothing) = nothing
merge_uplos(a, ::Nothing) = a
merge_uplos(::Nothing, b) = b
merge_uplos(a, b) = a == b ? a : 'T'
# the Val flag checks if only one Bidiagonal is encountered in the broadcast expression,
# in which case we may preserve the type of the array
merge_uplos(::Tuple{Nothing, Val{A}}, ::Tuple{Nothing, Val{B}}) where {A,B} = (nothing, Val(A & B))
merge_uplos(a::Tuple{Any,Val{A}}, ::Tuple{Nothing, Val{B}}) where {A,B} = (first(a), Val(A & B))
merge_uplos(::Tuple{Nothing, Val{A}}, b::Tuple{Any,Val{B}}) where {A,B} = (first(b), Val(A & B))
merge_uplos(a::Tuple{Any,Val}, b::Tuple{Any,Val}) = (first(a) == first(b) ? first(a) : 'T', Val(false))

find_uplo(a::Bidiagonal) = a.uplo
find_uplo(a) = nothing
find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=nothing)
find_uplo(a::Bidiagonal) = (a.uplo, Val(true))
find_uplo(a) = (nothing, Val(true))
find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nested(bc), init=(nothing, Val(true)))

function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
uplo = n > 0 ? find_uplo(bc) : 'U'
uplo, val = find_uplo(bc)
uplo = n > 0 ? uplo : 'U'
n1 = max(n - 1, 0)
if uplo == 'T'
if val isa Val{false} && uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
Expand Down Expand Up @@ -170,7 +173,7 @@ isvalidstructbc(dest, bc::Broadcasted{T}) where {T<:StructuredMatrixStyle} =
(isstructurepreserving(bc) || fzeropreserving(bc))

isvalidstructbc(dest::Bidiagonal, bc::Broadcasted{StructuredMatrixStyle{Bidiagonal}}) =
(size(dest, 1) < 2 || find_uplo(bc) == dest.uplo) &&
(size(dest, 1) < 2 || first(find_uplo(bc)) == dest.uplo) &&
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
(isstructurepreserving(bc) || fzeropreserving(bc))

function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ using Test, LinearAlgebra
@test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end

@testset "type-stability in Bidiagonal" begin
B2 = @inferred (B -> .- B)(B)
@test B2 isa Bidiagonal
@test B2 == -1 * B
B2 = @inferred (B -> B .* 2)(B)
@test B2 isa Bidiagonal
@test B2 == B + B
B2 = @inferred (B -> 2 .* B)(B)
@test B2 isa Bidiagonal
@test B2 == B + B
B2 = @inferred (B -> B ./ 1)(B)
@test B2 isa Bidiagonal
@test B2 == B
B2 = @inferred (B -> 1 .\ B)(B)
@test B2 isa Bidiagonal
@test B2 == B
end
end

@testset "broadcast! where the destination is a structured matrix" begin
Expand Down