Skip to content

Commit

Permalink
Specialize indexing triangular matrices with BandIndex (#55644)
Browse files Browse the repository at this point in the history
With this, certain indexing operations involving a `BandIndex` may be
evaluated as constants. This isn't used directly presently, but might
allow for more performant broadcasting in the future.
With this,
```julia
julia> n = 3; T = Tridiagonal(rand(n-1), rand(n), rand(n-1));

julia> @code_warntype ((T,j) -> UpperTriangular(T)[LinearAlgebra.BandIndex(2,j)])(T, 1)
MethodInstance for (::var"#17#18")(::Tridiagonal{Float64, Vector{Float64}}, ::Int64)
  from (::var"#17#18")(T, j) @ Main REPL[12]:1
Arguments
  #self#::Core.Const(var"#17#18"())
  T::Tridiagonal{Float64, Vector{Float64}}
  j::Int64
Body::Float64
1 ─ %1 = Main.UpperTriangular(T)::UpperTriangular{Float64, Tridiagonal{Float64, Vector{Float64}}}
│   %2 = LinearAlgebra.BandIndex::Core.Const(LinearAlgebra.BandIndex)
│   %3 = (%2)(2, j)::Core.PartialStruct(LinearAlgebra.BandIndex, Any[Core.Const(2), Int64])
│   %4 = Base.getindex(%1, %3)::Core.Const(0.0)
└──      return %4
```
The indexing operation may be evaluated at compile-time, as the band
index is constant-propagated.
  • Loading branch information
jishnub authored Sep 23, 2024
1 parent 9136bdd commit f62a380
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 4 deletions.
5 changes: 3 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,11 @@ end
end

@inline function getindex(A::Bidiagonal{T}, b::BandIndex) where T
@boundscheck checkbounds(A, _cartinds(b))
@boundscheck checkbounds(A, b)
if b.band == 0
return @inbounds A.dv[b.index]
elseif b.band == _offdiagind(A.uplo)
elseif b.band (-1,1) && b.band == _offdiagind(A.uplo)
# we explicitly compare the possible bands as b.band may be constant-propagated
return @inbounds A.ev[b.index]
else
return bidiagzero(A, Tuple(_cartinds(b))...)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ norm2(x::Union{Array{T},StridedVector{T}}) where {T<:BlasFloat} =
# Conservative assessment of types that have zero(T) defined for themselves
haszero(::Type) = false
haszero(::Type{T}) where {T<:Number} = isconcretetype(T)
@propagate_inbounds _zero(M::AbstractArray{T}, i, j) where {T} = haszero(T) ? zero(T) : zero(M[i,j])
@propagate_inbounds _zero(M::AbstractArray{T}, inds...) where {T} = haszero(T) ? zero(T) : zero(M[inds...])

"""
triu!(M, k::Integer)
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) =
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
i <= j ? A.data[i,j] : _zero(A.data,j,i)

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitLowerTriangular{T}, b::BandIndex) where {T}
b.band < 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::LowerTriangular, b::BandIndex)
b.band <= 0 ? A.data[b] : _zero(A.data, b)
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UnitUpperTriangular{T}, b::BandIndex) where {T}
b.band > 0 ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::UpperTriangular, b::BandIndex)
b.band >= 0 ? A.data[b] : _zero(A.data, b)
end

_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"

Expand Down
38 changes: 37 additions & 1 deletion stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ debug = false
using Test, LinearAlgebra, Random
using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!,
UnitUpperTriangular, UnitLowerTriangular,
mul!, rdiv!, rmul!, lmul!
mul!, rdiv!, rmul!, lmul!, BandIndex

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

Expand Down Expand Up @@ -1286,4 +1286,40 @@ end
end
end

@testset "indexing with a BandIndex" begin
# these tests should succeed even if the linear index along
# the band isn't a constant, or type-inferred at all
M = rand(Int,2,2)
f(A,j, v::Val{n}) where {n} = Val(A[BandIndex(n,j)])
function common_tests(M, ind)
j = ind[]
@test @inferred(f(UpperTriangular(M), j, Val(-1))) == Val(0)
@test @inferred(f(UnitUpperTriangular(M), j, Val(-1))) == Val(0)
@test @inferred(f(UnitUpperTriangular(M), j, Val(0))) == Val(1)
@test @inferred(f(LowerTriangular(M), j, Val(1))) == Val(0)
@test @inferred(f(UnitLowerTriangular(M), j, Val(1))) == Val(0)
@test @inferred(f(UnitLowerTriangular(M), j, Val(0))) == Val(1)
end
common_tests(M, Any[1])

M = Diagonal([1,2])
common_tests(M, Any[1])
# extra tests for banded structure of the parent
for T in (UpperTriangular, UnitUpperTriangular)
@test @inferred(f(T(M), 1, Val(1))) == Val(0)
end
for T in (LowerTriangular, UnitLowerTriangular)
@test @inferred(f(T(M), 1, Val(-1))) == Val(0)
end

M = Tridiagonal([1,2], [1,2,3], [1,2])
common_tests(M, Any[1])
for T in (UpperTriangular, UnitUpperTriangular)
@test @inferred(f(T(M), 1, Val(2))) == Val(0)
end
for T in (LowerTriangular, UnitLowerTriangular)
@test @inferred(f(T(M), 1, Val(-2))) == Val(0)
end
end

end # module TestTriangular

0 comments on commit f62a380

Please sign in to comment.