Skip to content

Commit

Permalink
inference: forward Conditional inter-procedurally (JuliaLang#42529)
Browse files Browse the repository at this point in the history
The PR JuliaLang#38905 only "back-propagates" conditional constraint
(from callee to caller), but currently we don't "forward" it
(caller to callee), and so inter-procedural constraint propagation
won't happen for e.g.:
```julia
ifelselike(cnd, x, y) = cnd ? x : y
@test Base.return_types((Any,Int,)) do x, y
    ifelselike(isa(x, Int), x, y)
end |> only == Int
```

This commit complements JuliaLang#38905 and enables further inter-procedural
conditional constraint propagation by forwarding `Conditional` to
callees when it imposes a constraint on any other argument,
during constant propagation.

We also improve constant-prop' heuristics in these ways:

- remove `const_prop_rettype_heuristic` since it handles rare cases,
  where const-prop' doens't seem to be worthwhile, e.g. it won't be
  so useful to try to propagate `Const(Tuple{DataType,DataType})` for
  `Const(convert)(::Const(Tuple{DataType,DataType}), ::Tuple{DataType,DataType} -> Tuple{DataType,DataType}`
- rename `is_allconst` to `is_all_overridden`
- also minor refactors and improvements added
  • Loading branch information
aviatesk authored and LilithHafner committed Mar 8, 2022
1 parent ed1cb97 commit 06b55e1
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 79 deletions.
147 changes: 86 additions & 61 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

54 changes: 49 additions & 5 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,57 @@
function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || isa(given_argtype, PartialOpaque)
if is_forwardable_argtype(given_argtype)
return is_lattice_equal(given_argtype, cache_argtype)
end
return !overridden_by_const
end

function is_forwardable_argtype(@nospecialize x)
return isa(x, Const) ||
isa(x, Conditional) ||
isa(x, PartialStruct) ||
isa(x, PartialOpaque)
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
# so that we can construct cache-correct `InferenceResult`s in the first place.
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
function matching_cache_argtypes(
linfo::MethodInstance, (; fargs, argtypes)::ArgInfo, va_override::Bool)
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
given_argtypes = anymap(widenconditional, given_argtypes)
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
given_argtypes = Vector{Any}(undef, length(argtypes))
local condargs = nothing
for i in 1:length(argtypes)
argtype = argtypes[i]
# forward `Conditional` if it conveys a constraint on any other argument
if isa(argtype, Conditional) && fargs !== nothing
cnd = argtype
slotid = find_constrained_arg(cnd, fargs)
if slotid !== nothing
# using union-split signature, we may be able to narrow down `Conditional`
sigt = widenconst(slotid > nargs ? argtypes[slotid] : cache_argtypes[slotid])
vtype = tmeet(cnd.vtype, sigt)
elsetype = tmeet(cnd.elsetype, sigt)
if vtype === Bottom && elsetype === Bottom
# we accidentally proved this method match is impossible
# TODO bail out here immediately rather than just propagating Bottom ?
given_argtypes[i] = Bottom
else
if condargs === nothing
condargs = Tuple{Int,Int}[]
end
push!(condargs, (slotid, i))
given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype)
end
continue
end
end
given_argtypes[i] = widenconditional(argtype)
end
isva = va_override || linfo.def.isva
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
Expand All @@ -30,15 +67,22 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector,
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
if slotid last
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
given_argtypes = isva_given_argtypes
end
@assert length(given_argtypes) == nargs
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
for i in 1:nargs
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
if !is_argtype_match(given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `linfo`
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any
v = unwrapva(argtypes[5])
TF = getfield_tfunc(o, f)
push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call
callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1)
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), sv, #=max_methods=# 1)
pop!(sv.ssavalue_uses[sv.currpc], sv.currpc)
TF2 = tmeet(callinfo.rt, widenconst(TF))
if TF2 === Bottom
Expand Down Expand Up @@ -1747,7 +1747,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
if contains_is(argtypes_vec, Union{})
return CallMeta(Const(Union{}), false)
end
call = abstract_call(interp, nothing, argtypes_vec, sv, -1)
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), sv, -1)
info = verbose_stmt_info(interp) ? ReturnTypeCallInfo(call.info) : false
rt = widenconditional(call.rt)
if isa(rt, Const)
Expand Down
11 changes: 9 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ If `interp` is an `AbstractInterpreter`, it is expected to provide at least the
"""
abstract type AbstractInterpreter end

struct ArgInfo
fargs::Union{Nothing,Vector{Any}}
argtypes::Vector{Any}
end

"""
InferenceResult
Expand All @@ -29,8 +34,10 @@ mutable struct InferenceResult
result # ::Type, or InferenceState if WIP
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
valid_worlds::WorldRange # if inference and optimization is finished
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing, va_override=false)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes, va_override)
function InferenceResult(linfo::MethodInstance,
arginfo::Union{Nothing,ArgInfo} = nothing,
va_override::Bool = false)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override)
return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange())
end
end
Expand Down
9 changes: 0 additions & 9 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,6 @@ unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity
unioncomplexity(t::TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T)::Int : 0
unioncomplexity(@nospecialize(x)) = 0

function improvable_via_constant_propagation(@nospecialize(t))
if isconcretetype(t) && t <: Tuple
for p in t.parameters
p === DataType && return true
end
end
return false
end

# convert a Union of Tuple types to a Tuple of Unions
function unswitchtupleunion(u::Union)
ts = uniontypes(u)
Expand Down
55 changes: 55 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,61 @@ function _g_ifelse_isa_()
end
@test Base.return_types(_g_ifelse_isa_, ()) == [Int]

@testset "Conditional forwarding" begin
# forward `Conditional` if it conveys a constraint on any other argument
ifelselike(cnd, x, y) = cnd ? x : y

@test Base.return_types((Any,Int,)) do x, y
ifelselike(isa(x, Int), x, y)
end |> only == Int

# should work nicely with union-split
@test Base.return_types((Union{Int,Nothing},)) do x
ifelselike(isa(x, Int), x, 0)
end |> only == Int

@test Base.return_types((Any,Int)) do x, y
ifelselike(!isa(x, Int), y, x)
end |> only == Int

@test Base.return_types((Any,Int)) do x, y
a = ifelselike(x === 0, x, 0) # ::Const(0)
if a == 0
return y
else
return nothing # dead branch
end
end |> only == Int

# pick up the first if there are multiple constrained arguments
@test Base.return_types((Any,)) do x
ifelselike(isa(x, Int), x, x)
end |> only == Any

# just propagate multiple constraints
ifelselike2(cnd1, cnd2, x, y, z) = cnd1 ? x : cnd2 ? y : z
@test Base.return_types((Any,Any)) do x, y
ifelselike2(isa(x, Int), isa(y, Int), x, y, 0)
end |> only == Int

# work with `invoke`
@test Base.return_types((Any,Any)) do x, y
Base.@invoke ifelselike(isa(x, Int), x, y::Int)
end |> only == Int

# don't be confused with vararg method
vacond(cnd, va...) = cnd ? va : 0
@test Base.return_types((Any,)) do x
# at runtime we will see `va::Tuple{Tuple{Int,Int}, Tuple{Int,Int}}`
vacond(isa(x, Tuple{Int,Int}), x, x)
end |> only == Union{Int,Tuple{Any,Any}}

# demonstrate extra constraint propagation for Base.ifelse
@test Base.return_types((Any,Int,)) do x, y
ifelse(isa(x, Int), x, y)
end |> only == Int
end

# Equivalence of Const(T.instance) and T for singleton types
@test Const(nothing) Nothing && Nothing Const(nothing)

Expand Down

0 comments on commit 06b55e1

Please sign in to comment.