Skip to content

Commit

Permalink
Change inlineable field to store inlining cost (JuliaLang#45378)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol committed Jul 26, 2022
1 parent 9e22e56 commit 5d2e24f
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 60 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ function const_prop_methodinstance_heuristic(
# was able to cut it down to something simple (inlineable in particular).
# If so, there's a good chance we might be able to const prop all the way
# through and learn something new.
if isdefined(method, :source) && ccall(:jl_ir_flag_inlineable, Bool, (Any,), method.source)
if isdefined(method, :source) && is_inlineable(method.source)
return true
else
flag = get_curr_ssaflag(sv)
Expand Down
36 changes: 25 additions & 11 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ const IR_FLAG_NOTHROW = 0x01 << 5

const TOP_TUPLE = GlobalRef(Core, :tuple)

# This corresponds to the type of `CodeInfo`'s `inlining_cost` field
const InlineCostType = UInt16
const MAX_INLINE_COST = typemax(InlineCostType)
const MIN_INLINE_COST = InlineCostType(10)

is_inlineable(src::Union{CodeInfo, Vector{UInt8}}) = ccall(:jl_ir_inlining_cost, InlineCostType, (Any,), src) != MAX_INLINE_COST
set_inlineable!(src::CodeInfo, val::Bool) = src.inlining_cost = (val ? MIN_INLINE_COST : MAX_INLINE_COST)

function inline_cost_clamp(x::Int)::InlineCostType
x > MAX_INLINE_COST && return MAX_INLINE_COST
x < MIN_INLINE_COST && return MIN_INLINE_COST
return convert(InlineCostType, x)
end

#####################
# OptimizationState #
#####################
Expand Down Expand Up @@ -67,7 +81,7 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f
mi::MethodInstance, argtypes::Vector{Any})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = is_source_inferred(src)
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src)
return src_inferred && src_inlineable ? src : nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
# if this statement is forced to be inlined, make an additional effort to find the
Expand Down Expand Up @@ -461,7 +475,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
@assert isconstType(result)
analyzed = ConstAPI(result.parameters[1])
end
force_noinline || (src.inlineable = true)
force_noinline || set_inlineable!(src, true)
end
end

Expand All @@ -482,14 +496,14 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
else
force_noinline = true
end
if !src.inlineable && result === Bottom
if !is_inlineable(src) && result === Bottom
force_noinline = true
end
end
if force_noinline
src.inlineable = false
set_inlineable!(src, false)
elseif isa(def, Method)
if src.inlineable && isdispatchtuple(specTypes)
if is_inlineable(src) && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
else
# compute the cost (size) of inlining this code
Expand All @@ -498,7 +512,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
cost_threshold += params.inline_tupleret_bonus
end
# if the method is declared as `@inline`, increase the cost threshold 20x
if src.inlineable
if is_inlineable(src)
cost_threshold += 19*default
end
# a few functions get special treatment
Expand All @@ -508,7 +522,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
cost_threshold += 4*default
end
end
src.inlineable = inline_worthy(ir, params, union_penalties, cost_threshold)
src.inlining_cost = inline_cost(ir, params, union_penalties, cost_threshold)
end
end

Expand Down Expand Up @@ -820,16 +834,16 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
return thiscost
end

function inline_worthy(ir::IRCode,
params::OptimizationParams, union_penalties::Bool=false, cost_threshold::Integer=params.inline_cost_threshold)
function inline_cost(ir::IRCode, params::OptimizationParams, union_penalties::Bool=false,
cost_threshold::Integer=params.inline_cost_threshold)::InlineCostType
bodycost::Int = 0
for line = 1:length(ir.stmts)
stmt = ir.stmts[line][:inst]
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
bodycost = plus_saturate(bodycost, thiscost)
bodycost > cost_threshold && return false
bodycost > cost_threshold && return MAX_INLINE_COST
end
return true
return inline_cost_clamp(bodycost)
end

function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any}, unionpenalties::Bool, params::OptimizationParams)
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
return ci
end
if may_discard_trees(interp)
cache_the_tree = ci.inferred && (ci.inlineable || isa_compileable_sig(linfo.specTypes, def))
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, def))
else
cache_the_tree = true
end
Expand Down Expand Up @@ -499,13 +499,13 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# we can throw everything else away now
me.result.src = nothing
me.cached = false
me.src.inlineable = false
set_inlineable!(me.src, false)
unlock_mi_inference(interp, me.linfo)
elseif limited_src
# a type result will be cached still, but not this intermediate work:
# we can throw everything else away now
me.result.src = nothing
me.src.inlineable = false
set_inlineable!(me.src, false)
else
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
Expand Down Expand Up @@ -998,7 +998,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
tree.inferred = true
tree.ssaflags = UInt8[0]
tree.pure = true
tree.inlineable = true
set_inlineable!(tree, true)
tree.parent = mi
tree.rettype = Core.Typeof(rettype_const)
tree.min_world = code.min_world
Expand Down
13 changes: 13 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,17 @@ const var"@_noinline_meta" = var"@noinline"

@deprecate splat(x) Splat(x) false

# We'd generally like to avoid direct external access to internal fields
# Core.Compiler.is_inlineable and Core.Compiler.set_inlineable! move towards this direction,
# but we need to keep these around for compat
function getproperty(ci::CodeInfo, s::Symbol)
s === :inlineable && return Core.Compiler.is_inlineable(ci)
return getfield(ci, s)
end

function setproperty!(ci::CodeInfo, s::Symbol, v)
s === :inlineable && return Core.Compiler.set_inlineable!(ci, v)
return setfield!(ci, s, convert(fieldtype(CodeInfo, s), v))
end

# END 1.9 deprecations
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8276,10 +8276,10 @@ jl_llvm_functions_t jl_emit_codeinst(
}
else if (// don't delete toplevel code
jl_is_method(def) &&
// and there is something to delete (test this before calling jl_ir_flag_inlineable)
// and there is something to delete (test this before calling jl_ir_inlining_cost)
codeinst->inferred != jl_nothing &&
// don't delete inlineable code, unless it is constant
(codeinst->invoke == jl_fptr_const_return_addr || !jl_ir_flag_inlineable((jl_array_t*)codeinst->inferred)) &&
(codeinst->invoke == jl_fptr_const_return_addr || (jl_ir_inlining_cost((jl_array_t*)codeinst->inferred) == UINT16_MAX)) &&
// don't delete code when generating a precompile file
!(params.imaging || jl_options.incremental)) {
// if not inlineable, code won't be needed again
Expand Down
27 changes: 13 additions & 14 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,11 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop)
static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inferred, uint8_t constprop)
{
jl_code_info_flags_t flags;
flags.bits.pure = pure;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.inlineable = inlineable;
flags.bits.inferred = inferred;
flags.bits.constprop = constprop;
return flags;
Expand Down Expand Up @@ -729,9 +728,10 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
1
};

jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inlineable, code->inferred, code->constprop);
jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inferred, code->constprop);
write_uint8(s.s, flags.packed);
write_uint8(s.s, code->purity.bits);
write_uint16(s.s, code->inlining_cost);

size_t nslots = jl_array_len(code->slotflags);
assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions
Expand Down Expand Up @@ -823,10 +823,10 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
flags.packed = read_uint8(s.s);
code->constprop = flags.bits.constprop;
code->inferred = flags.bits.inferred;
code->inlineable = flags.bits.inlineable;
code->propagate_inbounds = flags.bits.propagate_inbounds;
code->pure = flags.bits.pure;
code->purity.bits = read_uint8(s.s);
code->inlining_cost = read_uint16(s.s);

size_t nslots = read_int32(&src);
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
Expand Down Expand Up @@ -890,24 +890,23 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data)
return flags.bits.inferred;
}

JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data)
JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->inlineable;
return ((jl_code_info_t*)data)->pure;
assert(jl_typeis(data, jl_array_uint8_type));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.inlineable;
return flags.bits.pure;
}

JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data)
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_array_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->pure;
return ((jl_code_info_t*)data)->inlining_cost;
assert(jl_typeis(data, jl_array_uint8_type));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.pure;
uint16_t res = jl_load_unaligned_i16((char*)data->data + 2);
return res;
}

JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms)
Expand Down Expand Up @@ -944,7 +943,7 @@ JL_DLLEXPORT ssize_t jl_ir_nslots(jl_array_t *data)
}
else {
assert(jl_typeis(data, jl_array_uint8_type));
int nslots = jl_load_unaligned_i32((char*)data->data + 2);
int nslots = jl_load_unaligned_i32((char*)data->data + 2 + sizeof(uint16_t));
return nslots;
}
}
Expand All @@ -955,7 +954,7 @@ JL_DLLEXPORT uint8_t jl_ir_slotflag(jl_array_t *data, size_t i)
if (jl_is_code_info(data))
return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i];
assert(jl_typeis(data, jl_array_uint8_type));
return ((uint8_t*)data->data)[2 + sizeof(int32_t) + i];
return ((uint8_t*)data->data)[2 + sizeof(uint16_t) + sizeof(int32_t) + i];
}

JL_DLLEXPORT jl_array_t *jl_uncompress_argnames(jl_value_t *syms)
Expand Down
2 changes: 1 addition & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@
XX(jl_ios_fd) \
XX(jl_ios_get_nbyte_int) \
XX(jl_ir_flag_inferred) \
XX(jl_ir_flag_inlineable) \
XX(jl_ir_inlining_cost) \
XX(jl_ir_flag_pure) \
XX(jl_ir_nslots) \
XX(jl_ir_slotflag) \
Expand Down
4 changes: 2 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2417,7 +2417,7 @@ void jl_init_types(void) JL_GC_DISABLED
"min_world",
"max_world",
"inferred",
"inlineable",
"inlining_cost",
"propagate_inbounds",
"pure",
"constprop",
Expand All @@ -2438,7 +2438,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_ulong_type,
jl_ulong_type,
jl_bool_type,
jl_bool_type,
jl_uint16_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
Expand Down
4 changes: 2 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ typedef struct _jl_code_info_t {
size_t max_world;
// various boolean properties:
uint8_t inferred;
uint8_t inlineable;
uint16_t inlining_cost;
uint8_t propagate_inbounds;
uint8_t pure;
// uint8 settings
Expand Down Expand Up @@ -1813,8 +1813,8 @@ JL_DLLEXPORT jl_value_t *jl_copy_ast(jl_value_t *expr JL_MAYBE_UNROOTED);
JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_ir_nslots(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_slotflag(jl_array_t *data, size_t i) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms);
Expand Down
2 changes: 0 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO
typedef struct {
uint8_t pure:1;
uint8_t propagate_inbounds:1;
uint8_t inlineable:1;
uint8_t inferred:1;
uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_flags_bitfield_t;
Expand All @@ -532,7 +531,6 @@ typedef union {

// -- functions -- //

// jl_code_info_flag_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop);
JL_DLLEXPORT jl_code_info_t *jl_type_infer(jl_method_instance_t *li, size_t world, int force);
JL_DLLEXPORT jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *meth JL_PROPAGATES_ROOT, size_t world);
jl_code_instance_t *jl_generate_fptr(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world);
Expand Down
4 changes: 2 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
if (ma == (jl_value_t*)jl_pure_sym)
li->pure = 1;
else if (ma == (jl_value_t*)jl_inline_sym)
li->inlineable = 1;
li->inlining_cost = 0x10; // This corresponds to MIN_INLINE_COST
else if (ma == (jl_value_t*)jl_propagate_inbounds_sym)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)jl_aggressive_constprop_sym)
Expand Down Expand Up @@ -467,7 +467,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->min_world = 1;
src->max_world = ~(size_t)0;
src->inferred = 0;
src->inlineable = 0;
src->inlining_cost = UINT16_MAX;
src->propagate_inbounds = 0;
src->pure = 0;
src->edges = jl_nothing;
Expand Down
2 changes: 1 addition & 1 deletion src/precompile.c
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ static int precompile_enq_specialization_(jl_method_instance_t *mi, void *closur
if (jl_atomic_load_relaxed(&codeinst->invoke) != jl_fptr_const_return) {
if (codeinst->inferred && codeinst->inferred != jl_nothing &&
jl_ir_flag_inferred((jl_array_t*)codeinst->inferred) &&
!jl_ir_flag_inlineable((jl_array_t*)codeinst->inferred)) {
(jl_ir_inlining_cost((jl_array_t*)codeinst->inferred) == UINT16_MAX)) {
do_compile = 1;
}
else if (jl_atomic_load_relaxed(&codeinst->invoke) != NULL || jl_atomic_load_relaxed(&codeinst->precompile)) {
Expand Down
7 changes: 6 additions & 1 deletion stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,12 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
end
end
ci.inferred = deserialize(s)
ci.inlineable = deserialize(s)
inlining = deserialize(s)
if isa(inlining, Bool)
Core.Compiler.set_inlineable!(ci, inlining)
else
ci.inlining_cost = inlining
end
ci.propagate_inbounds = deserialize(s)
ci.pure = deserialize(s)
if format_version(s) >= 14
Expand Down
Loading

0 comments on commit 5d2e24f

Please sign in to comment.