Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: forward Conditional inter-procedurally #42529

Merged
merged 4 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
inference: forward Conditional inter-procedurally
The PR #38905 only "back-propagates" conditional constraint
(from callee to caller), but currently we don't "forward" it
(caller to callee), and so inter-procedural constraint propagation
won't happen for e.g.:
```julia
ifelselike(cnd, x, y) = cnd ? x : y
@test Base.return_types((Any,Int,)) do x, y
    ifelselike(isa(x, Int), x, y)
end |> only == Int
```

This commit complements #38905 and enables further inter-procedural
conditional constraint propagation by forwarding `Conditional` to
callees when it imposes a constraint on any other argument,
during constant propagation.
  • Loading branch information
aviatesk committed Oct 21, 2021
commit a955380bd7ea3a8cac2cc43ce85817453e58fc75
116 changes: 73 additions & 43 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype))
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
(; fargs, argtypes)::ArgInfo, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
add_remark!(interp, sv, "Skipped call in throw block")
Expand Down Expand Up @@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(edges, edge)
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt ⊑ rt
Expand All @@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
if const_result !== nothing
const_this_rt, const_result = const_result
if const_this_rt !== this_rt && const_this_rt ⊑ this_rt
Expand Down Expand Up @@ -523,13 +525,13 @@ struct MethodCallResult
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, va_override::Bool)
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try constant prop'
inf_cache = get_inference_cache(interp)
inf_result = cache_lookup(mi, argtypes, inf_cache)
inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache)
if inf_result === nothing
# if there might be a cycle, check to make sure we don't end up
# calling ourselves here.
Expand All @@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
return nothing
end
end
inf_result = InferenceResult(mi, argtypes, va_override)
inf_result = InferenceResult(mi; arginfo, va_override)
if !any(inf_result.overridden_by_const)
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
return nothing
Expand All @@ -565,7 +567,7 @@ end
# if there's a possibility we could get a better result (hopefully without doing too much work)
# returns `MethodInstance` with constant arguments, returns nothing otherwise
function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState)
if !InferenceParams(interp).ipo_constant_propagation
add_remark!(interp, sv, "[constprop] Disabled by parameter")
Expand All @@ -580,14 +582,14 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
nargs::Int = method.nargs
method.isva && (nargs -= 1)
length(argtypes) < nargs && return nothing
if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt))
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
length(arginfo.argtypes) < nargs && return nothing
if !(const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, result.rt))
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
return nothing
end
allconst = is_allconst(argtypes)
allconst = is_allconst(arginfo)
if !force
if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst)
if !const_prop_function_heuristic(interp, f, arginfo, nargs, allconst)
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
return nothing
end
Expand All @@ -599,7 +601,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand All @@ -617,8 +619,11 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC
end

# see if propagating constants may be worthwhile
function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any})
function const_prop_argument_heuristic(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo)
for a in argtypes
if isa(a, Conditional) && fargs !== nothing
return is_const_prop_profitable_conditional(a, fargs)
end
a = widenconditional(a)
if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a)
return true
Expand All @@ -642,13 +647,34 @@ function is_const_prop_profitable_arg(@nospecialize(arg))
return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val))
end

function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any})
slotid = find_constrained_arg(cnd, fargs)
if slotid !== nothing
return true
end
return is_const_prop_profitable_arg(widenconditional(cnd))
aviatesk marked this conversation as resolved.
Show resolved Hide resolved
end

function find_constrained_arg(cnd::Conditional, fargs::Vector{Any})
slot = cnd.var
return findfirst(fargs) do @nospecialize(x)
x === slot
end
end

function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype))
return improvable_via_constant_propagation(rettype)
end

function is_allconst(argtypes::Vector{Any})
function is_allconst((; fargs, argtypes)::ArgInfo)
for a in argtypes
if isa(a, Conditional) && fargs !== nothing
if is_const_prop_profitable_conditional(a, fargs)
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
continue
end
end
a = widenconditional(a)
# TODO unify these condition with `has_nontrivial_const_info`
if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
return false
end
Expand All @@ -663,7 +689,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
istopfunction(f, :setproperty!)
end

function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool)
function const_prop_function_heuristic(
interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
nargs::Int, allconst::Bool)
if nargs > 1
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
arrty = argtypes[2]
Expand Down Expand Up @@ -705,7 +733,7 @@ end
# result anyway.
function const_prop_methodinstance_heuristic(
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
argtypes::Vector{Any}, sv::InferenceState)
(; argtypes)::ArgInfo, sv::InferenceState)
method = match.method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
Expand Down Expand Up @@ -832,7 +860,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
return Any[Vararg{Any}], nothing
end
@assert !isvarargtype(itertype)
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv)
stateordonet = call.rt
info = call.info
# Return Bottom if this is not an iterator.
Expand Down Expand Up @@ -866,7 +894,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
valtype = getfield_tfunc(stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv)
stateordonet = call.rt
stateordonet_widened = widenconst(stateordonet)
push!(calls, call)
Expand Down Expand Up @@ -901,7 +929,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt
stateordonet_widened = widenconst(stateordonet)
end
if valtype !== Union{}
Expand Down Expand Up @@ -990,7 +1018,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
break
end
end
call = abstract_call(interp, nothing, ct, sv, max_methods)
call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods)
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if bail_out_apply(interp, res, sv)
Expand Down Expand Up @@ -1054,8 +1082,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
return argtypes[i:n]
end

function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo,
sv::InferenceState, max_methods::Int)
@nospecialize f
la = length(argtypes)
if f === ifelse && fargs isa Vector{Any} && la == 4
Expand Down Expand Up @@ -1188,7 +1216,7 @@ function abstract_call_unionall(argtypes::Vector{Any})
return Any
end

function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState)
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, false)
Expand Down Expand Up @@ -1216,14 +1244,17 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
argtypes′ = argtypes[4:end]
aviatesk marked this conversation as resolved.
Show resolved Hide resolved
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
pushfirst!(argtypes′, ft)
fargs′ = fargs[4:end]
pushfirst!(fargs′, fargs[1])
aviatesk marked this conversation as resolved.
Show resolved Hide resolved
arginfo = ArgInfo(fargs′, argtypes′)
const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
# for i in 1:length(argtypes′)
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false)
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt ⊑ rt
Expand All @@ -1235,21 +1266,20 @@ end

# call where the function is known exactly
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
sv::InferenceState,
arginfo::ArgInfo, sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)

(; fargs, argtypes) = arginfo
la = length(argtypes)

if isa(f, Builtin)
if f === _apply_iterate
return abstract_apply(interp, argtypes, sv, max_methods)
elseif f === invoke
return abstract_invoke(interp, argtypes, sv)
return abstract_invoke(interp, arginfo, sv)
elseif f === modifyfield!
return abstract_modifyfield!(interp, argtypes, sv)
end
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false)
elseif f === Core.kwfunc
if la == 2
ft = widenconst(argtypes[2])
Expand Down Expand Up @@ -1282,12 +1312,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt
rty = abstract_call_known(interp, (===), arginfo, sv).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
elseif isa(rty, Const)
Expand All @@ -1303,7 +1333,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs = nothing
end
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false)
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false)
elseif la == 2 &&
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
istopfunction(f, :length)
Expand All @@ -1326,7 +1356,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(val === false ? Type : val, MethodResultPure())
end
atype = argtypes_to_type(argtypes)
return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods)
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
Expand All @@ -1339,8 +1369,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
info = OpaqueClosureCallInfo(match)
if !result.edgecycle
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
match, sv, closure.isva)
const_result = abstract_call_method_with_const_args(interp, result, closure,
ArgInfo(nothing, argtypes), match, sv, closure.isva)
if const_result !== nothing
const_rettype, const_result = const_result
if const_rettype ⊑ rt
Expand All @@ -1363,9 +1393,9 @@ function most_general_argtypes(closure::PartialOpaque)
end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
#print("call ", e.args[1], argtypes, "\n\n")
argtypes = arginfo.argtypes
ft = argtypes[1]
f = singleton_type(ft)
if isa(ft, PartialOpaque)
Expand All @@ -1379,9 +1409,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods)
return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods)
end
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
return abstract_call_known(interp, f, arginfo, sv, max_methods)
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
Expand Down Expand Up @@ -1428,7 +1458,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V
# this may be the wrong world for the call,
# but some of the result is likely to be valid anyways
# and that may help generate better codegen
abstract_call(interp, nothing, at, sv)
abstract_call(interp, ArgInfo(nothing, at), sv)
nothing
end

Expand Down Expand Up @@ -1502,7 +1532,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if argtypes === nothing
t = Bottom
else
callinfo = abstract_call(interp, ea, argtypes, sv)
callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv)
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
end
Expand Down
Loading