Skip to content

Commit

Permalink
Refactor cache logic for easy replacement (JuliaLang#35831)
Browse files Browse the repository at this point in the history
* Refactor cache logic for easy replacement

This is the next step in the line of work started by JuliaLang#33955,
though a lot of enabling work towards this was previously done
by Jameson in his codegen-norecursion branch. The basic
thrust here is to allow external packages to manage their own
cache of compiled code that may have been generated using entirely
difference inference or compiler options. The GPU compilers are one
such example, but there are several others, including generating
code using offload compilers, such as XLA or compilers for secure
computation. A lot of this is just moving code arround to make it
clear exactly which parts of the code are accessing the internal
code cache (which is now its own type to make it obvious when
it's being accessed), as well as providing clear extension points
for custom cache implementations. The second part is to refactor
CodeInstance construction to separate construction and insertion
into the internal cache (so it can be inserted into an external
cache instead if desired). The last part of the change
is to give cgparams another hook that lets the caller replace
the cache lookup to be used by codegen.

* Update base/compiler/cicache.jl

Co-authored-by: Tim Besard <[email protected]>

* Apply suggestions from code review

Co-authored-by: Jameson Nash <[email protected]>

* Rename always_cache_tree -> !allow_discard_tree

Co-authored-by: Tim Besard <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
3 people authored and simeonschaub committed Aug 11, 2020
1 parent c5f2931 commit 684a3dc
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 100 deletions.
5 changes: 5 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ eval(Core, :(UpsilonNode(val) = $(Expr(:new, :UpsilonNode, :val))))
eval(Core, :(UpsilonNode() = $(Expr(:new, :UpsilonNode))))
eval(Core, :(LineInfoNode(@nospecialize(method), file::Symbol, line::Int, inlined_at::Int) =
$(Expr(:new, :LineInfoNode, :method, :file, :line, :inlined_at))))
eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32,
min_world::UInt, max_world::UInt) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))

Module(name::Symbol=:anonymous, std_imports::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool), name, std_imports)

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
mi = mi::MethodInstance
# decide if it's likely to be worthwhile
if !force_inference
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
cache_inlineable = declared_inline
if isdefined(code, :inferred) && !cache_inlineable
Expand Down
52 changes: 52 additions & 0 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
struct InternalCodeCache
Internally, each `MethodInstance` keep a unique global cache of code instances
that have been created for the given method instance, stratified by world age
ranges. This struct abstracts over access to this cache.
"""
struct InternalCodeCache
end

function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
end

const GLOBAL_CI_CACHE = InternalCodeCache()

"""
struct WorldView
Takes a given cache and provides access to the cache contents for the given
range of world ages, rather than defaulting to the current active world age.
"""
struct WorldView{Cache}
cache::Cache
min_world::UInt
max_world::UInt
end
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
WorldView(cache, world::UInt) = WorldView(cache, world, world)
WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) =
WorldView(wvc.cache, min_world, max_world)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
if r === nothing
return default
end
return r::CodeInstance
end

function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
r = get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end

setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance) =
setindex!(wvc.cache, ci, mi)
1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ include("compiler/validation.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")
include("compiler/cicache.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...)
return (sig, state)
end

function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any}, sv::OptimizationState)
function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any})
if isa(case, ConstantCase)
ir[SSAValue(idx)] = case.val
elseif isa(case, MethodInstance)
Expand Down Expand Up @@ -949,7 +949,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
methsp = methsp::SimpleVector
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
calltype)
handle_single_case!(ir, stmt, idx, result, true, todo, sv)
handle_single_case!(ir, stmt, idx, result, true, todo)
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
return nothing
end
Expand Down Expand Up @@ -1117,7 +1117,7 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo, sv)
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo)
continue
end
length(cases) == 0 && continue
Expand Down Expand Up @@ -1332,7 +1332,7 @@ function find_inferred(mi::MethodInstance, @nospecialize(atypes), sv::Optimizati
end
end

linfo = inf_for_methodinstance(sv.interp, mi, sv.world)
linfo = get(WorldView(code_cache(sv.interp), sv.world), mi, nothing)
if linfo isa CodeInstance
if invoke_api(linfo) == 2
# in this case function can be inlined to a constant
Expand Down
124 changes: 71 additions & 53 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool)
frame = InferenceState(result, cached, interp)
frame === nothing && return false
cached && (result.linfo.inInference = true)
cached && lock_mi_inference(interp, result.linfo)
return typeinf(interp, frame)
end

Expand Down Expand Up @@ -64,7 +64,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
caller.src.min_world = min_valid
caller.src.max_world = max_valid
if cached
cache_result(interp, caller.result, min_valid, max_valid)
cache_result!(interp, caller.result, min_valid, max_valid)
end
if max_valid == typemax(UInt)
# if we aren't cached, we don't need this edge
Expand All @@ -79,60 +79,78 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
return true
end

# inference completed on `me`
# update the MethodInstance and notify the edges
function cache_result(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
def = result.linfo.def
toplevel = !isa(result.linfo.def, Method)

# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
already_inferred = !result.linfo.inInference
if inf_for_methodinstance(interp, result.linfo, min_valid, max_valid) isa CodeInstance
already_inferred = true
end

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
inferred_result = result.src
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt,
may_compress=true, allow_discard_tree=true)
inferred_result = result.src
local const_flags::Int32
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
const_flags = 0x2
elseif isconstType(result.result)
rettype_const = result.result.parameters[1]
const_flags = 0x2
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
const_flags = 0x2
elseif isconstType(result.result)
rettype_const = result.result.parameters[1]
const_flags = 0x2
else
rettype_const = nothing
const_flags = 0x00
end
if !toplevel && inferred_result isa CodeInfo
cache_the_tree = result.src.inferred &&
rettype_const = nothing
const_flags = 0x00
end
if inferred_result isa CodeInfo
def = result.linfo.def
toplevel = !isa(def, Method)
if !toplevel
cache_the_tree = !allow_discard_tree || (result.src.inferred &&
(result.src.inlineable ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0)
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0))
if cache_the_tree
# compress code for non-toplevel thunks
nslots = length(inferred_result.slotflags)
resize!(inferred_result.slottypes, nslots)
resize!(inferred_result.slotnames, nslots)
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
if may_compress
nslots = length(inferred_result.slotflags)
resize!(inferred_result.slottypes, nslots)
resize!(inferred_result.slotnames, nslots)
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
end
else
inferred_result = nothing
end
end
end
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
end
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, inferred_result,
const_flags, min_valid, max_valid)
end
result.linfo.inInference = false
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
end
return CodeInstance(result.linfo,
widenconst(result.result), rettype_const, inferred_result,
const_flags, min_valid, max_valid)
end

# For the NativeInterpreter, we don't need to do an actual cache query to know
# if something was already inferred. If we reach this point, but the inference
# flag has been turned off, then it's in the cache. This is purely a performance
# optimization.
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) =
!mi.inInference
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) =
false

# inference completed on `me`
# update the MethodInstance
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
already_inferred = already_inferred_quick_test(interp, result.linfo)
if !already_inferred && haskey(WorldView(code_cache(interp), min_valid, max_valid), result.linfo)
already_inferred = true
end

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
code_cache(interp)[result.linfo] = CodeInstance(result, min_valid, max_valid)
end
unlock_mi_inference(interp, result.linfo)
nothing
end

Expand All @@ -142,7 +160,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# a top parent will be cached still, but not this intermediate work
# we can throw everything else away now
me.cached = false
me.linfo.inInference = false
unlock_mi_inference(interp, me.linfo)
me.src.inlineable = false
else
# annotate fulltree with type information
Expand Down Expand Up @@ -452,7 +470,7 @@ end
# compute (and cache) an inferred AST and return the current best estimate of the result type
function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atypes), sparams::SimpleVector, caller::InferenceState)
mi = specialize_method(method, atypes, sparams)::MethodInstance
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
update_valid_age!(min_world(code), max_world(code), caller)
if isdefined(code, :rettype_const)
Expand All @@ -470,12 +488,12 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
end
if frame === false
# completely new
mi.inInference = true
lock_mi_inference(interp, mi)
result = InferenceResult(mi)
frame = InferenceState(result, #=cached=#true, interp) # always use the cache for edge targets
if frame === nothing
# can't get the source for this, so we know nothing
mi.inInference = false
unlock_mi_inference(interp, mi)
return Any, nothing
end
if caller.cached || caller.limited # don't involve uncached functions in cycle resolution
Expand Down Expand Up @@ -524,7 +542,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
method = mi.def::Method
for i = 1:2 # test-and-lock-and-test
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
# see if this code already exists in the cache
inf = code.inferred
Expand Down Expand Up @@ -565,7 +583,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
end
end
end
mi.inInference = true
lock_mi_inference(interp, mi)
frame = InferenceState(InferenceResult(mi), #=cached=#true, interp)
frame === nothing && return nothing
typeinf(interp, frame)
Expand All @@ -582,7 +600,7 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
mi = specialize_method(method, atypes, sparams)::MethodInstance
for i = 1:2 # test-and-lock-and-test
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
# see if this rettype already exists in the cache
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
Expand Down
14 changes: 14 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,17 @@ InferenceParams(ni::NativeInterpreter) = ni.inf_params
OptimizationParams(ni::NativeInterpreter) = ni.opt_params
get_world_counter(ni::NativeInterpreter) = ni.world
get_inference_cache(ni::NativeInterpreter) = ni.cache

code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world)

"""
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)
Hint that `mi` is in inference to help accelerate bootstrapping. This helps limit the amount of wasted work we might do when inference is working on initially inferring itself by letting us detect when inference is already in progress and not running a second copy on it. This creates a data-race, but the entry point into this code from C (jl_type_infer) already includes detection and restriction on recursion, so it is hopefully mostly a benign problem (since it should really only happen during the first phase of bootstrapping that we encounter this flag).
"""
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = true; nothing)

"""
See lock_mi_inference
"""
unlock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = false; nothing)
5 changes: 0 additions & 5 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@ function retrieve_code_info(linfo::MethodInstance)
end
end

function inf_for_methodinstance(interp::AbstractInterpreter, mi::MethodInstance, min_world::UInt, max_world::UInt=min_world)
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, min_world, max_world)::Union{Nothing, CodeInstance}
end


# get a handle to the unique specialization object representing a particular instantiation of a call
function specialize_method(method::Method, @nospecialize(atypes), sparams::SimpleVector, preexisting::Bool=false)
if preexisting
Expand Down
7 changes: 5 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -981,17 +981,20 @@ struct CodegenParams
emit_function::Any
emitted_function::Any

lookup::Ptr{Cvoid}

function CodegenParams(; track_allocations::Bool=true, code_coverage::Bool=true,
static_alloc::Bool=true, prefer_specsig::Bool=false,
gnu_pubnames=true, debug_info_kind::Cint = default_debug_info_kind(),
module_setup=nothing, module_activation=nothing, raise_exception=nothing,
emit_function=nothing, emitted_function=nothing)
emit_function=nothing, emitted_function=nothing,
lookup::Ptr{Cvoid}=cglobal(:jl_rettype_inferred))
return new(
Cint(track_allocations), Cint(code_coverage),
Cint(static_alloc), Cint(prefer_specsig),
Cint(gnu_pubnames), debug_info_kind,
module_setup, module_activation, raise_exception,
emit_function, emitted_function)
emit_function, emitted_function, lookup)
end
end

Expand Down
Loading

0 comments on commit 684a3dc

Please sign in to comment.