Skip to content

Commit

Permalink
inference: parameterize some of hard-coded inference logic (JuliaLang…
Browse files Browse the repository at this point in the history
…#39439)

This commit parameterizes some of hard-coded inference logic:
- to bail out from inference when a lattice element can't be refined or
  a current inference frame is proven to throw or to be dead
- to add call backedges when the call return type won't be refined

Those `AbstractInterpreter`s used for code optimization (including
`NativeInterpreter`) usually just want the methods defined for
`AbstractInterpreter`, but some other `AbstractInterpreter` may want
other implementations and heuristics to control inference process.
For example, [`JETInterpreter`](https://github.com/aviatesk/JET.jl) is
used for code analysis and wants to add call backedges even when a call
return type is `Any`.
  • Loading branch information
aviatesk committed Feb 10, 2021
1 parent 5d7e13f commit 1ef49c8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 21 deletions.
61 changes: 40 additions & 21 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
napplicable = length(applicable)
rettype = Bottom
edgecycle = false
edges = Any[]
edges = MethodInstance[]
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
Expand All @@ -115,7 +114,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
match = applicable[i]::MethodMatch
method = match.method
sig = match.spec_types
if istoplevel && !isdispatchtuple(sig)
if bail_out_toplevel_call(interp, sig, sv)
# only infer concrete call sites in top-level expressions
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
rettype = Any
Expand All @@ -135,7 +134,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
edgecycle |= edgecycle1::Bool
this_rt = tmerge(this_rt, rt)
this_rt === Any && break
if bail_out_call(interp, this_rt, sv)
break
end
end
else
this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
Expand All @@ -153,7 +154,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
seen += 1
rettype = tmerge(rettype, this_rt)
rettype === Any && break
if bail_out_call(interp, rettype, sv)
break
end
end
# try constant propagation if only 1 method is inferred to non-Bottom
# this is in preparation for inlining, or improving the return result
Expand All @@ -179,18 +182,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
if !(rettype === Any) # adding a new method couldn't refine (widen) this type
for edge in edges
add_backedge!(edge::MethodInstance, sv)
end
for (thisfullmatch, mt) in zip(fullmatch, mts)
if !thisfullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge!(mt, atype, sv)
end
end
end
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
#print("=> ", rettype, "\n")
if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
Expand All @@ -205,6 +197,27 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(rettype, info)
end

function add_call_backedges!(interp::AbstractInterpreter,
@nospecialize(rettype),
edges::Vector{MethodInstance},
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
sv::InferenceState)
if rettype === Any
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
# (widen) this type
return
end
for edge in edges
add_backedge!(edge, sv)
end
for (thisfullmatch, mt) in zip(fullmatch, mts)
if !thisfullmatch
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge!(mt, atype, sv)
end
end
end

function const_prop_profitable(@nospecialize(arg))
# have new information from argtypes that wasn't available from the signature
Expand Down Expand Up @@ -746,7 +759,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
call = abstract_call(interp, nothing, ct, sv, max_methods)
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if res === Any
if bail_out_apply(interp, res, sv)
# No point carrying forward the info, we're not gonna inline it anyway
retinfo = nothing
break
Expand Down Expand Up @@ -1171,7 +1184,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
if ai === Bottom
if bail_out_statement(interp, ai, sv)
return Bottom
end
argtypes[i] = ai
Expand Down Expand Up @@ -1349,6 +1362,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
if condt === Bottom
empty!(frame.pclimitations)
end
if bail_out_local(interp, condt, frame)
break
end
condval = maybe_extract_const_bool(condt)
Expand Down Expand Up @@ -1440,7 +1455,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
else
if hd === :(=)
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
t === Bottom && break
if bail_out_local(interp, t, frame)
break
end
frame.src.ssavaluetypes[pc] = t
lhs = stmt.args[1]
if isa(lhs, Slot)
Expand All @@ -1455,7 +1472,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
t === Bottom && break
if bail_out_local(interp, t, frame)
break
end
if !isempty(frame.ssavalue_uses[pc])
record_ssa_assign(pc, t, frame)
else
Expand Down
13 changes: 13 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,16 @@ may_compress(ni::NativeInterpreter) = true
may_discard_trees(ni::NativeInterpreter) = true

method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai))

# define inference bail out logic
# `NativeInterpreter` bails out from inference when
# - a lattice element grows up to `Any` (inter-procedural call, abstract apply)
# - a lattice element gets down to `Bottom` (statement inference, local frame inference)
# - inferring non-concrete toplevel call sites
bail_out_call(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_apply(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_statement(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
bail_out_local(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
function bail_out_toplevel_call(interp::AbstractInterpreter, @nospecialize(sig), sv)
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig)
end

0 comments on commit 1ef49c8

Please sign in to comment.