diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 225bd46c6f262..96b41c0a2ae45 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig))) + if i != length(cases) || (!fully_covered || (!params.trust_inference)) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector spec = item.spec::ResolvedInliningSpec sparam_vals = item.mi.sparam_vals def = item.mi.def::Method - inline_cfg = spec.ir.cfg linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) @@ -459,6 +458,47 @@ end const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)") +""" + ir_inline_unionsplit! + +The core idea of this function is to simulate the dispatch semantics by generating +(flat) `isa`-checks corresponding to the signatures of union-split dispatch candidates, +and then inline their bodies into each `isa`-conditional block. + +This `isa`-based virtual dispatch requires some pre-conditions to hold in order to simulate +the actual semantics correctly. + +The first one is that these dispatch candidates need to be processed in order of their specificity, +and the corresponding `isa`-checks should reflect the method specificities, since now their +signatures are not necessarily concrete. +Fortunately, `ml_matches` should already sorted them in that way, except cases when there is +any ambiguity, from which we already bail out at this point. + +Another consideration is type equality constraint from type variables: the `isa`-checks are +not enough to simulate the dispatch semantics in cases like: + +Given a definition: + + f(x::T, y::T) where T<:Integer = ... + +Transform a callsite: + + (x::Any, y::Any) + +Into the optimized form: + + if isa(x, Integer) && isa(y, Integer) + f(x::Integer, y::Integer) + else + f(x::Integer, y::Integer) + end + +But again, we should already bail out from such cases at this point, essentially by +excluding cases where `case.sig::UnionAll`. + +In short, here we can process the dispatch candidates in order, assuming we haven't changed +their order somehow somewhere up to this point. +""" function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, (; fully_covered, atype, cases, bbs)::UnionSplit, @@ -468,8 +508,9 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, join_bb = bbs[end] pn = PhiNode() local bb = compact.active_result_bb - @assert length(bbs) >= length(cases) - for i in 1:length(cases) + ncases = length(cases) + @assert length(bbs) >= ncases + for i = 1:ncases ithcase = cases[i] mtype = ithcase.sig::DataType # checked within `handle_cases!` case = ithcase.item @@ -477,8 +518,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true nparams = fieldcount(atype) @assert nparams == fieldcount(mtype) - if i != length(cases) || !fully_covered || - (!params.trust_inference && isdispatchtuple(cases[i].sig)) + if i != ncases || !fully_covered || !params.trust_inference for i = 1:nparams a, m = fieldtype(atype, i), fieldtype(mtype, i) # If this is always true, we don't need to check for it @@ -535,7 +575,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - if !params.trust_inference && isdispatchtuple(cases[end].sig) + if !params.trust_inference e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -558,7 +598,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect state = CFGInliningState(ir) for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params) + cfg_inline_unionsplit!(ir, idx, item, state, params) else item = item::InliningTodo spec = item.spec::ResolvedInliningSpec @@ -1172,12 +1212,8 @@ function analyze_single_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) argtypes = sig.argtypes cases = InliningCase[] - local only_method = nothing # keep track of whether there is one matching method - local meth::MethodLookupResult + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false - local revisit_idx = nothing - for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1188,66 +1224,20 @@ function analyze_single_call!( # No applicable methods; try next union split handled_all_cases = false continue - else - if length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method - only_method = false - end - else - only_method = false - end end - for (j, match) in enumerate(meth) - any_covers_full |= match.fully_covers - if !isdispatchtuple(match.spec_types) - if !match.fully_covers - handled_all_cases = false - continue - end - if revisit_idx === nothing - revisit_idx = (i, j) - else - handled_all_cases = false - revisit_idx = nothing - end - else - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) - end + for match in meth + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + any_fully_covered |= match.fully_covers end end - atype = argtypes_to_type(argtypes) - if handled_all_cases && revisit_idx !== nothing - # If there's only one case that's not a dispatchtuple, we can - # still unionsplit by visiting all the other cases first. - # This is useful for code like: - # foo(x::Int) = 1 - # foo(@nospecialize(x::Any)) = 2 - # where we where only a small number of specific dispatchable - # cases are split off from an ::Any typed fallback. - (i, j) = revisit_idx - match = infos[i].results[j] - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) - elseif length(cases) == 0 && only_method isa Method - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple. - # -- But don't try it if we already tried to handle the match in the revisit_idx - # case, because that'll (necessarily) be the same method. - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) - else - @assert length(meth) == 1 - match = meth[1] - end - handle_match!(match, argtypes, flag, state, cases, true) || return nothing - any_covers_full = handled_all_cases = match.fully_covers + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1258,8 +1248,8 @@ function handle_const_call!( (; call, results) = cinfo infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches cases = InliningCase[] + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false local j = 0 for i in 1:length(infos) meth = infos[i].results @@ -1275,32 +1265,26 @@ function handle_const_call!( for match in meth j += 1 result = results[j] - any_covers_full |= match.fully_covers + any_fully_covered |= match.fully_covers if isa(result, ConstResult) case = const_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, InferenceResult) - handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) + handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) end end end - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple - atype = argtypes_to_type(argtypes) - if length(cases) == 0 - length(results) == 1 || return nothing - result = results[1] - isa(result, InferenceResult) || return nothing - handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing - spec_types = cases[1].sig - any_covers_full = handled_all_cases = atype <: spec_types + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end function handle_match!( @@ -1308,9 +1292,12 @@ function handle_match!( cases::Vector{InliningCase}, allow_abstract::Bool = false) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false + # we may see duplicated dispatch signatures here when a signature gets widened + # during abstract interpretation: for the purpose of inlining, we can just skip + # processing this dispatch candidate + _any(case->case.sig === spec_types, cases) && return true item = analyze_method!(match, argtypes, flag, state) item === nothing && return false - _any(case->case.sig === spec_types, cases) && return true push!(cases, InliningCase(spec_types, item)) return true end @@ -1346,7 +1333,9 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), handle_single_case!(ir, idx, stmt, cases[1].item, todo, params) elseif length(cases) > 0 isa(atype, DataType) || return nothing - all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing + for case in cases + isa(case.sig, DataType) || return nothing + end push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end return nothing @@ -1442,7 +1431,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo) end - todo + + return todo end function linear_inline_eligible(ir::IRCode) diff --git a/base/sort.jl b/base/sort.jl index d26e9a4b09332..981eea35d96ab 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -5,7 +5,7 @@ module Sort import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base.Order -using .Base: copymutable, LinearIndices, length, (:), +using .Base: copymutable, LinearIndices, length, (:), iterate, eachindex, axes, first, last, similar, zip, OrdinalRange, AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline, AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !, diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index fa4425893767c..a69e7c08e58dd 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -810,6 +810,103 @@ let @test invoke(Any[10]) === false end +# test union-split, non-dispatchtuple callsite inlining + +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Type) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +abstract_diagonal_dispatch(x::Int, y::Int) = 1 +abstract_diagonal_dispatch(x::Real, y::Int) = 2 +abstract_diagonal_dispatch(x::Int, y::Real) = 3 +function test_abstract_diagonal_dispatch(xs) + @test abstract_diagonal_dispatch(xs[1], xs[2]) == 1 + @test abstract_diagonal_dispatch(xs[3], xs[4]) == 3 + @test abstract_diagonal_dispatch(xs[5], xs[6]) == 2 + @test_throws MethodError abstract_diagonal_dispatch(xs[7], xs[8]) +end +test_abstract_diagonal_dispatch(Any[ + 1, 1, # => 1 + 1, 1.0, # => 3 + 1.0, 1, # => 2 + 1.0, 1.0 # => MethodError +]) + +constrained_dispatch(x::T, y::T) where T<:Real = 0 +let src = code_typed1((Real,Real,)) do x, y + constrained_dispatch(x, y) + end + @test any(iscall((src, constrained_dispatch)), src.code) # should account for MethodError +end +@test_throws MethodError let + x, y = 1.0, 1 + constrained_dispatch(x, y) +end + # issue 43104 @inline isGoodType(@nospecialize x::Type) = @@ -1097,11 +1194,11 @@ end global x44200::Int = 0 function f44200() - global x = 0 - while x < 10 - x += 1 + global x44200 = 0 + while x44200 < 10 + x44200 += 1 end - x + x44200 end let src = code_typed1(f44200) @test count(x -> isa(x, Core.PiNode), src.code) == 0