Skip to content

Commit

Permalink
Merge 7a76fdd into a2912e2
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 8, 2023
2 parents a2912e2 + 7a76fdd commit 850231b
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 49 deletions.
34 changes: 16 additions & 18 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,16 @@ function inlining_policy(interp::AbstractInterpreter,
end

struct InliningState{Interp<:AbstractInterpreter}
params::OptimizationParams
et::Union{EdgeTracker,Nothing}
world::UInt
interp::Interp
end
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
function InliningState(frame::InferenceState, interp::AbstractInterpreter)
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
return InliningState(params, et, frame.world, interp)
return InliningState(et, frame.world, interp)
end
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
return InliningState(params, nothing, get_world_counter(interp), interp)
function InliningState(interp::AbstractInterpreter)
return InliningState(nothing, get_world_counter(interp), interp)
end

# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
Expand All @@ -151,15 +150,14 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
cfg::Union{Nothing,CFG}
insert_coverage::Bool
end
function OptimizationState(frame::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
inlining = InliningState(frame, params, interp)
function OptimizationState(frame::InferenceState, interp::AbstractInterpreter,
recompute_cfg::Bool=true)
inlining = InliningState(frame, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
# prepare src for running optimization passes if it isn't already
nssavalues = src.ssavaluetypes
if nssavalues isa Int
Expand All @@ -179,13 +177,13 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
mod = isa(def, Method) ? def.module : def
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(params, interp)
inlining = InliningState(interp)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
end
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
src = retrieve_code_info(linfo)
src === nothing && return nothing
return OptimizationState(linfo, src, params, interp)
return OptimizationState(linfo, src, interp)
end

function ir_to_codeinf!(opt::OptimizationState)
Expand Down Expand Up @@ -392,13 +390,13 @@ abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = typ

"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
ir::IRCode, caller::InferenceResult)
Post-process information derived by Julia-level optimizations for later use.
In particular, this function determines the inlineability of the optimized code.
"""
function finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
ir::IRCode, caller::InferenceResult)
(; src, linfo) = opt
(; def, specTypes) = linfo

Expand Down Expand Up @@ -438,6 +436,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
set_inlineable!(src, true)
else
# compute the cost (size) of inlining this code
params = OptimizationParams(interp)
cost_threshold = default = params.inline_cost_threshold
if (optimizer_lattice(interp), result, Tuple) && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
Expand All @@ -460,10 +459,9 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, caller::InferenceResult)
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
return finish(interp, opt, params, ir, caller)
return finish(interp, opt, ir, caller)
end

using .EscapeAnalysis
Expand Down
34 changes: 17 additions & 17 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function ssa_inlining_pass!(ir::IRCode, state::InliningState, propagate_inbounds
@timeit "analysis" todo = assemble_inline_todo!(ir, state)
isempty(todo) && return ir
# Do the actual inlining for every call we identified
@timeit "execution" ir = batch_inline!(ir, todo, propagate_inbounds, state.params)
@timeit "execution" ir = batch_inline!(ir, todo, propagate_inbounds, OptimizationParams(state.interp))
return ir
end

Expand Down Expand Up @@ -872,14 +872,14 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes

# the duplicated check might have been done already within `analyze_method!`, but still
# we need it here too since we may come here directly using a constant-prop' result
if !state.params.inlining || is_stmt_noinline(flag)
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
return compileable_specialization(result, effects, et, info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
end

src = inlining_policy(state.interp, src, info, flag, mi, argtypes)
src === nothing && return compileable_specialization(result, effects, et, info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
Expand All @@ -888,7 +888,7 @@ end
# the special resolver for :invoke-d call
function resolve_todo(mi::MethodInstance, argtypes::Vector{Any},
@nospecialize(info::CallInfo), flag::UInt8, state::InliningState)
if !state.params.inlining || is_stmt_noinline(flag)
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
return nothing
end

Expand Down Expand Up @@ -958,7 +958,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
et = InliningEdgeTracker(state.et, invokesig)
effects = info_effects(nothing, match, state)
return compileable_specialization(match, effects, et, info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
end

return resolve_todo(mi, match, argtypes, info, flag, state; invokesig)
Expand Down Expand Up @@ -1124,7 +1124,7 @@ function inline_apply!(todo::Vector{Pair{Int,Any}},
arginfos = MaybeAbstractIterationInfo[]
for i = (arg_start + 1):length(argtypes)
thisarginfo = nothing
if !is_valid_type_for_apply_rewrite(argtypes[i], state.params)
if !is_valid_type_for_apply_rewrite(argtypes[i], OptimizationParams(state.interp))
isa(info, ApplyCallInfo) || return nothing
thisarginfo = info.arginfo[i-arg_start]
if thisarginfo === nothing || !thisarginfo.complete
Expand Down Expand Up @@ -1173,13 +1173,13 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
item = resolve_todo(mi, result.result, argtypes, info, flag, state; invokesig)
handle_single_case!(todo, ir, idx, stmt, item, state.params, true)
handle_single_case!(todo, ir, idx, stmt, item, OptimizationParams(state.interp), true)
return nothing
end
end
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig)
end
handle_single_case!(todo, ir, idx, stmt, item, state.params, true)
handle_single_case!(todo, ir, idx, stmt, item, OptimizationParams(state.interp), true)
return nothing
end

Expand Down Expand Up @@ -1433,7 +1433,7 @@ function handle_call!(todo::Vector{Pair{Int,Any}},
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
handle_cases!(todo, ir, idx, stmt, argtypes_to_type(sig.argtypes), cases,
all_covered, joint_effects, state.params)
all_covered, joint_effects, OptimizationParams(state.interp))
end

function handle_match!(cases::Vector{InliningCase},
Expand Down Expand Up @@ -1471,10 +1471,10 @@ end
function semiconcrete_result_item(result::SemiConcreteResult,
@nospecialize(info::CallInfo), flag::UInt8, state::InliningState)
mi = result.mi
if !state.params.inlining || is_stmt_noinline(flag)
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
et = InliningEdgeTracker(state.et, nothing)
return compileable_specialization(mi, result.effects, et, info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
else
return InliningTodo(mi, result.ir, result.effects)
end
Expand Down Expand Up @@ -1507,7 +1507,7 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
if !may_inline_concrete_result(result)
et = InliningEdgeTracker(state.et, invokesig)
case = compileable_specialization(result.mi, result.effects, et, info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
@assert case !== nothing "concrete evaluation should never happen for uncompileable callsite"
return case
end
Expand Down Expand Up @@ -1555,7 +1555,7 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}},
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false)
end
end
handle_single_case!(todo, ir, idx, stmt, item, state.params)
handle_single_case!(todo, ir, idx, stmt, item, OptimizationParams(state.interp))
return nothing
end

Expand All @@ -1568,7 +1568,7 @@ function handle_modifyfield!_call!(ir::IRCode, idx::Int, stmt::Expr, info::Modif
match = info.results[1]::MethodMatch
match.fully_covers || return nothing
case = compileable_specialization(match, Effects(), InliningEdgeTracker(state.et), info;
compilesig_invokes=state.params.compilesig_invokes)
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
case === nothing && return nothing
stmt.head = :invoke_modify
pushfirst!(stmt.args, case.invoke)
Expand Down Expand Up @@ -1696,7 +1696,7 @@ end
function early_inline_special_case(
ir::IRCode, stmt::Expr, @nospecialize(type), sig::Signature,
state::InliningState)
state.params.inlining || return nothing
OptimizationParams(state.interp).inlining || return nothing
(; f, ft, argtypes) = sig

if isa(type, Const) # || isconstType(type)
Expand Down Expand Up @@ -1749,7 +1749,7 @@ end
function late_inline_special_case!(
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(type), sig::Signature,
state::InliningState)
state.params.inlining || return nothing
OptimizationParams(state.interp).inlining || return nothing
(; f, ft, argtypes) = sig
if length(argtypes) == 3 && istopfunction(f, :!==)
# special-case inliner for !== that precedes _methods_by_ftype union splitting
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse
# Check #3
dominates(domtree, finalizer_bb, bb_insert_block) || return nothing

if !inlining.params.assume_fatal_throw
if !OptimizationParams(inlining.interp).assume_fatal_throw
# Collect all reachable blocks between the finalizer registration and the
# insertion point
blocks = finalizer_bb == bb_insert_block ? Int[finalizer_bb] :
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
for (caller, _, _) in results
opt = caller.src
if opt isa OptimizationState{typeof(interp)} # implies `may_optimize(interp) === true`
analyzed = optimize(interp, opt, OptimizationParams(interp), caller)
analyzed = optimize(interp, opt, caller)
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
end
end
Expand Down Expand Up @@ -551,7 +551,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
doopt = (me.cached || me.parent !== nothing)
recompute_cfg = type_annotate!(interp, me, doopt)
if doopt && may_optimize(interp)
me.result.src = OptimizationState(me, OptimizationParams(interp), interp, recompute_cfg)
me.result.src = OptimizationState(me, interp, recompute_cfg)
else
me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection)
end
Expand Down Expand Up @@ -966,8 +966,7 @@ function typeinf_ircode(
return nothing, Any
end
(; result) = frame
opt_params = OptimizationParams(interp)
opt = OptimizationState(frame, opt_params, interp)
opt = OptimizationState(frame, interp)
ir = run_passes(opt.src, opt, result, optimize_until)
rt = widenconst(ignorelimited(result.result))
ccall(:jl_typeinf_timing_end, Cvoid, ())
Expand Down
4 changes: 2 additions & 2 deletions test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ function invalidate_cache!(replaced, max_world, depth = 0)
end

function CC.optimize(interp::EscapeAnalyzer,
opt::OptimizationState, params::OptimizationParams, caller::InferenceResult)
opt::OptimizationState, caller::InferenceResult)
ir = run_passes_with_ea(interp, opt.src, opt, caller)
return CC.finish(interp, opt, params, ir, caller)
return CC.finish(interp, opt, ir, caller)
end

function CC.cache_result!(interp::EscapeAnalyzer, caller::InferenceResult)
Expand Down
8 changes: 3 additions & 5 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
let linfo = get_linfo(Base.convert, Tuple{Type{Int64}, Int32}),
world = UInt(23) # some small-numbered world that should be valid
interp = Core.Compiler.NativeInterpreter()
opt = Core.Compiler.OptimizationState(linfo, Core.Compiler.OptimizationParams(interp), interp)
opt = Core.Compiler.OptimizationState(linfo, interp)
# make sure the state of the properties look reasonable
@test opt.src !== linfo.def.source
@test length(opt.src.slotflags) == linfo.def.nargs <= length(opt.src.slotnames)
Expand Down Expand Up @@ -4125,16 +4125,14 @@ function f_convert_me_to_ir(b, x)
return a
end

let
# Test the presence of PhiNodes in lowered IR by taking the above function,
let # Test the presence of PhiNodes in lowered IR by taking the above function,
# running it through SSA conversion and then putting it into an opaque
# closure.
mi = Core.Compiler.specialize_method(first(methods(f_convert_me_to_ir)),
Tuple{Bool, Float64}, Core.svec())
ci = Base.uncompressed_ast(mi.def)
ci.ssavaluetypes = Any[Any for i = 1:ci.ssavaluetypes]
sv = Core.Compiler.OptimizationState(mi, Core.Compiler.OptimizationParams(),
Core.Compiler.NativeInterpreter())
sv = Core.Compiler.OptimizationState(mi, Core.Compiler.NativeInterpreter())
ir = Core.Compiler.convert_to_ircode(ci, sv)
ir = Core.Compiler.slot2reg(ir, ci, sv)
ir = Core.Compiler.compact!(ir)
Expand Down
3 changes: 1 addition & 2 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1754,8 +1754,7 @@ let interp = Core.Compiler.NativeInterpreter()

# ok, now delete the callsite flag, and see the second inlining pass can inline the call
@eval Core.Compiler $ir.stmts[$i][:flag] &= ~IR_FLAG_NOINLINE
inlining = Core.Compiler.InliningState(Core.Compiler.OptimizationParams(interp), nothing,
Core.Compiler.get_world_counter(interp), interp)
inlining = Core.Compiler.InliningState(interp)
ir = Core.Compiler.ssa_inlining_pass!(ir, inlining, false)
@test count(isinvoke(:*), ir.stmts.inst) == 0
@test count(iscall((ir, Core.Intrinsics.mul_int)), ir.stmts.inst) == 1
Expand Down

0 comments on commit 850231b

Please sign in to comment.