diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 738da73d93629..e488fc685277a 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -465,7 +465,8 @@ mutable struct IncrementalCompact result_flags::Vector{UInt8} result_bbs::Vector{BasicBlock} ssa_rename::Vector{Any} - bb_rename::Vector{Int} + bb_rename_pred::Vector{Int} + bb_rename_succ::Vector{Int} used_ssas::Vector{Int} late_fixup::Vector{Int} # This could be Stateful, but bootstrapping doesn't like that @@ -481,7 +482,8 @@ mutable struct IncrementalCompact result_idx::Int active_result_bb::Int renamed_new_nodes::Bool - allow_cfg_transforms::Bool + cfg_transforms_enabled::Bool + fold_constant_branches::Bool function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false) # Sort by position with attach after nodes affter regular ones perm = my_sortperm(Int[(code.new_nodes[i].pos*2 + Int(code.new_nodes[i].attach_after)) for i in 1:length(code.new_nodes)]) @@ -525,9 +527,9 @@ mutable struct IncrementalCompact new_new_nodes = NewNode[] pending_nodes = NewNode[] pending_perm = Int[] - return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, used_ssas, late_fixup, perm, 1, + return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1, new_new_nodes, pending_nodes, pending_perm, - 1, 1, 1, false, allow_cfg_transforms) + 1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms) end # For inlining @@ -542,10 +544,10 @@ mutable struct IncrementalCompact pending_nodes = NewNode[] pending_perm = Int[] return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, - parent.result_bbs, ssa_rename, bb_rename, parent.used_ssas, + parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas, late_fixup, perm, 1, new_new_nodes, pending_nodes, pending_perm, - 1, result_offset, parent.active_result_bb, false, false) + 1, result_offset, parent.active_result_bb, false, false, false) end end @@ -646,6 +648,18 @@ function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @ end end +function append_node!(ir, @nospecialize(typ), @nospecialize(node), line) + push!(ir.stmts, node) + push!(ir.types, typ) + push!(ir.lines, line) + push!(ir.flags, 0) + last_bb = ir.cfg.blocks[end] + ir.cfg.blocks[end] = BasicBlock(first(last_bb.stmts):length(ir.stmts), + last_bb.preds, + last_bb.succs) + return SSAValue(length(ir.stmts)) +end + function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int32, reverse_affinity::Bool=false) if compact.result_idx > length(compact.result) @assert compact.result_idx == length(compact.result) + 1 @@ -823,17 +837,17 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to:: # Note: We recursively kill as many edges as are obviously dead. However, this # may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or # worstcase during codegen). - preds, succs = compact.result_bbs[compact.bb_rename[to]].preds, compact.result_bbs[compact.bb_rename[from]].succs - deleteat!(preds, findfirst(x->x === compact.bb_rename[from], preds)::Int) - deleteat!(succs, findfirst(x->x === compact.bb_rename[to], succs)::Int) + preds, succs = compact.result_bbs[compact.bb_rename_succ[to]].preds, compact.result_bbs[compact.bb_rename_pred[from]].succs + deleteat!(preds, findfirst(x->x === compact.bb_rename_pred[from], preds)::Int) + deleteat!(succs, findfirst(x->x === compact.bb_rename_succ[to], succs)::Int) # Check if the block is now dead if length(preds) == 0 - for succ in copy(compact.result_bbs[compact.bb_rename[to]].succs) - kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename)) + for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs) + kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)) end if to < active_bb # Kill all statements in the block - stmts = compact.result_bbs[compact.bb_rename[to]].stmts + stmts = compact.result_bbs[compact.bb_rename_succ[to]].stmts for stmt in stmts compact.result[stmt] = nothing end @@ -842,12 +856,12 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to:: else # We need to remove this edge from any phi nodes if to < active_bb - idx = first(compact.result_bbs[compact.bb_rename[to]].stmts) + idx = first(compact.result_bbs[compact.bb_rename_succ[to]].stmts) while idx < length(compact.result) stmt = compact.result[idx] stmt === nothing && continue isa(stmt, PhiNode) || break - i = findfirst(x-> x === compact.bb_rename[from], stmt.edges) + i = findfirst(x-> x === compact.bb_rename_pred[from], stmt.edges) if i !== nothing deleteat!(stmt.edges, i) deleteat!(stmt.values, i) @@ -879,34 +893,34 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any}, ssa_rename[idx] = stmt elseif isa(stmt, OldSSAValue) ssa_rename[idx] = ssa_rename[stmt.id] - elseif isa(stmt, GotoNode) && compact.allow_cfg_transforms - result[result_idx] = GotoNode(compact.bb_rename[stmt.label]) + elseif isa(stmt, GotoNode) && compact.cfg_transforms_enabled + result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.label]) result_idx += 1 elseif isa(stmt, GlobalRef) || isa(stmt, GotoNode) result[result_idx] = stmt result_idx += 1 - elseif isa(stmt, GotoIfNot) && compact.allow_cfg_transforms + elseif isa(stmt, GotoIfNot) && compact.cfg_transforms_enabled stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot result[result_idx] = stmt cond = stmt.cond - if isa(cond, Bool) + if isa(cond, Bool) && compact.fold_constant_branches if cond result[result_idx] = nothing kill_edge!(compact, active_bb, active_bb, stmt.dest) # Don't increment result_idx => Drop this statement else - result[result_idx] = GotoNode(compact.bb_rename[stmt.dest]) + result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.dest]) kill_edge!(compact, active_bb, active_bb, active_bb+1) result_idx += 1 end else - result[result_idx] = GotoIfNot(cond, compact.bb_rename[stmt.dest]) + result[result_idx] = GotoIfNot(cond, compact.bb_rename_succ[stmt.dest]) result_idx += 1 end elseif isa(stmt, Expr) stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr - if compact.allow_cfg_transforms && isexpr(stmt, :enter) - stmt.args[1] = compact.bb_rename[stmt.args[1]::Int] + if compact.cfg_transforms_enabled && isexpr(stmt, :enter) + stmt.args[1] = compact.bb_rename_succ[stmt.args[1]::Int] end result[result_idx] = stmt result_idx += 1 @@ -936,13 +950,13 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any}, elseif isa(stmt, PhiNode) values = process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa) if length(stmt.edges) == 1 && isassigned(values, 1) && - length(compact.allow_cfg_transforms ? - compact.result_bbs[compact.bb_rename[active_bb]].preds : + length(compact.cfg_transforms_enabled ? + compact.result_bbs[compact.bb_rename_succ[active_bb]].preds : compact.ir.cfg.blocks[active_bb].preds) == 1 # There's only one predecessor left - just replace it ssa_rename[idx] = values[1] else - edges = compact.allow_cfg_transforms ? map!(i->compact.bb_rename[i], stmt.edges, stmt.edges) : stmt.edges + edges = compact.cfg_transforms_enabled ? map!(i->compact.bb_rename_pred[i], stmt.edges, stmt.edges) : stmt.edges result[result_idx] = PhiNode(edges, values) result_idx += 1 end @@ -983,14 +997,14 @@ end function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_idx, unreachable=false) if compact.active_result_bb > length(compact.result_bbs) - @assert compact.bb_rename[active_bb] == 0 + #@assert compact.bb_rename[active_bb] == 0 return true end bb = compact.result_bbs[compact.active_result_bb] # If this was the last statement in the BB and we decided to skip it, insert a # dummy `nothing` node, to prevent changing the structure of the CFG skipped = false - if !compact.allow_cfg_transforms || active_bb == 0 || active_bb > length(compact.bb_rename) || compact.bb_rename[active_bb] != 0 + if !compact.cfg_transforms_enabled || active_bb == 0 || active_bb > length(compact.bb_rename_succ) || compact.bb_rename_succ[active_bb] != 0 if compact.result_idx == first(bb.stmts) length(compact.result) < old_result_idx && resize!(compact, old_result_idx) if unreachable @@ -1003,7 +1017,7 @@ function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_id compact.result_lines[old_result_idx] = 0 compact.result_flags[old_result_idx] = 0x00 compact.result_idx = old_result_idx + 1 - elseif compact.allow_cfg_transforms && compact.result_idx - 1 == first(bb.stmts) + elseif compact.cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts) # Optimization: If this BB consists of only a branch, eliminate this bb end compact.result_bbs[compact.active_result_bb] = BasicBlock(bb, StmtRange(first(bb.stmts), compact.result_idx-1)) @@ -1083,7 +1097,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= resize!(compact, old_result_idx) end bb = compact.ir.cfg.blocks[active_bb] - if compact.allow_cfg_transforms && active_bb > 1 && active_bb <= length(compact.bb_rename) && length(bb.preds) == 0 + if compact.cfg_transforms_enabled && active_bb > 1 && active_bb <= length(compact.bb_rename_succ) && length(bb.preds) == 0 # No predecessors, kill the entire block. compact.idx = last(bb.stmts) # Pop any remaining insertion nodes @@ -1274,8 +1288,8 @@ function complete(compact::IncrementalCompact) return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes) end -function compact!(code::IRCode) - compact = IncrementalCompact(code) +function compact!(code::IRCode, allow_cfg_transforms=false) + compact = IncrementalCompact(code, allow_cfg_transforms) # Just run through the iterator without any processing foreach(x -> nothing, compact) # x isa Pair{Int, Any} return finish(compact) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index f93e490f56613..5fd9bbaa25044 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1012,3 +1012,131 @@ function type_lift_pass!(ir::IRCode) end ir end + +function cfg_simplify!(ir::IRCode) + bbs = ir.cfg.blocks + merge_into = zeros(Int, length(bbs)) + merged_succ = zeros(Int, length(bbs)) + + # Walk the CFG at from the entry block and aggressively combine blocks + for (idx, bb) in enumerate(bbs) + if length(bb.succs) == 1 + succ = bb.succs[1] + if length(bbs[succ].preds) == 1 + merge_into[succ] = idx + merged_succ[idx] = succ + end + end + end + max_bb_num = 1 + bb_rename_succ = zeros(Int, length(bbs)) + # Lay out the basic blocks + for i = 1:length(bbs) + if merge_into[i] != 0 + bb_rename_succ[i] = -1 + continue + end + # Drop unreachable blocks + if i != 1 && length(ir.cfg.blocks[i].preds) == 0 + bb_rename_succ[i] = -1 + end + bb_rename_succ[i] != 0 && continue + curr = i + while true + bb_rename_succ[curr] = max_bb_num + max_bb_num += 1 + # Now walk the chain of blocks we merged. + # If we end in something that may fall through, + # we have to schedule that block next + while merged_succ[curr] != 0 + curr = merged_succ[curr] + end + terminator = ir.stmts[ir.cfg.blocks[curr].stmts[end]] + if isa(terminator, GotoNode) || isa(terminator, ReturnNode) + break + end + curr += 1 + end + end + bb_rename_pred = zeros(Int, length(bbs)) + for i = 1:length(bbs) + if merged_succ[i] != 0 + bb_rename_pred[i] = -1 + continue + end + bbnum = i + while merge_into[bbnum] != 0 + bbnum = merge_into[bbnum] + end + bb_rename_pred[i] = bb_rename_succ[bbnum] + end + result_bbs = Int[findfirst(j->i==j, bb_rename_succ) for i = 1:max_bb_num-1] + result_bbs_lengths = zeros(Int, max_bb_num-1) + for (idx, orig_bb) in enumerate(result_bbs) + ms = orig_bb + while ms != 0 + result_bbs_lengths[idx] += length(bbs[ms].stmts) + ms = merged_succ[ms] + end + end + bb_starts = Vector{Int}(undef, 1+length(result_bbs_lengths)) + bb_starts[1] = 1 + for i = 1:length(result_bbs_lengths) + bb_starts[i+1] = bb_starts[i] + result_bbs_lengths[i] + end + # Look at the original successor + function compute_succs(i) + orig_bb = result_bbs[i] + while merged_succ[orig_bb] != 0 + orig_bb = merged_succ[orig_bb] + end + map(i->bb_rename_succ[i], bbs[orig_bb].succs) + end + + function compute_preds(i) + orig_bb = result_bbs[i] + preds = bbs[orig_bb].preds + map(preds) do pred + while merge_into[pred] != 0 + pred = merge_into[pred] + end + bb_rename_succ[pred] + end + end + cresult_bbs = BasicBlock[BasicBlock( + StmtRange(bb_starts[i], i+1 > length(bb_starts) ? length(compact.result) : bb_starts[i+1]-1), + compute_preds(i), compute_succs(i)) for i = 1:length(result_bbs)] + compact = IncrementalCompact(ir, true) + # We're messing with the CFG. We don't want compaction to do + # so independently + compact.fold_constant_branches = false + compact.bb_rename_succ = bb_rename_succ + compact.bb_rename_pred = bb_rename_pred + compact.result_bbs = cresult_bbs + result_idx = 1 + for (idx, orig_bb) in enumerate(result_bbs) + ms = orig_bb + while ms != 0 + for i in bbs[ms].stmts + stmt = ir.stmts[i] + compact.result[compact.result_idx] = nothing + compact.result_types[compact.result_idx] = ir.types[i] + compact.result_lines[compact.result_idx] = ir.lines[i] + compact.result_flags[compact.result_idx] = ir.flags[i] + # If we merged a basic block, we need remove the trailing GotoNode (if any) + if isa(stmt, GotoNode) && merged_succ[ms] != 0 + # Do nothing + else + process_node!(compact, compact.result_idx, stmt, i, i, ms, true) + end + # We always increase the result index to ensure a predicatable + # placement of the resulting nodes. + compact.result_idx += 1 + end + ms = merged_succ[ms] + end + end + + compact.active_result_bb = length(bb_starts) + return finish(compact) +end diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index fe2e244a8821c..bd9ed0de6f73e 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -551,7 +551,7 @@ function show_ir(io::IO, code::IRCode, expr_type_printer=default_expr_type_print # Compute BB guard rail if bb_idx > length(cfg.blocks) # Even if invariants are violated, try our best to still print - bbrange = (last(cfg.blocks[end].stmts) + 1):typemax(Int) + bbrange = (length(cfg.blocks) == 0 ? 1 : last(cfg.blocks[end].stmts) + 1):typemax(Int) bb_idx_str = "!" bb_type = "─" else diff --git a/base/show.jl b/base/show.jl index 378ded97aa1d0..d42e09296d2bb 100644 --- a/base/show.jl +++ b/base/show.jl @@ -1578,6 +1578,7 @@ module IRShow using Core.IR import ..Base import .Compiler: IRCode, ReturnNode, GotoIfNot, CFG, scan_ssa_use!, Argument, isexpr, compute_basic_blocks, block_for_inst + Base.getindex(r::Compiler.StmtRange, ind::Integer) = Compiler.getindex(r, ind) Base.size(r::Compiler.StmtRange) = Compiler.size(r) Base.first(r::Compiler.StmtRange) = Compiler.first(r) Base.last(r::Compiler.StmtRange) = Compiler.last(r) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index bf655a8feac09..b01bc679164f5 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -183,3 +183,66 @@ let m = Meta.@lower 1 + 1 ir = @test_nowarn Core.Compiler.getfield_elim_pass!(ir, domtree) @test Core.Compiler.verify_ir(ir) === nothing end + +# Tests for cfg simplification +let src = code_typed(gcd, Tuple{Int, Int})[1].first + # Test that cfg_simplify doesn't mangle IR on code with loops + ir = Core.Compiler.inflate_ir(src) + Core.Compiler.verify_ir(ir) + ir = Core.Compiler.cfg_simplify!(ir) + Core.Compiler.verify_ir(ir) +end + +let m = Meta.@lower 1 + 1 + # Test that CFG simplify combines redundant basic blocks + @assert Meta.isexpr(m, :thunk) + src = m.args[1]::Core.CodeInfo + src.code = Any[ + Core.Compiler.GotoNode(2), + Core.Compiler.GotoNode(3), + Core.Compiler.GotoNode(4), + Core.Compiler.GotoNode(5), + Core.Compiler.GotoNode(6), + Core.Compiler.GotoNode(7), + Expr(:return, 2) + ] + nstmts = length(src.code) + src.ssavaluetypes = nstmts + src.codelocs = fill(Int32(1), nstmts) + src.ssaflags = fill(Int32(0), nstmts) + ir = Core.Compiler.inflate_ir(src) + Core.Compiler.verify_ir(ir) + ir = Core.Compiler.cfg_simplify!(ir) + Core.Compiler.verify_ir(ir) + ir = Core.Compiler.compact!(ir) + @test length(ir.cfg.blocks) == 1 && length(ir.stmts) == 1 +end + +let m = Meta.@lower 1 + 1 + # Test that CFG simplify doesn't mess up when chaining past return blocks + @assert Meta.isexpr(m, :thunk) + src = m.args[1]::Core.CodeInfo + src.code = Any[ + Core.Compiler.GotoIfNot(Core.Compiler.Argument(2), 3), + Core.Compiler.GotoNode(4), + Expr(:return, 1), + Core.Compiler.GotoNode(5), + Core.Compiler.GotoIfNot(Core.Compiler.Argument(2), 7), + # This fall through block of the previous GotoIfNot + # must be moved up along with it, when we merge it + # into the goto 4 block. + Expr(:return, 2), + Expr(:return, 3) + ] + nstmts = length(src.code) + src.ssavaluetypes = nstmts + src.codelocs = fill(Int32(1), nstmts) + src.ssaflags = fill(Int32(0), nstmts) + ir = Core.Compiler.inflate_ir(src) + Core.Compiler.verify_ir(ir) + ir = Core.Compiler.cfg_simplify!(ir) + Core.Compiler.verify_ir(ir) + @test length(ir.cfg.blocks) == 5 + ret_2 = ir.stmts[ir.cfg.blocks[3].stmts[end]] + @test isa(ret_2, Core.Compiler.ReturnNode) && ret_2.val == 2 +end