Skip to content

Commit

Permalink
[AbstractInterpreter] Implement callbacks for cache invalidations (Ju…
Browse files Browse the repository at this point in the history
…liaLang#38370)

AbstractInterpreter allows down-stream compiler variants to implement
their own caching of inference results. This leads to the issue of that
invalidations of methods are not propagated to those caches, leaving
downstream implementations with their own variant of 265, with limited
ability to mitigate it.

This PR extends the invalidation scheme by allowing downstream
implementation to attach invalidation callbacks to MethodInstances.
MethodInstances are used as the key to cache and can therefore be used
to walk and invalidate the user cache.


Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
vchuravy and vtjnash committed Nov 24, 2020
1 parent ef3c20d commit bdcaee0
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
backedges = NULL;
}
jl_serialize_value(s, (jl_value_t*)backedges);
jl_serialize_value(s, (jl_value_t*)NULL); //callbacks
jl_serialize_value(s, (jl_value_t*)mi->cache);
}
else if (jl_is_code_instance(v)) {
Expand Down Expand Up @@ -1491,6 +1492,9 @@ static jl_value_t *jl_deserialize_value_method_instance(jl_serializer_state *s,
mi->backedges = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&mi->backedges);
if (mi->backedges)
jl_gc_wb(mi, mi->backedges);
mi->callbacks = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&mi->callbacks);
if (mi->callbacks)
jl_gc_wb(mi, mi->callbacks);
mi->cache = (jl_code_instance_t*)jl_deserialize_value(s, (jl_value_t**)&mi->cache);
if (mi->cache)
jl_gc_wb(mi, mi->cache);
Expand Down
37 changes: 37 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,40 @@ JL_DLLEXPORT jl_value_t *jl_debug_method_invalidation(int state)
return jl_nothing;
}

// call external callbacks registered with this method_instance
static void invalidate_external(jl_method_instance_t *mi, size_t max_world) {
jl_array_t *callbacks = mi->callbacks;
if (callbacks) {
// AbstractInterpreter allows for MethodInstances to be present in non-local caches
// inform those caches about the invalidation.
JL_TRY {
size_t i, l = jl_array_len(callbacks);
jl_value_t **args;
JL_GC_PUSHARGS(args, 3);
// these arguments are constant per call
args[1] = (jl_value_t*)mi;
args[2] = jl_box_uint32(max_world);

size_t last_age = jl_get_ptls_states()->world_age;
jl_get_ptls_states()->world_age = jl_get_world_counter();

jl_value_t **cbs = (jl_value_t**)jl_array_ptr_data(callbacks);
for (i = 0; i < l; i++) {
args[0] = cbs[i];
jl_apply(args, 3);
}
jl_get_ptls_states()->world_age = last_age;
JL_GC_POP();
}
JL_CATCH {
jl_printf((JL_STREAM*)STDERR_FILENO, "error in invalidation callback: ");
jl_static_show((JL_STREAM*)STDERR_FILENO, jl_current_exception());
jl_printf((JL_STREAM*)STDERR_FILENO, "\n");
jlbacktrace(); // writen to STDERR_FILENO
}
}
}

// recursively invalidate cached methods that had an edge to a replaced method
static void invalidate_method_instance(jl_method_instance_t *replaced, size_t max_world, int depth)
{
Expand Down Expand Up @@ -1526,6 +1560,7 @@ static void jl_method_table_invalidate(jl_methtable_t *mt, jl_typemap_entry_t *m
jl_method_instance_t *mi = (jl_method_instance_t*)jl_svecref(specializations, i);
if (mi) {
invalidated = 1;
invalidate_external(mi, methodentry->max_world);
invalidate_backedges(mi, methodentry->max_world, "jl_method_table_disable");
}
}
Expand Down Expand Up @@ -1644,6 +1679,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method
}
if (isect != jl_bottom_type) {
jl_method_instance_t *backedge = (jl_method_instance_t*)backedges[i];
invalidate_external(backedge, max_world);
invalidate_method_instance(backedge, max_world, 0);
invalidated = 1;
if (_jl_debug_method_invalidation)
Expand Down Expand Up @@ -1711,6 +1747,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method
continue;
}
jl_array_ptr_1d_push(oldmi, (jl_value_t*)mi);
invalidate_external(mi, max_world);
if (mi->backedges) {
invalidated = 1;
invalidate_backedges(mi, max_world, "jl_method_table_insert");
Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2337,21 +2337,23 @@ void jl_init_types(void) JL_GC_DISABLED
jl_method_instance_type =
jl_new_datatype(jl_symbol("MethodInstance"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(7,
jl_perm_symsvec(8,
"def",
"specTypes",
"sparam_vals",
"uninferred",
"backedges",
"callbacks",
"cache",
"inInference"),
jl_svec(7,
jl_svec(8,
jl_new_struct(jl_uniontype_type, jl_method_type, jl_module_type),
jl_any_type,
jl_simplevector_type,
jl_any_type,
jl_any_type,
jl_any_type,
jl_any_type,
jl_bool_type),
0, 1, 3);

Expand Down Expand Up @@ -2506,7 +2508,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svecset(jl_methtable_type->types, 10, jl_uint8_type);
jl_svecset(jl_methtable_type->types, 11, jl_uint8_type);
jl_svecset(jl_method_type->types, 11, jl_method_instance_type);
jl_svecset(jl_method_instance_type->types, 5, jl_code_instance_type);
jl_svecset(jl_method_instance_type->types, 6, jl_code_instance_type);
jl_svecset(jl_code_instance_type->types, 9, jl_voidpointer_type);
jl_svecset(jl_code_instance_type->types, 10, jl_voidpointer_type);

Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct _jl_method_instance_t {
jl_svec_t *sparam_vals; // static parameter values, indexed by def.method->sparam_syms
jl_value_t *uninferred; // cached uncompressed code, for generated functions, top-level thunks, or the interpreter
jl_array_t *backedges; // list of method-instances which contain a call into this method-instance
jl_array_t *callbacks; // list of callback functions to inform external caches about invalidations
struct _jl_code_instance_t *cache;
uint8_t inInference; // flags to tell if inference is running on this object
};
Expand Down
1 change: 1 addition & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ JL_DLLEXPORT jl_method_instance_t *jl_new_method_instance_uninit(void)
li->sparam_vals = jl_emptysvec;
li->uninferred = NULL;
li->backedges = NULL;
li->callbacks = NULL;
li->cache = NULL;
li->inInference = 0;
return li;
Expand Down

0 comments on commit bdcaee0

Please sign in to comment.