Skip to content

Commit

Permalink
optimize cases when there are no handles in the method body (#54415)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed May 9, 2024
1 parent 3c966a5 commit d543508
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 98 deletions.
31 changes: 16 additions & 15 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2651,8 +2651,10 @@ function abstract_eval_throw_undef_if_not(interp::AbstractInterpreter, e::Expr,
return RTEffects(rt, exct, effects)
end

abstract_eval_the_exception(::AbstractInterpreter, sv::InferenceState) =
the_exception_info(sv.handlers[sv.handler_at[sv.currpc][2]].exct)
function abstract_eval_the_exception(::AbstractInterpreter, sv::InferenceState)
(;handlers, handler_at) = sv.handler_info::HandlerInfo
return the_exception_info(handlers[handler_at[sv.currpc][2]].exct)
end
abstract_eval_the_exception(::AbstractInterpreter, ::IRInterpretationState) = the_exception_info(Any)
the_exception_info(@nospecialize t) = RTEffects(t, Union{}, Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE))

Expand Down Expand Up @@ -3159,22 +3161,21 @@ end

function update_exc_bestguess!(interp::AbstractInterpreter, @nospecialize(exct), frame::InferenceState)
𝕃ₚ = ipo_lattice(interp)
cur_hand = frame.handler_at[frame.currpc][1]
if cur_hand == 0
handler = gethandler(frame)
if handler === nothing
if !(𝕃ₚ, exct, frame.exc_bestguess)
frame.exc_bestguess = tmerge(𝕃ₚ, frame.exc_bestguess, exct)
update_cycle_worklists!(frame) do caller::InferenceState, caller_pc::Int
caller_handler = caller.handler_at[caller_pc][1]
caller_exct = caller_handler == 0 ?
caller.exc_bestguess : caller.handlers[caller_handler].exct
caller_handler = gethandler(caller, caller_pc)
caller_exct = caller_handler === nothing ?
caller.exc_bestguess : caller_handler.exct
return caller_exct !== Any
end
end
else
handler_frame = frame.handlers[cur_hand]
if !(𝕃ₚ, exct, handler_frame.exct)
handler_frame.exct = tmerge(𝕃ₚ, handler_frame.exct, exct)
enter = frame.src.code[handler_frame.enter_idx]::EnterNode
if !(𝕃ₚ, exct, handler.exct)
handler.exct = tmerge(𝕃ₚ, handler.exct, exct)
enter = frame.src.code[handler.enter_idx]::EnterNode
exceptbb = block_for_inst(frame.cfg, enter.catch_dest)
push!(frame.ip, exceptbb)
end
Expand All @@ -3184,9 +3185,9 @@ end
function propagate_to_error_handler!(currstate::VarTable, frame::InferenceState, 𝕃ᵢ::AbstractLattice)
# If this statement potentially threw, propagate the currstate to the
# exception handler, BEFORE applying any state changes.
cur_hand = frame.handler_at[frame.currpc][1]
if cur_hand != 0
enter = frame.src.code[frame.handlers[cur_hand].enter_idx]::EnterNode
curr_hand = gethandler(frame)
if curr_hand !== nothing
enter = frame.src.code[curr_hand.enter_idx]::EnterNode
exceptbb = block_for_inst(frame.cfg, enter.catch_dest)
if update_bbstate!(𝕃ᵢ, frame, exceptbb, currstate)
push!(frame.ip, exceptbb)
Expand Down Expand Up @@ -3333,7 +3334,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
if isdefined(stmt, :scope)
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
handler = frame.handlers[frame.handler_at[frame.currpc+1][1]]
handler = gethandler(frame, frame.currpc+1)::TryCatchFrame
@assert handler.scopet !== nothing
if !(𝕃ᵢ, scopet, handler.scopet)
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
Expand Down
114 changes: 99 additions & 15 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ mutable struct TryCatchFrame
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
end

struct HandlerInfo
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -237,8 +242,7 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
handler_info::Union{Nothing,HandlerInfo}
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand Down Expand Up @@ -290,7 +294,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at, handlers = compute_trycatch(code, BitSet())
handler_info = compute_trycatch(code)
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -335,12 +339,12 @@ mutable struct InferenceState
restrict_abstract_call_sites = isa(def, Module)

# some more setups
InferenceParams(interp).unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
InferenceParams(interp).unoptimize_throw_blocks && mark_throw_blocks!(src, handler_info)
!iszero(cache_mode & CACHE_MODE_LOCAL) && push!(get_inference_cache(interp), result)

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand All @@ -356,6 +360,14 @@ mutable struct InferenceState
end
end

gethandler(frame::InferenceState, pc::Int=frame.currpc) = gethandler(frame.handler_info, pc)
gethandler(::Nothing, ::Int) = nothing
function gethandler(handler_info::HandlerInfo, pc::Int)
handler_idx = handler_info.handler_at[pc][1]
handler_idx == 0 && return nothing
return handler_info.handlers[handler_idx]
end

is_nonoverlayed(m::Method) = !isdefined(m, :external_mt)
is_nonoverlayed(interp::AbstractInterpreter) = !isoverlayed(method_table(interp))
isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
Expand All @@ -368,21 +380,21 @@ is_inferred(result::InferenceResult) = result.result !== nothing

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND

compute_trycatch(ir::IRCode, ip::BitSet) = compute_trycatch(ir.stmts.stmt, ip, ir.cfg.blocks)
compute_trycatch(ir::IRCode) = compute_trycatch(ir.stmts.stmt, ir.cfg.blocks)

"""
compute_trycatch(code, ip [, bbs]) -> (handler_at, handlers)
compute_trycatch(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo}
Given the code of a function, compute, at every statement, the current
try/catch handler, and the current exception stack top. This function returns
a tuple of:
1. `handler_at`: A statement length vector of tuples `(catch_handler, exception_stack)`,
which are indices into `handlers`
1. `handler_info.handler_at`: A statement length vector of tuples
`(catch_handler, exception_stack)`, which are indices into `handlers`
2. `handlers`: A `TryCatchFrame` vector of handlers
2. `handler_info.handlers`: A `TryCatchFrame` vector of handlers
"""
function compute_trycatch(code::Vector{Any}, ip::BitSet, bbs::Union{Vector{BasicBlock}, Nothing}=nothing)
function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
Expand All @@ -391,16 +403,17 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet, bbs::Union{Vector{Basic
# then we can find all `try`s by walking backwards from :enter statements,
# and all `catch`es by looking at the statement after the :enter
n = length(code)
empty!(ip)
ip = BitSet()
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]
handler_info = nothing

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isa(stmt, EnterNode)
(;handlers, handler_at) = handler_info =
(handler_info === nothing ? HandlerInfo(TryCatchFrame[], fill((0, 0), n)) : handler_info)
l = stmt.catch_dest
(bbs !== nothing) && (l = first(bbs[l].stmts))
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
Expand All @@ -414,7 +427,12 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet, bbs::Union{Vector{Basic
end
end

if handler_info === nothing
return nothing
end

# now forward those marks to all :leave statements
(;handlers, handler_at) = handler_info
while true
# make progress on the active ip set
pc = _bits_findnext(ip.bits, 0)::Int
Expand Down Expand Up @@ -488,7 +506,73 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet, bbs::Union{Vector{Basic
end

@assert first(ip) == n + 1
return handler_at, handlers
return handler_info
end

function is_throw_call(e::Expr, code::Vector{Any})
if e.head === :call
f = e.args[1]
if isa(f, SSAValue)
f = code[f.id]
end
if isa(f, GlobalRef)
ff = abstract_eval_globalref_type(f)
if isa(ff, Const) && ff.val === Core.throw
return true
end
end
end
return false
end

function mark_throw_blocks!(src::CodeInfo, handler_info::Union{Nothing,HandlerInfo})
for stmt in find_throw_blocks(src.code, handler_info)
src.ssaflags[stmt] |= IR_FLAG_THROW_BLOCK
end
return nothing
end

# this utility function is incomplete and won't catch every block that always throws, since:
# - it only recognizes direct calls to `throw` within the target code, so it can't mark
# blocks that deterministically call `throw` internally, like those containing `error`.
# - it just does a reverse linear traverse of statements, there's a chance it might miss
# blocks, particularly when there are reverse control edges.
function find_throw_blocks(code::Vector{Any}, handler_info::Union{Nothing,HandlerInfo})
stmts = BitSet()
n = length(code)
for i in n:-1:1
s = code[i]
if isa(s, Expr)
if s.head === :gotoifnot
if i+1 in stmts && s.args[2]::Int in stmts
push!(stmts, i)
end
elseif s.head === :return
# see `ReturnNode` handling
elseif is_throw_call(s, code)
if handler_info === nothing || handler_info.handler_at[i][1] == 0
push!(stmts, i)
end
elseif i+1 in stmts
push!(stmts, i)
end
elseif isa(s, ReturnNode)
# NOTE: it potentially makes sense to treat unreachable nodes
# (where !isdefined(s, :val)) as `throw` points, but that can cause
# worse codegen around the call site (issue #37558)
elseif isa(s, GotoNode)
if s.label in stmts
push!(stmts, i)
end
elseif isa(s, GotoIfNot)
if i+1 in stmts && s.dest in stmts
push!(stmts, i)
end
elseif i+1 in stmts
push!(stmts, i)
end
end
return stmts
end

# check if coverage mode is enabled
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
end

# Record the correct exception handler for all critical sections
handler_at, handlers = compute_trycatch(code, BitSet())
handler_info = compute_trycatch(code)

phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
Expand Down Expand Up @@ -780,8 +780,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
incoming_vals[id] = Pair{Any, Any}(thisval, thisdef)
has_pinode[id] = false
enter_idx = idx
while handler_at[enter_idx][1] != 0
(; enter_idx) = handlers[handler_at[enter_idx][1]]
while (handler = gethandler(handler_info, enter_idx)) !== nothing
(; enter_idx) = handler
leave_block = block_for_inst(cfg, (code[enter_idx]::EnterNode).catch_dest)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
if cidx !== nothing
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2576,13 +2576,13 @@ end

function current_scope_tfunc(interp::AbstractInterpreter, sv::InferenceState)
pc = sv.currpc
handler_info = sv.handler_info
while true
handleridx = sv.handler_at[pc][1]
if handleridx == 0
pchandler = gethandler(sv, pc)
if pchandler === nothing
# No local scope available - inherited from the outside
return Any
end
pchandler = sv.handlers[handleridx]
# Remember that we looked at this handler, so we get re-scheduled
# if the scope information changes
isdefined(pchandler, :scope_uses) || (pchandler.scope_uses = Int[])
Expand Down
61 changes: 0 additions & 61 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,67 +459,6 @@ function find_ssavalue_uses!(uses::Vector{BitSet}, e::PhiNode, line::Int)
end
end

function is_throw_call(e::Expr, code::Vector{Any})
if e.head === :call
f = e.args[1]
if isa(f, SSAValue)
f = code[f.id]
end
if isa(f, GlobalRef)
ff = abstract_eval_globalref_type(f)
if isa(ff, Const) && ff.val === Core.throw
return true
end
end
end
return false
end

function mark_throw_blocks!(src::CodeInfo, handler_at::Vector{Tuple{Int, Int}})
for stmt in find_throw_blocks(src.code, handler_at)
src.ssaflags[stmt] |= IR_FLAG_THROW_BLOCK
end
return nothing
end

function find_throw_blocks(code::Vector{Any}, handler_at::Vector{Tuple{Int, Int}})
stmts = BitSet()
n = length(code)
for i in n:-1:1
s = code[i]
if isa(s, Expr)
if s.head === :gotoifnot
if i+1 in stmts && s.args[2]::Int in stmts
push!(stmts, i)
end
elseif s.head === :return
# see `ReturnNode` handling
elseif is_throw_call(s, code)
if handler_at[i][1] == 0
push!(stmts, i)
end
elseif i+1 in stmts
push!(stmts, i)
end
elseif isa(s, ReturnNode)
# NOTE: it potentially makes sense to treat unreachable nodes
# (where !isdefined(s, :val)) as `throw` points, but that can cause
# worse codegen around the call site (issue #37558)
elseif isa(s, GotoNode)
if s.label in stmts
push!(stmts, i)
end
elseif isa(s, GotoIfNot)
if i+1 in stmts && s.dest in stmts
push!(stmts, i)
end
elseif i+1 in stmts
push!(stmts, i)
end
end
return stmts
end

# using a function to ensure we can infer this
@inline function slot_id(s)
isa(s, SlotNumber) && return s.id
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4435,7 +4435,7 @@ let x = Tuple{Int,Any}[
#=19=# (0, Expr(:pop_exception, Core.SSAValue(2)))
#=20=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
handler_at, handlers = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
(;handler_at, handlers) = Core.Compiler.compute_trycatch(last.(x))
@test map(x->x[1] == 0 ? 0 : handlers[x[1]].enter_idx, handler_at) == first.(x)
end

Expand Down

0 comments on commit d543508

Please sign in to comment.