Skip to content

Commit

Permalink
Optimizer: Re-use CFG from type inference (#50924)
Browse files Browse the repository at this point in the history
This change will allow us to re-use Inference-collected information at
the basic block level, such as the `bb_vartables`.

There were a couple of spots where the `unreachable` insertion pass at
IRCode conversion (i.e. optimizer entry) was ignoring statically
divergent code that inference had discovered:
- `%x = SlotNumber(3)` can throw and cause the following statements to
be statically unreachable
- `goto #b if not %cond` can be statically throwing if %cond is known to
never be Bool (or to always throw during its own evaluation)

CFG re-computation was hiding these bugs by flowing through the
"Core.Const(...)"-wrapped statements that would follow, inserting
unnecessary but harmless extra branches in the CFG.
  • Loading branch information
Keno committed Aug 18, 2023
2 parents c239e99 + ce6b332 commit 49e6ff8
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 54 deletions.
125 changes: 90 additions & 35 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ end
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)

include("compiler/ssair/driver.jl")

mutable struct OptimizationState{Interp<:AbstractInterpreter}
linfo::MethodInstance
src::CodeInfo
Expand All @@ -131,15 +129,13 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
sptypes::Vector{VarState}
slottypes::Vector{Any}
inlining::InliningState{Interp}
cfg::Union{Nothing,CFG}
cfg::CFG
insert_coverage::Bool
end
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
recompute_cfg::Bool=true)
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
inlining = InliningState(sv, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
sv.sptypes, sv.slottypes, inlining, sv.cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
# prepare src for running optimization passes if it isn't already
Expand All @@ -162,7 +158,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::Abstrac
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(interp)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
cfg = compute_basic_blocks(src.code)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, cfg, false)
end
function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
world = get_world_counter(interp)
Expand All @@ -171,6 +168,9 @@ function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
return OptimizationState(linfo, src, interp)
end


include("compiler/ssair/driver.jl")

function ir_to_codeinf!(opt::OptimizationState)
(; linfo, src) = opt
src = ir_to_codeinf!(src, opt.ir::IRCode)
Expand Down Expand Up @@ -534,7 +534,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
idx = 1
oldidx = 1
nstmts = length(code)
ssachangemap = labelchangemap = nothing
ssachangemap = labelchangemap = blockchangemap = nothing
prevloc = zero(eltype(ci.codelocs))
while idx <= length(code)
codeloc = codelocs[idx]
Expand All @@ -555,54 +555,93 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
if oldidx < length(labelchangemap)
labelchangemap[oldidx + 1] += 1
end
if blockchangemap === nothing
blockchangemap = fill(0, length(sv.cfg.blocks))
end
blockchangemap[block_for_inst(sv.cfg, oldidx)] += 1
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
if ssavaluetypes[idx] === Union{} && !(code[idx] isa Core.Const)
# Type inference should have converted any must-throw terminators to an equivalent w/o control-flow edges
@assert !isterminator(code[idx])

block = block_for_inst(sv.cfg, oldidx)
block_end = last(sv.cfg.blocks[block].stmts) + (idx - oldidx)

# Delete all successors to this basic block
for succ in sv.cfg.blocks[block].succs
preds = sv.cfg.blocks[succ].preds
deleteat!(preds, findfirst(x::Int->x==block, preds)::Int)
end
empty!(sv.cfg.blocks[block].succs)

if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
insert!(codelocs, idx + 1, codelocs[idx])
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, NoCallInfo())
insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
sv.insert_coverage && (labelchangemap[oldidx + 1] += 1)
# Any statements from here to the end of the block have been wrapped in Core.Const(...)
# by type inference (effectively deleting them). Only task left is to replace the block
# terminator with an explicit `unreachable` marker.
if block_end > idx
code[block_end] = ReturnNode()
codelocs[block_end] = codelocs[idx]
ssavaluetypes[block_end] = Union{}
stmtinfo[block_end] = NoCallInfo()
ssaflags[block_end] = IR_FLAG_NOTHROW

# Verify that type-inference did its job
if JLOptions().debug_level == 2
for i = (idx + 1):(block_end - 1)
@assert (code[i] isa Core.Const) || is_meta_expr(code[i])
end
end

idx += block_end - idx
else
insert!(code, idx + 1, ReturnNode())
insert!(codelocs, idx + 1, codelocs[idx])
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, NoCallInfo())
insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
sv.insert_coverage && (labelchangemap[oldidx + 1] += 1)
end
if blockchangemap === nothing
blockchangemap = fill(0, length(sv.cfg.blocks))
end
blockchangemap[block] += 1
idx += 1
end
idx += 1
oldidx = last(sv.cfg.blocks[block].stmts)
end
end
idx += 1
oldidx += 1
end

cfg = sv.cfg
if ssachangemap !== nothing && labelchangemap !== nothing
renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = nothing # recompute CFG
end
if blockchangemap !== nothing
renumber_cfg_stmts!(sv.cfg, blockchangemap)
end

for i = 1:length(code)
code[i] = process_meta!(meta, code[i])
end
strip_trailing_junk!(ci, code, stmtinfo)
strip_trailing_junk!(ci, sv.cfg, code, stmtinfo)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
if cfg === nothing
cfg = compute_basic_blocks(code)
end
# NOTE this `argtypes` contains types of slots yet: it will be modified to contain the
# types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form
# and eliminates slots (see below)
argtypes = sv.slottypes
return IRCode(stmts, cfg, linetable, argtypes, meta, sv.sptypes)
return IRCode(stmts, sv.cfg, linetable, argtypes, meta, sv.sptypes)
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
Expand Down Expand Up @@ -763,8 +802,8 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
return maxcost
end

function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int})
return renumber_ir_elements!(body, ssachangemap, ssachangemap)
function renumber_ir_elements!(body::Vector{Any}, cfg::Union{CFG,Nothing}, ssachangemap::Vector{Int})
return renumber_ir_elements!(body, cfg, ssachangemap, ssachangemap)
end

function cumsum_ssamap!(ssachangemap::Vector{Int})
Expand Down Expand Up @@ -847,3 +886,19 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
end
end
end

function renumber_cfg_stmts!(cfg::CFG, blockchangemap::Vector{Int})
any_change = cumsum_ssamap!(blockchangemap)
any_change || return

last_end = 0
for i = 1:length(cfg.blocks)
old_range = cfg.blocks[i].stmts
new_range = StmtRange(first(old_range) + ((i > 1) ? blockchangemap[i - 1] : 0),
last(old_range) + blockchangemap[i])
cfg.blocks[i] = BasicBlock(cfg.blocks[i], new_range)
if i <= length(cfg.index)
cfg.index[i] = cfg.index[i] + blockchangemap[i]
end
end
end
8 changes: 7 additions & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function rename_uses!(ir::IRCode, ci::CodeInfo, idx::Int, @nospecialize(stmt), r
return fixemup!(stmt::UnoptSlot->true, stmt::UnoptSlot->renames[slot_id(stmt)], ir, ci, idx, stmt)
end

function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{CallInfo})
function strip_trailing_junk!(ci::CodeInfo, cfg::CFG, code::Vector{Any}, info::Vector{CallInfo})
# Remove `nothing`s at the end, we don't handle them well
# (we expect the last instruction to be a terminator)
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
Expand All @@ -207,6 +207,12 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{Call
push!(codelocs, 0)
push!(info, NoCallInfo())
push!(ssaflags, IR_FLAG_NOTHROW)

# Update CFG to include appended terminator
old_range = cfg.blocks[end].stmts
new_range = StmtRange(first(old_range), last(old_range) + 1)
cfg.blocks[end] = BasicBlock(cfg.blocks[end], new_range)
(length(cfg.index) == length(cfg.blocks)) && (cfg.index[end] += 1)
end
nothing
end
Expand Down
47 changes: 44 additions & 3 deletions base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ end
function verify_ir(ir::IRCode, print::Bool=true,
allow_frontend_forms::Bool=false,
𝕃ₒ::AbstractLattice = SimpleInferenceLattice.instance)
# Verify CFG graph. Must be well formed to construct domtree
if !(length(ir.cfg.blocks) - 1 <= length(ir.cfg.index) <= length(ir.cfg.blocks))
@verify_error "CFG index length ($(length(ir.cfg.index))) does not correspond to # of blocks $(length(ir.cfg.blocks))"
error("")
end
if length(ir.stmts.stmt) != length(ir.stmts)
@verify_error "IR stmt length is invalid $(length(ir.stmts.stmt)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.type) != length(ir.stmts)
@verify_error "IR type length is invalid $(length(ir.stmts.type)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.info) != length(ir.stmts)
@verify_error "IR info length is invalid $(length(ir.stmts.info)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.line) != length(ir.stmts)
@verify_error "IR line length is invalid $(length(ir.stmts.line)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.flag) != length(ir.stmts)
@verify_error "IR flag length is invalid $(length(ir.stmts.flag)) / $(length(ir.stmts))"
error("")
end
# For now require compact IR
# @assert isempty(ir.new_nodes)
# Verify CFG
Expand Down Expand Up @@ -125,6 +150,18 @@ function verify_ir(ir::IRCode, print::Bool=true,
error("")
end
end
if !(1 <= first(block.stmts) <= length(ir.stmts))
@verify_error "First statement of BB $idx ($(first(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))"
error("")
end
if !(1 <= last(block.stmts) <= length(ir.stmts))
@verify_error "Last statement of BB $idx ($(last(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))"
error("")
end
if idx <= length(ir.cfg.index) && last(block.stmts) + 1 != ir.cfg.index[idx]
@verify_error "End of BB $idx ($(last(block.stmts))) is not one less than CFG index ($(ir.cfg.index[idx]))"
error("")
end
end
# Verify statements
domtree = construct_domtree(ir.cfg.blocks)
Expand All @@ -145,7 +182,7 @@ function verify_ir(ir::IRCode, print::Bool=true,
end
elseif isa(terminator, GotoNode)
if length(block.succs) != 1 || block.succs[1] != terminator.label
@verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator"
@verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator ($(terminator.label))"
error("")
end
elseif isa(terminator, GotoIfNot)
Expand All @@ -167,8 +204,8 @@ function verify_ir(ir::IRCode, print::Bool=true,
if length(block.succs) != 1 || block.succs[1] != idx + 1
# As a special case, we allow extra statements in the BB of an :enter
# statement, until we can do proper CFG manipulations during compaction.
for idx in first(block.stmts):last(block.stmts)
stmt = ir[SSAValue(idx)][:stmt]
for stmt_idx in first(block.stmts):last(block.stmts)
stmt = ir[SSAValue(stmt_idx)][:stmt]
if isexpr(stmt, :enter)
terminator = stmt
@goto enter_check
Expand All @@ -188,6 +225,10 @@ function verify_ir(ir::IRCode, print::Bool=true,
end
end
end
if length(ir.stmts) != last(ir.cfg.blocks[end].stmts)
@verify_error "End of last BB $(last(ir.cfg.blocks[end].stmts)) does not match last IR statement $(length(ir.stmts))"
error("")
end
lastbb = 0
is_phinode_block = false
firstidx = 1
Expand Down
31 changes: 21 additions & 10 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,9 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
doopt = (me.cached || me.parent !== nothing)
recompute_cfg = type_annotate!(interp, me, doopt)
type_annotate!(interp, me, doopt)
if doopt && may_optimize(interp)
me.result.src = OptimizationState(me, interp, recompute_cfg)
me.result.src = OptimizationState(me, interp)
else
me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection)
end
Expand Down Expand Up @@ -713,20 +713,31 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState, run_opt
# 3. mark unreached statements for a bulk code deletion (see issue #7836)
# 4. widen slot wrappers (`Conditional` and `MustAlias`) and remove `NOT_FOUND` from `ssavaluetypes`
# NOTE because of this, `was_reached` will no longer be available after this point
# 5. eliminate GotoIfNot if either branch target is unreachable
# 5. eliminate GotoIfNot if either or both branches are statically unreachable
changemap = nothing # initialized if there is any dead region
for i = 1:nstmt
expr = stmts[i]
if was_reached(sv, i)
if run_optimizer
if isa(expr, GotoIfNot) && widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool
if isa(expr, GotoIfNot)
# 5: replace this live GotoIfNot with:
# - GotoNode if the fallthrough target is unreachable
# - no-op if the branch target is unreachable
if !was_reached(sv, i+1)
expr = GotoNode(expr.dest)
elseif !was_reached(sv, expr.dest)
expr = nothing
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
if widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool
block = block_for_inst(sv.cfg, i)
if !was_reached(sv, i+1)
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif !was_reached(sv, expr.dest)
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
end
elseif ssavaluetypes[i] === Bottom
block = block_for_inst(sv.cfg, i)
cfg_delete_edge!(sv.cfg, block, block + 1)
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
end
end
end
Expand Down
3 changes: 3 additions & 0 deletions test/compiler/interpreter_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down Expand Up @@ -61,6 +62,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down Expand Up @@ -98,6 +100,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down
1 change: 1 addition & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ function each_stmt_a_bb(stmts, preds, succs)
empty!(ir.stmts.line); append!(ir.stmts.line, [Int32(0) for _ = 1:length(stmts)])
empty!(ir.stmts.info); append!(ir.stmts.info, [NoCallInfo() for _ = 1:length(stmts)])
empty!(ir.cfg.blocks); append!(ir.cfg.blocks, [BasicBlock(StmtRange(i, i), preds[i], succs[i]) for i = 1:length(stmts)])
empty!(ir.cfg.index); append!(ir.cfg.index, [i for i = 2:length(stmts)])
Core.Compiler.verify_ir(ir)
return ir
end
Expand Down
Loading

0 comments on commit 49e6ff8

Please sign in to comment.