Skip to content

Commit

Permalink
Refactor compiler method lookup interface (JuliaLang#36743)
Browse files Browse the repository at this point in the history
The primary motivation here is to clean up the notion of
"looking up a method in the method table" into a single
object that can be passed around. At the moment, it's
all a bit murky, with fairly large state objects being
passed around everywhere and implicit accesses to the
global environment. In my AD use case, I need to be a
bit careful to make sure the various inference and
optimization steps are looking things up in the correct
tables/caches, so being very explicit about where things
need to be looked up is quite helpful.

In particular, I would like to clean up the optimizer, to
not require the big `OptimizationState` which is currently
a bit of a mix of things that go into IRCode and information
needed for method lookup/edge tracking. That isn't part of
this PR, but will build on top of it.

More generally, with a bunch of the recent compiler work,
I've been trying to define more crisp boundaries between
the various components of the system, giving them clearer
interfaces, and at least a little bit of documentation.
The compiler is a very powerful bit of technology, but I
think people having been avoiding it, because the code
looks a bit scary. I'm hoping some of these cleanups will
make it easier for people to understand what's going on.
Here in particular, I'm using `findall(sig, table)` as
the predicate for method lookup. The idea being that
people familiar with the `findall(predicate, collection)`
idiom from regular julia will have a good intuitive
understanding of what's happening (a collection is searched
for a predicate), an array of matches is returned, etc.
Of course, it's not a perfect fit, but I think these
kinds of mental aids can be helpful in making it easier
for people to read compiler code (similar to how JuliaLang#35831
used `getindex` as the verb for cache lookup). While
I was at it, I also cleaned up the use of out-parameters
which leaked through too much of the underlying C API
and replaced them by a proper struct of results.
  • Loading branch information
Keno authored and simeonschaub committed Aug 11, 2020
1 parent 581b766 commit d4c63d2
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 129 deletions.
52 changes: 15 additions & 37 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,13 @@ const _REF_NAME = Ref.body.name
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt)
box = Core.Box(atype)
return get!(cache, atype) do
_min_val = UInt[typemin(UInt)]
_max_val = UInt[typemax(UInt)]
_ambig = Int32[0]
ms = _methods_by_ftype(box.contents, max_methods, world, false, _min_val, _max_val, _ambig)
return ms, _min_val[1], _max_val[1], _ambig[1] != 0
end
end

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt})
ms, minvalid, maxvalid, ambig = matching_methods(atype, cache, max_methods, world)
min_valid[1] = max(min_valid[1], minvalid)
max_valid[1] = min(max_valid[1], maxvalid)
return ms, ambig
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.currpc in sv.throw_blocks
return CallMeta(Any, false)
end
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
valid_worlds = WorldRange()
atype_params = unwrap_unionall(atype).parameters
splitunions = 1 < countunionsplit(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING
mts = Core.MethodTable[]
Expand All @@ -56,15 +38,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
xapplicable, ambig = matching_methods(sig_n, sv.matching_methods_cache, max_methods,
get_world_counter(interp), min_valid, max_valid)
if xapplicable === false
matches = findall(sig_n, method_table(interp); limit=max_methods)
if matches === missing
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
return CallMeta(Any, false)
end
push!(infos, MethodMatchInfo(xapplicable, ambig))
append!(applicable, xapplicable)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, xapplicable)
push!(infos, MethodMatchInfo(matches))
append!(applicable, matches)
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
Expand All @@ -86,19 +68,20 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
applicable, ambig = matching_methods(atype, sv.matching_methods_cache, max_methods,
get_world_counter(interp), min_valid, max_valid)
if applicable === false
matches = findall(atype, method_table(interp, sv); limit=max_methods)
if matches === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
add_remark!(interp, sv, "Too many methods matched")
return CallMeta(Any, false)
end
push!(mts, mt)
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, applicable))
info = MethodMatchInfo(applicable, ambig)
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
info = MethodMatchInfo(matches)
applicable = matches.matches
valid_worlds = matches.valid_worlds
end
update_valid_age!(min_valid[1], max_valid[1], sv)
update_valid_age!(sv, valid_worlds)
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
rettype = Bottom
Expand Down Expand Up @@ -1460,12 +1443,7 @@ function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, caller)
no_active_ips_in_callers = false
end
if caller.min_valid < frame.min_valid
caller.min_valid = frame.min_valid
end
if caller.max_valid > frame.max_valid
caller.max_valid = frame.max_valid
end
caller.valid_worlds = intersect(caller.valid_worlds, frame.valid_worlds)
end
end
return true
Expand Down
32 changes: 24 additions & 8 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ end

const GLOBAL_CI_CACHE = InternalCodeCache()

struct WorldRange
min_world::UInt
max_world::UInt
end
WorldRange() = WorldRange(typemin(UInt), typemax(UInt))
WorldRange(w::UInt) = WorldRange(w, w)
WorldRange(r::UnitRange) = WorldRange(first(r), last(r))
first(wr::WorldRange) = wr.min_world
last(wr::WorldRange) = wr.max_world
in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world

function intersect(a::WorldRange, b::WorldRange)
ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world))
@assert ret.min_world <= ret.max_world
return ret
end

"""
struct WorldView
Expand All @@ -22,20 +39,19 @@ range of world ages, rather than defaulting to the current active world age.
"""
struct WorldView{Cache}
cache::Cache
min_world::UInt
max_world::UInt
worlds::WorldRange
WorldView(cache::Cache, range::WorldRange) where Cache = new{Cache}(cache, range)
end
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
WorldView(cache, world::UInt) = WorldView(cache, world, world)
WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) =
WorldView(wvc.cache, min_world, max_world)
WorldView(cache, args...) = WorldView(cache, WorldRange(args...))
WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance}
if r === nothing
return default
end
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")

include("compiler/cicache.jl")
include("compiler/methodtable.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")
include("compiler/cicache.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
Expand Down
30 changes: 15 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ mutable struct InferenceState
# info on the state of inference and the linfo
src::CodeInfo
world::UInt
min_valid::UInt
max_valid::UInt
valid_worlds::WorldRange
nargs::Int
stmt_types::Vector{Any}
stmt_edges::Vector{Any}
Expand Down Expand Up @@ -44,9 +43,10 @@ mutable struct InferenceState
inferred::Bool
dont_work_on_me::Bool

# cached results of calling `_methods_by_ftype`, including `min_valid` and
# `max_valid`, to be used in inlining
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
# The place to look up methods while working on this function.
# In particular, we cache method lookup results for the same function to
# fast path repeated queries.
method_table::CachedMethodTable{InternalMethodTable}

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
Expand Down Expand Up @@ -100,13 +100,12 @@ mutable struct InferenceState
inmodule = linfo.def::Module
end

min_valid = src.min_world
max_valid = src.max_world == typemax(UInt) ?
get_world_counter() : src.max_world
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
src, get_world_counter(interp), min_valid, max_valid,
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Expand All @@ -115,14 +114,16 @@ mutable struct InferenceState
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false, false,
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(),
CachedMethodTable(method_table(interp)),
interp)
result.result = frame
cached && push!(get_inference_cache(interp), result)
return frame
end
end

method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table

function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
src = retrieve_code_info(result.linfo)
Expand Down Expand Up @@ -202,14 +203,13 @@ end
_topmod(sv::InferenceState) = _topmod(sv.mod)

# work towards converging the valid age range for sv
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update")
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
sv.valid_worlds = intersect(worlds, sv.valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv)
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
old = frame.src.ssavaluetypes[ssa_id]
Expand Down
93 changes: 93 additions & 0 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
abstract type MethodTableView; end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
struct InternalMethodTable <: MethodTableView
A struct representing the state of the internal method table at a
particular world age.
"""
struct InternalMethodTable <: MethodTableView
world::UInt
end

"""
struct CachedMethodTable <: MethodTableView
Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T} <: MethodTableView
cache::IdDict{Any, Union{Missing, MethodLookupResult}}
table::T
end
CachedMethodTable(table::T) where T =
CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodLookupResult}}(),
table)

"""
findall(sig::Type{<:Tuple}, view::MethodTableView; limit=typemax(Int))
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exeeded the specified limit,
`missing` is returned.
"""
function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type{<:Tuple}), table::CachedMethodTable; limit::Int=typemax(Int))
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit=limit)
end
end

"""
findsup(sig::Type{<:Tuple}, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.
Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
"""
function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, table.world, min_valid, max_valid)::Union{Method, Nothing}
result === nothing && return nothing
(result, WorldRange(min_valid[], max_valid[]))
end

# This query is not cached
findsup(sig::Type{<:Tuple}, table::CachedMethodTable) = findsup(sig, table.table)
24 changes: 9 additions & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ mutable struct OptimizationState
mod::Module
nargs::Int
world::UInt
min_valid::UInt
max_valid::UInt
valid_worlds::WorldRange
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
# cached results of calling `_methods_by_ftype` from inference, including
# `min_valid` and `max_valid`
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
# TODO: This will be eliminated once optimization no longer needs to do method lookups
interp::AbstractInterpreter
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
Expand All @@ -33,9 +29,9 @@ mutable struct OptimizationState
return new(params, frame.linfo,
s_edges::Vector{Any},
src, frame.stmt_info, frame.mod, frame.nargs,
frame.world, frame.min_valid, frame.max_valid,
frame.world, frame.valid_worlds,
frame.sptypes, frame.slottypes, false,
frame.matching_methods_cache, interp)
interp)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
# prepare src for running optimization passes
Expand Down Expand Up @@ -64,9 +60,9 @@ mutable struct OptimizationState
return new(params, linfo,
s_edges::Vector{Any},
src, stmt_info, inmodule, nargs,
get_world_counter(), UInt(1), get_world_counter(),
get_world_counter(), WorldRange(UInt(1), get_world_counter()),
sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp)
interp)
end
end

Expand Down Expand Up @@ -110,11 +106,9 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)

_topmod(sv::OptimizationState) = _topmod(sv.mod)

function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.world <= sv.max_valid,
"invalid age range update")
function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange)
sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
end

Expand All @@ -126,7 +120,7 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState)
end

function add_backedge!(li::CodeInstance, caller::OptimizationState)
update_valid_age!(min_world(li), max_world(li), caller)
update_valid_age!(caller, WorldRange(min_world(li), max_world(li)))
add_backedge!(li.def, caller)
nothing
end
Expand Down
Loading

0 comments on commit d4c63d2

Please sign in to comment.