Skip to content

Commit

Permalink
Replace InferenceResult varargs hack with new inference lattice ele…
Browse files Browse the repository at this point in the history
…ment for "partially constant" tuples

Previously, we hacked in an additional `InferenceResult` field to store varargs type information
in order to facilitate better constant propagation through varargs methods. There were many
other places, however, where constants moving in/out of tuples/varargs thwarted constant
propagation.

This commit removes the varargs hack, replacing it with a new inference lattice element
(`PartialTuple`) that represents tuples where some (but not all) of the elements are
constants. This allows us to follow through with constant propagation in more
situations involving tuple construction/destructuring, and also enabled a clean-up
of the `InferenceResult` caching code.
  • Loading branch information
jrevels committed Oct 6, 2018
1 parent ec574fb commit 55b500c
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 164 deletions.
44 changes: 9 additions & 35 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
haveconst = false
for a in argtypes
a = maybe_widen_conditional(a)
if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val))
if has_nontrivial_const_info(a)
# have new information from argtypes that wasn't available from the signature
if isa(a.val, Symbol) || isa(a.val, Type) || (!isa(a.val, String) && isimmutable(a.val))
if !isa(a, Const) || (isa(a.val, Symbol) || isa(a.val, Type) || (!isa(a.val, String) && isimmutable(a.val)))
# don't consider mutable values or Strings useful constants
haveconst = true
break
Expand Down Expand Up @@ -176,32 +176,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
end
inf_result = cache_lookup(code, argtypes, sv.params.cache)
if inf_result === nothing
inf_result = InferenceResult(code)
atypes = get_argtypes(inf_result)
if method.isva
vargs = argtypes[(nargs + 1):end]
all_vargs_const = true
for i in 1:length(vargs)
a = maybe_widen_conditional(vargs[i])
all_vargs_const &= a isa Const
if i > length(inf_result.vargs)
push!(inf_result.vargs, a)
elseif a isa Const
inf_result.vargs[i] = a
end
end
# If all vargs are const, the result may be a constant
# tuple. If so, we should make sure to treat it as such
if all_vargs_const
atypes[nargs + 1] = builtin_tfunction(tuple, inf_result.vargs, sv)
end
end
for i in 1:nargs
a = maybe_widen_conditional(argtypes[i])
if a isa Const
atypes[i] = a # inject Const argtypes into inference
end
end
inf_result = InferenceResult(code, argtypes)
frame = InferenceState(inf_result, #=cache=#false, sv.params)
frame.limited = true
frame.parent = sv
Expand Down Expand Up @@ -368,6 +343,10 @@ end
# Union of Tuples of the same length is converted to Tuple of Unions.
# returns an array of types
function precise_container_type(@nospecialize(arg), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
if isa(typ, PartialTuple)
return typ.fields
end

if isa(typ, Const)
val = typ.val
if isa(val, SimpleVector) || isa(val, Tuple)
Expand All @@ -376,14 +355,9 @@ function precise_container_type(@nospecialize(arg), @nospecialize(typ), vtypes::
end

arg = ssa_def_expr(arg, sv)
if is_specializable_vararg_slot(arg, sv.nargs, sv.result.vargs)
return sv.result.vargs
end

tti0 = widenconst(typ)
tti = unwrap_unionall(tti0)
if isa(arg, Expr) && arg.head === :call && (abstract_evals_to_constant(arg.args[1], svec, vtypes, sv) ||
abstract_evals_to_constant(arg.args[1], tuple, vtypes, sv))
if isa(arg, Expr) && arg.head === :call && abstract_evals_to_constant(arg.args[1], svec, vtypes, sv)
aa = arg.args
result = Any[ abstract_eval(aa[j],vtypes,sv) for j=2:length(aa) ]
if _any(isvarargtype, result)
Expand Down Expand Up @@ -1061,7 +1035,7 @@ function typeinf_local(frame::InferenceState)
elseif hd === :return
pc´ = n + 1
rt = maybe_widen_conditional(abstract_eval(stmt.args[1], s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type)
if !isa(rt, Const) && !isa(rt, Type) && (!isa(rt, PartialTuple) || frame.cached)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
Expand Down
175 changes: 105 additions & 70 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,122 @@ const EMPTY_VECTOR = Vector{Any}()

mutable struct InferenceResult
linfo::MethodInstance
args::Vector{Any}
vargs::Vector{Any} # Memoize vararg type info w/Consts here when calling get_argtypes
# on the InferenceResult, so that the optimizer can use this info
# later during inlining.
argtypes::Vector{Any}
overridden_by_const::BitVector
result # ::Type, or InferenceState if WIP
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
function InferenceResult(linfo::MethodInstance)
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
if isdefined(linfo, :inferred_const)
result = Const(linfo.inferred_const)
else
result = linfo.rettype
end
return new(linfo, EMPTY_VECTOR, Any[], result, nothing)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
return new(linfo, argtypes, overridden_by_const, result, nothing)
end
end

function get_argtypes(result::InferenceResult)
result.args === EMPTY_VECTOR || return result.args # already cached
argtypes, vargs = get_argtypes(result.linfo)
result.args = argtypes
if vargs !== nothing
result.vargs = vargs
function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
if isa(given_argtype, Const) || isa(given_argtype, PartialTuple)
return is_lattice_equal(given_argtype, cache_argtype)
end
return argtypes
return !overridden_by_const
end

function get_argtypes(linfo::MethodInstance)
# In theory, there could be a `cache` containing a matching `InferenceResult`
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
# so that we can construct cache-correct `InferenceResult`s in the first place.
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector)
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
@assert length(given_argtypes) >= (nargs - 1)
given_argtypes = anymap(maybe_widen_conditional, given_argtypes)
if linfo.def.isva
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - 1)
isva_given_argtypes[i] = given_argtypes[i]
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[nargs:end])
given_argtypes = isva_given_argtypes
end
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing)
if nargs === length(given_argtypes)
for i in 1:nargs
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
# prefer the argtype we were given over the one computed from `linfo`
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
end
end
end
return cache_argtypes, overridden_by_const
end

function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
toplevel = !isa(linfo.def, Method)
atypes::SimpleVector = unwrap_unionall(linfo.specTypes).parameters
linfo_argtypes = Any[unwrap_unionall(linfo.specTypes).parameters...]
nargs::Int = toplevel ? 0 : linfo.def.nargs
args = Vector{Any}(undef, nargs)
vargs = nothing
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialTuple` instance.
if !toplevel && linfo.def.isva
if linfo.specTypes == Tuple
if nargs > 1
atypes = svec(Any[ Any for i = 1:(nargs - 1) ]..., Tuple.parameters[1])
linfo_argtypes = svec(Any[Any for i = 1:(nargs - 1)]..., Tuple.parameters[1])
end
vararg_type = Tuple
vargtype = Tuple
else
laty = length(atypes)
if nargs > laty
va = atypes[laty]
linfo_argtypes_length = length(linfo_argtypes)
if nargs > linfo_argtypes_length
va = linfo_argtypes[linfo_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes)
vararg_type_vec = Any[new_va]
vararg_type = Tuple{new_va}
vargtype_elements = Any[new_va]
vargtype = Tuple{new_va}
else
vararg_type_vec = Any[]
vararg_type = Tuple{}
vargtype_elements = Any[]
vargtype = Tuple{}
end
else
vararg_type_vec = Any[]
for p in atypes[nargs:laty]
vargtype_elements = Any[]
for p in linfo_argtypes[nargs:linfo_argtypes_length]
p = isvarargtype(p) ? unconstrain_vararg_length(p) : p
push!(vararg_type_vec, rewrap_unionall(p, linfo.specTypes))
push!(vargtype_elements, rewrap(p, linfo.specTypes))
end
vararg_type = tuple_tfunc(Tuple{vararg_type_vec...})
for i in 1:length(vararg_type_vec)
atyp = vararg_type_vec[i]
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
if isa(atyp, DataType) && isdefined(atyp, :instance)
# replace singleton types with their equivalent Const object
vararg_type_vec[i] = Const(atyp.instance)
vargtype_elements[i] = Const(atyp.instance)
elseif isconstType(atyp)
vararg_type_vec[i] = Const(atyp.parameters[1])
vargtype_elements[i] = Const(atyp.parameters[1])
end
end
vargtype = tuple_tfunc(vargtype_elements)
end
vargs = vararg_type_vec
end
args[nargs] = vararg_type
cache_argtypes[nargs] = vargtype
nargs -= 1
end
laty = length(atypes)
if laty > 0
if laty > nargs
laty = nargs
end
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
linfo_argtypes_length = length(linfo_argtypes)
if linfo_argtypes_length > 0
n = linfo_argtypes_length > nargs ? nargs : linfo_argtypes_length
tail_index = n
local lastatype
atail = laty
for i = 1:laty
atyp = atypes[i]
if i == laty && isvarargtype(atyp)
for i = 1:n
atyp = linfo_argtypes[i]
if i == n && isvarargtype(atyp)
atyp = unwrapva(atyp)
atail -= 1
tail_index -= 1
end
while isa(atyp, TypeVar)
atyp = atyp.ub
Expand All @@ -98,42 +130,45 @@ function get_argtypes(linfo::MethodInstance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = rewrap_unionall(atyp, linfo.specTypes)
atyp = rewrap(atyp, linfo.specTypes)
end
i == laty && (lastatype = atyp)
args[i] = atyp
i == n && (lastatype = atyp)
cache_argtypes[i] = atyp
end
for i = (atail + 1):nargs
args[i] = lastatype
for i = (tail_index + 1):nargs
cache_argtypes[i] = lastatype
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
end
return args, vargs
return cache_argtypes, falses(length(cache_argtypes))
end

function cache_lookup(code::MethodInstance, argtypes::Vector{Any}, cache::Vector{InferenceResult})
method = code.def::Method
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
method = linfo.def::Method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
for cache_code in cache
# try to search cache first
cache_args = cache_code.args
cache_vargs = cache_code.vargs
if cache_code.linfo === code && length(argtypes) === (length(cache_vargs) + nargs)
cache_match = true
for i in 1:length(argtypes)
a = maybe_widen_conditional(argtypes[i])
ca = i <= nargs ? cache_args[i] : cache_vargs[i - nargs]
# verify that all Const argument types match between the call and cache
if (isa(a, Const) || isa(ca, Const)) && !(a === ca)
cache_match = false
break
end
length(given_argtypes) >= nargs || return nothing
for cached_result in cache
cached_result.linfo === linfo || continue
cache_match = true
cache_argtypes = cached_result.argtypes
cache_overridden_by_const = cached_result.overridden_by_const
for i in 1:nargs
if !is_argtype_match(given_argtypes[i],
cache_argtypes[i],
cache_overridden_by_const[i])
cache_match = false
break
end
cache_match || continue
return cache_code
end
if method.isva && cache_match
cache_match = is_argtype_match(tuple_tfunc(given_argtypes[(nargs + 1):end]),
cache_argtypes[end],
cache_overridden_by_const[end])
end
cache_match || continue
return cached_result
end
return nothing
end
2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mutable struct InferenceState

# initial types
nslots = length(src.slotnames)
argtypes = get_argtypes(result)
argtypes = result.argtypes
nargs = length(argtypes)
s_argtypes = VarTable(undef, nslots)
slottypes = Vector{Any}(undef, nslots)
Expand Down
6 changes: 2 additions & 4 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

mutable struct OptimizationState
linfo::MethodInstance
result_vargs::Vector{Any}
calledges::Vector{Any}
src::CodeInfo
mod::Module
Expand All @@ -24,7 +23,7 @@ mutable struct OptimizationState
frame.stmt_edges[1] = s_edges
end
src = frame.src
return new(frame.linfo, frame.result.vargs,
return new(frame.linfo,
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
frame.min_valid, frame.max_valid,
Expand All @@ -51,8 +50,7 @@ mutable struct OptimizationState
inmodule = linfo.def::Module
nargs = 0
end
result_vargs = Any[] # if you want something more accurate, set it yourself :P
return new(linfo, result_vargs,
return new(linfo,
s_edges::Vector{Any},
src, inmodule, nargs,
min_world(linfo), max_world(linfo),
Expand Down
Loading

0 comments on commit 55b500c

Please sign in to comment.