Skip to content

Commit

Permalink
use specialized code when compiling opaque closure expressions (Julia…
Browse files Browse the repository at this point in the history
…Lang#44176)

invoke specialization when an OC is created at run time
  • Loading branch information
JeffBezanson committed Feb 15, 2022
1 parent 8d8d58f commit 88edb11
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 107 deletions.
186 changes: 95 additions & 91 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4554,6 +4554,68 @@ static void emit_stmtpos(jl_codectx_t &ctx, jl_value_t *expr, int ssaval_result)
}
}

static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_method_t *closure_method, jl_tupletype_t *env_t, jl_tupletype_t *argt_typ, jl_value_t *rettype, bool vaOverride)
{
jl_svec_t *sig_args = NULL;
jl_value_t *sigtype = NULL;
jl_code_info_t *ir = NULL;
JL_GC_PUSH3(&sig_args, &sigtype, &ir);

size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, env_t);
for (size_t i = 0; i < jl_svec_len(argt_typ->parameters); ++i) {
jl_svecset(sig_args, 1+i, jl_svecref(argt_typ->parameters, i));
}
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);

jl_method_instance_t *mi = jl_specializations_get_linfo(closure_method, sigtype, jl_emptysvec);
jl_code_instance_t *ci = (jl_code_instance_t*)jl_rettype_inferred(mi, ctx.world, ctx.world);

if (ci == NULL || (jl_value_t*)ci == jl_nothing || ci->inferred == NULL || ci->inferred == jl_nothing) {
JL_GC_POP();
return std::make_pair((Function*)NULL, (Function*)NULL);
}

ir = jl_uncompress_ir(closure_method, ci, (jl_array_t*)ci->inferred);

// TODO: Emit this inline and outline it late using LLVM's coroutine support.
std::unique_ptr<Module> closure_m;
jl_llvm_functions_t closure_decls;
std::tie(closure_m, closure_decls) = emit_function(mi, ir, rettype, ctx.emission_context,
ctx.builder.getContext(), vaOverride);

assert(closure_decls.functionObject != "jl_fptr_sparam");
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";

Function *F = NULL;
std::string fname = isspecsig ?
closure_decls.functionObject :
closure_decls.specFunctionObject;
if (GlobalValue *V = jl_Module->getNamedValue(fname)) {
F = cast<Function>(V);
} else {
F = Function::Create(get_func_sig(jl_LLVMContext),
Function::ExternalLinkage,
fname, jl_Module);
F->setAttributes(get_func_attrs(jl_LLVMContext));
}
Function *specF = NULL;
if (!isspecsig) {
specF = F;
} else {
specF = closure_m->getFunction(closure_decls.specFunctionObject);
if (specF) {
jl_returninfo_t returninfo = get_specsig_function(ctx, jl_Module,
closure_decls.specFunctionObject, sigtype, rettype, true);
specF = returninfo.decl;
}
}
ctx.oc_modules.push_back(std::move(closure_m));
JL_GC_POP();
return std::make_pair(F, specF);
}

// `expr` is not clobbered in JL_TRY
JL_GCC_IGNORE_START("-Wclobbered")
static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
Expand Down Expand Up @@ -4832,112 +4894,54 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
}

if (can_optimize) {
// TODO: Emit this inline and outline it late using LLVM's coroutine
// support.
jl_method_t *closure_method = (jl_method_t *)source.constant;
jl_code_info_t *closure_src = jl_uncompress_ir(closure_method, NULL,
(jl_array_t*)closure_method->source);

std::unique_ptr<Module> closure_m;
jl_llvm_functions_t closure_decls;

jl_method_instance_t *li = NULL;
jl_value_t *closure_t = NULL;
jl_tupletype_t *env_t = NULL;
jl_svec_t *sig_args = NULL;
JL_GC_PUSH5(&li, &closure_src, &closure_t, &env_t, &sig_args);

li = jl_new_method_instance_uninit();
li->def.method = closure_method;
jl_tupletype_t *argt_typ = (jl_tupletype_t *)argt.constant;

closure_t = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt_typ, ub.constant);

size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, closure_t);
for (size_t i = 0; i < jl_svec_len(argt_typ->parameters); ++i) {
jl_svecset(sig_args, 1+i, jl_svecref(argt_typ->parameters, i));
}
li->specTypes = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
jl_gc_wb(li, li->specTypes);

std::tie(closure_m, closure_decls) = emit_function(li, closure_src,
ub.constant, ctx.emission_context, ctx.builder.getContext(), jl_unbox_bool(isva.constant));
JL_GC_PUSH2(&closure_t, &env_t);

jl_value_t **env_component_ts = (jl_value_t**)alloca(sizeof(jl_value_t*) * (nargs-5));
for (size_t i = 0; i < nargs - 5; ++i) {
env_component_ts[i] = argv[5+i].typ;
}

env_t = jl_apply_tuple_type_v(env_component_ts, nargs-5);
jl_cgval_t env(ctx.builder.getContext());
// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
// we need to know the full env type to look up the right specialization
if (jl_is_concrete_type((jl_value_t*)env_t)) {
env = emit_new_struct(ctx, (jl_value_t*)env_t, nargs-5, &argv.data()[5]);
}
else {
Value *env_val = emit_jlcall(ctx, jltuple_func, Constant::getNullValue(ctx.types().T_prjlvalue),
&argv[5], nargs-5, JLCALL_F_CC);
env = mark_julia_type(ctx, env_val, true, env_t);
}

assert(closure_decls.functionObject != "jl_fptr_sparam");
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";

Function *F = NULL;
std::string fname = isspecsig ?
closure_decls.functionObject :
closure_decls.specFunctionObject;
if (GlobalValue *V = jl_Module->getNamedValue(fname)) {
F = cast<Function>(V);
}
else {
F = Function::Create(get_func_sig(ctx.builder.getContext()),
Function::ExternalLinkage,
fname, jl_Module);
F->setAttributes(get_func_attrs(ctx.builder.getContext()));
}
jl_cgval_t jlcall_ptr = mark_julia_type(ctx,
F, false, jl_voidpointer_type);

jl_cgval_t fptr(ctx.builder.getContext());
if (!isspecsig) {
fptr = jlcall_ptr;
} else {
Function *specptr = closure_m->getFunction(closure_decls.specFunctionObject);
if (specptr) {
jl_returninfo_t returninfo = get_specsig_function(ctx, jl_Module,
closure_decls.specFunctionObject, li->specTypes, ub.constant, true);
fptr = mark_julia_type(ctx, returninfo.decl, false, jl_voidpointer_type);
} else {
fptr = mark_julia_type(ctx,
(llvm::Value*)Constant::getNullValue(getSizeTy(ctx.builder.getContext())),
false, jl_voidpointer_type);
}
}
jl_tupletype_t *argt_typ = (jl_tupletype_t*)argt.constant;
Function *F, *specF;
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, env_t, argt_typ, ub.constant, jl_unbox_bool(isva.constant));
if (F) {
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
jl_cgval_t world_age = mark_julia_type(ctx,
tbaa_decorate(ctx.tbaa().tbaa_gcframe,
ctx.builder.CreateAlignedLoad(ctx.world_age_field, Align(sizeof(size_t)))),
false,
jl_long_type);
jl_cgval_t fptr(ctx.builder.getContext());
if (specF)
fptr = mark_julia_type(ctx, specF, false, jl_voidpointer_type);
else
fptr = mark_julia_type(ctx, (llvm::Value*)Constant::getNullValue(getSizeTy(ctx.builder.getContext())), false, jl_voidpointer_type);

jl_cgval_t world_age = mark_julia_type(ctx,
tbaa_decorate(ctx.tbaa().tbaa_gcframe,
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), ctx.world_age_field, Align(sizeof(size_t)))),
false,
jl_long_type);

jl_cgval_t closure_fields[6] = {
env,
isva,
world_age,
source,
jlcall_ptr,
fptr
};
// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
jl_cgval_t env = emit_new_struct(ctx, (jl_value_t*)env_t, nargs-5, &argv.data()[5]);

jl_cgval_t ret = emit_new_struct(ctx, closure_t, 6, closure_fields);
jl_cgval_t closure_fields[6] = {
env,
isva,
world_age,
source,
jlcall_ptr,
fptr
};

ctx.oc_modules.push_back(std::move(closure_m));
closure_t = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt_typ, ub.constant);
jl_cgval_t ret = emit_new_struct(ctx, closure_t, 6, closure_fields);

JL_GC_POP();
return ret;
}
}
JL_GC_POP();
return ret;
}

return mark_julia_type(ctx,
Expand Down
8 changes: 7 additions & 1 deletion src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
jl_code_info_t *code = jl_uncompress_ir(source, NULL, (jl_array_t*)source->source);
interpreter_state *s;
unsigned nroots = jl_source_nslots(code) + jl_source_nssavalues(code) + 2;
jl_task_t *ct = jl_current_task;
size_t last_age = ct->world_age;
ct->world_age = oc->world;
jl_value_t **locals = NULL;
JL_GC_PUSHFRAME(s, locals, nroots);
locals[0] = (jl_value_t*)oc;
Expand All @@ -710,7 +713,6 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
s->preevaluation = 0;
s->continue_at = 0;
s->mi = NULL;

size_t defargs = source->nargs;
int isva = !!oc->isva;
assert(isva ? nargs + 2 >= defargs : nargs + 1 == defargs);
Expand All @@ -722,6 +724,10 @@ jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **ar
}
JL_GC_ENABLEFRAME(s);
jl_value_t *r = eval_body(code->code, s, 0, 0);
locals[0] = r; // GC root
JL_GC_PROMISE_ROOTED(r);
jl_typeassert(r, jl_tparam1(jl_typeof(oc)));
ct->world_age = last_age;
JL_GC_POP();
return r;
}
Expand Down
67 changes: 52 additions & 15 deletions src/opaque_closure.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
#include "julia.h"
#include "julia_internal.h"

JL_DLLEXPORT jl_value_t *jl_invoke_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
jl_value_t *jl_fptr_const_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
{
jl_value_t *ret = NULL;
JL_GC_PUSH1(&ret);
jl_task_t *ct = jl_current_task;
size_t last_age = ct->world_age;
ct->world_age = oc->world;
ret = jl_interpret_opaque_closure(oc, args, nargs);
jl_typeassert(ret, jl_tparam1(jl_typeof(oc)));
ct->world_age = last_age;
return oc->captures;
}

// TODO: remove
jl_value_t *jl_fptr_va_opaque_closure(jl_opaque_closure_t *oc, jl_value_t **args, size_t nargs)
{
size_t defargs = oc->source->nargs;
jl_value_t **newargs;
JL_GC_PUSHARGS(newargs, defargs - 1);
for (size_t i = 0; i < defargs - 2; i++)
newargs[i] = args[i];
newargs[defargs - 2] = jl_f_tuple(NULL, &args[defargs - 2], nargs + 2 - defargs);
jl_value_t *ans = ((jl_fptr_args_t)oc->specptr)((jl_value_t*)oc, newargs, defargs - 1);
JL_GC_POP();
return ret;
return ans;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *isva,
Expand All @@ -31,17 +36,49 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *isv
jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
JL_GC_PROMISE_ROOTED(oc_type);
jl_value_t *captures = NULL;
JL_GC_PUSH1(&captures);
jl_value_t *captures = NULL, *sigtype = NULL;
jl_svec_t *sig_args = NULL;
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
captures = jl_f_tuple(NULL, env, nenv);

size_t nsig = 1 + jl_svec_len(argt->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, jl_typeof(captures));
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
}
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)source, sigtype, jl_emptysvec);
size_t world = jl_atomic_load_acquire(&jl_world_counter);
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);

jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
JL_GC_POP();
oc->source = (jl_method_t*)source;
oc->isva = jl_unbox_bool(isva);
oc->invoke = (jl_fptr_args_t)jl_invoke_opaque_closure;
oc->specptr = NULL;
oc->captures = captures;
oc->world = jl_atomic_load_acquire(&jl_world_counter);
oc->specptr = NULL;
int compiled = 0;
if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_interpret_call) {
oc->invoke = (jl_fptr_args_t)jl_interpret_opaque_closure;
}
else if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_args) {
oc->invoke = jl_atomic_load_relaxed(&ci->specptr.fptr1);
compiled = 1;
}
else if (jl_atomic_load_relaxed(&ci->invoke) == jl_fptr_const_return) {
oc->invoke = (jl_fptr_args_t)jl_fptr_const_opaque_closure;
oc->captures = ci->rettype_const;
}
else {
oc->invoke = (jl_fptr_args_t)jl_atomic_load_relaxed(&ci->invoke);
compiled = 1;
}
if (oc->isva && compiled) {
oc->specptr = (jl_fptr_args_t)oc->invoke;
oc->invoke = (jl_fptr_args_t)jl_fptr_va_opaque_closure;
}
oc->world = world;
return oc;
}

Expand Down

0 comments on commit 88edb11

Please sign in to comment.