Skip to content

Commit

Permalink
Refactor compiler method lookup interface
Browse files Browse the repository at this point in the history
The primary motivation here is to clean up the notion of
"looking up a method in the method table" into a single
object that can be passed around. At the moment, it's
all a bit murky, with fairly large state objects being
passed around everywhere and implicit accesses to the
global environment. In my AD use case, I need to be a
bit careful to make sure the various inference and
optimization steps are looking things up in the correct
tables/caches, so being very explicit about where things
need to be looked up is quite helpful.

In particular, I would like to clean up the optimizer, to
not require the big `OptimizationState` which is currently
a bit of a mix of things that go into IRCode and information
needed for method lookup/edge tracking. That isn't part of
this PR, but will build on top of it.

More generally, with a bunch of the recent compiler work,
I've been trying to define more crisp boundaries between
the various components of the system, giving them clearer
interfaces, and at least a little bit of documentation.
The compiler is a very powerful bit of technology, but I
think people having been avoiding it, because the code
looks a bit scary. I'm hoping some of these cleanups will
make it easier for people to understand what's going on.
Here in particular, I'm using `findall(sig, table)` as
the predicate for method lookup. The idea being that
people familiar with the `findall(predicate, collection)`
idiom from regular julia will have a good intuitive
understanding of what's happening (a collection is searched
for a predicate), an array of matches is returned, etc.
Of course, it's not a perfect fit, but I think these
kinds of mental aids can be helpful in making it easier
for people to read compiler code (similar to how #35831
used `getindex` as the verb for cache lookup). While
I was at it, I also cleaned up the use of out-parameters
which leaked through too much of the underlying C API
and replaced them by a proper struct of results.
  • Loading branch information
Keno committed Jul 21, 2020
1 parent e4a3329 commit 3423bde
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 129 deletions.
52 changes: 15 additions & 37 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,13 @@ const _REF_NAME = Ref.body.name
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt)
box = Core.Box(atype)
return get!(cache, atype) do
_min_val = UInt[typemin(UInt)]
_max_val = UInt[typemax(UInt)]
_ambig = Int32[0]
ms = _methods_by_ftype(box.contents, max_methods, world, false, _min_val, _max_val, _ambig)
return ms, _min_val[1], _max_val[1], _ambig[1] != 0
end
end

function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt})
ms, minvalid, maxvalid, ambig = matching_methods(atype, cache, max_methods, world)
min_valid[1] = max(min_valid[1], minvalid)
max_valid[1] = min(max_valid[1], maxvalid)
return ms, ambig
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.currpc in sv.throw_blocks
return CallMeta(Any, false)
end
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
valid_worlds = WorldRange()
atype_params = unwrap_unionall(atype).parameters
splitunions = 1 < countunionsplit(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING
mts = Core.MethodTable[]
Expand All @@ -56,15 +38,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
xapplicable, ambig = matching_methods(sig_n, sv.matching_methods_cache, max_methods,
get_world_counter(interp), min_valid, max_valid)
if xapplicable === false
matches = findall(sig_n, method_table(interp); limit=max_methods)
if matches === missing
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
return CallMeta(Any, false)
end
push!(infos, MethodMatchInfo(xapplicable, ambig))
append!(applicable, xapplicable)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, xapplicable)
push!(infos, MethodMatchInfo(matches))
append!(applicable, matches)
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
Expand All @@ -86,19 +68,20 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(Any, false)
end
mt = mt::Core.MethodTable
applicable, ambig = matching_methods(atype, sv.matching_methods_cache, max_methods,
get_world_counter(interp), min_valid, max_valid)
if applicable === false
matches = findall(atype, method_table(interp, sv); limit=max_methods)
if matches === missing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
add_remark!(interp, sv, "Too many methods matched")
return CallMeta(Any, false)
end
push!(mts, mt)
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, applicable))
info = MethodMatchInfo(applicable, ambig)
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
info = MethodMatchInfo(matches)
applicable = matches.matches
valid_worlds = matches.valid_worlds
end
update_valid_age!(min_valid[1], max_valid[1], sv)
update_valid_age!(sv, valid_worlds)
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
rettype = Bottom
Expand Down Expand Up @@ -1460,12 +1443,7 @@ function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, caller)
no_active_ips_in_callers = false
end
if caller.min_valid < frame.min_valid
caller.min_valid = frame.min_valid
end
if caller.max_valid > frame.max_valid
caller.max_valid = frame.max_valid
end
caller.valid_worlds = intersect(caller.valid_worlds, frame.valid_worlds)
end
end
return true
Expand Down
32 changes: 24 additions & 8 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ end

const GLOBAL_CI_CACHE = InternalCodeCache()

struct WorldRange
min_world::UInt
max_world::UInt
end
WorldRange() = WorldRange(typemin(UInt), typemax(UInt))
WorldRange(w::UInt) = WorldRange(w, w)
WorldRange(r::UnitRange) = WorldRange(first(r), last(r))
first(wr::WorldRange) = wr.min_world
last(wr::WorldRange) = wr.max_world
in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world

function intersect(a::WorldRange, b::WorldRange)
ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world))
@assert ret.min_world <= ret.max_world
return ret
end

"""
struct WorldView
Expand All @@ -22,20 +39,19 @@ range of world ages, rather than defaulting to the current active world age.
"""
struct WorldView{Cache}
cache::Cache
min_world::UInt
max_world::UInt
worlds::WorldRange
WorldView(cache::Cache, range::WorldRange) where Cache = new{Cache}(cache, range)
end
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
WorldView(cache, world::UInt) = WorldView(cache, world, world)
WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) =
WorldView(wvc.cache, min_world, max_world)
WorldView(cache, args...) = WorldView(cache, WorldRange(args...))
WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance}
if r === nothing
return default
end
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")

include("compiler/cicache.jl")
include("compiler/methodtable.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")
include("compiler/cicache.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
Expand Down
30 changes: 15 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ mutable struct InferenceState
# info on the state of inference and the linfo
src::CodeInfo
world::UInt
min_valid::UInt
max_valid::UInt
valid_worlds::WorldRange
nargs::Int
stmt_types::Vector{Any}
stmt_edges::Vector{Any}
Expand Down Expand Up @@ -44,9 +43,10 @@ mutable struct InferenceState
inferred::Bool
dont_work_on_me::Bool

# cached results of calling `_methods_by_ftype`, including `min_valid` and
# `max_valid`, to be used in inlining
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
# The place to look up methods while working on this function.
# In particular, we cache method lookup results for the same function to
# fast path repeated queries.
method_table::CachedMethodTable{InternalMethodTable}

# The interpreter that created this inference state. Not looked at by
# NativeInterpreter. But other interpreters may use this to detect cycles
Expand Down Expand Up @@ -100,13 +100,12 @@ mutable struct InferenceState
inmodule = linfo.def::Module
end

min_valid = src.min_world
max_valid = src.max_world == typemax(UInt) ?
get_world_counter() : src.max_world
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
frame = new(
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
src, get_world_counter(interp), min_valid, max_valid,
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Expand All @@ -115,14 +114,16 @@ mutable struct InferenceState
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false, false,
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(),
CachedMethodTable(method_table(interp)),
interp)
result.result = frame
cached && push!(get_inference_cache(interp), result)
return frame
end
end

method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table

function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
src = retrieve_code_info(result.linfo)
Expand Down Expand Up @@ -202,14 +203,13 @@ end
_topmod(sv::InferenceState) = _topmod(sv.mod)

# work towards converging the valid age range for sv
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update")
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
sv.valid_worlds = intersect(worlds, sv.valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv)
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
old = frame.src.ssavaluetypes[ssa_id]
Expand Down
93 changes: 93 additions & 0 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
abstract type MethodTableView; end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
struct InternalMethodTable <: MethodTableView
A struct representing the state of the internal method table at a
particular world age.
"""
struct InternalMethodTable <: MethodTableView
world::UInt
end

"""
struct CachedMethodTable <: MethodTableView
Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T} <: MethodTableView
cache::IdDict{Any, Union{Missing, MethodLookupResult}}
table::T
end
CachedMethodTable(table::T) where T =
CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodLookupResult}}(),
table)

"""
findall(sig::Type{<:Tuple}, view::MethodTableView; limit=typemax(Int))
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exeeded the specified limit,
`missing` is returned.
"""
function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type{<:Tuple}), table::CachedMethodTable; limit::Int=typemax(Int))
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit=limit)
end
end

"""
findsup(sig::Type{<:Tuple}, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.
Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
"""
function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
sig, table.world, min_valid, max_valid)::Union{Method, Nothing}
result === nothing && return nothing
(result, WorldRange(min_valid[], max_valid[]))
end

# This query is not cached
findsup(sig::Type{<:Tuple}, table::CachedMethodTable) = findsup(sig, table.table)
24 changes: 9 additions & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ mutable struct OptimizationState
mod::Module
nargs::Int
world::UInt
min_valid::UInt
max_valid::UInt
valid_worlds::WorldRange
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
# cached results of calling `_methods_by_ftype` from inference, including
# `min_valid` and `max_valid`
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}
# TODO: This will be eliminated once optimization no longer needs to do method lookups
interp::AbstractInterpreter
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
Expand All @@ -33,9 +29,9 @@ mutable struct OptimizationState
return new(params, frame.linfo,
s_edges::Vector{Any},
src, frame.stmt_info, frame.mod, frame.nargs,
frame.world, frame.min_valid, frame.max_valid,
frame.world, frame.valid_worlds,
frame.sptypes, frame.slottypes, false,
frame.matching_methods_cache, interp)
interp)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
# prepare src for running optimization passes
Expand Down Expand Up @@ -64,9 +60,9 @@ mutable struct OptimizationState
return new(params, linfo,
s_edges::Vector{Any},
src, stmt_info, inmodule, nargs,
get_world_counter(), UInt(1), get_world_counter(),
get_world_counter(), WorldRange(UInt(1), get_world_counter()),
sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp)
interp)
end
end

Expand Down Expand Up @@ -110,11 +106,9 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)

_topmod(sv::OptimizationState) = _topmod(sv.mod)

function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.world <= sv.max_valid,
"invalid age range update")
function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange)
sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
end

Expand All @@ -126,7 +120,7 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState)
end

function add_backedge!(li::CodeInstance, caller::OptimizationState)
update_valid_age!(min_world(li), max_world(li), caller)
update_valid_age!(caller, WorldRange(min_world(li), max_world(li)))
add_backedge!(li.def, caller)
nothing
end
Expand Down
Loading

0 comments on commit 3423bde

Please sign in to comment.