From fed45445638ae9ad850a08e81e0ede6e4c29f61b Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Thu, 10 Feb 2022 14:58:41 +0900 Subject: [PATCH] inference: follow up #43852 (#44101) This commit consists of minor follow up tweaks for #43852: - inlining: use `ConstResult` if available - refactor tests - simplify `CodeInstance` constructor signature - tweak `concrete_eval_const_proven_total_or_error` signature for JET integration --- base/compiler/abstractinterpretation.jl | 9 ++-- base/compiler/ssair/inlining.jl | 60 +++++++++++---------- base/compiler/typeinfer.jl | 15 +++--- base/expr.jl | 7 +-- test/compiler/inline.jl | 71 +++++++++++++++---------- test/compiler/irpasses.jl | 49 ++--------------- test/compiler/irutils.jl | 34 ++++++++++++ 7 files changed, 129 insertions(+), 116 deletions(-) create mode 100644 test/compiler/irutils.jl diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 7c3510263ca42..fe94a7f0600aa 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -619,7 +619,7 @@ struct MethodCallResult end end -function is_all_const_arg((; fargs, argtypes)::ArgInfo) +function is_all_const_arg((; argtypes)::ArgInfo) for a in argtypes if !isa(a, Const) && !isconstType(a) && !issingletontype(a) return false @@ -628,9 +628,8 @@ function is_all_const_arg((; fargs, argtypes)::ArgInfo) return true end -function concrete_eval_const_proven_total_or_error( - interp::AbstractInterpreter, - @nospecialize(f), argtypes::Vector{Any}) +function concrete_eval_const_proven_total_or_error(interp::AbstractInterpreter, + @nospecialize(f), (; argtypes)::ArgInfo, _::InferenceState) args = Any[ (a = widenconditional(argtypes[i]); isa(a, Const) ? a.val : isconstType(a) ? (a::DataType).parameters[1] : @@ -673,7 +672,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul return nothing end if f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo) - rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo.argtypes) + rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo, sv) add_backedge!(result.edge, sv) if rt === nothing # The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index a14a23326d58e..d3a322f7d44c2 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1034,18 +1034,22 @@ function inline_invoke!( # TODO: We could union split out the signature check and continue on return nothing end - argtypes = invoke_rewrite(sig.argtypes) result = info.result - if isa(result, InferenceResult) - (; mi) = item = InliningTodo(result, argtypes) - validate_sparams(mi.sparam_vals) || return nothing - if argtypes_to_type(argtypes) <: mi.def.sig - state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - handle_single_case!(ir, idx, stmt, item, todo, state.params, true) - return nothing + if isa(result, ConstResult) + item = const_result_item(result, state) + else + argtypes = invoke_rewrite(sig.argtypes) + if isa(result, InferenceResult) + (; mi) = item = InliningTodo(result, argtypes) + validate_sparams(mi.sparam_vals) || return nothing + if argtypes_to_type(argtypes) <: mi.def.sig + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + handle_single_case!(ir, idx, stmt, item, todo, state.params, true) + return nothing + end end + item = analyze_method!(match, argtypes, flag, state) end - item = analyze_method!(match, argtypes, flag, state) handle_single_case!(ir, idx, stmt, item, todo, state.params, true) return nothing end @@ -1241,28 +1245,18 @@ function handle_const_call!( for match in meth j += 1 result = results[j] - if result === false - # Inference determined that this call is guaranteed to throw. - # Do not inline. - fully_covered = false - continue - end if isa(result, ConstResult) - if !isdefined(result, :result) || !is_inlineable_constant(result.result) - case = compileable_specialization(state.et, result.mi, EFFECTS_TOTAL) - else - case = ConstantCase(quoted(result.result)) - end + case = const_result_item(result, state) signature_union = Union{signature_union, result.mi.specTypes} push!(cases, InliningCase(result.mi.specTypes, case)) continue - end - if result === nothing + elseif isa(result, InferenceResult) + signature_union = Union{signature_union, result.linfo.specTypes} + fully_covered &= handle_inf_result!(result, argtypes, flag, state, cases) + else + @assert result === nothing signature_union = Union{signature_union, match.spec_types} fully_covered &= handle_match!(match, argtypes, flag, state, cases) - else - signature_union = Union{signature_union, result.linfo.specTypes} - fully_covered &= handle_const_result!(result, argtypes, flag, state, cases) end end end @@ -1296,7 +1290,7 @@ function handle_match!( return true end -function handle_const_result!( +function handle_inf_result!( result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState, cases::Vector{InliningCase}) (; mi) = item = InliningTodo(result, argtypes) @@ -1309,6 +1303,14 @@ function handle_const_result!( return true end +function const_result_item(result::ConstResult, state::InliningState) + if !isdefined(result, :result) || !is_inlineable_constant(result.result) + return compileable_specialization(state.et, result.mi, EFFECTS_TOTAL) + else + return ConstantCase(quoted(result.result)) + end +end + function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}}, params::OptimizationParams) @@ -1375,7 +1377,11 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) ir, idx, stmt, result, flag, sig, state, todo) else - item = analyze_method!(info.match, sig.argtypes, flag, state) + if isa(result, ConstResult) + item = const_result_item(result, state) + else + item = analyze_method!(info.match, sig.argtypes, flag, state) + end handle_single_case!(ir, idx, stmt, item, todo, state.params) end continue diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index a15f81abde919..03ba383de4f61 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -277,9 +277,8 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) return true end -function CodeInstance(result::InferenceResult, @nospecialize(inferred_result), - valid_worlds::WorldRange, effects::Effects, ipo_effects::Effects, - relocatability::UInt8) +function CodeInstance( + result::InferenceResult, @nospecialize(inferred_result), valid_worlds::WorldRange) local const_flags::Int32 result_type = result.result @assert !(result_type isa LimitedAccuracy) @@ -309,10 +308,13 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result), const_flags = 0x00 end end + relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0) return CodeInstance(result.linfo, widenconst(result_type), rettype_const, inferred_result, const_flags, first(valid_worlds), last(valid_worlds), - encode_effects(effects), encode_effects(ipo_effects), relocatability) + # TODO: Actually do something with non-IPO effects + encode_effects(result.ipo_effects), encode_effects(result.ipo_effects), + relocatability) end # For the NativeInterpreter, we don't need to do an actual cache query to know @@ -386,10 +388,7 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult) # TODO: also don't store inferred code if we've previously decided to interpret this function if !already_inferred inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src) - relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0) - code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds, - # TODO: Actually do something with non-IPO effects - result.ipo_effects, result.ipo_effects, relocatability) + code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds) end unlock_mi_inference(interp, linfo) nothing diff --git a/base/expr.jl b/base/expr.jl index 59e6b075d8db9..38e89d284c989 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -378,9 +378,10 @@ end `@assume_effects` overrides the compiler's effect modeling for the given method. `ex` must be a method definition or `@ccall` expression. -WARNING: Improper use of this macro causes undefined behavior (including crashes, -incorrect answers, or other hard to track bugs). Use with care and only if absolutely -required. +!!! warning + Improper use of this macro causes undefined behavior (including crashes, + incorrect answers, or other hard to track bugs). Use with care and only if + absolutely required. In general, each `setting` value makes an assertion about the behavior of the function, without requiring the compiler to prove that this behavior is indeed diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index f2130e1c7eab4..4261665edc80e 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -4,6 +4,8 @@ using Test using Base.Meta using Core: ReturnNode +include(normpath(@__DIR__, "irutils.jl")) + """ Helper to walk the AST and call a function on every node. """ @@ -150,19 +152,6 @@ end @test !any(x -> x isa Expr && x.head === :invoke, src.code) end -function fully_eliminated(f, args) - @nospecialize f args - let code = code_typed(f, args)[1][1].code - return length(code) == 1 && isa(code[1], ReturnNode) - end -end -function fully_eliminated(f, args, retval) - @nospecialize f args - let code = code_typed(f, args)[1][1].code - return length(code) == 1 && isa(code[1], ReturnNode) && code[1].val == retval - end -end - # check that ismutabletype(type) can be fully eliminated f_mutable_nothrow(s::String) = Val{typeof(s).name.flags} @test fully_eliminated(f_mutable_nothrow, (String,)) @@ -246,7 +235,7 @@ function f_subtype() T = SomeArbitraryStruct T <: Bool end -@test fully_eliminated(f_subtype, Tuple{}, false) +@test fully_eliminated(f_subtype, Tuple{}; retval=false) # check that pointerref gets deleted if unused f_pointerref(T::Type{S}) where S = Val(length(T.parameters)) @@ -270,7 +259,7 @@ function foo_apply_apply_type_svec() B = Tuple{Float32, Float32} Core.apply_type(A..., B.types...) end -@test fully_eliminated(foo_apply_apply_type_svec, Tuple{}, NTuple{3, Float32}) +@test fully_eliminated(foo_apply_apply_type_svec, Tuple{}; retval=NTuple{3, Float32}) # The that inlining doesn't drop ambiguity errors (#30118) c30118(::Tuple{Ref{<:Type}, Vararg}) = nothing @@ -284,7 +273,7 @@ b30118(x...) = c30118(x) f34900(x::Int, y) = x f34900(x, y::Int) = y f34900(x::Int, y::Int) = invoke(f34900, Tuple{Int, Any}, x, y) -@test fully_eliminated(f34900, Tuple{Int, Int}, Core.Argument(2)) +@test fully_eliminated(f34900, Tuple{Int, Int}; retval=Core.Argument(2)) @testset "check jl_ir_flag_inlineable for inline macro" begin @test ccall(:jl_ir_flag_inlineable, Bool, (Any,), first(methods(@inline x -> x)).source) @@ -324,10 +313,7 @@ struct NonIsBitsDims dims::NTuple{N, Int} where N end NonIsBitsDims() = NonIsBitsDims(()) -let ci = code_typed(NonIsBitsDims, Tuple{})[1].first - @test length(ci.code) == 1 && isa(ci.code[1], ReturnNode) && - ci.code[1].val.value == NonIsBitsDims() -end +@test fully_eliminated(NonIsBitsDims, (); retval=QuoteNode(NonIsBitsDims())) struct NonIsBitsDimsUndef dims::NTuple{N, Int} where N @@ -923,7 +909,7 @@ end g() = Core.get_binding_type($m, :y) end - @test fully_eliminated(m.f, Tuple{}, Int) + @test fully_eliminated(m.f, Tuple{}; retval=Int) src = code_typed(m.g, ())[][1] @test count(iscall((src, Core.get_binding_type)), src.code) == 1 @test m.g() === Any @@ -962,17 +948,48 @@ end @test fully_eliminated(f_sin_perf, Tuple{}) # Test that we inline the constructor of something that is not const-inlineable +const THE_REF_NULL = Ref{Int}() const THE_REF = Ref{Int}(0) struct FooTheRef x::Ref - FooTheRef() = new(THE_REF) + FooTheRef(v) = new(v === nothing ? THE_REF_NULL : THE_REF) +end +let src = code_typed1() do + FooTheRef(nothing) + end + @test count(isnew, src.code) == 1 +end +let src = code_typed1() do + FooTheRef(0) + end + @test count(isnew, src.code) == 1 end -f_make_the_ref() = FooTheRef() -f_make_the_ref_but_dont_return_it() = (FooTheRef(); nothing) -let src = code_typed1(f_make_the_ref, ()) - @test count(x->isexpr(x, :new), src.code) == 1 +let src = code_typed1() do + Base.@invoke FooTheRef(nothing::Any) + end + @test count(isnew, src.code) == 1 +end +let src = code_typed1() do + Base.@invoke FooTheRef(0::Any) + end + @test count(isnew, src.code) == 1 +end +@test fully_eliminated() do + FooTheRef(nothing) + nothing +end +@test fully_eliminated() do + FooTheRef(0) + nothing +end +@test fully_eliminated() do + Base.@invoke FooTheRef(nothing::Any) + nothing +end +@test fully_eliminated() do + Base.@invoke FooTheRef(0::Any) + nothing end -@test fully_eliminated(f_make_the_ref_but_dont_return_it, Tuple{}) # Test that the Core._apply_iterate bail path taints effects function f_apply_bail(f) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 9eb77490cbdb4..128fd6cc84b7b 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -4,31 +4,7 @@ using Test using Base.Meta using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot -# utilities -# ========= - -import Core.Compiler: argextype, singleton_type - -argextype(@nospecialize args...) = argextype(args..., Any[]) -code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo -get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code - -# check if `x` is a statement with a given `head` -isnew(@nospecialize x) = Meta.isexpr(x, :new) - -# check if `x` is a dynamic call of a given function -iscall(y) = @nospecialize(x) -> iscall(y, x) -function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x)) - return iscall(x) do @nospecialize x - singleton_type(argextype(x, src)) === f - end -end -iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) - -# check if `x` is a statically-resolved call of a function whose name is `sym` -isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) -isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) -isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance) +include(normpath(@__DIR__, "irutils.jl")) # domsort # ======= @@ -473,9 +449,7 @@ struct FooPartial global f_partial f_partial(x) = new(x, 2).x end -let ci = code_typed(f_partial, Tuple{Float64})[1].first - @test length(ci.code) == 1 && isa(ci.code[1], ReturnNode) -end +@test fully_eliminated(f_partial, Tuple{Float64}) # A SSAValue after the compaction line let m = Meta.@lower 1 + 1 @@ -657,11 +631,7 @@ function no_op_refint(r) r[] return end -let code = code_typed(no_op_refint,Tuple{Base.RefValue{Int}})[1].first.code - @test length(code) == 1 - @test isa(code[1], Core.ReturnNode) - @test code[1].val === nothing -end +@test fully_eliminated(no_op_refint,Tuple{Base.RefValue{Int}}; retval=nothing) # check getfield elim handling of GlobalRef const _some_coeffs = (1,[2],3,4) @@ -773,19 +743,6 @@ end # test `stmt_effect_free` and DCE # =============================== -function fully_eliminated(f, args) - @nospecialize f args - let code = code_typed(f, args)[1][1].code - return length(code) == 1 && isa(code[1], ReturnNode) - end -end -function fully_eliminated(f, args, retval) - @nospecialize f args - let code = code_typed(f, args)[1][1].code - return length(code) == 1 && isa(code[1], ReturnNode) && code[1].val == retval - end -end - let # effect-freeness computation for array allocation # should eliminate dead allocations diff --git a/test/compiler/irutils.jl b/test/compiler/irutils.jl new file mode 100644 index 0000000000000..06d261720bdf8 --- /dev/null +++ b/test/compiler/irutils.jl @@ -0,0 +1,34 @@ +import Core: CodeInfo, ReturnNode, MethodInstance +import Core.Compiler: argextype, singleton_type +import Base.Meta: isexpr + +argextype(@nospecialize args...) = argextype(args..., Any[]) +code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::CodeInfo +get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code + +# check if `x` is a statement with a given `head` +isnew(@nospecialize x) = isexpr(x, :new) +isreturn(@nospecialize x) = isa(x, ReturnNode) + +# check if `x` is a dynamic call of a given function +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((src, f)::Tuple{CodeInfo,Base.Callable}, @nospecialize(x)) + return iscall(x) do @nospecialize x + singleton_type(argextype(x, src)) === f + end +end +iscall(pred::Base.Callable, @nospecialize(x)) = isexpr(x, :call) && pred(x.args[1]) + +# check if `x` is a statically-resolved call of a function whose name is `sym` +isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) +isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) +isinvoke(pred::Function, @nospecialize(x)) = isexpr(x, :invoke) && pred(x.args[1]::MethodInstance) + +function fully_eliminated(@nospecialize args...; retval=(@__FILE__), kwargs...) + code = code_typed1(args...; kwargs...).code + if retval !== (@__FILE__) + return length(code) == 1 && isreturn(code[1]) && code[1].val == retval + else + return length(code) == 1 && isreturn(code[1]) + end +end