Skip to content

Commit

Permalink
optimize construction of InferenceResult for constant inference
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Apr 8, 2024
1 parent 62df400 commit a5356a4
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 116 deletions.
40 changes: 23 additions & 17 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ const_prop_result(inf_result::InferenceResult) =
ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
inf_result.ipo_effects, inf_result.linfo)

# return cached constant analysis result
# return cached result of constant analysis
return_cached_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) =
const_prop_result(inf_result)

Expand All @@ -1249,8 +1249,15 @@ function const_prop_call(interp::AbstractInterpreter,
inf_cache = get_inference_cache(interp)
𝕃ᡒ = typeinf_lattice(interp)
argtypes = has_conditional(𝕃ᡒ, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes)
given_argtypes, overridden_by_const = matching_cache_argtypes(𝕃ᡒ, mi, argtypes)
inf_result = cache_lookup(𝕃ᡒ, mi, given_argtypes, inf_cache)
# use `cache_argtypes` that has been constructed for fresh regular inference if available
volatile_inf_result = result.volatile_inf_result
if volatile_inf_result !== nothing
cache_argtypes = volatile_inf_result.inf_result.argtypes
else
cache_argtypes = matching_cache_argtypes(𝕃ᡒ, mi)
end
argtypes = matching_cache_argtypes(𝕃ᡒ, mi, argtypes, cache_argtypes)
inf_result = cache_lookup(𝕃ᡒ, mi, argtypes, inf_cache)
if inf_result !== nothing
# found the cache for this constant prop'
if inf_result.result === nothing
Expand All @@ -1260,12 +1267,18 @@ function const_prop_call(interp::AbstractInterpreter,
@assert inf_result.linfo === mi "MethodInstance for cached inference result does not match"
return return_cached_result(interp, inf_result, sv)
end
# perform fresh constant prop'
inf_result = InferenceResult(mi, given_argtypes, overridden_by_const)
if !any(inf_result.overridden_by_const)
overridden_by_const = falses(length(argtypes))
for i = 1:length(argtypes)
if argtypes[i] !== cache_argtypes[i]
overridden_by_const[i] = true
end
end
if !any(overridden_by_const)
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
return nothing
end
# perform fresh constant prop'
inf_result = InferenceResult(mi, argtypes, overridden_by_const)
frame = InferenceState(inf_result, #=cache_mode=#:local, interp)
if frame === nothing
add_remark!(interp, sv, "[constprop] Could not retrieve the source")
Expand All @@ -1287,26 +1300,19 @@ end

# TODO implement MustAlias forwarding

struct ConditionalArgtypes <: ForwardableArgtypes
struct ConditionalArgtypes
arginfo::ArgInfo
sv::InferenceState
end

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
conditional_argtypes::ConditionalArgtypes)
The implementation is able to forward `Conditional` of `conditional_argtypes`,
as well as the other general extended lattice information.
"""
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
conditional_argtypes::ConditionalArgtypes)
conditional_argtypes::ConditionalArgtypes,
cache_argtypes::Vector{Any})
(; arginfo, sv) = conditional_argtypes
(; fargs, argtypes) = arginfo
given_argtypes = Vector{Any}(undef, length(argtypes))
def = mi.def::Method
nargs = Int(def.nargs)
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
local condargs = nothing
for i in 1:length(argtypes)
argtype = argtypes[i]
Expand Down Expand Up @@ -1349,7 +1355,7 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
else
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
end
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

# This is only for use with `Conditional`.
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ include("compiler/ssair/ir.jl")
include("compiler/ssair/tarjan.jl")

include("compiler/abstractlattice.jl")
include("compiler/stmtinfo.jl")
include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
include("compiler/tfuncs.jl")
include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
Expand Down
110 changes: 33 additions & 77 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,30 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance) ->
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
Returns argument types `cache_argtypes::Vector{Any}` for `mi` that are in the native
Julia type domain. `overridden_by_const::BitVector` is all `false` meaning that
there is no additional extended lattice information there.
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::ForwardableArgtypes) ->
(cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
Returns cache-correct extended lattice argument types `cache_argtypes::Vector{Any}`
for `mi` given some `argtypes` accompanied by `overridden_by_const::BitVector`
that marks which argument contains additional extended lattice information.
In theory, there could be a `cache` containing a matching `InferenceResult`
for the provided `mi` and `given_argtypes`. The purpose of this function is
to return a valid value for `cache_lookup(𝕃, mi, argtypes, cache).argtypes`,
so that we can construct cache-correct `InferenceResult`s in the first place.
"""
function matching_cache_argtypes end

function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance)
method = isa(mi.def, Method) ? mi.def::Method : nothing
cache_argtypes = most_general_argtypes(method, mi.specTypes)
overridden_by_const = falses(length(cache_argtypes))
return cache_argtypes, overridden_by_const
(; def, specTypes) = mi
return most_general_argtypes(isa(def, Method) ? def : nothing, specTypes)
end

struct SimpleArgtypes <: ForwardableArgtypes
struct SimpleArgtypes
argtypes::Vector{Any}
end

"""
matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, argtypes::SimpleArgtypes)
The implementation for `argtypes` with general extended lattice information.
This is supposed to be used for debugging and testing or external `AbstractInterpreter`
usages and in general `matching_cache_argtypes(::MethodInstance, ::ConditionalArgtypes)`
is more preferred it can forward `Conditional` information.
"""
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance, simple_argtypes::SimpleArgtypes)
function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
simple_argtypes::SimpleArgtypes,
cache_argtypes::Vector{Any})
(; argtypes) = simple_argtypes
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
return pick_const_args(𝕃, mi, given_argtypes)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

function pick_const_args(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any})
cache_argtypes, overridden_by_const = matching_cache_argtypes(𝕃, mi)
return pick_const_args!(𝕃, cache_argtypes, overridden_by_const, given_argtypes)
end

function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, overridden_by_const::BitVector, given_argtypes::Vector{Any})
for i = 1:length(given_argtypes)
function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
nargtypes = length(given_argtypes)
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
for i = 1:nargtypes
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
Expand All @@ -66,13 +33,13 @@ function pick_const_args!(𝕃::AbstractLattice, cache_argtypes::Vector{Any}, ov
!⊏(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
end
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
else
given_argtypes[i] = cache_argtype
end
end
return cache_argtypes, overridden_by_const
return given_argtypes
end

function is_argtype_match(𝕃::AbstractLattice,
Expand All @@ -89,9 +56,9 @@ end
va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
def = mi.def
isva = isa(def, Method) ? def.isva : false
nargs = isa(def, Method) ? Int(def.nargs) : length(mi.specTypes.parameters)
def = mi.def::Method
isva = def.isva
nargs = Int(def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs-isva)
Expand All @@ -112,14 +79,11 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
return given_argtypes
end

function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes),
withfirst::Bool = true)
function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
toplevel = method === nothing
isva = !toplevel && method.isva
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
# For opaque closure, the closure environment is processed elsewhere
withfirst || (nargs -= 1)
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
Expand Down Expand Up @@ -162,17 +126,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
cache_argtypes[nargs] = vargtype
nargs -= 1
end
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
if mi_argtypes_length > 0
n = mi_argtypes_length > nargs ? nargs : mi_argtypes_length
tail_index = n
tail_index = nargtypes = min(mi_argtypes_length, nargs)
local lastatype
for i = 1:n
for i = 1:nargtypes
atyp = mi_argtypes[i]
if i == n && isvarargtype(atyp)
if i == nargtypes && isvarargtype(atyp)
atyp = unwrapva(atyp)
tail_index -= 1
end
Expand All @@ -185,16 +148,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == n && (lastatype = atyp)
i == nargtypes && (lastatype = atyp)
cache_argtypes[i] = atyp
end
for i = (tail_index + 1):nargs
for i = (tail_index+1):nargs
cache_argtypes[i] = lastatype
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
end
cache_argtypes
return cache_argtypes
end

# eliminate free `TypeVar`s in order to make the life much easier down the road:
Expand All @@ -213,22 +176,15 @@ end
function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any},
cache::Vector{InferenceResult})
method = mi.def::Method
nargs = Int(method.nargs)
method.isva && (nargs -= 1)
length(given_argtypes) β‰₯ nargs || return nothing
nargtypes = length(given_argtypes)
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
for cached_result in cache
cached_result.linfo === mi || continue
cached_result.linfo === mi || @goto next_cache
cache_argtypes = cached_result.argtypes
cache_overridden_by_const = cached_result.overridden_by_const
for i in 1:nargs
if !is_argtype_match(𝕃, widenmustalias(given_argtypes[i]),
cache_argtypes[i], cache_overridden_by_const[i])
@goto next_cache
end
end
if method.isva
if !is_argtype_match(𝕃, tuple_tfunc(𝕃, given_argtypes[(nargs + 1):end]),
cache_argtypes[end], cache_overridden_by_const[end])
@assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`"
cache_overridden_by_const = cached_result.overridden_by_const::BitVector
for i in 1:nargtypes
if !is_argtype_match(𝕃, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i])
@goto next_cache
end
end
Expand Down
5 changes: 4 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,10 @@ end
frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState}
frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState}

is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const)
function is_constproped(sv::InferenceState)
(;overridden_by_const) = sv.result
return overridden_by_const !== nothing
end
is_constproped(::IRInterpretationState) = true

is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ the original `ci::CodeInfo` are modified.
"""
function inflate_ir!(ci::CodeInfo, mi::MethodInstance)
sptypes = sptypes_from_meth_instance(mi)
argtypes, _ = matching_cache_argtypes(fallback_lattice, mi)
argtypes = matching_cache_argtypes(fallback_lattice, mi)
return inflate_ir!(ci, sptypes, argtypes)
end
function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any})
Expand Down
8 changes: 3 additions & 5 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ struct EdgeCallResult
end
end

# return cached regular inference result
# return cached result of regular inference
function return_cached_result(::AbstractInterpreter, codeinst::CodeInstance, caller::AbsIntState)
rt = cached_return_type(codeinst)
effects = ipo_effects(codeinst)
Expand Down Expand Up @@ -869,10 +869,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
effects = isinferred ? frame.result.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
# propagate newly inferred source to the inliner, allowing efficient inlining w/o deserialization:
# note that this result is cached globally exclusively, we can use this local result destructively
volatile_inf_result = (isinferred && (force_inline ||
src_inlining_policy(interp, result.src, NoCallInfo(), IR_FLAG_NULL))) ?
VolatileInferenceResult(result) : nothing
# note that this result is cached globally exclusively, so we can use this local result destructively
volatile_inf_result = isinferred ? VolatileInferenceResult(result) : nothing
return EdgeCallResult(frame.bestguess, exc_bestguess, edge, effects, volatile_inf_result)
elseif frame === true
# unresolvable cycle
Expand Down
29 changes: 16 additions & 13 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ struct VarState
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
end

abstract type ForwardableArgtypes end

struct AnalysisResults
result
next::AnalysisResults
Expand All @@ -70,16 +68,19 @@ end
const NULL_ANALYSIS_RESULTS = AnalysisResults(nothing)

"""
InferenceResult(mi::MethodInstance, [argtypes::ForwardableArgtypes, 𝕃::AbstractLattice])
result::InferenceResult
A type that represents the result of running type inference on a chunk of code.
See also [`matching_cache_argtypes`](@ref).
There are two constructor available:
- `InferenceResult(mi::MethodInstance, [𝕃::AbstractLattice])` for regular inference,
without extended lattice information included in `result.argtypes`.
- `InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::BitVector)`
for constant inference, with extended lattice information included in `result.argtypes`.
"""
mutable struct InferenceResult
const linfo::MethodInstance
const argtypes::Vector{Any}
const overridden_by_const::BitVector
const overridden_by_const::Union{Nothing,BitVector}
result # extended lattice element if inferred, nothing otherwise
exc_result # like `result`, but for the thrown value
src # ::Union{CodeInfo, IRCode, OptimizationState} if inferred copy is available, nothing otherwise
Expand All @@ -89,16 +90,18 @@ mutable struct InferenceResult
analysis_results::AnalysisResults # AnalysisResults with e.g. result::ArgEscapeCache if optimized, otherwise NULL_ANALYSIS_RESULTS
is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively
ci::CodeInstance # CodeInstance if this result has been added to the cache
function InferenceResult(mi::MethodInstance, cache_argtypes::Vector{Any}, overridden_by_const::BitVector)
# def = mi.def
# nargs = def isa Method ? Int(def.nargs) : 0
# @assert length(cache_argtypes) == nargs
return new(mi, cache_argtypes, overridden_by_const, nothing, nothing, nothing,
function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector})
def = mi.def
nargs = def isa Method ? Int(def.nargs) : 0
@assert length(argtypes) == nargs "invalid `argtypes` for `mi`"
return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing,
WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false)
end
end
InferenceResult(mi::MethodInstance, 𝕃::AbstractLattice=fallback_lattice) =
InferenceResult(mi, matching_cache_argtypes(𝕃, mi)...)
function InferenceResult(mi::MethodInstance, 𝕃::AbstractLattice=fallback_lattice)
argtypes = matching_cache_argtypes(𝕃, mi)
return InferenceResult(mi, argtypes, #=overridden_by_const=#nothing)
end

function stack_analysis_result!(inf_result::InferenceResult, @nospecialize(result))
return inf_result.analysis_results = AnalysisResults(result, inf_result.analysis_results)
Expand Down
Loading

0 comments on commit a5356a4

Please sign in to comment.