Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cache logic for easy replacement #35831

Merged
merged 4 commits into from
Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactor cache logic for easy replacement
This is the next step in the line of work started by #33955,
though a lot of enabling work towards this was previously done
by Jameson in his codegen-norecursion branch. The basic
thrust here is to allow external packages to manage their own
cache of compiled code that may have been generated using entirely
difference inference or compiler options. The GPU compilers are one
such example, but there are several others, including generating
code using offload compilers, such as XLA or compilers for secure
computation. A lot of this is just moving code arround to make it
clear exactly which parts of the code are accessing the internal
code cache (which is now its own type to make it obvious when
it's being accessed), as well as providing clear extension points
for custom cache implementations. The second part is to refactor
CodeInstance construction to separate construction and insertion
into the internal cache (so it can be inserted into an external
cache instead if desired). The last part of the change
is to give cgparams another hook that lets the caller replace
the cache lookup to be used by codegen.
  • Loading branch information
Keno committed May 12, 2020
commit c79044c3d52d0689dbbcd6c7ae0e75bc8f869bfe
5 changes: 5 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ eval(Core, :(UpsilonNode(val) = $(Expr(:new, :UpsilonNode, :val))))
eval(Core, :(UpsilonNode() = $(Expr(:new, :UpsilonNode))))
eval(Core, :(LineInfoNode(@nospecialize(method), file::Symbol, line::Int, inlined_at::Int) =
$(Expr(:new, :LineInfoNode, :method, :file, :line, :inlined_at))))
eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32,
min_world::UInt, max_world::UInt) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))

Module(name::Symbol=:anonymous, std_imports::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool), name, std_imports)

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
mi = mi::MethodInstance
# decide if it's likely to be worthwhile
if !force_inference
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
declared_inline = isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
cache_inlineable = declared_inline
if isdefined(code, :inferred) && !cache_inlineable
Expand Down
52 changes: 52 additions & 0 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
struct InternalCodeCache

Internally, each `MethodInstance` keep a unique global cache of code instances
that have been created for the given method instance, stratified by world age
ranges. This struct abstracts over access to this cache.
"""
struct InternalCodeCache
end

function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
end

const GLOBAL_CI_CACHE = InternalCodeCache()

"""
struct WorldView

Takes a given cache and provides access to the cache contents for the given
range of world ages, rather than defaulting to the current active world age.
"""
struct WorldView{Cache}
cache::Cache
min_world::UInt
max_world::UInt
end
WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r))
WorldView(cache, world::UInt) = WorldView(cache, world, world)
WorldView(wvc::WorldView{InternalCodeCache}, min_world::UInt, max_world::UInt) =
Keno marked this conversation as resolved.
Show resolved Hide resolved
WorldView(wvc.cache, min_world, max_world)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance}
if r === nothing
return default
end
return r::CodeInstance
end

function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
r = get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end

setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance) =
setindex!(wvc.cache, ci, mi)
1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ include("compiler/validation.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")
include("compiler/cicache.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...)
return (sig, state)
end

function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any}, sv::OptimizationState)
function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Any})
if isa(case, ConstantCase)
ir[SSAValue(idx)] = case.val
elseif isa(case, MethodInstance)
Expand Down Expand Up @@ -935,7 +935,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
methsp = methsp::SimpleVector
result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data,
calltype)
handle_single_case!(ir, stmt, idx, result, true, todo, sv)
handle_single_case!(ir, stmt, idx, result, true, todo)
update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv)
return nothing
end
Expand Down Expand Up @@ -1103,7 +1103,7 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo, sv)
handle_single_case!(ir, stmt, idx, cases[1][2], false, todo)
continue
end
length(cases) == 0 && continue
Expand Down Expand Up @@ -1318,7 +1318,7 @@ function find_inferred(mi::MethodInstance, @nospecialize(atypes), sv::Optimizati
end
end

linfo = inf_for_methodinstance(sv.interp, mi, sv.world)
linfo = get(WorldView(code_cache(sv.interp), sv.world), mi, nothing)
if linfo isa CodeInstance
if invoke_api(linfo) == 2
# in this case function can be inlined to a constant
Expand Down
124 changes: 71 additions & 53 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool)
frame = InferenceState(result, cached, interp)
frame === nothing && return false
cached && (result.linfo.inInference = true)
cached && lock_mi_inference(interp, result.linfo)
return typeinf(interp, frame)
end

Expand Down Expand Up @@ -64,7 +64,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
caller.src.min_world = min_valid
caller.src.max_world = max_valid
if cached
cache_result(interp, caller.result, min_valid, max_valid)
cache_result!(interp, caller.result, min_valid, max_valid)
end
if max_valid == typemax(UInt)
# if we aren't cached, we don't need this edge
Expand All @@ -79,60 +79,78 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
return true
end

# inference completed on `me`
# update the MethodInstance and notify the edges
function cache_result(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
def = result.linfo.def
toplevel = !isa(result.linfo.def, Method)

# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
already_inferred = !result.linfo.inInference
if inf_for_methodinstance(interp, result.linfo, min_valid, max_valid) isa CodeInstance
already_inferred = true
end

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
inferred_result = result.src
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt,
may_compress=true, always_cache_tree=false)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"always_cache" sounds like a bit of a misnomer, since there's always going to be intermediate work that we discarded because it'd be invalid to call cache_result on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allow_discard_tree=true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll just rename this to allow_discard_tree and then get this merged so I can put up some of the follow on work to this.

inferred_result = result.src
local const_flags::Int32
if inferred_result isa Const
# use constant calling convention
rettype_const = (result.src::Const).val
const_flags = 0x3
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
const_flags = 0x2
elseif isconstType(result.result)
rettype_const = result.result.parameters[1]
const_flags = 0x2
else
if isa(result.result, Const)
rettype_const = (result.result::Const).val
const_flags = 0x2
elseif isconstType(result.result)
rettype_const = result.result.parameters[1]
const_flags = 0x2
else
rettype_const = nothing
const_flags = 0x00
end
if !toplevel && inferred_result isa CodeInfo
cache_the_tree = result.src.inferred &&
rettype_const = nothing
const_flags = 0x00
end
if inferred_result isa CodeInfo
def = result.linfo.def
toplevel = !isa(def, Method)
if !toplevel
cache_the_tree = always_cache_tree || (result.src.inferred &&
(result.src.inlineable ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0)
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), result.linfo.specTypes, def) != 0))
if cache_the_tree
# compress code for non-toplevel thunks
nslots = length(inferred_result.slotflags)
resize!(inferred_result.slottypes, nslots)
resize!(inferred_result.slotnames, nslots)
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
if may_compress
nslots = length(inferred_result.slotflags)
resize!(inferred_result.slottypes, nslots)
resize!(inferred_result.slotnames, nslots)
inferred_result = ccall(:jl_compress_ir, Any, (Any, Any), def, inferred_result)
end
else
inferred_result = nothing
end
end
end
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
end
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, inferred_result,
const_flags, min_valid, max_valid)
end
result.linfo.inInference = false
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
end
return CodeInstance(result.linfo,
widenconst(result.result), rettype_const, inferred_result,
const_flags, min_valid, max_valid)
end

# For the NativeInterpreter, we don't need to do an actual cache query to know
# if something was already inferred. If we reach this point, but the inference
# flag has been turned off, then it's in the cache. This is purely a performance
# optimization.
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) =
!mi.inInference
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) =
false

# inference completed on `me`
# update the MethodInstance and notify the edges
Keno marked this conversation as resolved.
Show resolved Hide resolved
function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt)
# check if the existing linfo metadata is also sufficient to describe the current inference result
# to decide if it is worth caching this
already_inferred = already_inferred_quick_test(interp, result.linfo)
if !already_inferred && haskey(WorldView(code_cache(interp), min_valid, max_valid), result.linfo)
already_inferred = true
end

# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
code_cache(interp)[result.linfo] = CodeInstance(result, min_valid, max_valid)
end
unlock_mi_inference(interp, result.linfo)
nothing
end

Expand All @@ -142,7 +160,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# a top parent will be cached still, but not this intermediate work
# we can throw everything else away now
me.cached = false
me.linfo.inInference = false
unlock_mi_inference(interp, me.linfo)
me.src.inlineable = false
else
# annotate fulltree with type information
Expand Down Expand Up @@ -452,7 +470,7 @@ end
# compute (and cache) an inferred AST and return the current best estimate of the result type
function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize(atypes), sparams::SimpleVector, caller::InferenceState)
mi = specialize_method(method, atypes, sparams)::MethodInstance
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance # return existing rettype if the code is already inferred
update_valid_age!(min_world(code), max_world(code), caller)
if isdefined(code, :rettype_const)
Expand All @@ -470,12 +488,12 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
end
if frame === false
# completely new
mi.inInference = true
lock_mi_inference(interp, mi)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that this is a lock in the proper sense. It's a bootstrapping device, but it's also merely an advisory hint.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a lock in the atomic sense, but I do think it's a semantic lock.

result = InferenceResult(mi)
frame = InferenceState(result, #=cached=#true, interp) # always use the cache for edge targets
if frame === nothing
# can't get the source for this, so we know nothing
mi.inInference = false
unlock_mi_inference(interp, mi)
return Any, nothing
end
if caller.cached || caller.limited # don't involve uncached functions in cycle resolution
Expand Down Expand Up @@ -524,7 +542,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
method = mi.def::Method
for i = 1:2 # test-and-lock-and-test
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
# see if this code already exists in the cache
inf = code.inferred
Expand Down Expand Up @@ -565,7 +583,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
end
end
end
mi.inInference = true
lock_mi_inference(interp, mi)
frame = InferenceState(InferenceResult(mi), #=cached=#true, interp)
frame === nothing && return nothing
typeinf(interp, frame)
Expand All @@ -582,7 +600,7 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
mi = specialize_method(method, atypes, sparams)::MethodInstance
for i = 1:2 # test-and-lock-and-test
i == 2 && ccall(:jl_typeinf_begin, Cvoid, ())
code = inf_for_methodinstance(interp, mi, get_world_counter(interp))
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
# see if this rettype already exists in the cache
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
Expand Down
22 changes: 22 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,25 @@ InferenceParams(ni::NativeInterpreter) = ni.inf_params
OptimizationParams(ni::NativeInterpreter) = ni.opt_params
get_world_counter(ni::NativeInterpreter) = ni.world
get_inference_cache(ni::NativeInterpreter) = ni.cache

code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world)

"""
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)

Locks `mi`'s inference flag to prevent infinite recursion inside inference.
While running inference, we must occaisionally compile additional code, which
may in turn request additional code to be inferred. This can happen during
bootstrap (for part of inference itself), but also during regular execution (e.g.
to expand generated functions). The inference flag lets the runtime know that
it should not attempt to re-infer the requested functions as it is being worked
on higher in the stack. Not that inference itself does not look at this flag,
instead checking its own working caches - it is only used for locking the C
runtime.
Keno marked this conversation as resolved.
Show resolved Hide resolved
"""
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = true; nothing)

"""
See lock_mi_inference
"""
unlock_mi_inference(ni::NativeInterpreter, mi::MethodInstance) = (mi.inInference = false; nothing)
5 changes: 0 additions & 5 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@ function retrieve_code_info(linfo::MethodInstance)
end
end

function inf_for_methodinstance(interp::AbstractInterpreter, mi::MethodInstance, min_world::UInt, max_world::UInt=min_world)
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, min_world, max_world)::Union{Nothing, CodeInstance}
end


# get a handle to the unique specialization object representing a particular instantiation of a call
function specialize_method(method::Method, @nospecialize(atypes), sparams::SimpleVector, preexisting::Bool=false)
if preexisting
Expand Down
7 changes: 5 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -981,17 +981,20 @@ struct CodegenParams
emit_function::Any
emitted_function::Any

lookup::Ptr{Cvoid}

function CodegenParams(; track_allocations::Bool=true, code_coverage::Bool=true,
static_alloc::Bool=true, prefer_specsig::Bool=false,
gnu_pubnames=true, debug_info_kind::Cint = default_debug_info_kind(),
module_setup=nothing, module_activation=nothing, raise_exception=nothing,
emit_function=nothing, emitted_function=nothing)
emit_function=nothing, emitted_function=nothing,
lookup::Ptr{Cvoid}=cglobal(:jl_rettype_inferred))
return new(
Cint(track_allocations), Cint(code_coverage),
Cint(static_alloc), Cint(prefer_specsig),
Cint(gnu_pubnames), debug_info_kind,
module_setup, module_activation, raise_exception,
emit_function, emitted_function)
emit_function, emitted_function, lookup)
end
end

Expand Down
Loading