Skip to content

Commit

Permalink
inference: follow up #43852 (#44101)
Browse files Browse the repository at this point in the history
This commit consists of minor follow up tweaks for #43852:
- inlining: use `ConstResult` if available
- refactor tests
- simplify `CodeInstance` constructor signature
- tweak `concrete_eval_const_proven_total_or_error` signature for JET integration
  • Loading branch information
aviatesk committed Feb 10, 2022
1 parent 708873a commit fed4544
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 116 deletions.
9 changes: 4 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ struct MethodCallResult
end
end

function is_all_const_arg((; fargs, argtypes)::ArgInfo)
function is_all_const_arg((; argtypes)::ArgInfo)
for a in argtypes
if !isa(a, Const) && !isconstType(a) && !issingletontype(a)
return false
Expand All @@ -628,9 +628,8 @@ function is_all_const_arg((; fargs, argtypes)::ArgInfo)
return true
end

function concrete_eval_const_proven_total_or_error(
interp::AbstractInterpreter,
@nospecialize(f), argtypes::Vector{Any})
function concrete_eval_const_proven_total_or_error(interp::AbstractInterpreter,
@nospecialize(f), (; argtypes)::ArgInfo, _::InferenceState)
args = Any[ (a = widenconditional(argtypes[i]);
isa(a, Const) ? a.val :
isconstType(a) ? (a::DataType).parameters[1] :
Expand Down Expand Up @@ -673,7 +672,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
return nothing
end
if f !== nothing && result.edge !== nothing && is_total_or_error(result.edge_effects) && is_all_const_arg(arginfo)
rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo.argtypes)
rt = concrete_eval_const_proven_total_or_error(interp, f, arginfo, sv)
add_backedge!(result.edge, sv)
if rt === nothing
# The evaulation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
Expand Down
60 changes: 33 additions & 27 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1034,18 +1034,22 @@ function inline_invoke!(
# TODO: We could union split out the signature check and continue on
return nothing
end
argtypes = invoke_rewrite(sig.argtypes)
result = info.result
if isa(result, InferenceResult)
(; mi) = item = InliningTodo(result, argtypes)
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
return nothing
if isa(result, ConstResult)
item = const_result_item(result, state)
else
argtypes = invoke_rewrite(sig.argtypes)
if isa(result, InferenceResult)
(; mi) = item = InliningTodo(result, argtypes)
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
return nothing
end
end
item = analyze_method!(match, argtypes, flag, state)
end
item = analyze_method!(match, argtypes, flag, state)
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
return nothing
end
Expand Down Expand Up @@ -1241,28 +1245,18 @@ function handle_const_call!(
for match in meth
j += 1
result = results[j]
if result === false
# Inference determined that this call is guaranteed to throw.
# Do not inline.
fully_covered = false
continue
end
if isa(result, ConstResult)
if !isdefined(result, :result) || !is_inlineable_constant(result.result)
case = compileable_specialization(state.et, result.mi, EFFECTS_TOTAL)
else
case = ConstantCase(quoted(result.result))
end
case = const_result_item(result, state)
signature_union = Union{signature_union, result.mi.specTypes}
push!(cases, InliningCase(result.mi.specTypes, case))
continue
end
if result === nothing
elseif isa(result, InferenceResult)
signature_union = Union{signature_union, result.linfo.specTypes}
fully_covered &= handle_inf_result!(result, argtypes, flag, state, cases)
else
@assert result === nothing
signature_union = Union{signature_union, match.spec_types}
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
else
signature_union = Union{signature_union, result.linfo.specTypes}
fully_covered &= handle_const_result!(result, argtypes, flag, state, cases)
end
end
end
Expand Down Expand Up @@ -1296,7 +1290,7 @@ function handle_match!(
return true
end

function handle_const_result!(
function handle_inf_result!(
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase})
(; mi) = item = InliningTodo(result, argtypes)
Expand All @@ -1309,6 +1303,14 @@ function handle_const_result!(
return true
end

function const_result_item(result::ConstResult, state::InliningState)
if !isdefined(result, :result) || !is_inlineable_constant(result.result)
return compileable_specialization(state.et, result.mi, EFFECTS_TOTAL)
else
return ConstantCase(quoted(result.result))
end
end

function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype),
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}},
params::OptimizationParams)
Expand Down Expand Up @@ -1375,7 +1377,11 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
ir, idx, stmt, result, flag,
sig, state, todo)
else
item = analyze_method!(info.match, sig.argtypes, flag, state)
if isa(result, ConstResult)
item = const_result_item(result, state)
else
item = analyze_method!(info.match, sig.argtypes, flag, state)
end
handle_single_case!(ir, idx, stmt, item, todo, state.params)
end
continue
Expand Down
15 changes: 7 additions & 8 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,8 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
return true
end

function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
valid_worlds::WorldRange, effects::Effects, ipo_effects::Effects,
relocatability::UInt8)
function CodeInstance(
result::InferenceResult, @nospecialize(inferred_result), valid_worlds::WorldRange)
local const_flags::Int32
result_type = result.result
@assert !(result_type isa LimitedAccuracy)
Expand Down Expand Up @@ -309,10 +308,13 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
const_flags = 0x00
end
end
relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0)
return CodeInstance(result.linfo,
widenconst(result_type), rettype_const, inferred_result,
const_flags, first(valid_worlds), last(valid_worlds),
encode_effects(effects), encode_effects(ipo_effects), relocatability)
# TODO: Actually do something with non-IPO effects
encode_effects(result.ipo_effects), encode_effects(result.ipo_effects),
relocatability)
end

# For the NativeInterpreter, we don't need to do an actual cache query to know
Expand Down Expand Up @@ -386,10 +388,7 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src)
relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0)
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds,
# TODO: Actually do something with non-IPO effects
result.ipo_effects, result.ipo_effects, relocatability)
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds)
end
unlock_mi_inference(interp, linfo)
nothing
Expand Down
7 changes: 4 additions & 3 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ end
`@assume_effects` overrides the compiler's effect modeling for the given method.
`ex` must be a method definition or `@ccall` expression.
WARNING: Improper use of this macro causes undefined behavior (including crashes,
incorrect answers, or other hard to track bugs). Use with care and only if absolutely
required.
!!! warning
Improper use of this macro causes undefined behavior (including crashes,
incorrect answers, or other hard to track bugs). Use with care and only if
absolutely required.
In general, each `setting` value makes an assertion about the behavior of the
function, without requiring the compiler to prove that this behavior is indeed
Expand Down
71 changes: 44 additions & 27 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using Test
using Base.Meta
using Core: ReturnNode

include(normpath(@__DIR__, "irutils.jl"))

"""
Helper to walk the AST and call a function on every node.
"""
Expand Down Expand Up @@ -150,19 +152,6 @@ end
@test !any(x -> x isa Expr && x.head === :invoke, src.code)
end

function fully_eliminated(f, args)
@nospecialize f args
let code = code_typed(f, args)[1][1].code
return length(code) == 1 && isa(code[1], ReturnNode)
end
end
function fully_eliminated(f, args, retval)
@nospecialize f args
let code = code_typed(f, args)[1][1].code
return length(code) == 1 && isa(code[1], ReturnNode) && code[1].val == retval
end
end

# check that ismutabletype(type) can be fully eliminated
f_mutable_nothrow(s::String) = Val{typeof(s).name.flags}
@test fully_eliminated(f_mutable_nothrow, (String,))
Expand Down Expand Up @@ -246,7 +235,7 @@ function f_subtype()
T = SomeArbitraryStruct
T <: Bool
end
@test fully_eliminated(f_subtype, Tuple{}, false)
@test fully_eliminated(f_subtype, Tuple{}; retval=false)

# check that pointerref gets deleted if unused
f_pointerref(T::Type{S}) where S = Val(length(T.parameters))
Expand All @@ -270,7 +259,7 @@ function foo_apply_apply_type_svec()
B = Tuple{Float32, Float32}
Core.apply_type(A..., B.types...)
end
@test fully_eliminated(foo_apply_apply_type_svec, Tuple{}, NTuple{3, Float32})
@test fully_eliminated(foo_apply_apply_type_svec, Tuple{}; retval=NTuple{3, Float32})

# The that inlining doesn't drop ambiguity errors (#30118)
c30118(::Tuple{Ref{<:Type}, Vararg}) = nothing
Expand All @@ -284,7 +273,7 @@ b30118(x...) = c30118(x)
f34900(x::Int, y) = x
f34900(x, y::Int) = y
f34900(x::Int, y::Int) = invoke(f34900, Tuple{Int, Any}, x, y)
@test fully_eliminated(f34900, Tuple{Int, Int}, Core.Argument(2))
@test fully_eliminated(f34900, Tuple{Int, Int}; retval=Core.Argument(2))

@testset "check jl_ir_flag_inlineable for inline macro" begin
@test ccall(:jl_ir_flag_inlineable, Bool, (Any,), first(methods(@inline x -> x)).source)
Expand Down Expand Up @@ -324,10 +313,7 @@ struct NonIsBitsDims
dims::NTuple{N, Int} where N
end
NonIsBitsDims() = NonIsBitsDims(())
let ci = code_typed(NonIsBitsDims, Tuple{})[1].first
@test length(ci.code) == 1 && isa(ci.code[1], ReturnNode) &&
ci.code[1].val.value == NonIsBitsDims()
end
@test fully_eliminated(NonIsBitsDims, (); retval=QuoteNode(NonIsBitsDims()))

struct NonIsBitsDimsUndef
dims::NTuple{N, Int} where N
Expand Down Expand Up @@ -923,7 +909,7 @@ end
g() = Core.get_binding_type($m, :y)
end

@test fully_eliminated(m.f, Tuple{}, Int)
@test fully_eliminated(m.f, Tuple{}; retval=Int)
src = code_typed(m.g, ())[][1]
@test count(iscall((src, Core.get_binding_type)), src.code) == 1
@test m.g() === Any
Expand Down Expand Up @@ -962,17 +948,48 @@ end
@test fully_eliminated(f_sin_perf, Tuple{})

# Test that we inline the constructor of something that is not const-inlineable
const THE_REF_NULL = Ref{Int}()
const THE_REF = Ref{Int}(0)
struct FooTheRef
x::Ref
FooTheRef() = new(THE_REF)
FooTheRef(v) = new(v === nothing ? THE_REF_NULL : THE_REF)
end
let src = code_typed1() do
FooTheRef(nothing)
end
@test count(isnew, src.code) == 1
end
let src = code_typed1() do
FooTheRef(0)
end
@test count(isnew, src.code) == 1
end
f_make_the_ref() = FooTheRef()
f_make_the_ref_but_dont_return_it() = (FooTheRef(); nothing)
let src = code_typed1(f_make_the_ref, ())
@test count(x->isexpr(x, :new), src.code) == 1
let src = code_typed1() do
Base.@invoke FooTheRef(nothing::Any)
end
@test count(isnew, src.code) == 1
end
let src = code_typed1() do
Base.@invoke FooTheRef(0::Any)
end
@test count(isnew, src.code) == 1
end
@test fully_eliminated() do
FooTheRef(nothing)
nothing
end
@test fully_eliminated() do
FooTheRef(0)
nothing
end
@test fully_eliminated() do
Base.@invoke FooTheRef(nothing::Any)
nothing
end
@test fully_eliminated() do
Base.@invoke FooTheRef(0::Any)
nothing
end
@test fully_eliminated(f_make_the_ref_but_dont_return_it, Tuple{})

# Test that the Core._apply_iterate bail path taints effects
function f_apply_bail(f)
Expand Down
49 changes: 3 additions & 46 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,7 @@ using Test
using Base.Meta
using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot

# utilities
# =========

import Core.Compiler: argextype, singleton_type

argextype(@nospecialize args...) = argextype(args..., Any[])
code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code

# check if `x` is a statement with a given `head`
isnew(@nospecialize x) = Meta.isexpr(x, :new)

# check if `x` is a dynamic call of a given function
iscall(y) = @nospecialize(x) -> iscall(y, x)
function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x))
return iscall(x) do @nospecialize x
singleton_type(argextype(x, src)) === f
end
end
iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])

# check if `x` is a statically-resolved call of a function whose name is `sym`
isinvoke(y) = @nospecialize(x) -> isinvoke(y, x)
isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x)
isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance)
include(normpath(@__DIR__, "irutils.jl"))

# domsort
# =======
Expand Down Expand Up @@ -473,9 +449,7 @@ struct FooPartial
global f_partial
f_partial(x) = new(x, 2).x
end
let ci = code_typed(f_partial, Tuple{Float64})[1].first
@test length(ci.code) == 1 && isa(ci.code[1], ReturnNode)
end
@test fully_eliminated(f_partial, Tuple{Float64})

# A SSAValue after the compaction line
let m = Meta.@lower 1 + 1
Expand Down Expand Up @@ -657,11 +631,7 @@ function no_op_refint(r)
r[]
return
end
let code = code_typed(no_op_refint,Tuple{Base.RefValue{Int}})[1].first.code
@test length(code) == 1
@test isa(code[1], Core.ReturnNode)
@test code[1].val === nothing
end
@test fully_eliminated(no_op_refint,Tuple{Base.RefValue{Int}}; retval=nothing)

# check getfield elim handling of GlobalRef
const _some_coeffs = (1,[2],3,4)
Expand Down Expand Up @@ -773,19 +743,6 @@ end
# test `stmt_effect_free` and DCE
# ===============================

function fully_eliminated(f, args)
@nospecialize f args
let code = code_typed(f, args)[1][1].code
return length(code) == 1 && isa(code[1], ReturnNode)
end
end
function fully_eliminated(f, args, retval)
@nospecialize f args
let code = code_typed(f, args)[1][1].code
return length(code) == 1 && isa(code[1], ReturnNode) && code[1].val == retval
end
end

let # effect-freeness computation for array allocation

# should eliminate dead allocations
Expand Down
Loading

0 comments on commit fed4544

Please sign in to comment.