diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 782a591a6d449..f419e952b6f23 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -120,8 +120,6 @@ end # get `code_cache(::AbstractInterpreter)` from `state::InliningState` code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world) -include("compiler/ssair/driver.jl") - mutable struct OptimizationState{Interp<:AbstractInterpreter} linfo::MethodInstance src::CodeInfo @@ -131,15 +129,13 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter} sptypes::Vector{VarState} slottypes::Vector{Any} inlining::InliningState{Interp} - cfg::Union{Nothing,CFG} + cfg::CFG insert_coverage::Bool end -function OptimizationState(sv::InferenceState, interp::AbstractInterpreter, - recompute_cfg::Bool=true) +function OptimizationState(sv::InferenceState, interp::AbstractInterpreter) inlining = InliningState(sv, interp) - cfg = recompute_cfg ? nothing : sv.cfg return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod, - sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage) + sv.sptypes, sv.slottypes, inlining, sv.cfg, sv.insert_coverage) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter) # prepare src for running optimization passes if it isn't already @@ -162,7 +158,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::Abstrac # Allow using the global MI cache, but don't track edges. # This method is mostly used for unit testing the optimizer inlining = InliningState(interp) - return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false) + cfg = compute_basic_blocks(src.code) + return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, cfg, false) end function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter) world = get_world_counter(interp) @@ -171,6 +168,9 @@ function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter) return OptimizationState(linfo, src, interp) end + +include("compiler/ssair/driver.jl") + function ir_to_codeinf!(opt::OptimizationState) (; linfo, src) = opt src = ir_to_codeinf!(src, opt.ir::IRCode) @@ -534,7 +534,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) idx = 1 oldidx = 1 nstmts = length(code) - ssachangemap = labelchangemap = nothing + ssachangemap = labelchangemap = blockchangemap = nothing prevloc = zero(eltype(ci.codelocs)) while idx <= length(code) codeloc = codelocs[idx] @@ -555,54 +555,93 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) if oldidx < length(labelchangemap) labelchangemap[oldidx + 1] += 1 end + if blockchangemap === nothing + blockchangemap = fill(0, length(sv.cfg.blocks)) + end + blockchangemap[block_for_inst(sv.cfg, oldidx)] += 1 idx += 1 prevloc = codeloc end - if code[idx] isa Expr && ssavaluetypes[idx] === Union{} + if ssavaluetypes[idx] === Union{} && !(code[idx] isa Core.Const) + # Type inference should have converted any must-throw terminators to an equivalent w/o control-flow edges + @assert !isterminator(code[idx]) + + block = block_for_inst(sv.cfg, oldidx) + block_end = last(sv.cfg.blocks[block].stmts) + (idx - oldidx) + + # Delete all successors to this basic block + for succ in sv.cfg.blocks[block].succs + preds = sv.cfg.blocks[succ].preds + deleteat!(preds, findfirst(x::Int->x==block, preds)::Int) + end + empty!(sv.cfg.blocks[block].succs) + if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val)) - # insert unreachable in the same basic block after the current instruction (splitting it) - insert!(code, idx + 1, ReturnNode()) - insert!(codelocs, idx + 1, codelocs[idx]) - insert!(ssavaluetypes, idx + 1, Union{}) - insert!(stmtinfo, idx + 1, NoCallInfo()) - insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW) - if ssachangemap === nothing - ssachangemap = fill(0, nstmts) - end - if labelchangemap === nothing - labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap - end - if oldidx < length(ssachangemap) - ssachangemap[oldidx + 1] += 1 - sv.insert_coverage && (labelchangemap[oldidx + 1] += 1) + # Any statements from here to the end of the block have been wrapped in Core.Const(...) + # by type inference (effectively deleting them). Only task left is to replace the block + # terminator with an explicit `unreachable` marker. + if block_end > idx + code[block_end] = ReturnNode() + codelocs[block_end] = codelocs[idx] + ssavaluetypes[block_end] = Union{} + stmtinfo[block_end] = NoCallInfo() + ssaflags[block_end] = IR_FLAG_NOTHROW + + # Verify that type-inference did its job + if JLOptions().debug_level == 2 + for i = (idx + 1):(block_end - 1) + @assert (code[i] isa Core.Const) || is_meta_expr(code[i]) + end + end + + idx += block_end - idx + else + insert!(code, idx + 1, ReturnNode()) + insert!(codelocs, idx + 1, codelocs[idx]) + insert!(ssavaluetypes, idx + 1, Union{}) + insert!(stmtinfo, idx + 1, NoCallInfo()) + insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW) + if ssachangemap === nothing + ssachangemap = fill(0, nstmts) + end + if labelchangemap === nothing + labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap + end + if oldidx < length(ssachangemap) + ssachangemap[oldidx + 1] += 1 + sv.insert_coverage && (labelchangemap[oldidx + 1] += 1) + end + if blockchangemap === nothing + blockchangemap = fill(0, length(sv.cfg.blocks)) + end + blockchangemap[block] += 1 + idx += 1 end - idx += 1 + oldidx = last(sv.cfg.blocks[block].stmts) end end idx += 1 oldidx += 1 end - cfg = sv.cfg if ssachangemap !== nothing && labelchangemap !== nothing renumber_ir_elements!(code, ssachangemap, labelchangemap) - cfg = nothing # recompute CFG + end + if blockchangemap !== nothing + renumber_cfg_stmts!(sv.cfg, blockchangemap) end for i = 1:length(code) code[i] = process_meta!(meta, code[i]) end - strip_trailing_junk!(ci, code, stmtinfo) + strip_trailing_junk!(ci, sv.cfg, code, stmtinfo) types = Any[] stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags) - if cfg === nothing - cfg = compute_basic_blocks(code) - end # NOTE this `argtypes` contains types of slots yet: it will be modified to contain the # types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form # and eliminates slots (see below) argtypes = sv.slottypes - return IRCode(stmts, cfg, linetable, argtypes, meta, sv.sptypes) + return IRCode(stmts, sv.cfg, linetable, argtypes, meta, sv.sptypes) end function process_meta!(meta::Vector{Expr}, @nospecialize stmt) @@ -763,8 +802,8 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI return maxcost end -function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}) - return renumber_ir_elements!(body, ssachangemap, ssachangemap) +function renumber_ir_elements!(body::Vector{Any}, cfg::Union{CFG,Nothing}, ssachangemap::Vector{Int}) + return renumber_ir_elements!(body, cfg, ssachangemap, ssachangemap) end function cumsum_ssamap!(ssachangemap::Vector{Int}) @@ -847,3 +886,19 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab end end end + +function renumber_cfg_stmts!(cfg::CFG, blockchangemap::Vector{Int}) + any_change = cumsum_ssamap!(blockchangemap) + any_change || return + + last_end = 0 + for i = 1:length(cfg.blocks) + old_range = cfg.blocks[i].stmts + new_range = StmtRange(first(old_range) + ((i > 1) ? blockchangemap[i - 1] : 0), + last(old_range) + blockchangemap[i]) + cfg.blocks[i] = BasicBlock(cfg.blocks[i], new_range) + if i <= length(cfg.index) + cfg.index[i] = cfg.index[i] + blockchangemap[i] + end + end +end diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index a8b82c4af6c33..3966dde7fbbd9 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -183,7 +183,7 @@ function rename_uses!(ir::IRCode, ci::CodeInfo, idx::Int, @nospecialize(stmt), r return fixemup!(stmt::UnoptSlot->true, stmt::UnoptSlot->renames[slot_id(stmt)], ir, ci, idx, stmt) end -function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{CallInfo}) +function strip_trailing_junk!(ci::CodeInfo, cfg::CFG, code::Vector{Any}, info::Vector{CallInfo}) # Remove `nothing`s at the end, we don't handle them well # (we expect the last instruction to be a terminator) ssavaluetypes = ci.ssavaluetypes::Vector{Any} @@ -207,6 +207,12 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{Call push!(codelocs, 0) push!(info, NoCallInfo()) push!(ssaflags, IR_FLAG_NOTHROW) + + # Update CFG to include appended terminator + old_range = cfg.blocks[end].stmts + new_range = StmtRange(first(old_range), last(old_range) + 1) + cfg.blocks[end] = BasicBlock(cfg.blocks[end], new_range) + (length(cfg.index) == length(cfg.blocks)) && (cfg.index[end] += 1) end nothing end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index f2a1782629124..c9cef58f3566d 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -89,6 +89,31 @@ end function verify_ir(ir::IRCode, print::Bool=true, allow_frontend_forms::Bool=false, 𝕃ₒ::AbstractLattice = SimpleInferenceLattice.instance) + # Verify CFG graph. Must be well formed to construct domtree + if !(length(ir.cfg.blocks) - 1 <= length(ir.cfg.index) <= length(ir.cfg.blocks)) + @verify_error "CFG index length ($(length(ir.cfg.index))) does not correspond to # of blocks $(length(ir.cfg.blocks))" + error("") + end + if length(ir.stmts.stmt) != length(ir.stmts) + @verify_error "IR stmt length is invalid $(length(ir.stmts.stmt)) / $(length(ir.stmts))" + error("") + end + if length(ir.stmts.type) != length(ir.stmts) + @verify_error "IR type length is invalid $(length(ir.stmts.type)) / $(length(ir.stmts))" + error("") + end + if length(ir.stmts.info) != length(ir.stmts) + @verify_error "IR info length is invalid $(length(ir.stmts.info)) / $(length(ir.stmts))" + error("") + end + if length(ir.stmts.line) != length(ir.stmts) + @verify_error "IR line length is invalid $(length(ir.stmts.line)) / $(length(ir.stmts))" + error("") + end + if length(ir.stmts.flag) != length(ir.stmts) + @verify_error "IR flag length is invalid $(length(ir.stmts.flag)) / $(length(ir.stmts))" + error("") + end # For now require compact IR # @assert isempty(ir.new_nodes) # Verify CFG @@ -125,6 +150,18 @@ function verify_ir(ir::IRCode, print::Bool=true, error("") end end + if !(1 <= first(block.stmts) <= length(ir.stmts)) + @verify_error "First statement of BB $idx ($(first(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))" + error("") + end + if !(1 <= last(block.stmts) <= length(ir.stmts)) + @verify_error "Last statement of BB $idx ($(last(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))" + error("") + end + if idx <= length(ir.cfg.index) && last(block.stmts) + 1 != ir.cfg.index[idx] + @verify_error "End of BB $idx ($(last(block.stmts))) is not one less than CFG index ($(ir.cfg.index[idx]))" + error("") + end end # Verify statements domtree = construct_domtree(ir.cfg.blocks) @@ -145,7 +182,7 @@ function verify_ir(ir::IRCode, print::Bool=true, end elseif isa(terminator, GotoNode) if length(block.succs) != 1 || block.succs[1] != terminator.label - @verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator" + @verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator ($(terminator.label))" error("") end elseif isa(terminator, GotoIfNot) @@ -167,8 +204,8 @@ function verify_ir(ir::IRCode, print::Bool=true, if length(block.succs) != 1 || block.succs[1] != idx + 1 # As a special case, we allow extra statements in the BB of an :enter # statement, until we can do proper CFG manipulations during compaction. - for idx in first(block.stmts):last(block.stmts) - stmt = ir[SSAValue(idx)][:stmt] + for stmt_idx in first(block.stmts):last(block.stmts) + stmt = ir[SSAValue(stmt_idx)][:stmt] if isexpr(stmt, :enter) terminator = stmt @goto enter_check @@ -188,6 +225,10 @@ function verify_ir(ir::IRCode, print::Bool=true, end end end + if length(ir.stmts) != last(ir.cfg.blocks[end].stmts) + @verify_error "End of last BB $(last(ir.cfg.blocks[end].stmts)) does not match last IR statement $(length(ir.stmts))" + error("") + end lastbb = 0 is_phinode_block = false firstidx = 1 diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index e867f5e9ad9dc..1342722f53206 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -551,9 +551,9 @@ function finish(me::InferenceState, interp::AbstractInterpreter) # annotate fulltree with type information, # either because we are the outermost code, or we might use this later doopt = (me.cached || me.parent !== nothing) - recompute_cfg = type_annotate!(interp, me, doopt) + type_annotate!(interp, me, doopt) if doopt && may_optimize(interp) - me.result.src = OptimizationState(me, interp, recompute_cfg) + me.result.src = OptimizationState(me, interp) else me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection) end @@ -713,20 +713,31 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState, run_opt # 3. mark unreached statements for a bulk code deletion (see issue #7836) # 4. widen slot wrappers (`Conditional` and `MustAlias`) and remove `NOT_FOUND` from `ssavaluetypes` # NOTE because of this, `was_reached` will no longer be available after this point - # 5. eliminate GotoIfNot if either branch target is unreachable + # 5. eliminate GotoIfNot if either or both branches are statically unreachable changemap = nothing # initialized if there is any dead region for i = 1:nstmt expr = stmts[i] if was_reached(sv, i) if run_optimizer - if isa(expr, GotoIfNot) && widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool + if isa(expr, GotoIfNot) # 5: replace this live GotoIfNot with: - # - GotoNode if the fallthrough target is unreachable - # - no-op if the branch target is unreachable - if !was_reached(sv, i+1) - expr = GotoNode(expr.dest) - elseif !was_reached(sv, expr.dest) - expr = nothing + # - no-op if :nothrow and the branch target is unreachable + # - cond if :nothrow and both targets are unreachable + # - typeassert if must-throw + if widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool + block = block_for_inst(sv.cfg, i) + if !was_reached(sv, i+1) + cfg_delete_edge!(sv.cfg, block, block + 1) + expr = GotoNode(expr.dest) + elseif !was_reached(sv, expr.dest) + cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest)) + expr = nothing + end + elseif ssavaluetypes[i] === Bottom + block = block_for_inst(sv.cfg, i) + cfg_delete_edge!(sv.cfg, block, block + 1) + cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest)) + expr = Expr(:call, Core.typeassert, expr.cond, Bool) end end end diff --git a/test/compiler/interpreter_exec.jl b/test/compiler/interpreter_exec.jl index a310a2740131d..dbadf64c97491 100644 --- a/test/compiler/interpreter_exec.jl +++ b/test/compiler/interpreter_exec.jl @@ -21,6 +21,7 @@ let m = Meta.@lower 1 + 1 ] nstmts = length(src.code) src.ssavaluetypes = Any[ Any for _ = 1:nstmts ] + src.ssaflags = fill(UInt8(0x00), nstmts) src.codelocs = fill(Int32(1), nstmts) src.inferred = true Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src)) @@ -61,6 +62,7 @@ let m = Meta.@lower 1 + 1 ] nstmts = length(src.code) src.ssavaluetypes = Any[ Any for _ = 1:nstmts ] + src.ssaflags = fill(UInt8(0x00), nstmts) src.codelocs = fill(Int32(1), nstmts) src.inferred = true Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src)) @@ -98,6 +100,7 @@ let m = Meta.@lower 1 + 1 ] nstmts = length(src.code) src.ssavaluetypes = Any[ Any for _ = 1:nstmts ] + src.ssaflags = fill(UInt8(0x00), nstmts) src.codelocs = fill(Int32(1), nstmts) src.inferred = true Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src)) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 1ff477fa22293..a8d5df6a24cdf 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -849,6 +849,7 @@ function each_stmt_a_bb(stmts, preds, succs) empty!(ir.stmts.line); append!(ir.stmts.line, [Int32(0) for _ = 1:length(stmts)]) empty!(ir.stmts.info); append!(ir.stmts.info, [NoCallInfo() for _ = 1:length(stmts)]) empty!(ir.cfg.blocks); append!(ir.cfg.blocks, [BasicBlock(StmtRange(i, i), preds[i], succs[i]) for i = 1:length(stmts)]) + empty!(ir.cfg.index); append!(ir.cfg.index, [i for i = 2:length(stmts)]) Core.Compiler.verify_ir(ir) return ir end diff --git a/test/compiler/irutils.jl b/test/compiler/irutils.jl index c98e824b1ddae..9c79f40f33280 100644 --- a/test/compiler/irutils.jl +++ b/test/compiler/irutils.jl @@ -43,16 +43,22 @@ fully_eliminated(src::CodeInfo; retval=(@__FILE__)) = fully_eliminated(src.code; fully_eliminated(ir::IRCode; retval=(@__FILE__)) = fully_eliminated(ir.stmts.stmt; retval) function fully_eliminated(code::Vector{Any}; retval=(@__FILE__), kwargs...) if retval !== (@__FILE__) - length(code) == 1 || return false - code1 = code[1] - isreturn(code1) || return false - val = code1.val + (length(code) <= 2) || return false + for i = 1:(length(code) - 1) + code[i] === nothing || return false + end + isreturn(code[end]) || return false + val = code[end].val if val isa QuoteNode val = val.value end return val == retval else - return length(code) == 1 && isreturn(code[1]) + (length(code) <= 2) || return false + for i = 1:(length(code) - 1) + code[i] === nothing || return false + end + return isreturn(code[end]) end end macro fully_eliminated(ex0...)