Skip to content

Commit

Permalink
codegen: manage gc-safe-region implicitly in cfunction (#45550)
Browse files Browse the repository at this point in the history
This makes cfunction safe to call from anywhere, including unmanaged
code callbacks.
  • Loading branch information
vtjnash committed Jun 2, 2022
1 parent 86e3f0a commit 8bfb42a
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 43 deletions.
6 changes: 2 additions & 4 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ GlobalVariable *jl_emit_RTLD_DEFAULT_var(Module *M)
return prepare_global_in(M, jlRTLD_DEFAULT_var);
}


// Find or create the GVs for the library and symbol lookup.
// Return `runtime_lib` (whether the library name is a string)
// The `lib` and `sym` GV returned may not be in the current module.
Expand Down Expand Up @@ -1546,10 +1547,7 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
assert(lrt == getVoidTy(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 0);
JL_GC_POP();
ctx.builder.CreateCall(prepare_call(gcroot_flush_func));
emit_signal_fence(ctx);
ctx.builder.CreateLoad(getSizeTy(ctx.builder.getContext()), get_current_signal_page(ctx), true);
emit_signal_fence(ctx);
emit_gc_safepoint(ctx);
return ghostValue(ctx, jl_nothing_type);
}
else if (is_libjulia_func("jl_get_ptls_states")) {
Expand Down
68 changes: 66 additions & 2 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3802,12 +3802,76 @@ static Value *emit_defer_signal(jl_codectx_t &ctx)
{
++EmittedDeferSignal;
Value *ptls = emit_bitcast(ctx, get_current_ptls(ctx),
PointerType::get(ctx.types().T_sigatomic, 0));
PointerType::get(ctx.types().T_sigatomic, 0));
Constant *offset = ConstantInt::getSigned(getInt32Ty(ctx.builder.getContext()),
offsetof(jl_tls_states_t, defer_signal) / sizeof(sig_atomic_t));
offsetof(jl_tls_states_t, defer_signal) / sizeof(sig_atomic_t));
return ctx.builder.CreateInBoundsGEP(ctx.types().T_sigatomic, ptls, ArrayRef<Value*>(offset), "jl_defer_signal");
}

static void emit_gc_safepoint(jl_codectx_t &ctx)
{
ctx.builder.CreateCall(prepare_call(gcroot_flush_func));
emit_signal_fence(ctx);
ctx.builder.CreateLoad(getSizeTy(ctx.builder.getContext()), get_current_signal_page(ctx), true);
emit_signal_fence(ctx);
}

static Value *emit_gc_state_set(jl_codectx_t &ctx, Value *state, Value *old_state)
{
Type *T_int8 = state->getType();
Value *ptls = emit_bitcast(ctx, get_current_ptls(ctx), getInt8PtrTy(ctx.builder.getContext()));
Constant *offset = ConstantInt::getSigned(getInt32Ty(ctx.builder.getContext()), offsetof(jl_tls_states_t, gc_state));
Value *gc_state = ctx.builder.CreateInBoundsGEP(T_int8, ptls, ArrayRef<Value*>(offset), "gc_state");
if (old_state == nullptr) {
old_state = ctx.builder.CreateLoad(T_int8, gc_state);
cast<LoadInst>(old_state)->setOrdering(AtomicOrdering::Monotonic);
}
ctx.builder.CreateAlignedStore(state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
if (auto *C = dyn_cast<ConstantInt>(old_state))
if (C->isZero())
return old_state;
if (auto *C = dyn_cast<ConstantInt>(state))
if (!C->isZero())
return old_state;
BasicBlock *passBB = BasicBlock::Create(ctx.builder.getContext(), "safepoint", ctx.f);
BasicBlock *exitBB = BasicBlock::Create(ctx.builder.getContext(), "after_safepoint", ctx.f);
Constant *zero8 = ConstantInt::get(T_int8, 0);
ctx.builder.CreateCondBr(ctx.builder.CreateAnd(ctx.builder.CreateICmpNE(old_state, zero8), // if (old_state && !state)
ctx.builder.CreateICmpEQ(state, zero8)),
passBB, exitBB);
ctx.builder.SetInsertPoint(passBB);
emit_gc_safepoint(ctx);
ctx.builder.CreateBr(exitBB);
ctx.builder.SetInsertPoint(exitBB);
return old_state;
}

static Value *emit_gc_unsafe_enter(jl_codectx_t &ctx)
{
Value *state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), 0);
return emit_gc_state_set(ctx, state, nullptr);
}

static Value *emit_gc_unsafe_leave(jl_codectx_t &ctx, Value *state)
{
Value *old_state = ConstantInt::get(state->getType(), 0);
return emit_gc_state_set(ctx, state, old_state);
}

//static Value *emit_gc_safe_enter(jl_codectx_t &ctx)
//{
// Value *state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), JL_GC_STATE_SAFE);
// return emit_gc_state_set(ctx, state, nullptr);
//}
//
//static Value *emit_gc_safe_leave(jl_codectx_t &ctx, Value *state)
//{
// Value *old_state = ConstantInt::get(state->getType(), JL_GC_STATE_SAFE);
// return emit_gc_state_set(ctx, state, old_state);
//}



#ifndef JL_NDEBUG
static int compare_cgparams(const jl_cgparams_t *a, const jl_cgparams_t *b)
{
Expand Down
86 changes: 49 additions & 37 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,8 @@ static const auto pointer_from_objref_func = new JuliaFunction{

static const auto jltuple_func = new JuliaFunction{XSTR(jl_f_tuple), get_func_sig, get_func_attrs};
static const auto &builtin_func_map() {
static std::map<jl_fptr_args_t, JuliaFunction*> builtins = { { jl_f_is_addr, new JuliaFunction{XSTR(jl_f_is), get_func_sig, get_func_attrs} },
static std::map<jl_fptr_args_t, JuliaFunction*> builtins = {
{ jl_f_is_addr, new JuliaFunction{XSTR(jl_f_is), get_func_sig, get_func_attrs} },
{ jl_f_typeof_addr, new JuliaFunction{XSTR(jl_f_typeof), get_func_sig, get_func_attrs} },
{ jl_f_sizeof_addr, new JuliaFunction{XSTR(jl_f_sizeof), get_func_sig, get_func_attrs} },
{ jl_f_issubtype_addr, new JuliaFunction{XSTR(jl_f_issubtype), get_func_sig, get_func_attrs} },
Expand Down Expand Up @@ -1372,8 +1373,8 @@ class jl_codectx_t {
int nvargs = -1;
bool is_opaque_closure = false;

CallInst *pgcstack = NULL;
Value *world_age_field = NULL;
Value *pgcstack = NULL;
Instruction *topalloca = NULL;

bool debug_enabled = false;
bool use_cache = false;
Expand Down Expand Up @@ -1423,6 +1424,7 @@ static Value *emit_condition(jl_codectx_t &ctx, const jl_cgval_t &condV, const s
static void allocate_gc_frame(jl_codectx_t &ctx, BasicBlock *b0);
static Value *get_current_task(jl_codectx_t &ctx);
static Value *get_current_ptls(jl_codectx_t &ctx);
static Value *get_last_age_field(jl_codectx_t &ctx);
static Value *get_current_signal_page(jl_codectx_t &ctx);
static void CreateTrap(IRBuilder<> &irbuilder, bool create_new_block = true);
static CallInst *emit_jlcall(jl_codectx_t &ctx, Function *theFptr, Value *theF,
Expand Down Expand Up @@ -1502,7 +1504,7 @@ static GlobalVariable *get_pointer_to_constant(jl_codegen_params_t &emission_con
static AllocaInst *emit_static_alloca(jl_codectx_t &ctx, Type *lty)
{
++EmittedAllocas;
return new AllocaInst(lty, 0, "", /*InsertBefore=*/ctx.pgcstack);
return new AllocaInst(lty, 0, "", /*InsertBefore=*/ctx.topalloca);
}

static void undef_derived_strct(IRBuilder<> &irbuilder, Value *ptr, jl_datatype_t *sty, MDNode *tbaa)
Expand Down Expand Up @@ -4765,13 +4767,6 @@ static void emit_stmtpos(jl_codectx_t &ctx, jl_value_t *expr, int ssaval_result)
return;
}
else {
if (!jl_is_method(ctx.linfo->def.method) && !ctx.is_opaque_closure) {
// TODO: inference is invalid if this has any effect (which it often does)
LoadInst *world = ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()),
prepare_global_in(jl_Module, jlgetworld_global), Align(sizeof(size_t)));
world->setOrdering(AtomicOrdering::Acquire);
ctx.builder.CreateAlignedStore(world, ctx.world_age_field, Align(sizeof(size_t)));
}
assert(ssaval_result != -1);
emit_ssaval_assign(ctx, ssaval_result, expr);
}
Expand Down Expand Up @@ -5150,7 +5145,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
jl_cgval_t world_age = mark_julia_type(ctx,
tbaa_decorate(ctx.tbaa().tbaa_gcframe,
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), ctx.world_age_field, Align(sizeof(size_t)))),
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), get_last_age_field(ctx), Align(sizeof(size_t)))),
false,
jl_long_type);
jl_cgval_t fptr(ctx.builder.getContext());
Expand Down Expand Up @@ -5284,9 +5279,10 @@ JL_GCC_IGNORE_STOP
// gc frame emission
static void allocate_gc_frame(jl_codectx_t &ctx, BasicBlock *b0)
{
// TODO: requires the runtime, but is generated unconditionally
// allocate a placeholder gc instruction
ctx.pgcstack = ctx.builder.CreateCall(prepare_call(jlpgcstack_func));
// this will require the runtime, but it gets deleted later if unused
ctx.topalloca = ctx.builder.CreateCall(prepare_call(jlpgcstack_func));
ctx.pgcstack = ctx.topalloca;
}

static Value *get_current_task(jl_codectx_t &ctx)
Expand All @@ -5304,15 +5300,13 @@ static Value *get_current_ptls(jl_codectx_t &ctx)
return get_current_ptls_from_task(ctx.builder, get_current_task(ctx), ctx.tbaa().tbaa_gcframe);
}

// Store world age at the entry block of the function. This function should be
// called right after `allocate_gc_frame` and there should be no context switch.
static void emit_last_age_field(jl_codectx_t &ctx)
// Get the address of the world age of the current task
static Value *get_last_age_field(jl_codectx_t &ctx)
{
auto ptls = get_current_task(ctx);
assert(ctx.builder.GetInsertBlock() == ctx.pgcstack->getParent());
ctx.world_age_field = ctx.builder.CreateInBoundsGEP(
Value *ct = get_current_task(ctx);
return ctx.builder.CreateInBoundsGEP(
getSizeTy(ctx.builder.getContext()),
ctx.builder.CreateBitCast(ptls, getSizePtrTy(ctx.builder.getContext())),
ctx.builder.CreateBitCast(ct, getSizePtrTy(ctx.builder.getContext())),
ConstantInt::get(getSizeTy(ctx.builder.getContext()), offsetof(jl_task_t, world_age) / sizeof(size_t)),
"world_age");
}
Expand All @@ -5321,7 +5315,7 @@ static void emit_last_age_field(jl_codectx_t &ctx)
static Value *get_current_signal_page(jl_codectx_t &ctx)
{
// return ctx.builder.CreateCall(prepare_call(reuse_signal_page_func));
auto ptls = get_current_ptls(ctx);
Value *ptls = get_current_ptls(ctx);
int nthfield = offsetof(jl_tls_states_t, safepoint) / sizeof(void *);
return emit_nthptr_recast(ctx, ptls, nthfield, ctx.tbaa().tbaa_const, getSizePtrTy(ctx.builder.getContext()));
}
Expand Down Expand Up @@ -5612,14 +5606,19 @@ static Function* gen_cfun_wrapper(
DebugLoc noDbg;
ctx.builder.SetCurrentDebugLocation(noDbg);
allocate_gc_frame(ctx, b0);
emit_last_age_field(ctx);

Value *dummy_world = ctx.builder.CreateAlloca(getSizeTy(ctx.builder.getContext()));
Value *have_tls = ctx.builder.CreateIsNotNull(ctx.pgcstack);
// TODO: in the future, try to initialize a full TLS context here
// for now, just use a dummy field to avoid a branch in this function
ctx.world_age_field = ctx.builder.CreateSelect(have_tls, ctx.world_age_field, dummy_world);
Value *last_age = tbaa_decorate(ctx.tbaa().tbaa_gcframe, ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), ctx.world_age_field, Align(sizeof(size_t))));
// TODO: in the future, initialize a full TLS context here
Value *world_age_field = get_last_age_field(ctx);
world_age_field = ctx.builder.CreateSelect(have_tls, world_age_field, dummy_world);
Value *last_age = tbaa_decorate(ctx.tbaa().tbaa_gcframe,
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), world_age_field, Align(sizeof(size_t))));
Value *last_gc_state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), JL_GC_STATE_SAFE);
last_gc_state = emit_guarded_test(ctx, have_tls, last_gc_state, [&] {
return emit_gc_unsafe_enter(ctx);
});

Value *world_v = ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()),
prepare_global_in(jl_Module, jlgetworld_global), Align(sizeof(size_t)));
cast<LoadInst>(world_v)->setOrdering(AtomicOrdering::Acquire);
Expand All @@ -5640,7 +5639,7 @@ static Function* gen_cfun_wrapper(
world_v = ctx.builder.CreateSelect(ctx.builder.CreateOr(have_tls, age_ok), world_v, lam_max);
age_ok = ctx.builder.CreateOr(ctx.builder.CreateNot(have_tls), age_ok);
}
ctx.builder.CreateStore(world_v, ctx.world_age_field);
ctx.builder.CreateStore(world_v, world_age_field);

// first emit code to record the arguments
Function::arg_iterator AI = cw->arg_begin();
Expand Down Expand Up @@ -5996,7 +5995,13 @@ static Function* gen_cfun_wrapper(
r = NULL;
}

ctx.builder.CreateStore(last_age, ctx.world_age_field);
ctx.builder.CreateStore(last_age, world_age_field);
if (!sig.retboxed) {
emit_guarded_test(ctx, have_tls, nullptr, [&] {
emit_gc_unsafe_leave(ctx, last_gc_state);
return nullptr;
});
}
ctx.builder.CreateRet(r);

ctx.builder.SetCurrentDebugLocation(noDbg);
Expand Down Expand Up @@ -6921,10 +6926,10 @@ static jl_llvm_functions_t
// step 6. set up GC frame
allocate_gc_frame(ctx, b0);
Value *last_age = NULL;
emit_last_age_field(ctx);
Value *world_age_field = get_last_age_field(ctx);
if (toplevel || ctx.is_opaque_closure) {
last_age = tbaa_decorate(ctx.tbaa().tbaa_gcframe, ctx.builder.CreateAlignedLoad(
getSizeTy(ctx.builder.getContext()), ctx.world_age_field, Align(sizeof(size_t))));
getSizeTy(ctx.builder.getContext()), world_age_field, Align(sizeof(size_t))));
}

// step 7. allocate local variables slots
Expand Down Expand Up @@ -6969,10 +6974,10 @@ static jl_llvm_functions_t
Type *vtype = julia_type_to_llvm(ctx, jt, &isboxed);
assert(!isboxed);
assert(!type_is_ghost(vtype) && "constants should already be handled");
Value *lv = new AllocaInst(vtype, 0, jl_symbol_name(s), /*InsertBefore*/ctx.pgcstack);
Value *lv = new AllocaInst(vtype, 0, jl_symbol_name(s), /*InsertBefore*/ctx.topalloca);
if (CountTrackedPointers(vtype).count) {
StoreInst *SI = new StoreInst(Constant::getNullValue(vtype), lv, false, Align(sizeof(void*)));
SI->insertAfter(ctx.pgcstack);
SI->insertAfter(ctx.topalloca);
}
varinfo.value = mark_julia_slot(lv, jt, NULL, ctx.tbaa(), ctx.tbaa().tbaa_stack);
alloc_def_flag(ctx, varinfo);
Expand All @@ -6989,9 +6994,9 @@ static jl_llvm_functions_t
(va && (int)i == ctx.vaSlot) || // or it's the va arg tuple
i == 0) { // or it is the first argument (which isn't in `argArray`)
AllocaInst *av = new AllocaInst(ctx.types().T_prjlvalue, 0,
jl_symbol_name(s), /*InsertBefore*/ctx.pgcstack);
jl_symbol_name(s), /*InsertBefore*/ctx.topalloca);
StoreInst *SI = new StoreInst(Constant::getNullValue(ctx.types().T_prjlvalue), av, false, Align(sizeof(void*)));
SI->insertAfter(ctx.pgcstack);
SI->insertAfter(ctx.topalloca);
varinfo.boxroot = av;
if (ctx.debug_enabled && varinfo.dinfo) {
DIExpression *expr;
Expand Down Expand Up @@ -7149,7 +7154,7 @@ static jl_llvm_functions_t

jl_cgval_t closure_world = typed_load(ctx, worldaddr, NULL, (jl_value_t*)jl_long_type,
theArg.tbaa, nullptr, false, AtomicOrdering::NotAtomic, false, sizeof(size_t));
emit_unbox(ctx, getSizeTy(ctx.builder.getContext()), closure_world, (jl_value_t*)jl_long_type, ctx.world_age_field, ctx.tbaa().tbaa_gcframe);
emit_unbox(ctx, getSizeTy(ctx.builder.getContext()), closure_world, (jl_value_t*)jl_long_type, world_age_field, ctx.tbaa().tbaa_gcframe);

// Load closure env
Value *envaddr = ctx.builder.CreateInBoundsGEP(
Expand Down Expand Up @@ -7624,7 +7629,7 @@ static jl_llvm_functions_t

mallocVisitStmt(debuginfoloc, sync_bytes);
if (toplevel || ctx.is_opaque_closure)
ctx.builder.CreateStore(last_age, ctx.world_age_field);
ctx.builder.CreateStore(last_age, world_age_field);
assert(type_is_ghost(retty) || returninfo.cc == jl_returninfo_t::SRet ||
retval->getType() == ctx.f->getReturnType());
ctx.builder.CreateRet(retval);
Expand Down Expand Up @@ -7684,6 +7689,13 @@ static jl_llvm_functions_t
ctx.builder.SetInsertPoint(tryblk);
}
else {
if (!jl_is_method(ctx.linfo->def.method) && !ctx.is_opaque_closure) {
// TODO: inference is invalid if this has any effect (which it often does)
LoadInst *world = ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()),
prepare_global_in(jl_Module, jlgetworld_global), Align(sizeof(size_t)));
world->setOrdering(AtomicOrdering::Acquire);
ctx.builder.CreateAlignedStore(world, world_age_field, Align(sizeof(size_t)));
}
emit_stmtpos(ctx, stmt, cursor);
mallocVisitStmt(debuginfoloc, nullptr);
}
Expand Down

0 comments on commit 8bfb42a

Please sign in to comment.