Skip to content

Commit

Permalink
Add relocatable root compression (#43881)
Browse files Browse the repository at this point in the history
Currently we can't cache "external" CodeInstances, i.e., those generated
by compiling other modules' methods with externally-defined types.
For example, consider `push!([], MyPkg.MyType())`: Base owns
the method `push!(::Vector{Any}, ::Any)` but doesn't know about `MyType`.

While there are several obstacles to caching exteral CodeInstances,
the primary one is that in compressed IR, method roots are referenced
from a list by index, and the index is defined by order of insertion.
This order might change depending on package-loading sequence or other
history-dependent factors. If the order isn't consistent, our current
serialization techniques would result in corrupted code upon
decompression, and that would generally trigger catastrophic
failure. To avoid this problem, we simply avoid caching such
CodeInstances.

This enables roots to be referenced with respect to a `(key, index)`
pair, where `key` identifies the module and `index` numbers just those
roots with the same `key`. Roots with `key = 0` are considered to be
of unknown origin, and CodeInstances referencing such roots will remain
unserializable unless all such roots were added at the time of system
image creation.  To track this additional data, this adds two fields
to core types:

- to methods, it adds a `nroots_sysimg` field to count the number
  of roots defined at the time of writing the system image
  (such occur first in the list of `roots`)
- to CodeInstances, it adds a flag `relocatability` having value 1
  if every root is "safe," meaning it was either added at sysimg
  creation or is tagged with a non-zero `key`. Even a single
  unsafe root will cause this to have value 0.
  • Loading branch information
timholy committed Jan 30, 2022
1 parent 11bac43 commit 4abf26e
Show file tree
Hide file tree
Showing 19 changed files with 341 additions and 33 deletions.
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ eval(Core, :(LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line
$(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at))))
eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32,
min_world::UInt, max_world::UInt) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
min_world::UInt, max_world::UInt, relocatability::UInt8) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt, UInt8),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world, relocatability)))
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source))))
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end

function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
valid_worlds::WorldRange)
valid_worlds::WorldRange, relocatability::UInt8)
local const_flags::Int32
result_type = result.result
@assert !(result_type isa LimitedAccuracy)
Expand Down Expand Up @@ -310,7 +310,7 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
end
return CodeInstance(result.linfo,
widenconst(result_type), rettype_const, inferred_result,
const_flags, first(valid_worlds), last(valid_worlds))
const_flags, first(valid_worlds), last(valid_worlds), relocatability)
end

# For the NativeInterpreter, we don't need to do an actual cache query to know
Expand Down Expand Up @@ -384,7 +384,8 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult)
# TODO: also don't store inferred code if we've previously decided to interpret this function
if !already_inferred
inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src)
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds)
relocatability = isa(inferred_result, Vector{UInt8}) ? inferred_result[end] : UInt8(0)
code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds, relocatability)
end
unlock_mi_inference(interp, linfo)
nothing
Expand Down
5 changes: 4 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7795,8 +7795,11 @@ jl_compile_result_t jl_emit_codeinst(
jl_options.debug_level > 1) {
// update the stored code
if (codeinst->inferred != (jl_value_t*)src) {
if (jl_is_method(def))
if (jl_is_method(def)) {
src = (jl_code_info_t*)jl_compress_ir(def, src);
assert(jl_typeis(src, jl_array_uint8_type));
codeinst->relocatability = ((uint8_t*)jl_array_data(src))[jl_array_len(src)-1];
}
codeinst->inferred = (jl_value_t*)src;
jl_gc_wb(codeinst, src);
}
Expand Down
1 change: 0 additions & 1 deletion src/common_symbols1.inc
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,3 @@ jl_symbol("undef"),
jl_symbol("sizeof"),
jl_symbol("String"),
jl_symbol("namedtuple.jl"),
jl_symbol("pop"),
2 changes: 1 addition & 1 deletion src/common_symbols2.inc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
jl_symbol("pop"),
jl_symbol("inbounds"),
jl_symbol("strings/string.jl"),
jl_symbol("Ref"),
Expand Down Expand Up @@ -251,4 +252,3 @@ jl_symbol("GitError"),
jl_symbol("zeros"),
jl_symbol("InexactError"),
jl_symbol("LogLevel"),
jl_symbol("between"),
4 changes: 4 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_
jl_serialize_value(s, NULL);
jl_serialize_value(s, jl_any_type);
}
write_uint8(s->s, codeinst->relocatability);
jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque);
}

Expand Down Expand Up @@ -705,6 +706,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
jl_serialize_value(s, (jl_value_t*)m->roots);
jl_serialize_value(s, (jl_value_t*)m->root_blocks);
write_int32(s->s, m->nroots_sysimg);
jl_serialize_value(s, (jl_value_t*)m->ccallable);
jl_serialize_value(s, (jl_value_t*)m->source);
jl_serialize_value(s, (jl_value_t*)m->unspecialized);
Expand Down Expand Up @@ -1577,6 +1579,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->root_blocks = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->root_blocks);
if (m->root_blocks)
jl_gc_wb(m, m->root_blocks);
m->nroots_sysimg = read_int32(s->s);
m->ccallable = (jl_svec_t*)jl_deserialize_value(s, (jl_value_t**)&m->ccallable);
if (m->ccallable) {
jl_gc_wb(m, m->ccallable);
Expand Down Expand Up @@ -1661,6 +1664,7 @@ static jl_value_t *jl_deserialize_value_code_instance(jl_serializer_state *s, jl
codeinst->invoke = jl_fptr_const_return;
if ((flags >> 3) & 1)
codeinst->precompile = 1;
codeinst->relocatability = read_uint8(s->s);
codeinst->next = (jl_code_instance_t*)jl_deserialize_value(s, (jl_value_t**)&codeinst->next);
jl_gc_wb(codeinst, codeinst->next);
if (validate) {
Expand Down
15 changes: 8 additions & 7 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_value_t *typ
JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
int32_t const_flags, size_t min_world, size_t max_world);
int32_t const_flags, size_t min_world, size_t max_world, uint8_t relocatability);
JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);

Expand Down Expand Up @@ -243,7 +243,7 @@ jl_datatype_t *jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_a

jl_code_instance_t *codeinst = jl_new_codeinst(mi,
(jl_value_t*)jl_any_type, jl_nothing, jl_nothing,
0, 1, ~(size_t)0);
0, 1, ~(size_t)0, 0);
jl_mi_cache_insert(mi, codeinst);
codeinst->specptr.fptr1 = fptr;
codeinst->invoke = jl_fptr_args;
Expand Down Expand Up @@ -366,15 +366,15 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
}
codeinst = jl_new_codeinst(
mi, rettype, NULL, NULL,
0, min_world, max_world);
0, min_world, max_world, 0);
jl_mi_cache_insert(mi, codeinst);
return codeinst;
}

JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
int32_t const_flags, size_t min_world, size_t max_world
int32_t const_flags, size_t min_world, size_t max_world, uint8_t relocatability
/*, jl_array_t *edges, int absolute_max*/)
{
jl_task_t *ct = jl_current_task;
Expand All @@ -399,6 +399,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
codeinst->isspecsig = 0;
codeinst->precompile = 0;
codeinst->next = NULL;
codeinst->relocatability = relocatability;
return codeinst;
}

Expand Down Expand Up @@ -2008,7 +2009,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
if (unspec && jl_atomic_load_relaxed(&unspec->invoke)) {
jl_code_instance_t *codeinst = jl_new_codeinst(mi,
(jl_value_t*)jl_any_type, NULL, NULL,
0, 1, ~(size_t)0);
0, 1, ~(size_t)0, 0);
codeinst->isspecsig = 0;
codeinst->specptr = unspec->specptr;
codeinst->rettype_const = unspec->rettype_const;
Expand All @@ -2026,7 +2027,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
if (!jl_code_requires_compiler(src)) {
jl_code_instance_t *codeinst = jl_new_codeinst(mi,
(jl_value_t*)jl_any_type, NULL, NULL,
0, 1, ~(size_t)0);
0, 1, ~(size_t)0, 0);
codeinst->invoke = jl_fptr_interpret_call;
jl_mi_cache_insert(mi, codeinst);
record_precompile_statement(mi);
Expand Down Expand Up @@ -2061,7 +2062,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
return ucache;
}
codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL,
0, 1, ~(size_t)0);
0, 1, ~(size_t)0, 0);
codeinst->isspecsig = 0;
codeinst->specptr = ucache->specptr;
codeinst->rettype_const = ucache->rettype_const;
Expand Down
43 changes: 34 additions & 9 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,37 @@ typedef struct {
// method we're compressing for
jl_method_t *method;
jl_ptls_t ptls;
uint8_t relocatability;
} jl_ircode_state;

// --- encoding ---

#define jl_encode_value(s, v) jl_encode_value_((s), (jl_value_t*)(v), 0)

static int literal_val_id(jl_ircode_state *s, jl_value_t *v) JL_GC_DISABLED
static void tagged_root(rle_reference *rr, jl_ircode_state *s, int i)
{
if (!get_root_reference(rr, s->method, i))
s->relocatability = 0;
}

static void literal_val_id(rle_reference *rr, jl_ircode_state *s, jl_value_t *v) JL_GC_DISABLED
{
jl_array_t *rs = s->method->roots;
int i, l = jl_array_len(rs);
if (jl_is_symbol(v) || jl_is_concrete_type(v)) {
for (i = 0; i < l; i++) {
if (jl_array_ptr_ref(rs, i) == v)
return i;
return tagged_root(rr, s, i);
}
}
else {
for (i = 0; i < l; i++) {
if (jl_egal(jl_array_ptr_ref(rs, i), v))
return i;
return tagged_root(rr, s, i);
}
}
jl_add_method_root(s->method, jl_precompile_toplevel_module, v);
return jl_array_len(rs) - 1;
return tagged_root(rr, s, jl_array_len(rs) - 1);
}

static void jl_encode_int32(jl_ircode_state *s, int32_t x)
Expand All @@ -67,6 +74,7 @@ static void jl_encode_int32(jl_ircode_state *s, int32_t x)
static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED
{
size_t i;
rle_reference rr;

if (v == NULL) {
write_uint8(s->s, TAG_NULL);
Expand Down Expand Up @@ -321,8 +329,13 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
if (!as_literal && !(jl_is_uniontype(v) || jl_is_newvarnode(v) || jl_is_tuple(v) ||
jl_is_linenode(v) || jl_is_upsilonnode(v) || jl_is_pinode(v) ||
jl_is_slot(v) || jl_is_ssavalue(v))) {
int id = literal_val_id(s, v);
literal_val_id(&rr, s, v);
int id = rr.index;
assert(id >= 0);
if (rr.key) {
write_uint8(s->s, TAG_RELOC_METHODROOT);
write_int64(s->s, rr.key);
}
if (id < 256) {
write_uint8(s->s, TAG_METHODROOT);
write_uint8(s->s, id);
Expand Down Expand Up @@ -577,6 +590,7 @@ static jl_value_t *jl_decode_value(jl_ircode_state *s) JL_GC_DISABLED
assert(!ios_eof(s->s));
jl_value_t *v;
size_t i, n;
uint64_t key;
uint8_t tag = read_uint8(s->s);
if (tag > LAST_TAG)
return jl_deser_tag(tag);
Expand All @@ -585,10 +599,15 @@ static jl_value_t *jl_decode_value(jl_ircode_state *s) JL_GC_DISABLED
case 0:
tag = read_uint8(s->s);
return jl_deser_tag(tag);
case TAG_RELOC_METHODROOT:
key = read_uint64(s->s);
tag = read_uint8(s->s);
assert(tag == TAG_METHODROOT || tag == TAG_LONG_METHODROOT);
return lookup_root(s->method, key, tag == TAG_METHODROOT ? read_uint8(s->s) : read_uint16(s->s));
case TAG_METHODROOT:
return jl_array_ptr_ref(s->method->roots, read_uint8(s->s));
return lookup_root(s->method, 0, read_uint8(s->s));
case TAG_LONG_METHODROOT:
return jl_array_ptr_ref(s->method->roots, read_uint16(s->s));
return lookup_root(s->method, 0, read_uint16(s->s));
case TAG_SVEC: JL_FALLTHROUGH; case TAG_LONG_SVEC:
return jl_decode_value_svec(s, tag);
case TAG_COMMONSYM:
Expand Down Expand Up @@ -706,7 +725,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
jl_ircode_state s = {
&dest,
m,
jl_current_task->ptls
jl_current_task->ptls,
1
};

jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inlineable, code->inferred, code->constprop);
Expand Down Expand Up @@ -756,6 +776,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
ios_write(s.s, (char*)jl_array_data(code->codelocs), nstmt * sizeof(int32_t));
}

write_uint8(s.s, s.relocatability);

ios_flush(s.s);
jl_array_t *v = jl_take_buffer(&dest);
ios_close(s.s);
Expand Down Expand Up @@ -786,7 +808,8 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
jl_ircode_state s = {
&src,
m,
jl_current_task->ptls
jl_current_task->ptls,
1
};

jl_code_info_t *code = jl_new_code_info_uninit();
Expand Down Expand Up @@ -831,6 +854,8 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
ios_readall(s.s, (char*)jl_array_data(code->codelocs), nstmt * sizeof(int32_t));
}

(void) read_uint8(s.s); // relocatability

assert(ios_getc(s.s) == -1);
ios_close(s.s);
JL_GC_PUSH1(&code);
Expand Down
16 changes: 10 additions & 6 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2393,7 +2393,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(27,
jl_perm_symsvec(28,
"name",
"module",
"file",
Expand All @@ -2410,6 +2410,7 @@ void jl_init_types(void) JL_GC_DISABLED
"generator", // !const
"roots", // !const
"root_blocks", // !const
"nroots_sysimg",
"ccallable", // !const
"invokes", // !const
"recursion_relation", // !const
Expand All @@ -2421,7 +2422,7 @@ void jl_init_types(void) JL_GC_DISABLED
"pure",
"is_for_opaque_closure",
"constprop"),
jl_svec(27,
jl_svec(28,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand All @@ -2438,6 +2439,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_any_type,
jl_array_any_type,
jl_array_uint64_type,
jl_int32_type,
jl_simplevector_type,
jl_any_type,
jl_any_type,
Expand Down Expand Up @@ -2483,7 +2485,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_instance_type =
jl_new_datatype(jl_symbol("CodeInstance"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(11,
jl_perm_symsvec(12,
"def",
"next",
"min_world",
Expand All @@ -2493,8 +2495,9 @@ void jl_init_types(void) JL_GC_DISABLED
"inferred",
//"edges",
//"absolute_max",
"isspecsig", "precompile", "invoke", "specptr"), // function object decls
jl_svec(11,
"isspecsig", "precompile", "invoke", "specptr", // function object decls
"relocatability"),
jl_svec(12,
jl_method_instance_type,
jl_any_type,
jl_ulong_type,
Expand All @@ -2506,7 +2509,8 @@ void jl_init_types(void) JL_GC_DISABLED
//jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_any_type, jl_any_type), // fptrs
jl_any_type, jl_any_type, // fptrs
jl_uint8_type),
jl_emptysvec,
0, 1, 1);
jl_svecset(jl_code_instance_type->types, 1, jl_code_instance_type);
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ typedef struct _jl_method_t {
// Identify roots by module-of-origin. We only track the module for roots added during incremental compilation.
// May be NULL if no external roots have been added, otherwise it's a Vector{UInt64}
jl_array_t *root_blocks; // RLE (build_id, offset) pairs (even/odd indexing)
int32_t nroots_sysimg; // # of roots stored in the system image
jl_svec_t *ccallable; // svec(rettype, sig) if a ccallable entry point is requested for this

// cache of specializations of this method for invoke(), i.e.
Expand Down Expand Up @@ -381,6 +382,7 @@ typedef struct _jl_code_instance_t {
_Atomic(jl_fptr_sparam_t) fptr3;
// 4 interpreter
} specptr; // private data for `jlcall entry point
uint8_t relocatability; // nonzero if all roots are built into sysimg or tagged by module key
} jl_code_instance_t;

// all values are callable as Functions
Expand Down
3 changes: 3 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "support/ptrhash.h"
#include "support/strtod.h"
#include "gc-alloc-profiler.h"
#include "support/rle.h"
#include <uv.h>
#if !defined(_WIN32)
#include <unistd.h>
Expand Down Expand Up @@ -528,6 +529,8 @@ void jl_resolve_globals_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *spar
int binding_effects);

JL_DLLEXPORT void jl_add_method_root(jl_method_t *m, jl_module_t *mod, jl_value_t* root);
int get_root_reference(rle_reference *rr, jl_method_t *m, size_t i);
jl_value_t *lookup_root(jl_method_t *m, uint64_t key, int index);

int jl_valid_type_param(jl_value_t *v);

Expand Down
Loading

0 comments on commit 4abf26e

Please sign in to comment.