Skip to content

Commit

Permalink
Improve codegen, accuracy of inlining cost for unknown intrinsics (Ju…
Browse files Browse the repository at this point in the history
  • Loading branch information
ianatol committed Oct 28, 2023
1 parent f573c4a commit 4975c02
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 39 deletions.
11 changes: 6 additions & 5 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,7 @@ isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetyp

function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
params::OptimizationParams, error_path::Bool = false)
#=const=# UNKNOWN_CALL_COST = 20
head = ex.head
if is_meta_expr_head(head)
return 0
Expand Down Expand Up @@ -1067,8 +1068,8 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
return cost
end
end
# unknown/unhandled intrinsic
return params.inline_nonleaf_penalty
# unknown/unhandled intrinsic: hopefully the caller gets a slightly better answer after the inlining
return UNKNOWN_CALL_COST
end
if isa(f, Builtin) && f !== invoke
# The efficiency of operations like a[i] and s.b
Expand All @@ -1092,7 +1093,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
if fidx === nothing
# unknown/unhandled builtin
# Use the generic cost of a direct function call
return 20
return UNKNOWN_CALL_COST
end
return T_FFUNC_COST[fidx]
end
Expand All @@ -1114,10 +1115,10 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
# consideration. This way, non-inlined error branches do not
# prevent inlining.
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
return extyp === Union{} ? 0 : 20
return extyp === Union{} ? 0 : UNKNOWN_CALL_COST
elseif head === :(=)
if ex.args[1] isa GlobalRef
cost = 20
cost = UNKNOWN_CALL_COST
else
cost = 0
end
Expand Down
12 changes: 11 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,
try
appl = apply_type(headtype, tparams...)
catch ex
ex isa InterruptException && rethrow()
# type instantiation might fail if one of the type parameters doesn't
# match, which could happen only if a type estimate is too coarse
# and might guess a concrete value while the actual type for it is Bottom
Expand Down Expand Up @@ -2528,8 +2529,17 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes)
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
try
# unroll a few cases which have specialized codegen
if length(argvals) == 1
return Const(f(argvals[1]))
elseif length(argvals) == 2
return Const(f(argvals[1], argvals[2]))
elseif length(argvals) == 3
return Const(f(argvals[1], argvals[2], argvals[3]))
end
return Const(f(argvals...))
catch
catch ex # expected ErrorException, TypeError, ConcurrencyViolationError, DivideError etc.
ex isa InterruptException && rethrow()
return Bottom
end
end
Expand Down
8 changes: 4 additions & 4 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,6 @@ static unsigned intrinsic_nargs[num_intrinsics];

JL_CALLABLE(jl_f_intrinsic_call)
{
JL_TYPECHK(intrinsic_call, intrinsic, F);
enum intrinsic f = (enum intrinsic)*(uint32_t*)jl_data_ptr(F);
if (f == cglobal && nargs == 1)
f = cglobal_auto;
Expand Down Expand Up @@ -2022,6 +2021,7 @@ unsigned jl_intrinsic_nargs(int f)

static void add_intrinsic_properties(enum intrinsic f, unsigned nargs, void (*pfunc)(void))
{
assert(nargs <= 5 && "jl_f_intrinsic_call only implements up to 5 args");
intrinsic_nargs[f] = nargs;
runtime_fp[f] = pfunc;
}
Expand Down Expand Up @@ -2074,10 +2074,10 @@ static void add_builtin(const char *name, jl_value_t *v)
jl_set_const(jl_core_module, jl_symbol(name), v);
}

jl_fptr_args_t jl_get_builtin_fptr(jl_value_t *b)
jl_fptr_args_t jl_get_builtin_fptr(jl_datatype_t *dt)
{
assert(jl_isa(b, (jl_value_t*)jl_builtin_type));
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_atomic_load_relaxed(&jl_gf_mtable(b)->defs);
assert(jl_subtype((jl_value_t*)dt, (jl_value_t*)jl_builtin_type));
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_atomic_load_relaxed(&dt->name->mt->defs);
jl_method_instance_t *mi = jl_atomic_load_relaxed(&entry->func.method->unspecialized);
jl_code_instance_t *ci = jl_atomic_load_relaxed(&mi->cache);
return jl_atomic_load_relaxed(&ci->specptr.fptr1);
Expand Down
72 changes: 57 additions & 15 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ static Type *get_pjlvalue(LLVMContext &C) { return JuliaType::get_pjlvalue_ty(C)

static FunctionType *get_func_sig(LLVMContext &C) { return JuliaType::get_jlfunc_ty(C); }
static FunctionType *get_func2_sig(LLVMContext &C) { return JuliaType::get_jlfunc2_ty(C); }
static FunctionType *get_func3_sig(LLVMContext &C) { return JuliaType::get_jlfunc3_ty(C); }

static FunctionType *get_donotdelete_sig(LLVMContext &C) {
return FunctionType::get(getVoidTy(C), true);
Expand Down Expand Up @@ -1225,6 +1226,12 @@ static const auto box_ssavalue_func = new JuliaFunction<TypeFnContextAndSizeT>{
},
get_attrs_basic,
};
static const auto jlgetbuiltinfptr_func = new JuliaFunction<>{
XSTR(jl_get_builtin_fptr),
[](LLVMContext &C) { return FunctionType::get(get_func_sig(C)->getPointerTo(),
{JuliaType::get_prjlvalue_ty(C)}, false); },
nullptr,
};


// placeholder functions
Expand Down Expand Up @@ -1277,7 +1284,6 @@ static const auto gc_loaded_func = new JuliaFunction<>{
Attributes(C, {Attribute::NonNull, Attribute::NoUndef, Attribute::ReadNone}) }); },
};


// julia.call represents a call with julia calling convention, it is used as
//
// ptr julia.call(ptr %fptr, ptr %f, ptr %arg1, ptr %arg2, ...)
Expand Down Expand Up @@ -1314,7 +1320,23 @@ static const auto julia_call2 = new JuliaFunction<>{
get_attrs_basic,
};

// julia.call3 is like julia.call, except that %fptr is derived rather than tracked
static const auto julia_call3 = new JuliaFunction<>{
"julia.call3",
[](LLVMContext &C) {
auto T_prjlvalue = JuliaType::get_prjlvalue_ty(C);
Type *T = PointerType::get(JuliaType::get_jlvalue_ty(C), AddressSpace::Derived);
return FunctionType::get(T_prjlvalue,
{get_func3_sig(C)->getPointerTo(),
T}, // %f
true); }, // %args
get_attrs_basic,
};


static const auto jltuple_func = new JuliaFunction<>{XSTR(jl_f_tuple), get_func_sig, get_func_attrs};
static const auto jlintrinsic_func = new JuliaFunction<>{XSTR(jl_f_intrinsic_call), get_func3_sig, get_func_attrs};

static const auto &builtin_func_map() {
static std::map<jl_fptr_args_t, JuliaFunction<>*> builtins = {
{ jl_f_is_addr, new JuliaFunction<>{XSTR(jl_f_is), get_func_sig, get_func_attrs} },
Expand Down Expand Up @@ -4659,22 +4681,40 @@ static jl_cgval_t emit_call(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_t *rt, bo
return jl_cgval_t(); // anything past here is unreachable
}

if (f.constant && jl_isa(f.constant, (jl_value_t*)jl_builtin_type)) {
if (f.constant == jl_builtin_ifelse && nargs == 4)
return emit_ifelse(ctx, argv[1], argv[2], argv[3], rt);
jl_cgval_t result;
bool handled = emit_builtin_call(ctx, &result, f.constant, argv, nargs - 1, rt, ex, is_promotable);
if (handled) {
return result;
if (jl_subtype(f.typ, (jl_value_t*)jl_builtin_type)) {
if (f.constant) {
if (f.constant == jl_builtin_ifelse && nargs == 4)
return emit_ifelse(ctx, argv[1], argv[2], argv[3], rt);
jl_cgval_t result;
bool handled = emit_builtin_call(ctx, &result, f.constant, argv, nargs - 1, rt, ex, is_promotable);
if (handled)
return result;

// special case for some known builtin not handled by emit_builtin_call
auto it = builtin_func_map().find(jl_get_builtin_fptr((jl_datatype_t*)jl_typeof(f.constant)));
if (it != builtin_func_map().end()) {
Value *ret = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, julia_call);
setName(ctx.emission_context, ret, it->second->name + "_ret");
return mark_julia_type(ctx, ret, true, rt);
}
}
FunctionCallee fptr;
Value *F;
JuliaFunction<> *cc;
if (f.typ == (jl_value_t*)jl_intrinsic_type) {
fptr = prepare_call(jlintrinsic_func);
F = f.ispointer() ? data_pointer(ctx, f) : value_to_pointer(ctx, f).V;
F = decay_derived(ctx, maybe_bitcast(ctx, F, ctx.types().T_pjlvalue));
cc = julia_call3;
}

// special case for known builtin not handled by emit_builtin_call
auto it = builtin_func_map().find(jl_get_builtin_fptr(f.constant));
if (it != builtin_func_map().end()) {
Value *ret = emit_jlcall(ctx, it->second, Constant::getNullValue(ctx.types().T_prjlvalue), &argv[1], nargs - 1, julia_call);
setName(ctx.emission_context, ret, it->second->name + "_ret");
return mark_julia_type(ctx, ret, true, rt);
else {
fptr = FunctionCallee(get_func_sig(ctx.builder.getContext()), ctx.builder.CreateCall(prepare_call(jlgetbuiltinfptr_func), {emit_typeof(ctx, f)}));
F = boxed(ctx, f);
cc = julia_call;
}
Value *ret = emit_jlcall(ctx, fptr, F, &argv[1], nargs - 1, cc);
setName(ctx.emission_context, ret, "Builtin_ret");
return mark_julia_type(ctx, ret, true, rt);
}

// handle calling an OpaqueClosure
Expand Down Expand Up @@ -9195,6 +9235,8 @@ static void init_jit_functions(void)
add_named_global(jlboundp_func, &jl_boundp);
for (auto it : builtin_func_map())
add_named_global(it.second, it.first);
add_named_global(jlintrinsic_func, &jl_f_intrinsic_call);
add_named_global(jlgetbuiltinfptr_func, &jl_get_builtin_fptr);
add_named_global(jlapplygeneric_func, &jl_apply_generic);
add_named_global(jlinvoke_func, &jl_invoke);
add_named_global(jltopeval_func, &jl_toplevel_eval);
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ void jl_install_default_signal_handlers(void);
void restore_signals(void);
void jl_install_thread_signal_handler(jl_ptls_t ptls);

JL_DLLEXPORT jl_fptr_args_t jl_get_builtin_fptr(jl_value_t *b);
JL_DLLEXPORT jl_fptr_args_t jl_get_builtin_fptr(jl_datatype_t *dt);

extern uv_loop_t *jl_io_loop;
JL_DLLEXPORT void jl_uv_flush(uv_stream_t *stream);
Expand Down
16 changes: 13 additions & 3 deletions src/llvm-codegen-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,19 @@ namespace JuliaType {
return llvm::FunctionType::get(T_prjlvalue, {
T_prjlvalue, // function
T_pprjlvalue, // args[]
llvm::Type::getInt32Ty(C),
T_prjlvalue, // linfo
}, // nargs
llvm::Type::getInt32Ty(C), // nargs
T_prjlvalue}, // linfo
false);
}

static inline auto get_jlfunc3_ty(llvm::LLVMContext &C) {
auto T_prjlvalue = get_prjlvalue_ty(C);
auto T_pprjlvalue = llvm::PointerType::get(T_prjlvalue, 0);
auto T = get_pjlvalue_ty(C, Derived);
return llvm::FunctionType::get(T_prjlvalue, {
T, // function
T_pprjlvalue, // args[]
llvm::Type::getInt32Ty(C)}, // nargs
false);
}

Expand Down
12 changes: 8 additions & 4 deletions src/llvm-gc-invariant-verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,17 @@ void GCInvariantVerifier::visitGetElementPtrInst(GetElementPtrInst &GEP) {
void GCInvariantVerifier::visitCallInst(CallInst &CI) {
Function *Callee = CI.getCalledFunction();
if (Callee && (Callee->getName() == "julia.call" ||
Callee->getName() == "julia.call2")) {
bool First = true;
Callee->getName() == "julia.call2" ||
Callee->getName() == "julia.call3")) {
unsigned Fixed = CI.getFunctionType()->getNumParams();
for (Value *Arg : CI.args()) {
if (Fixed) {
Fixed--;
continue;
}
Type *Ty = Arg->getType();
Check(Ty->isPointerTy() && cast<PointerType>(Ty)->getAddressSpace() == (First ? 0 : AddressSpace::Tracked),
Check(Ty->isPointerTy() && cast<PointerType>(Ty)->getAddressSpace() == AddressSpace::Tracked,
"Invalid derived pointer in jlcall", &CI);
First = false;
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2459,14 +2459,15 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
++it;
continue;
} else if ((call_func && callee == call_func) ||
(call2_func && callee == call2_func)) {
(call2_func && callee == call2_func) ||
(call3_func && callee == call3_func)) {
assert(T_prjlvalue);
size_t nargs = CI->arg_size();
size_t nframeargs = nargs-1;
if (callee == call_func)
nframeargs -= 1;
else if (callee == call2_func)
if (callee == call2_func)
nframeargs -= 2;
else
nframeargs -= 1;
SmallVector<Value*, 4> ReplacementArgs;
auto arg_it = CI->arg_begin();
assert(arg_it != CI->arg_end());
Expand Down Expand Up @@ -2499,7 +2500,9 @@ bool LateLowerGCFrame::CleanupIR(Function &F, State *S, bool *CFGModified) {
ReplacementArgs.erase(ReplacementArgs.begin());
ReplacementArgs.push_back(front);
}
FunctionType *FTy = callee == call2_func ? JuliaType::get_jlfunc2_ty(CI->getContext()) : JuliaType::get_jlfunc_ty(CI->getContext());
FunctionType *FTy = callee == call3_func ? JuliaType::get_jlfunc3_ty(CI->getContext()) :
callee == call2_func ? JuliaType::get_jlfunc2_ty(CI->getContext()) :
JuliaType::get_jlfunc_ty(CI->getContext());
CallInst *NewCall = CallInst::Create(FTy, new_callee, ReplacementArgs, "", CI);
NewCall->setTailCallKind(CI->getTailCallKind());
auto callattrs = CI->getAttributes();
Expand Down
3 changes: 2 additions & 1 deletion src/llvm-pass-helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ JuliaPassContext::JuliaPassContext()
gc_preserve_begin_func(nullptr), gc_preserve_end_func(nullptr),
pointer_from_objref_func(nullptr), gc_loaded_func(nullptr), alloc_obj_func(nullptr),
typeof_func(nullptr), write_barrier_func(nullptr),
call_func(nullptr), call2_func(nullptr), module(nullptr)
call_func(nullptr), call2_func(nullptr), call3_func(nullptr), module(nullptr)
{
}

Expand All @@ -54,6 +54,7 @@ void JuliaPassContext::initFunctions(Module &M)
alloc_obj_func = M.getFunction("julia.gc_alloc_obj");
call_func = M.getFunction("julia.call");
call2_func = M.getFunction("julia.call2");
call3_func = M.getFunction("julia.call3");
}

void JuliaPassContext::initAll(Module &M)
Expand Down
1 change: 1 addition & 0 deletions src/llvm-pass-helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct JuliaPassContext {
llvm::Function *write_barrier_func;
llvm::Function *call_func;
llvm::Function *call2_func;
llvm::Function *call3_func;

// Creates a pass context. Type and function pointers
// are set to `nullptr`. Metadata nodes are initialized.
Expand Down

0 comments on commit 4975c02

Please sign in to comment.