From bdcaee033e8913459b340a2116e667770b4f0bc7 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 23 Nov 2020 22:27:08 -0500 Subject: [PATCH] [AbstractInterpreter] Implement callbacks for cache invalidations (#38370) AbstractInterpreter allows down-stream compiler variants to implement their own caching of inference results. This leads to the issue of that invalidations of methods are not propagated to those caches, leaving downstream implementations with their own variant of 265, with limited ability to mitigate it. This PR extends the invalidation scheme by allowing downstream implementation to attach invalidation callbacks to MethodInstances. MethodInstances are used as the key to cache and can therefore be used to walk and invalidate the user cache. Co-authored-by: Jameson Nash --- src/dump.c | 4 ++++ src/gf.c | 37 +++++++++++++++++++++++++++++++++++++ src/jltypes.c | 8 +++++--- src/julia.h | 1 + src/method.c | 1 + 5 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/dump.c b/src/dump.c index 7e9ba2f34d79b..654d47f35a18a 100644 --- a/src/dump.c +++ b/src/dump.c @@ -641,6 +641,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li backedges = NULL; } jl_serialize_value(s, (jl_value_t*)backedges); + jl_serialize_value(s, (jl_value_t*)NULL); //callbacks jl_serialize_value(s, (jl_value_t*)mi->cache); } else if (jl_is_code_instance(v)) { @@ -1491,6 +1492,9 @@ static jl_value_t *jl_deserialize_value_method_instance(jl_serializer_state *s, mi->backedges = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&mi->backedges); if (mi->backedges) jl_gc_wb(mi, mi->backedges); + mi->callbacks = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&mi->callbacks); + if (mi->callbacks) + jl_gc_wb(mi, mi->callbacks); mi->cache = (jl_code_instance_t*)jl_deserialize_value(s, (jl_value_t**)&mi->cache); if (mi->cache) jl_gc_wb(mi, mi->cache); diff --git a/src/gf.c b/src/gf.c index 3aa6314fdfb36..d639637105891 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1318,6 +1318,40 @@ JL_DLLEXPORT jl_value_t *jl_debug_method_invalidation(int state) return jl_nothing; } +// call external callbacks registered with this method_instance +static void invalidate_external(jl_method_instance_t *mi, size_t max_world) { + jl_array_t *callbacks = mi->callbacks; + if (callbacks) { + // AbstractInterpreter allows for MethodInstances to be present in non-local caches + // inform those caches about the invalidation. + JL_TRY { + size_t i, l = jl_array_len(callbacks); + jl_value_t **args; + JL_GC_PUSHARGS(args, 3); + // these arguments are constant per call + args[1] = (jl_value_t*)mi; + args[2] = jl_box_uint32(max_world); + + size_t last_age = jl_get_ptls_states()->world_age; + jl_get_ptls_states()->world_age = jl_get_world_counter(); + + jl_value_t **cbs = (jl_value_t**)jl_array_ptr_data(callbacks); + for (i = 0; i < l; i++) { + args[0] = cbs[i]; + jl_apply(args, 3); + } + jl_get_ptls_states()->world_age = last_age; + JL_GC_POP(); + } + JL_CATCH { + jl_printf((JL_STREAM*)STDERR_FILENO, "error in invalidation callback: "); + jl_static_show((JL_STREAM*)STDERR_FILENO, jl_current_exception()); + jl_printf((JL_STREAM*)STDERR_FILENO, "\n"); + jlbacktrace(); // writen to STDERR_FILENO + } + } +} + // recursively invalidate cached methods that had an edge to a replaced method static void invalidate_method_instance(jl_method_instance_t *replaced, size_t max_world, int depth) { @@ -1526,6 +1560,7 @@ static void jl_method_table_invalidate(jl_methtable_t *mt, jl_typemap_entry_t *m jl_method_instance_t *mi = (jl_method_instance_t*)jl_svecref(specializations, i); if (mi) { invalidated = 1; + invalidate_external(mi, methodentry->max_world); invalidate_backedges(mi, methodentry->max_world, "jl_method_table_disable"); } } @@ -1644,6 +1679,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method } if (isect != jl_bottom_type) { jl_method_instance_t *backedge = (jl_method_instance_t*)backedges[i]; + invalidate_external(backedge, max_world); invalidate_method_instance(backedge, max_world, 0); invalidated = 1; if (_jl_debug_method_invalidation) @@ -1711,6 +1747,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method continue; } jl_array_ptr_1d_push(oldmi, (jl_value_t*)mi); + invalidate_external(mi, max_world); if (mi->backedges) { invalidated = 1; invalidate_backedges(mi, max_world, "jl_method_table_insert"); diff --git a/src/jltypes.c b/src/jltypes.c index 94eef9c50d9ce..a5d82dc87dc50 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2337,21 +2337,23 @@ void jl_init_types(void) JL_GC_DISABLED jl_method_instance_type = jl_new_datatype(jl_symbol("MethodInstance"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(7, + jl_perm_symsvec(8, "def", "specTypes", "sparam_vals", "uninferred", "backedges", + "callbacks", "cache", "inInference"), - jl_svec(7, + jl_svec(8, jl_new_struct(jl_uniontype_type, jl_method_type, jl_module_type), jl_any_type, jl_simplevector_type, jl_any_type, jl_any_type, jl_any_type, + jl_any_type, jl_bool_type), 0, 1, 3); @@ -2506,7 +2508,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_svecset(jl_methtable_type->types, 10, jl_uint8_type); jl_svecset(jl_methtable_type->types, 11, jl_uint8_type); jl_svecset(jl_method_type->types, 11, jl_method_instance_type); - jl_svecset(jl_method_instance_type->types, 5, jl_code_instance_type); + jl_svecset(jl_method_instance_type->types, 6, jl_code_instance_type); jl_svecset(jl_code_instance_type->types, 9, jl_voidpointer_type); jl_svecset(jl_code_instance_type->types, 10, jl_voidpointer_type); diff --git a/src/julia.h b/src/julia.h index d4e03eb282aba..a3addc636cbb6 100644 --- a/src/julia.h +++ b/src/julia.h @@ -347,6 +347,7 @@ struct _jl_method_instance_t { jl_svec_t *sparam_vals; // static parameter values, indexed by def.method->sparam_syms jl_value_t *uninferred; // cached uncompressed code, for generated functions, top-level thunks, or the interpreter jl_array_t *backedges; // list of method-instances which contain a call into this method-instance + jl_array_t *callbacks; // list of callback functions to inform external caches about invalidations struct _jl_code_instance_t *cache; uint8_t inInference; // flags to tell if inference is running on this object }; diff --git a/src/method.c b/src/method.c index 5f4a954f882b8..6ba2600f618b8 100644 --- a/src/method.c +++ b/src/method.c @@ -314,6 +314,7 @@ JL_DLLEXPORT jl_method_instance_t *jl_new_method_instance_uninit(void) li->sparam_vals = jl_emptysvec; li->uninferred = NULL; li->backedges = NULL; + li->callbacks = NULL; li->cache = NULL; li->inInference = 0; return li;