Skip to content

Commit

Permalink
inference: correctly translate inter-procedural information at everyw…
Browse files Browse the repository at this point in the history
…here (JuliaLang#42847)

There are some remaining cases where we need to call `collect_limitations!`
within `abstract_invoke`.
Also it turns out we need to handle `InterConditional` in any inter-procedural
contexts, so refactored that part within `abstract_call_gf_by_type` and
apply the logic everywhere inter-procedural propagation happens.

Now `InterConditional` propagation and callsite argument type refinement
are enabled for `invoke` and opaque closure call sites, e.g.:
```julia
ispositive(a) = isa(a, Int) && a > 0
@test Base.return_types((Any,)) do a
    if Base.@invoke ispositive(a::Any)
    return a
    end
    return 0
end |> only == Int

@test Base.return_types((Any,)) do a
    f = Base.Experimental.@opaque a -> isa(a, Int) && a > 0
    if f(a)
        return a
    end
    return 0
end |> only == Int
```

Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
aviatesk and vtjnash committed Oct 29, 2021
1 parent 295a093 commit 11fc7ed
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 93 deletions.
216 changes: 136 additions & 80 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ function is_improvable(@nospecialize(rtype))
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
(; fargs, argtypes)::ArgInfo, @nospecialize(atype),
arginfo::ArgInfo, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
add_remark!(interp, sv, "Skipped call in throw block")
return CallMeta(Any, false)
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
Expand All @@ -61,6 +62,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

fargs = arginfo.fargs
for i in 1:napplicable
match = applicable[i]::MethodMatch
method = match.method
Expand All @@ -85,8 +87,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(edges, edge)
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
this_arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
Expand All @@ -111,8 +113,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
this_arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, this_arginfo, match, sv, false)
if const_result !== nothing
const_this_rt, const_result = const_result
if const_this_rt !== this_rt && const_this_rt this_rt
Expand All @@ -134,19 +136,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals = Any[Bottom for _ in 1:length(argtypes)],
Any[Bottom for _ in 1:length(argtypes)]
end
condval = maybe_extract_const_bool(this_conditional)
for i = 1:length(argtypes)
fargs[i] isa SlotNumber || continue
if this_conditional isa InterConditional && this_conditional.slot == i
vtype = this_conditional.vtype
elsetype = this_conditional.elsetype
else
elsetype = vtype = tmeet(argtypes[i], fieldtype(sig, i))
condval === true && (elsetype = Union{})
condval === false && (vtype = Union{})
end
conditionals[1][i] = tmerge(conditionals[1][i], vtype)
conditionals[2][i] = tmerge(conditionals[2][i], elsetype)
cnd = conditional_argtype(this_conditional, sig, argtypes, i)
conditionals[1][i] = tmerge(conditionals[1][i], cnd.vtype)
conditionals[2][i] = tmerge(conditionals[2][i], cnd.elsetype)
end
end
if bail_out_call(interp, rettype, sv)
Expand All @@ -161,56 +154,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
info = ConstCallInfo(info, const_results)
end

rettype = collect_limitations!(rettype, sv)
# if we have argument refinement information, apply that now to get the result
if is_lattice_bool(rettype) && conditionals !== nothing && fargs !== nothing
slot = 0
vtype = elsetype = Any
condval = maybe_extract_const_bool(rettype)
for i in 1:length(fargs)
# find the first argument which supports refinement,
# and intersect all equivalent arguments with it
arg = fargs[i]
arg isa SlotNumber || continue # can't refine
old = argtypes[i]
old isa Type || continue # unlikely to refine
id = slot_id(arg)
if slot == 0 || id == slot
new_vtype = conditionals[1][i]
if condval === false
vtype = Union{}
elseif new_vtype vtype
vtype = new_vtype
else
vtype = tmeet(vtype, widenconst(new_vtype))
end
new_elsetype = conditionals[2][i]
if condval === true
elsetype = Union{}
elseif new_elsetype elsetype
elsetype = new_elsetype
else
elsetype = tmeet(elsetype, widenconst(new_elsetype))
end
if (slot > 0 || condval !== false) && !(old vtype) # essentially vtype ⋤ old
slot = id
elseif (slot > 0 || condval !== true) && !(old elsetype) # essentially elsetype ⋤ old
slot = id
else # reset: no new useful information for this slot
vtype = elsetype = Any
if slot > 0
slot = 0
end
end
end
end
if vtype === Bottom && elsetype === Bottom
rettype = Bottom # accidentally proved this call to be dead / throw !
elseif slot > 0
rettype = Conditional(SlotNumber(slot), vtype, elsetype) # record a Conditional improvement to this slot
end
end
@assert !(rettype isa InterConditional) "invalid lattice element returned from inter-procedural context"
rettype = from_interprocedural!(rettype, sv, arginfo, conditionals)

if call_result_unused(sv) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
Expand Down Expand Up @@ -322,6 +266,117 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
end
end

"""
from_interprocedural!(rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt
Converts inter-procedural return type `rt` into a local lattice element `newrt`,
that is appropriate in the context of current local analysis frame `sv`, especially:
- unwraps `rt::LimitedAccuracy` and collects its limitations into the current frame `sv`
- converts boolean `rt` to new boolean `newrt` in a way `newrt` can propagate extra conditional
refinement information, e.g. translating `rt::InterConditional` into `newrt::Conditional`
that holds a type constraint information about a variable in `sv`
This function _should_ be used wherever we propagate results returned from
`abstract_call_method` or `abstract_call_method_with_const_args`.
When `maybecondinfo !== nothing`, this function also tries extra conditional argument type refinement.
In such cases `maybecondinfo` should be either of:
- `maybecondinfo::Tuple{Vector{Any},Vector{Any}}`: precomputed argument type refinement information
- method call signature tuple type
When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by
`tmerge`ing argument signature type of each method call.
"""
function from_interprocedural!(@nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
rt = collect_limitations!(rt, sv)
if is_lattice_bool(rt)
if maybecondinfo === nothing
rt = widenconditional(rt)
else
rt = from_interconditional(rt, arginfo, maybecondinfo)
end
end
@assert !(rt isa InterConditional) "invalid lattice element returned from inter-procedural context"
return rt
end

function collect_limitations!(@nospecialize(typ), sv::InferenceState)
if isa(typ, LimitedAccuracy)
union!(sv.pclimitations, typ.causes)
return typ.typ
end
return typ
end

function from_interconditional(@nospecialize(typ), (; fargs, argtypes)::ArgInfo, @nospecialize(maybecondinfo))
fargs === nothing && return widenconditional(typ)
slot = 0
vtype = elsetype = Any
condval = maybe_extract_const_bool(typ)
for i in 1:length(fargs)
# find the first argument which supports refinement,
# and intersect all equivalent arguments with it
arg = fargs[i]
arg isa SlotNumber || continue # can't refine
old = argtypes[i]
old isa Type || continue # unlikely to refine
id = slot_id(arg)
if slot == 0 || id == slot
if isa(maybecondinfo, Tuple{Vector{Any},Vector{Any}})
# if we have already computed argument refinement information, apply that now to get the result
new_vtype = maybecondinfo[1][i]
new_elsetype = maybecondinfo[2][i]
else
# otherwise compute it on the fly
cnd = conditional_argtype(typ, maybecondinfo, argtypes, i)
new_vtype = cnd.vtype
new_elsetype = cnd.elsetype
end
if condval === false
vtype = Bottom
elseif new_vtype vtype
vtype = new_vtype
else
vtype = tmeet(vtype, widenconst(new_vtype))
end
if condval === true
elsetype = Bottom
elseif new_elsetype elsetype
elsetype = new_elsetype
else
elsetype = tmeet(elsetype, widenconst(new_elsetype))
end
if (slot > 0 || condval !== false) && !(old vtype) # essentially vtype ⋤ old
slot = id
elseif (slot > 0 || condval !== true) && !(old elsetype) # essentially elsetype ⋤ old
slot = id
else # reset: no new useful information for this slot
vtype = elsetype = Any
if slot > 0
slot = 0
end
end
end
end
if vtype === Bottom && elsetype === Bottom
return Bottom # accidentally proved this call to be dead / throw !
elseif slot > 0
return Conditional(SlotNumber(slot), vtype, elsetype) # record a Conditional improvement to this slot
end
return widenconditional(typ)
end

function conditional_argtype(@nospecialize(rt), @nospecialize(sig), argtypes::Vector{Any}, i::Int)
if isa(rt, InterConditional) && rt.slot == i
return rt
else
vtype = elsetype = tmeet(argtypes[i], fieldtype(sig, i))
condval = maybe_extract_const_bool(rt)
condval === true && (elsetype = Bottom)
condval === false && (vtype = Bottom)
return InterConditional(i, vtype, elsetype)
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::InferenceState)
Expand Down Expand Up @@ -1237,9 +1292,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
match = MethodMatch(ti, env, method, argtype <: method.sig)
# try constant propagation with manual inlinings of some of the heuristics
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
res = nothing
sig = match.spec_types
argtypes′ = argtypes[3:end]
argtypes′[1] = ft
if fargs === nothing
Expand All @@ -1249,7 +1303,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
fargs′[1] = fargs[1]
end
arginfo = ArgInfo(fargs′, argtypes′)
const_prop_argument_heuristic(interp, arginfo, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
# for i in 1:length(argtypes′)
# t, a = ti.parameters[i], argtypes′[i]
Expand All @@ -1259,10 +1312,10 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
return CallMeta(collect_limitations!(const_rt, sv), InvokeCallInfo(match, const_result))
rt, res = const_rt, const_result
end
end
return CallMeta(collect_limitations!(rt, sv), InvokeCallInfo(match, nothing))
return CallMeta(from_interprocedural!(rt, sv, arginfo, sig), InvokeCallInfo(match, res))
end

# call where the function is known exactly
Expand Down Expand Up @@ -1360,9 +1413,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
pushfirst!(argtypes, closure.env)
sig = argtypes_to_type(argtypes)
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, arginfo::ArgInfo, sv::InferenceState)
sig = argtypes_to_type(arginfo.argtypes)
(; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
Expand All @@ -1371,7 +1423,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
info = OpaqueClosureCallInfo(match)
if !result.edgecycle
const_result = abstract_call_method_with_const_args(interp, result, closure,
ArgInfo(nothing, argtypes), match, sv, closure.isva)
arginfo, match, sv, closure.isva)
if const_result !== nothing
const_rettype, const_result = const_result
if const_rettype rt
Expand All @@ -1380,7 +1432,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result])
end
end
return CallMeta(collect_limitations!(rt, sv), info)
return CallMeta(from_interprocedural!(rt, sv, arginfo, match.spec_types), info)
end

function most_general_argtypes(closure::PartialOpaque)
Expand All @@ -1400,7 +1452,9 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
ft = argtypes[1]
f = singleton_type(ft)
if isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
newargtypes = copy(argtypes)
newargtypes[1] = ft.env
return abstract_call_opaque_closure(interp, ft, ArgInfo(arginfo.fargs, newargtypes), sv)
elseif (uft = unwrap_unionall(ft); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], ft), false)
elseif f === nothing
Expand Down Expand Up @@ -1600,8 +1654,10 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if isa(t, PartialOpaque)
# Infer this now so that the specialization is available to
# optimization.
argtypes = most_general_argtypes(t)
pushfirst!(argtypes, t.env)
callinfo = abstract_call_opaque_closure(interp, t,
most_general_argtypes(t), sv)
ArgInfo(nothing, argtypes), sv)
sv.stmt_info[sv.currpc] = OpaqueClosureCreateInfo(callinfo)
end
end
Expand Down
8 changes: 0 additions & 8 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,6 @@ struct LimitedAccuracy
end
end

@inline function collect_limitations!(@nospecialize(typ), sv::InferenceState)
if isa(typ, LimitedAccuracy)
union!(sv.pclimitations, typ.causes)
return typ.typ
end
return typ
end

"""
struct NotFound end
const NOT_FOUND = NotFound()
Expand Down
34 changes: 29 additions & 5 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,35 @@ end
end == Any[Union{Bool,Nothing}]
end

@testset "`from_interprocedural!`: translate inter-procedural information" begin
# TODO come up with a test case to check the functionality of `collect_limitations!`
# one heavy test case would be to use https://github.com/aviatesk/JET.jl and
# check `julia /path/to/JET/jet /path/to/JET/src/JET.jl` doesn't result in errors
# because of nested `LimitedAccuracy`es

# `InterConditional` handling: `abstract_invoke`
ispositive(a) = isa(a, Int) && a > 0
@test Base.return_types((Any,)) do a
if Base.@invoke ispositive(a::Any)
return a
end
return 0
end |> only == Int
# the `fargs = nothing` edge case
@test Base.return_types((Any,)) do a
Core.Compiler.return_type(invoke, Tuple{typeof(ispositive), Type{Tuple{Any}}, Any})
end |> only == Type{Bool}

# `InterConditional` handling: `abstract_call_opaque_closure`
@test Base.return_types((Any,)) do a
f = Base.Experimental.@opaque a -> isa(a, Int) && a > 0
if f(a)
return a
end
return 0
end |> only === Int
end

function f25579(g)
h = g[]
t = (h === nothing)
Expand Down Expand Up @@ -3285,11 +3314,6 @@ function splat_lotta_unions()
end
@test Core.Compiler.return_type(splat_lotta_unions, Tuple{}) >: Tuple{Int,Int,Int}

# handle `fargs = nothing` edge cases
@test (code_typed(; optimize=false) do
Core.Compiler.return_type(invoke, Tuple{typeof(sin), Type{Tuple{Integer}}, Int})
end; true)

# Bare Core.Argument in IR
@eval f_bare_argument(x) = $(Core.Argument(2))
@test Base.return_types(f_bare_argument, (Int,))[1] == Int
Expand Down

0 comments on commit 11fc7ed

Please sign in to comment.