Skip to content

Commit

Permalink
interpreter: Use world-age-partitioned cache for @generated results (#…
Browse files Browse the repository at this point in the history
…54362)

This fixes #54360 by moving the interpreter's cache of `@generated`
results from `mi.uninferred` into `mi.cache` with a separate cache owner
to partition the cache from regular inference results. There are two
other uses of the `mi.uninferred` field:

1. As the place to store uninferred code for temporary top-level thunks
2. Is an uncompressed copy of m->source to avoid having to re-uncompress
every time in the interpreter.

In this PR, use case 1 is changed to use the same mechanism as generated
functions. Use case 2 is changed to just uncompress the source in place
in m->source. As a result, the `uninferred` field is unused and removed.

Note that I'm planning a somewhat larger refactor of `MethodInstance` in
the immediate future, so this might be a somewhat shortlived
representation, but that change should hopefully by largely transparent
to users of the wrappers introduced here.
  • Loading branch information
Keno committed May 7, 2024
1 parent c77671a commit dbf0bab
Show file tree
Hide file tree
Showing 25 changed files with 258 additions and 99 deletions.
16 changes: 14 additions & 2 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,28 @@ function get_staged(mi::MethodInstance, world::UInt)
may_invoke_generator(mi) || return nothing
try
# user code might throw errors – ignore them
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
ci = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)::CodeInfo
return ci
catch
return nothing
end
end

function get_cached_uninferred(mi::MethodInstance, world::UInt)
ccall(:jl_cached_uninferred, Any, (Any, UInt), mi.cache, world)::CodeInstance
end

function retrieve_code_info(mi::MethodInstance, world::UInt)
def = mi.def
isa(def, Method) || return mi.uninferred
if !isa(def, Method)
ci = get_cached_uninferred(mi, world)
src = ci.inferred
# Inference may corrupt the src, which is fine, because this is a
# (short-lived) top-level thunk, but set it to NULL anyway, so we
# can catch it if somebody tries to read it again by accident.
# @atomic ci.inferred = C_NULL
return src
end
c = isdefined(def, :generator) ? get_staged(mi, world) : nothing
if c === nothing && isdefined(def, :source)
src = def.source
Expand Down
1 change: 1 addition & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const __next_removal_version = v"1.12-alpha"
const __internal_changes_list = (
:invertedlinetables,
:codeinforefactor,
:miuninferredrm,
# Add new change names above this line
)

Expand Down
4 changes: 2 additions & 2 deletions base/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ rethrow() = ccall(:jl_rethrow, Bottom, ())
rethrow(@nospecialize(e)) = ccall(:jl_rethrow_other, Bottom, (Any,), e)

struct InterpreterIP
code::Union{CodeInfo,Core.MethodInstance,Nothing}
code::Union{CodeInfo,Core.MethodInstance,Core.CodeInstance,Nothing}
stmt::Csize_t
mod::Union{Module,Nothing}
end
Expand All @@ -96,7 +96,7 @@ function _reformat_bt(bt::Array{Ptr{Cvoid},1}, bt2::Array{Any,1})
tag = (entry_metadata >> 6) & 0xf
header = entry_metadata >> 10
if tag == 1 # JL_BT_INTERP_FRAME_TAG
code = bt2[j]::Union{CodeInfo,Core.MethodInstance,Nothing}
code = bt2[j]::Union{CodeInfo,Core.MethodInstance,Core.CodeInstance,Nothing}
mod = njlvalues == 2 ? bt2[j+1]::Union{Module,Nothing} : nothing
push!(ret, InterpreterIP(code, header, mod))
else
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
for m in method_instances(f, t, world)
if generated && hasgenerator(m)
if may_invoke_generator(m)
code = ccall(:jl_code_for_staged, Any, (Any, UInt), m, world)::CodeInfo
code = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)::CodeInfo
else
error("Could not expand generator for `@generated` method ", m, ". ",
"This can happen if the provided argument types (", t, ") are ",
Expand Down
9 changes: 5 additions & 4 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,10 @@ function show_mi(io::IO, mi::Core.MethodInstance, from_stackframe::Bool=false)
# MethodInstance is part of a stacktrace, it gets location info
# added by other means. But if it isn't, then we should try
# to print a little more identifying information.
if !from_stackframe
if isdefined(mi, :uninferred)
di = mi.uninferred.debuginfo
if !from_stackframe && isdefined(mi, :cache)
ci = mi.cache
if ci.owner === :uninferred
di = ci.inferred.debuginfo
file, line = IRShow.debuginfo_firstline(di)
file = string(file)
line = isempty(file) || line < 0 ? "<unknown>" : "$file:$line"
Expand All @@ -1381,7 +1382,7 @@ function show(io::IO, mi_info::Core.Compiler.Timings.InferenceFrameInfo)
show_tuple_as_call(io, def.name, mi.specTypes; argnames, qualified=true)
end
else
di = mi.uninferred.debuginfo
di = mi.cache.inferred.debuginfo
file, line = IRShow.debuginfo_firstline(di)
file = string(file)
line = isempty(file) || line < 0 ? "<unknown>" : "$file:$line"
Expand Down
7 changes: 6 additions & 1 deletion base/stacktraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,18 @@ function lookup(ip::Union{Base.InterpreterIP,Core.Compiler.InterpreterIP})
# interpreted top-level expression with no CodeInfo
return [StackFrame(top_level_scope_sym, empty_sym, 0, nothing, false, false, 0)]
end
codeinfo = (code isa MethodInstance ? code.uninferred : code)::CodeInfo
# prepare approximate code info
if code isa MethodInstance && (meth = code.def; meth isa Method)
func = meth.name
file = meth.file
line = meth.line
codeinfo = meth.source
else
if code isa Core.CodeInstance
codeinfo = code.inferred::CodeInfo
else
codeinfo = code::CodeInfo
end
func = top_level_scope_sym
file = empty_sym
line = Int32(0)
Expand Down
5 changes: 0 additions & 5 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,11 +605,6 @@ for important details on how to modify these fields safely.
For the `MethodInstance` at `Method.unspecialized`, this is the empty `SimpleVector`.
But for a runtime `MethodInstance` from the `MethodTable` cache, this will always be defined and indexable.

* `uninferred`

The uncompressed source code for a toplevel thunk. Additionally, for a generated function,
this is one of many places that the source code might be found.

* `backedges`

We store the reverse-list of cache dependencies for efficient tracking of incremental reanalysis/recompilation work that may be needed after a new method definitions.
Expand Down
3 changes: 2 additions & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ JL_DLLEXPORT jl_sym_t *jl_acquire_sym;
JL_DLLEXPORT jl_sym_t *jl_release_sym;
JL_DLLEXPORT jl_sym_t *jl_acquire_release_sym;
JL_DLLEXPORT jl_sym_t *jl_sequentially_consistent_sym;

JL_DLLEXPORT jl_sym_t *jl_uninferred_sym;

static const uint8_t flisp_system_image[] = {
#include <julia_flisp.boot.inc>
Expand Down Expand Up @@ -416,6 +416,7 @@ void jl_init_common_symbols(void)
jl_release_sym = jl_symbol("release");
jl_acquire_release_sym = jl_symbol("acquire_release");
jl_sequentially_consistent_sym = jl_symbol("sequentially_consistent");
jl_uninferred_sym = jl_symbol("uninferred");
}

JL_DLLEXPORT void jl_lisp_prompt(void)
Expand Down
26 changes: 24 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,29 @@ JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMEN
return;
}

JL_DLLEXPORT int jl_mi_try_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *expected_ci,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED)
{
JL_GC_PUSH1(&ci);
if (jl_is_method(mi->def.method))
JL_LOCK(&mi->def.method->writelock);
jl_code_instance_t *oldci = jl_atomic_load_relaxed(&mi->cache);
int ret = 0;
if (oldci == expected_ci) {
jl_atomic_store_relaxed(&ci->next, oldci);
if (oldci)
jl_gc_wb(ci, oldci);
jl_atomic_store_release(&mi->cache, ci);
jl_gc_wb(mi, ci);
ret = 1;
}
if (jl_is_method(mi->def.method))
JL_UNLOCK(&mi->def.method->writelock);
JL_GC_POP();
return ret;
}

static int get_method_unspec_list(jl_typemap_entry_t *def, void *closure)
{
size_t world = jl_atomic_load_acquire(&jl_world_counter);
Expand Down Expand Up @@ -2587,8 +2610,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
jl_callptr_t ucache_invoke = jl_atomic_load_acquire(&ucache->invoke);
if (ucache_invoke == NULL) {
if ((!jl_is_method(def) || def->source == jl_nothing) &&
(jl_atomic_load_relaxed(&ucache->def->uninferred) == jl_nothing ||
jl_atomic_load_relaxed(&ucache->def->uninferred) == NULL)) {
!jl_cached_uninferred(jl_atomic_load_relaxed(&ucache->def->cache), world)) {
jl_throw(jl_new_struct(jl_missingcodeerror_type, (jl_value_t*)mi));
}
jl_generate_fptr_for_unspecialized(ucache);
Expand Down
70 changes: 54 additions & 16 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ extern "C" {
typedef struct {
jl_code_info_t *src; // contains the names and number of slots
jl_method_instance_t *mi; // MethodInstance we're executing, or NULL if toplevel
jl_code_instance_t *ci; // CodeInstance we're executing (for generated functions)
jl_module_t *module; // context for globals
jl_value_t **locals; // slots for holding local slots and ssavalues
jl_svec_t *sparam_vals; // method static parameters, if eval-ing a method body
Expand Down Expand Up @@ -683,31 +684,55 @@ static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip,

// preparing method IR for interpreter

jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *mi, size_t world)
jl_value_t *jl_code_or_ci_for_interpreter(jl_method_instance_t *mi, size_t world)
{
jl_code_info_t *src = (jl_code_info_t*)jl_atomic_load_relaxed(&mi->uninferred);
jl_value_t *ret = NULL;
jl_code_info_t *src = NULL;
if (jl_is_method(mi->def.value)) {
if (!src || (jl_value_t*)src == jl_nothing) {
if (mi->def.method->source) {
src = (jl_code_info_t*)mi->def.method->source;
if (mi->def.method->source) {
jl_method_t *m = mi->def.method;
src = (jl_code_info_t*)m->source;
if (!jl_is_code_info(src)) {
src = jl_uncompress_ir(mi->def.method, NULL, (jl_value_t*)src);
// Replace the method source by the uncompressed version,
// under the assumption that the interpreter may need to
// access it frequently. TODO: Have some sort of usage-based
// cache here.
m->source = (jl_value_t*)src;
jl_gc_wb(m, src);
}
else {
ret = (jl_value_t*)src;
}
else {
jl_code_instance_t *cache = jl_atomic_load_relaxed(&mi->cache);
jl_code_instance_t *uninferred = jl_cached_uninferred(cache, world);
if (!uninferred) {
assert(mi->def.method->generator);
src = jl_code_for_staged(mi, world);
src = jl_code_for_staged(mi, world, &uninferred);
}
ret = (jl_value_t*)uninferred;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred);
}
if (src && (jl_value_t*)src != jl_nothing) {
JL_GC_PUSH1(&src);
src = jl_uncompress_ir(mi->def.method, NULL, (jl_value_t*)src);
jl_atomic_store_release(&mi->uninferred, (jl_value_t*)src);
jl_gc_wb(mi, src);
JL_GC_POP();
}
else {
jl_code_instance_t *uninferred = jl_cached_uninferred(jl_atomic_load_relaxed(&mi->cache), world);
ret = (jl_value_t*)uninferred;
if (ret) {
src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred);
}
}
if (!src || !jl_is_code_info(src)) {
jl_throw(jl_new_struct(jl_missingcodeerror_type, (jl_value_t*)mi));
}
return src;
return ret;
}

jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *mi, size_t world)
{
jl_value_t *code_or_ci = jl_code_or_ci_for_interpreter(mi, world);
if (jl_is_code_instance(code_or_ci))
return (jl_code_info_t*)jl_atomic_load_relaxed(&((jl_code_instance_t*)code_or_ci)->inferred);
return (jl_code_info_t*)code_or_ci;
}

// interpreter entry points
Expand All @@ -718,7 +743,15 @@ jl_value_t *NOINLINE jl_fptr_interpret_call(jl_value_t *f, jl_value_t **args, ui
jl_method_instance_t *mi = codeinst->def;
jl_task_t *ct = jl_current_task;
size_t world = ct->world_age;
jl_code_info_t *src = jl_code_for_interpreter(mi, world);
jl_code_info_t *src = NULL;
jl_value_t *code = jl_code_or_ci_for_interpreter(mi, world);
jl_code_instance_t *ci = NULL;
if (jl_is_code_instance(code)) {
ci = (jl_code_instance_t*)code;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&ci->inferred);
} else {
src = (jl_code_info_t*)code;
}
jl_array_t *stmts = src->code;
assert(jl_typetagis(stmts, jl_array_any_type));
unsigned nroots = jl_source_nslots(src) + jl_source_nssavalues(src) + 2;
Expand Down Expand Up @@ -749,6 +782,7 @@ jl_value_t *NOINLINE jl_fptr_interpret_call(jl_value_t *f, jl_value_t **args, ui
s->preevaluation = 0;
s->continue_at = 0;
s->mi = mi;
s->ci = ci;
JL_GC_ENABLEFRAME(s);
jl_value_t *r = eval_body(stmts, s, 0, 0);
JL_GC_POP();
Expand Down Expand Up @@ -792,6 +826,7 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
s->preevaluation = 0;
s->continue_at = 0;
s->mi = NULL;
s->ci = NULL;
size_t defargs = source->nargs;
int isva = source->isva;
assert(isva ? nargs + 2 >= defargs : nargs + 1 == defargs);
Expand Down Expand Up @@ -823,6 +858,7 @@ jl_value_t *NOINLINE jl_interpret_toplevel_thunk(jl_module_t *m, jl_code_info_t
s->sparam_vals = jl_emptysvec;
s->continue_at = 0;
s->mi = NULL;
s->ci = NULL;
JL_GC_ENABLEFRAME(s);
jl_task_t *ct = jl_current_task;
size_t last_age = ct->world_age;
Expand All @@ -847,6 +883,7 @@ jl_value_t *NOINLINE jl_interpret_toplevel_expr_in(jl_module_t *m, jl_value_t *e
s->preevaluation = (sparam_vals != NULL);
s->continue_at = 0;
s->mi = NULL;
s->ci = NULL;
JL_GC_ENABLEFRAME(s);
jl_value_t *v = eval_value(e, s);
assert(v);
Expand All @@ -866,7 +903,8 @@ JL_DLLEXPORT size_t jl_capture_interp_frame(jl_bt_element_t *bt_entry,
uintptr_t entry_tags = jl_bt_entry_descriptor(njlvalues, 0, JL_BT_INTERP_FRAME_TAG, s->ip);
bt_entry[0].uintptr = JL_BT_NON_PTR_ENTRY;
bt_entry[1].uintptr = entry_tags;
bt_entry[2].jlvalue = s->mi ? (jl_value_t*)s->mi :
bt_entry[2].jlvalue = s->ci ? (jl_value_t*)s->ci :
s->mi ? (jl_value_t*)s->mi :
s->src ? (jl_value_t*)s->src : (jl_value_t*)jl_nothing;
if (need_module) {
// If we only have a CodeInfo (s->src), we are in a top level thunk and
Expand Down
6 changes: 5 additions & 1 deletion src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,11 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
src = jl_uncompress_ir(def, NULL, (jl_value_t*)src);
}
else {
src = (jl_code_info_t*)jl_atomic_load_relaxed(&unspec->def->uninferred);
jl_method_instance_t *mi = unspec->def;
jl_code_instance_t *uninferred = jl_cached_uninferred(
jl_atomic_load_relaxed(&mi->cache), 1);
assert(uninferred);
src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred);
assert(src);
}
if (src) {
Expand Down
10 changes: 4 additions & 6 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3283,21 +3283,19 @@ void jl_init_types(void) JL_GC_DISABLED
jl_method_instance_type =
jl_new_datatype(jl_symbol("MethodInstance"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(9,
jl_perm_symsvec(8,
"def",
"specTypes",
"sparam_vals",
"uninferred",
"backedges",
"cache",
"inInference",
"cache_with_orig",
"precompiled"),
jl_svec(9,
jl_svec(8,
jl_new_struct(jl_uniontype_type, jl_method_type, jl_module_type),
jl_any_type,
jl_simplevector_type,
jl_any_type,
jl_array_any_type,
jl_any_type/*jl_code_instance_type*/,
jl_bool_type,
Expand All @@ -3307,7 +3305,7 @@ void jl_init_types(void) JL_GC_DISABLED
0, 1, 3);
// These fields should be constant, but Serialization wants to mutate them in initialization
//const static uint32_t method_instance_constfields[1] = { 0x00000007 }; // (1<<0)|(1<<1)|(1<<2);
const static uint32_t method_instance_atomicfields[1] = { 0x00000128 }; // (1<<3)|(1<<5)|(1<<8);
const static uint32_t method_instance_atomicfields[1] = { 0x0000090 }; // (1<<4)|(1<<7);
//Fields 4 and 5 must be protected by method->write_lock, and thus all operations on jl_method_instance_t are threadsafe. TODO: except inInference
//jl_method_instance_type->name->constfields = method_instance_constfields;
jl_method_instance_type->name->atomicfields = method_instance_atomicfields;
Expand Down Expand Up @@ -3496,7 +3494,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svecset(jl_methtable_type->types, 10, jl_uint8_type);
jl_svecset(jl_method_type->types, 13, jl_method_instance_type);
//jl_svecset(jl_debuginfo_type->types, 0, jl_method_instance_type); // union(jl_method_instance_type, jl_method_type, jl_symbol_type)
jl_svecset(jl_method_instance_type->types, 5, jl_code_instance_type);
jl_svecset(jl_method_instance_type->types, 4, jl_code_instance_type);
jl_svecset(jl_code_instance_type->types, 16, jl_voidpointer_type);
jl_svecset(jl_code_instance_type->types, 17, jl_voidpointer_type);
jl_svecset(jl_binding_type->types, 1, jl_globalref_type);
Expand Down
3 changes: 1 addition & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ struct _jl_method_instance_t {
} def; // pointer back to the context for this code
jl_value_t *specTypes; // argument types this was specialized for
jl_svec_t *sparam_vals; // static parameter values, indexed by def.method->sig
_Atomic(jl_value_t*) uninferred; // cached uncompressed code, for generated functions, top-level thunks, or the interpreter
jl_array_t *backedges; // list of method-instances which call this method-instance; `invoke` records (invokesig, caller) pairs
_Atomic(struct _jl_code_instance_t*) cache;
uint8_t inInference; // flags to tell if inference is running on this object
Expand Down Expand Up @@ -1804,7 +1803,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
_Atomic(jl_value_t*) *bp,
jl_binding_t *bnd);
JL_DLLEXPORT jl_method_t *jl_method_def(jl_svec_t *argdata, jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module);
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world);
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache);
JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src);
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_box_bool(int8_t x) JL_NOTSAFEPOINT;
Expand Down
Loading

0 comments on commit dbf0bab

Please sign in to comment.