Skip to content

Commit

Permalink
Remove jl_LLVMContext (#43827)
Browse files Browse the repository at this point in the history
  • Loading branch information
pchintalapudi committed Mar 2, 2022
1 parent 1a9ad0a commit 12286e0
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 49 deletions.
15 changes: 8 additions & 7 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
// all reachable & inferrrable functions. The `policy` flag switches between the default
// mode `0`, the extern mode `1`, and imaging mode `2`.
extern "C" JL_DLLEXPORT
void *jl_create_native_impl(jl_array_t *methods, const jl_cgparams_t *cgparams, int _policy)
void *jl_create_native_impl(jl_array_t *methods, LLVMContextRef llvmctxt, const jl_cgparams_t *cgparams, int _policy)
{
if (cgparams == NULL)
cgparams = &jl_default_cgparams;
Expand All @@ -268,6 +268,7 @@ void *jl_create_native_impl(jl_array_t *methods, const jl_cgparams_t *cgparams,
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
JL_LOCK(&jl_codegen_lock);
auto &ctxt = llvmctxt ? *unwrap(llvmctxt) : *jl_ExecutionEngine->getContext().getContext();
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand All @@ -276,7 +277,7 @@ void *jl_create_native_impl(jl_array_t *methods, const jl_cgparams_t *cgparams,
CompilationPolicy policy = (CompilationPolicy) _policy;
if (policy == CompilationPolicy::ImagingMode)
imaging_mode = 1;
std::unique_ptr<Module> clone(jl_create_llvm_module("text"));
std::unique_ptr<Module> clone(jl_create_llvm_module("text", ctxt));

// compile all methods for the current world and type-inference world
size_t compile_for[] = { jl_typeinf_world, jl_atomic_load_acquire(&jl_world_counter) };
Expand All @@ -294,7 +295,7 @@ void *jl_create_native_impl(jl_array_t *methods, const jl_cgparams_t *cgparams,
jl_value_t *item = jl_array_ptr_ref(methods, i);
if (jl_is_simplevector(item)) {
if (worlds == 1)
jl_compile_extern_c(clone.get(), &params, NULL, jl_svecref(item, 0), jl_svecref(item, 1));
jl_compile_extern_c(wrap(clone.get()), &params, NULL, jl_svecref(item, 0), jl_svecref(item, 1));
continue;
}
mi = (jl_method_instance_t*)item;
Expand All @@ -309,15 +310,15 @@ void *jl_create_native_impl(jl_array_t *methods, const jl_cgparams_t *cgparams,
if (src && !emitted.count(codeinst)) {
// now add it to our compilation results
JL_GC_PROMISE_ROOTED(codeinst->rettype);
jl_compile_result_t result = jl_emit_code(mi, src, codeinst->rettype, params);
jl_compile_result_t result = jl_emit_code(mi, src, codeinst->rettype, params, ctxt);
if (std::get<0>(result))
emitted[codeinst] = std::move(result);
}
}
}

// finally, make sure all referenced methods also get compiled or fixed up
jl_compile_workqueue(emitted, params, policy);
jl_compile_workqueue(emitted, params, policy, clone->getContext());
}
JL_GC_POP();

Expand Down Expand Up @@ -967,7 +968,7 @@ llvmGetPassPluginInfo() {
// this is paired with jl_dump_function_ir, jl_dump_function_asm, jl_dump_method_asm in particular ways:
// misuse will leak memory or cause read-after-free
extern "C" JL_DLLEXPORT
void *jl_get_llvmf_defn_impl(jl_method_instance_t *mi, size_t world, char getwrapper, char optimize, const jl_cgparams_t params)
void *jl_get_llvmf_defn_impl(jl_method_instance_t *mi, LLVMContextRef ctxt, size_t world, char getwrapper, char optimize, const jl_cgparams_t params)
{
if (jl_is_method(mi->def.method) && mi->def.method->source == NULL &&
mi->def.method->generator == NULL) {
Expand Down Expand Up @@ -1019,7 +1020,7 @@ void *jl_get_llvmf_defn_impl(jl_method_instance_t *mi, size_t world, char getwra
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
compiler_start_time = jl_hrtime();
std::tie(m, decls) = jl_emit_code(mi, src, jlrettype, output);
std::tie(m, decls) = jl_emit_code(mi, src, jlrettype, output, *unwrap(ctxt));

Function *F = NULL;
if (m) {
Expand Down
4 changes: 2 additions & 2 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,9 @@ static Type *julia_type_to_llvm(jl_codectx_t &ctx, jl_value_t *jt, bool *isboxed
}

extern "C" JL_DLLEXPORT
Type *jl_type_to_llvm_impl(jl_value_t *jt, bool *isboxed)
Type *jl_type_to_llvm_impl(jl_value_t *jt, LLVMContextRef ctxt, bool *isboxed)
{
return _julia_type_to_llvm(NULL, jl_LLVMContext, jt, isboxed);
return _julia_type_to_llvm(NULL, *unwrap(ctxt), jt, isboxed);
}


Expand Down
10 changes: 6 additions & 4 deletions src/codegen-stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ JL_DLLEXPORT void jl_extern_c_fallback(jl_function_t *f, jl_value_t *rt, jl_valu
JL_DLLEXPORT jl_value_t *jl_dump_method_asm_fallback(jl_method_instance_t *linfo, size_t world,
char raw_mc, char getwrapper, const char* asm_variant, const char *debuginfo, char binary) UNAVAILABLE
JL_DLLEXPORT jl_value_t *jl_dump_function_ir_fallback(void *f, char strip_ir_metadata, char dump_module, const char *debuginfo) UNAVAILABLE
JL_DLLEXPORT void *jl_get_llvmf_defn_fallback(jl_method_instance_t *linfo, size_t world, char getwrapper, char optimize, const jl_cgparams_t params) UNAVAILABLE
JL_DLLEXPORT void *jl_get_llvmf_defn_fallback(jl_method_instance_t *linfo, LLVMContextRef ctxt, size_t world, char getwrapper, char optimize, const jl_cgparams_t params) UNAVAILABLE

JL_DLLEXPORT void *jl_LLVMCreateDisasm_fallback(const char *TripleName, void *DisInfo, int TagType, void *GetOpInfo, void *SymbolLookUp) UNAVAILABLE
JL_DLLEXPORT size_t jl_LLVMDisasmInstruction_fallback(void *DC, uint8_t *Bytes, uint64_t BytesSize, uint64_t PC, char *OutString, size_t OutStringSize) UNAVAILABLE
Expand Down Expand Up @@ -52,7 +52,7 @@ JL_DLLEXPORT uint32_t jl_get_LLVM_VERSION_fallback(void)
return 0;
}

JL_DLLEXPORT int jl_compile_extern_c_fallback(void *llvmmod, void *params, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
JL_DLLEXPORT int jl_compile_extern_c_fallback(LLVMModuleRef llvmmod, void *params, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
{
return 0;
}
Expand All @@ -74,7 +74,7 @@ JL_DLLEXPORT void jl_unlock_profile_fallback(void)
{
}

JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, const jl_cgparams_t *cgparams, int _policy) UNAVAILABLE
JL_DLLEXPORT void *jl_create_native_fallback(jl_array_t *methods, LLVMContextRef llvmctxt, const jl_cgparams_t *cgparams, int _policy) UNAVAILABLE

JL_DLLEXPORT void jl_dump_compiles_fallback(void *s)
{
Expand All @@ -92,6 +92,8 @@ JL_DLLEXPORT jl_value_t *jl_dump_fptr_asm_fallback(uint64_t fptr, char raw_mc, c

JL_DLLEXPORT jl_value_t *jl_dump_function_asm_fallback(void *F, char raw_mc, const char* asm_variant, const char *debuginfo, char binary) UNAVAILABLE

JL_DLLEXPORT LLVMContextRef jl_get_ee_context_fallback(void) UNAVAILABLE

JL_DLLEXPORT void jl_get_function_id_fallback(void *native_code, jl_code_instance_t *ncode,
int32_t *func_idx, int32_t *specfunc_idx) UNAVAILABLE

Expand All @@ -101,7 +103,7 @@ JL_DLLEXPORT void *jl_get_llvm_function_fallback(void *native_code, uint32_t idx

JL_DLLEXPORT void *jl_get_llvm_module_fallback(void *native_code) UNAVAILABLE

JL_DLLEXPORT void *jl_type_to_llvm_fallback(jl_value_t *jt, bool_t *isboxed) UNAVAILABLE
JL_DLLEXPORT void *jl_type_to_llvm_fallback(jl_value_t *jt, LLVMContextRef llvmctxt, bool_t *isboxed) UNAVAILABLE

JL_DLLEXPORT jl_value_t *jl_get_libllvm_fallback(void) JL_NOTSAFEPOINT
{
Expand Down
31 changes: 16 additions & 15 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ extern void _chkstk(void);
bool imaging_mode = false;

// shared llvm state
static LLVMContext &jl_LLVMContext = *(new LLVMContext());
TargetMachine *jl_TargetMachine;
static DataLayout &jl_data_layout = *(new DataLayout(""));
#define jl_Module ctx.f->getParent()
Expand Down Expand Up @@ -248,7 +247,7 @@ struct jl_typecache_t {
}
initialized = true;
T_ppint8 = PointerType::get(getInt8PtrTy(context), 0);
T_sigatomic = Type::getIntNTy(jl_LLVMContext, sizeof(sig_atomic_t) * 8);
T_sigatomic = Type::getIntNTy(context, sizeof(sig_atomic_t) * 8);
T_jlvalue = JuliaType::get_jlvalue_ty(context);
T_pjlvalue = PointerType::get(T_jlvalue, 0);
T_prjlvalue = PointerType::get(T_jlvalue, AddressSpace::Tracked);
Expand Down Expand Up @@ -1939,9 +1938,9 @@ Module *_jl_create_llvm_module(StringRef name, LLVMContext &context, const jl_cg
return M;
}

Module *jl_create_llvm_module(StringRef name)
Module *jl_create_llvm_module(StringRef name, LLVMContext &context)
{
return _jl_create_llvm_module(name, jl_LLVMContext, &jl_default_cgparams);
return _jl_create_llvm_module(name, context, &jl_default_cgparams);
}

static void jl_init_function(Function *F)
Expand Down Expand Up @@ -4599,10 +4598,10 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
if (GlobalValue *V = jl_Module->getNamedValue(fname)) {
F = cast<Function>(V);
} else {
F = Function::Create(get_func_sig(jl_LLVMContext),
F = Function::Create(get_func_sig(ctx.builder.getContext()),
Function::ExternalLinkage,
fname, jl_Module);
F->setAttributes(get_func_attrs(jl_LLVMContext));
F->setAttributes(get_func_attrs(ctx.builder.getContext()));
}
Function *specF = NULL;
if (!isspecsig) {
Expand Down Expand Up @@ -7766,7 +7765,8 @@ jl_compile_result_t jl_emit_code(
jl_method_instance_t *li,
jl_code_info_t *src,
jl_value_t *jlrettype,
jl_codegen_params_t &params)
jl_codegen_params_t &params,
LLVMContext &context)
{
JL_TIMING(CODEGEN);
// caller must hold codegen_lock
Expand All @@ -7776,7 +7776,7 @@ jl_compile_result_t jl_emit_code(
compare_cgparams(params.params, &jl_default_cgparams)) &&
"functions compiled with custom codegen params must not be cached");
JL_TRY {
std::tie(m, decls) = emit_function(li, src, jlrettype, params, jl_LLVMContext);
std::tie(m, decls) = emit_function(li, src, jlrettype, params, context);
if (dump_emitted_mi_name_stream != NULL) {
jl_printf(dump_emitted_mi_name_stream, "%s\t", decls.specFunctionObject.c_str());
// NOTE: We print the Type Tuple without surrounding quotes, because the quotes
Expand Down Expand Up @@ -7807,7 +7807,8 @@ jl_compile_result_t jl_emit_code(
jl_compile_result_t jl_emit_codeinst(
jl_code_instance_t *codeinst,
jl_code_info_t *src,
jl_codegen_params_t &params)
jl_codegen_params_t &params,
LLVMContext &context)
{
JL_TIMING(CODEGEN);
JL_GC_PUSH1(&src);
Expand All @@ -7821,7 +7822,7 @@ jl_compile_result_t jl_emit_codeinst(
return jl_compile_result_t(); // failed
}
}
jl_compile_result_t result = jl_emit_code(codeinst->def, src, codeinst->rettype, params);
jl_compile_result_t result = jl_emit_code(codeinst->def, src, codeinst->rettype, params, context);

const jl_llvm_functions_t &decls = std::get<1>(result);
const std::string &specf = decls.specFunctionObject;
Expand Down Expand Up @@ -7883,7 +7884,7 @@ jl_compile_result_t jl_emit_codeinst(

void jl_compile_workqueue(
std::map<jl_code_instance_t*, jl_compile_result_t> &emitted,
jl_codegen_params_t &params, CompilationPolicy policy)
jl_codegen_params_t &params, CompilationPolicy policy, LLVMContext &context)
{
JL_TIMING(CODEGEN);
jl_code_info_t *src = NULL;
Expand Down Expand Up @@ -7925,10 +7926,10 @@ void jl_compile_workqueue(
codeinst->inferred && codeinst->inferred == jl_nothing) {
src = jl_type_infer(codeinst->def, jl_atomic_load_acquire(&jl_world_counter), 0);
if (src)
result = jl_emit_code(codeinst->def, src, src->rettype, params);
result = jl_emit_code(codeinst->def, src, src->rettype, params, context);
}
else {
result = jl_emit_codeinst(codeinst, NULL, params);
result = jl_emit_codeinst(codeinst, NULL, params, context);
}
if (std::get<0>(result))
decls = &std::get<1>(result);
Expand Down Expand Up @@ -8314,7 +8315,7 @@ extern "C" void jl_init_llvm(void)
jl_TargetMachine->setFastISel(true);
#endif

jl_ExecutionEngine = new JuliaOJIT(*jl_TargetMachine, &jl_LLVMContext);
jl_ExecutionEngine = new JuliaOJIT(*jl_TargetMachine, new LLVMContext());

// Mark our address spaces as non-integral
jl_data_layout = jl_ExecutionEngine->getDataLayout();
Expand Down Expand Up @@ -8386,7 +8387,7 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
jl_init_jit();
init_jit_functions();

Module *m = _jl_create_llvm_module("julia", jl_LLVMContext, &jl_default_cgparams);
Module *m = _jl_create_llvm_module("julia", *jl_ExecutionEngine->getContext().getContext(), &jl_default_cgparams);
init_julia_llvm_env(m);

jl_init_intrinsic_functions_codegen();
Expand Down
34 changes: 24 additions & 10 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ void jl_jit_globals(std::map<void *, GlobalVariable*> &globals)
static jl_callptr_t _jl_compile_codeinst(
jl_code_instance_t *codeinst,
jl_code_info_t *src,
size_t world)
size_t world,
LLVMContext &context)
{
// caller must hold codegen_lock
// and have disabled finalizers
Expand All @@ -116,10 +117,10 @@ static jl_callptr_t _jl_compile_codeinst(
params.world = world;
std::map<jl_code_instance_t*, jl_compile_result_t> emitted;
{
jl_compile_result_t result = jl_emit_codeinst(codeinst, src, params);
jl_compile_result_t result = jl_emit_codeinst(codeinst, src, params, context);
if (std::get<0>(result))
emitted[codeinst] = std::move(result);
jl_compile_workqueue(emitted, params, CompilationPolicy::Default);
jl_compile_workqueue(emitted, params, CompilationPolicy::Default, context);

if (params._shared_module)
jl_add_to_ee(std::unique_ptr<Module>(params._shared_module));
Expand Down Expand Up @@ -204,7 +205,7 @@ const char *jl_generate_ccallable(void *llvmmod, void *sysimg_handle, jl_value_t

// compile a C-callable alias
extern "C" JL_DLLEXPORT
int jl_compile_extern_c_impl(void *llvmmod, void *p, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
int jl_compile_extern_c_impl(LLVMModuleRef llvmmod, void *p, void *sysimg, jl_value_t *declrt, jl_value_t *sigt)
{
JL_LOCK(&jl_codegen_lock);
uint64_t compiler_start_time = 0;
Expand All @@ -215,9 +216,9 @@ int jl_compile_extern_c_impl(void *llvmmod, void *p, void *sysimg, jl_value_t *d
jl_codegen_params_t *pparams = (jl_codegen_params_t*)p;
if (pparams == NULL)
pparams = &params;
Module *into = (Module*)llvmmod;
Module *into = unwrap(llvmmod);
if (into == NULL)
into = jl_create_llvm_module("cextern");
into = jl_create_llvm_module("cextern", *jl_ExecutionEngine->getContext().getContext());
const char *name = jl_generate_ccallable(into, sysimg, declrt, sigt, *pparams, into->getContext());
bool success = true;
if (!sysimg) {
Expand Down Expand Up @@ -289,6 +290,7 @@ extern "C" JL_DLLEXPORT
jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world)
{
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
auto &context = *jl_ExecutionEngine->getContext().getContext();
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand Down Expand Up @@ -324,7 +326,7 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
if (src->inferred && !codeinst->inferred)
codeinst->inferred = jl_nothing;
}
_jl_compile_codeinst(codeinst, src, world);
_jl_compile_codeinst(codeinst, src, world, context);
if (codeinst->invoke == NULL)
codeinst = NULL;
}
Expand All @@ -345,6 +347,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
return;
}
JL_LOCK(&jl_codegen_lock);
auto &context = *jl_ExecutionEngine->getContext().getContext();
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand All @@ -368,7 +371,7 @@ void jl_generate_fptr_for_unspecialized_impl(jl_code_instance_t *unspec)
src = (jl_code_info_t*)unspec->def->uninferred;
}
assert(src && jl_is_code_info(src));
_jl_compile_codeinst(unspec, src, unspec->min_world);
_jl_compile_codeinst(unspec, src, unspec->min_world, context);
if (unspec->invoke == NULL) {
// if we hit a codegen bug (or ran into a broken generated function or llvmcall), fall back to the interpreter as a last resort
jl_atomic_store_release(&unspec->invoke, jl_fptr_interpret_call_addr);
Expand Down Expand Up @@ -398,6 +401,7 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
// (using sentinel value `1` instead)
// so create an exception here so we can print pretty our lies
JL_LOCK(&jl_codegen_lock); // also disables finalizers, to prevent any unexpected recursion
auto &context = *jl_ExecutionEngine->getContext().getContext();
uint64_t compiler_start_time = 0;
uint8_t measure_compile_time_enabled = jl_atomic_load_relaxed(&jl_measure_compile_time_enabled);
if (measure_compile_time_enabled)
Expand All @@ -419,7 +423,7 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (src && jl_is_code_info(src)) {
if (fptr == (uintptr_t)jl_fptr_const_return_addr && specfptr == 0) {
fptr = (uintptr_t)_jl_compile_codeinst(codeinst, src, world);
fptr = (uintptr_t)_jl_compile_codeinst(codeinst, src, world, context);
specfptr = (uintptr_t)jl_atomic_load_relaxed(&codeinst->specptr.fptr);
}
}
Expand All @@ -434,7 +438,8 @@ jl_value_t *jl_dump_method_asm_impl(jl_method_instance_t *mi, size_t world,
}

// whatever, that didn't work - use the assembler output instead
void *F = jl_get_llvmf_defn(mi, world, getwrapper, true, jl_default_cgparams);
// just make a new context for this one operation
void *F = jl_get_llvmf_defn(mi, wrap(jl_ExecutionEngine->getContext().getContext()), world, getwrapper, true, jl_default_cgparams);
if (!F)
return jl_an_empty_string;
return jl_dump_function_asm(F, raw_mc, asm_variant, debuginfo, binary);
Expand Down Expand Up @@ -1044,6 +1049,10 @@ void JuliaOJIT::RegisterJITEventListener(JITEventListener *L)
}
#endif

orc::ThreadSafeContext &JuliaOJIT::getContext() {
return TSCtx;
}

const DataLayout& JuliaOJIT::getDataLayout() const
{
return DL;
Expand Down Expand Up @@ -1324,3 +1333,8 @@ size_t jl_jit_total_bytes_impl(void)
{
return jl_ExecutionEngine->getTotalBytes();
}

extern "C" JL_DLLEXPORT
LLVMContextRef jl_get_ee_context_impl(void) {
return wrap(jl_ExecutionEngine->getContext().getContext());
}
Loading

0 comments on commit 12286e0

Please sign in to comment.