From 3672a3a24aa2f3b82a206c990a3dbe27725e4662 Mon Sep 17 00:00:00 2001 From: Justin Willmert Date: Wed, 4 Nov 2020 12:13:17 -0600 Subject: [PATCH] Tweak order of operations to get `nnz` to infer as `Int` return type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If the sparse array does not have a concrete index type, then union splitting occurs over the possible `<:Integer` types permitted by `SparseMatrixCSC`: ```julia julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none) Variables #self#::Core.Const(SparseArrays.nnz) S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer Body::Any 1 ── %1 = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer │ %2 = SparseArrays.getfield(S, :n)::Int64 │ %3 = Base.add_int(%2, 1)::Int64 │ %4 = Base.getindex(%1, %3)::Integer │ %5 = (isa)(%4, Int64)::Bool └─── goto #3 if not %5 2 ── %7 = π (%4, Int64) │ %8 = Base.sub_int(%7, 1)::Int64 └─── goto #15 3 ── %10 = (isa)(%4, BigInt)::Bool └─── goto #14 if not %10 4 ── %12 = π (%4, BigInt) │ %13 = Base.slt_int(1, 0)::Bool └─── goto #6 if not %13 5 ── %15 = Base.bitcast(UInt64, 1)::UInt64 │ %16 = Base.neg_int(%15)::UInt64 │ %17 = Base.GMP.MPZ.add_ui::typeof(Base.GMP.MPZ.add_ui) │ %18 = invoke %17(%12::BigInt, %16::UInt64)::BigInt └─── goto #13 6 ── %20 = Core.lshr_int(1, 63)::Int64 │ %21 = Core.trunc_int(Core.UInt8, %20)::UInt8 │ %22 = Core.eq_int(%21, 0x01)::Bool └─── goto #8 if not %22 7 ── invoke Core.throw_inexacterror(:check_top_bit::Symbol, UInt64::Type{UInt64}, 1::Int64) └─── unreachable 8 ── goto #9 9 ── %27 = Core.bitcast(Core.UInt64, 1)::UInt64 └─── goto #10 10 ─ goto #11 11 ─ goto #12 12 ─ %31 = Base.GMP.MPZ.sub_ui::typeof(Base.GMP.MPZ.sub_ui) │ %32 = invoke %31(%12::BigInt, %27::UInt64)::BigInt └─── goto #13 13 ┄ %34 = φ (#5 => %18, #12 => %32)::Any └─── goto #15 14 ─ %36 = (%4 - 1)::Any └─── goto #15 15 ┄ %38 = φ (#2 => %8, #13 => %34, #14 => %36)::Any │ %39 = SparseArrays.Int(%38)::Any └─── return %39 ``` It appears that union splitting over the subtraction by one includes an `Any` branch that widens the return type of `nnz`. By instead converting the index type to `Int` before subtracting, type inference is able to infer that all paths give an `Int` result: ```julia julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none) Variables #self#::Core.Const(SparseArrays.nnz) S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer Body::Int64 1 ── %1 = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer │ %2 = SparseArrays.getfield(S, :n)::Int64 │ %3 = Base.add_int(%2, 1)::Int64 │ %4 = Base.getindex(%1, %3)::Integer │ %5 = (isa)(%4, BigInt)::Bool └─── goto #14 if not %5 2 ── %7 = π (%4, BigInt) │ %8 = Base.getfield(%7, :size)::Int32 │ %9 = Base.flipsign_int(%8, %8)::Int32 │ %10 = Core.sext_int(Core.Int64, %9)::Int64 │ %11 = Base.sle_int(0, %10)::Bool └─── goto #4 if not %11 3 ── %13 = Core.sext_int(Core.Int64, %9)::Int64 │ %14 = Base.sle_int(%13, 1)::Bool └─── goto #5 4 ── nothing 5 ┄─ %17 = φ (#3 => %14, #4 => false)::Bool └─── goto #12 if not %17 6 ── %19 = Base.getfield(%7, :size)::Int32 │ %20 = Core.sext_int(Core.Int64, %19)::Int64 │ %21 = (%20 === 0)::Bool └─── goto #8 if not %21 7 ── goto #9 8 ── %24 = Base.getfield(%7, :d)::Ptr{UInt64} │ %25 = Base.pointerref(%24, 1, 1)::UInt64 │ %26 = Base.bitcast(Int64, %25)::Int64 │ %27 = Base.getfield(%7, :size)::Int32 │ %28 = Core.sext_int(Core.Int64, %27)::Int64 │ %29 = Base.flipsign_int(%26, %28)::Int64 └─── goto #9 9 ┄─ %31 = φ (#7 => 0, #8 => %29)::Int64 │ %32 = Base.getfield(%7, :size)::Int32 │ %33 = Core.sext_int(Core.Int64, %32)::Int64 │ %34 = Base.slt_int(0, %33)::Bool │ %35 = Base.slt_int(0, %31)::Bool │ %36 = (%34 === %35)::Bool │ %37 = Base.not_int(%36)::Bool └─── goto #11 if not %37 10 ─ %39 = Base.GMP.nameof(Int64)::Any │ %40 = Base.GMP.InexactError(%39, Int64, %7)::Any │ Base.GMP.throw(%40) └─── unreachable 11 ─ goto #13 12 ─ %44 = Base.GMP.nameof(Int64)::Any │ %45 = Base.GMP.InexactError(%44, Int64, %7)::Any │ Base.GMP.throw(%45) └─── unreachable 13 ─ goto #15 14 ─ %49 = SparseArrays.Int(%4)::Int64 └─── goto #15 15 ┄ %51 = φ (#13 => %31, #14 => %49)::Int64 │ %52 = Base.sub_int(%51, 1)::Int64 └─── return %52 ``` --- stdlib/SparseArrays/src/sparsematrix.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 628a6a496c099..48eea94d51526 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -108,7 +108,7 @@ julia> nnz(A) 3 ``` """ -nnz(S::AbstractSparseMatrixCSC) = Int(getcolptr(S)[size(S, 2) + 1] - 1) +nnz(S::AbstractSparseMatrixCSC) = Int(getcolptr(S)[size(S, 2) + 1]) - 1 nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S)) nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S) nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)