Skip to content

Commit

Permalink
Merge pull request #31533 from JuliaLang/vc/cfg_simplify
Browse files Browse the repository at this point in the history
Backport CFG simplification pass from XLA backend
  • Loading branch information
vchuravy committed Apr 2, 2019
2 parents 11844a2 + f62a4b3 commit c5be814
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 32 deletions.
76 changes: 45 additions & 31 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ mutable struct IncrementalCompact
result_flags::Vector{UInt8}
result_bbs::Vector{BasicBlock}
ssa_rename::Vector{Any}
bb_rename::Vector{Int}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}
used_ssas::Vector{Int}
late_fixup::Vector{Int}
# This could be Stateful, but bootstrapping doesn't like that
Expand All @@ -481,7 +482,8 @@ mutable struct IncrementalCompact
result_idx::Int
active_result_bb::Int
renamed_new_nodes::Bool
allow_cfg_transforms::Bool
cfg_transforms_enabled::Bool
fold_constant_branches::Bool
function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false)
# Sort by position with attach after nodes affter regular ones
perm = my_sortperm(Int[(code.new_nodes[i].pos*2 + Int(code.new_nodes[i].attach_after)) for i in 1:length(code.new_nodes)])
Expand Down Expand Up @@ -525,9 +527,9 @@ mutable struct IncrementalCompact
new_new_nodes = NewNode[]
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, used_ssas, late_fixup, perm, 1,
return new(code, result, result_types, result_lines, result_flags, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, 1, 1, false, allow_cfg_transforms)
1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms)
end

# For inlining
Expand All @@ -542,10 +544,10 @@ mutable struct IncrementalCompact
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags,
parent.result_bbs, ssa_rename, bb_rename, parent.used_ssas,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false)
1, result_offset, parent.active_result_bb, false, false, false)
end
end

Expand Down Expand Up @@ -646,6 +648,18 @@ function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @
end
end

function append_node!(ir, @nospecialize(typ), @nospecialize(node), line)
push!(ir.stmts, node)
push!(ir.types, typ)
push!(ir.lines, line)
push!(ir.flags, 0)
last_bb = ir.cfg.blocks[end]
ir.cfg.blocks[end] = BasicBlock(first(last_bb.stmts):length(ir.stmts),
last_bb.preds,
last_bb.succs)
return SSAValue(length(ir.stmts))
end

function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int32, reverse_affinity::Bool=false)
if compact.result_idx > length(compact.result)
@assert compact.result_idx == length(compact.result) + 1
Expand Down Expand Up @@ -823,17 +837,17 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
# Note: We recursively kill as many edges as are obviously dead. However, this
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
# worstcase during codegen).
preds, succs = compact.result_bbs[compact.bb_rename[to]].preds, compact.result_bbs[compact.bb_rename[from]].succs
deleteat!(preds, findfirst(x->x === compact.bb_rename[from], preds)::Int)
deleteat!(succs, findfirst(x->x === compact.bb_rename[to], succs)::Int)
preds, succs = compact.result_bbs[compact.bb_rename_succ[to]].preds, compact.result_bbs[compact.bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x->x === compact.bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x->x === compact.bb_rename_succ[to], succs)::Int)
# Check if the block is now dead
if length(preds) == 0
for succ in copy(compact.result_bbs[compact.bb_rename[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename))
for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred))
end
if to < active_bb
# Kill all statements in the block
stmts = compact.result_bbs[compact.bb_rename[to]].stmts
stmts = compact.result_bbs[compact.bb_rename_succ[to]].stmts
for stmt in stmts
compact.result[stmt] = nothing
end
Expand All @@ -842,12 +856,12 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
else
# We need to remove this edge from any phi nodes
if to < active_bb
idx = first(compact.result_bbs[compact.bb_rename[to]].stmts)
idx = first(compact.result_bbs[compact.bb_rename_succ[to]].stmts)
while idx < length(compact.result)
stmt = compact.result[idx]
stmt === nothing && continue
isa(stmt, PhiNode) || break
i = findfirst(x-> x === compact.bb_rename[from], stmt.edges)
i = findfirst(x-> x === compact.bb_rename_pred[from], stmt.edges)
if i !== nothing
deleteat!(stmt.edges, i)
deleteat!(stmt.values, i)
Expand Down Expand Up @@ -879,34 +893,34 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any},
ssa_rename[idx] = stmt
elseif isa(stmt, OldSSAValue)
ssa_rename[idx] = ssa_rename[stmt.id]
elseif isa(stmt, GotoNode) && compact.allow_cfg_transforms
result[result_idx] = GotoNode(compact.bb_rename[stmt.label])
elseif isa(stmt, GotoNode) && compact.cfg_transforms_enabled
result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.label])
result_idx += 1
elseif isa(stmt, GlobalRef) || isa(stmt, GotoNode)
result[result_idx] = stmt
result_idx += 1
elseif isa(stmt, GotoIfNot) && compact.allow_cfg_transforms
elseif isa(stmt, GotoIfNot) && compact.cfg_transforms_enabled
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::GotoIfNot
result[result_idx] = stmt
cond = stmt.cond
if isa(cond, Bool)
if isa(cond, Bool) && compact.fold_constant_branches
if cond
result[result_idx] = nothing
kill_edge!(compact, active_bb, active_bb, stmt.dest)
# Don't increment result_idx => Drop this statement
else
result[result_idx] = GotoNode(compact.bb_rename[stmt.dest])
result[result_idx] = GotoNode(compact.bb_rename_succ[stmt.dest])
kill_edge!(compact, active_bb, active_bb, active_bb+1)
result_idx += 1
end
else
result[result_idx] = GotoIfNot(cond, compact.bb_rename[stmt.dest])
result[result_idx] = GotoIfNot(cond, compact.bb_rename_succ[stmt.dest])
result_idx += 1
end
elseif isa(stmt, Expr)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr
if compact.allow_cfg_transforms && isexpr(stmt, :enter)
stmt.args[1] = compact.bb_rename[stmt.args[1]::Int]
if compact.cfg_transforms_enabled && isexpr(stmt, :enter)
stmt.args[1] = compact.bb_rename_succ[stmt.args[1]::Int]
end
result[result_idx] = stmt
result_idx += 1
Expand Down Expand Up @@ -936,13 +950,13 @@ function process_node!(compact::IncrementalCompact, result::Vector{Any},
elseif isa(stmt, PhiNode)
values = process_phinode_values(stmt.values, late_fixup, processed_idx, result_idx, ssa_rename, used_ssas, do_rename_ssa)
if length(stmt.edges) == 1 && isassigned(values, 1) &&
length(compact.allow_cfg_transforms ?
compact.result_bbs[compact.bb_rename[active_bb]].preds :
length(compact.cfg_transforms_enabled ?
compact.result_bbs[compact.bb_rename_succ[active_bb]].preds :
compact.ir.cfg.blocks[active_bb].preds) == 1
# There's only one predecessor left - just replace it
ssa_rename[idx] = values[1]
else
edges = compact.allow_cfg_transforms ? map!(i->compact.bb_rename[i], stmt.edges, stmt.edges) : stmt.edges
edges = compact.cfg_transforms_enabled ? map!(i->compact.bb_rename_pred[i], stmt.edges, stmt.edges) : stmt.edges
result[result_idx] = PhiNode(edges, values)
result_idx += 1
end
Expand Down Expand Up @@ -983,14 +997,14 @@ end

function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_idx, unreachable=false)
if compact.active_result_bb > length(compact.result_bbs)
@assert compact.bb_rename[active_bb] == 0
#@assert compact.bb_rename[active_bb] == 0
return true
end
bb = compact.result_bbs[compact.active_result_bb]
# If this was the last statement in the BB and we decided to skip it, insert a
# dummy `nothing` node, to prevent changing the structure of the CFG
skipped = false
if !compact.allow_cfg_transforms || active_bb == 0 || active_bb > length(compact.bb_rename) || compact.bb_rename[active_bb] != 0
if !compact.cfg_transforms_enabled || active_bb == 0 || active_bb > length(compact.bb_rename_succ) || compact.bb_rename_succ[active_bb] != 0
if compact.result_idx == first(bb.stmts)
length(compact.result) < old_result_idx && resize!(compact, old_result_idx)
if unreachable
Expand All @@ -1003,7 +1017,7 @@ function finish_current_bb!(compact, active_bb, old_result_idx=compact.result_id
compact.result_lines[old_result_idx] = 0
compact.result_flags[old_result_idx] = 0x00
compact.result_idx = old_result_idx + 1
elseif compact.allow_cfg_transforms && compact.result_idx - 1 == first(bb.stmts)
elseif compact.cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts)
# Optimization: If this BB consists of only a branch, eliminate this bb
end
compact.result_bbs[compact.active_result_bb] = BasicBlock(bb, StmtRange(first(bb.stmts), compact.result_idx-1))
Expand Down Expand Up @@ -1083,7 +1097,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
resize!(compact, old_result_idx)
end
bb = compact.ir.cfg.blocks[active_bb]
if compact.allow_cfg_transforms && active_bb > 1 && active_bb <= length(compact.bb_rename) && length(bb.preds) == 0
if compact.cfg_transforms_enabled && active_bb > 1 && active_bb <= length(compact.bb_rename_succ) && length(bb.preds) == 0
# No predecessors, kill the entire block.
compact.idx = last(bb.stmts)
# Pop any remaining insertion nodes
Expand Down Expand Up @@ -1274,8 +1288,8 @@ function complete(compact::IncrementalCompact)
return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes)
end

function compact!(code::IRCode)
compact = IncrementalCompact(code)
function compact!(code::IRCode, allow_cfg_transforms=false)
compact = IncrementalCompact(code, allow_cfg_transforms)
# Just run through the iterator without any processing
foreach(x -> nothing, compact) # x isa Pair{Int, Any}
return finish(compact)
Expand Down
128 changes: 128 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,131 @@ function type_lift_pass!(ir::IRCode)
end
ir
end

function cfg_simplify!(ir::IRCode)
bbs = ir.cfg.blocks
merge_into = zeros(Int, length(bbs))
merged_succ = zeros(Int, length(bbs))

# Walk the CFG at from the entry block and aggressively combine blocks
for (idx, bb) in enumerate(bbs)
if length(bb.succs) == 1
succ = bb.succs[1]
if length(bbs[succ].preds) == 1
merge_into[succ] = idx
merged_succ[idx] = succ
end
end
end
max_bb_num = 1
bb_rename_succ = zeros(Int, length(bbs))
# Lay out the basic blocks
for i = 1:length(bbs)
if merge_into[i] != 0
bb_rename_succ[i] = -1
continue
end
# Drop unreachable blocks
if i != 1 && length(ir.cfg.blocks[i].preds) == 0
bb_rename_succ[i] = -1
end
bb_rename_succ[i] != 0 && continue
curr = i
while true
bb_rename_succ[curr] = max_bb_num
max_bb_num += 1
# Now walk the chain of blocks we merged.
# If we end in something that may fall through,
# we have to schedule that block next
while merged_succ[curr] != 0
curr = merged_succ[curr]
end
terminator = ir.stmts[ir.cfg.blocks[curr].stmts[end]]
if isa(terminator, GotoNode) || isa(terminator, ReturnNode)
break
end
curr += 1
end
end
bb_rename_pred = zeros(Int, length(bbs))
for i = 1:length(bbs)
if merged_succ[i] != 0
bb_rename_pred[i] = -1
continue
end
bbnum = i
while merge_into[bbnum] != 0
bbnum = merge_into[bbnum]
end
bb_rename_pred[i] = bb_rename_succ[bbnum]
end
result_bbs = Int[findfirst(j->i==j, bb_rename_succ) for i = 1:max_bb_num-1]
result_bbs_lengths = zeros(Int, max_bb_num-1)
for (idx, orig_bb) in enumerate(result_bbs)
ms = orig_bb
while ms != 0
result_bbs_lengths[idx] += length(bbs[ms].stmts)
ms = merged_succ[ms]
end
end
bb_starts = Vector{Int}(undef, 1+length(result_bbs_lengths))
bb_starts[1] = 1
for i = 1:length(result_bbs_lengths)
bb_starts[i+1] = bb_starts[i] + result_bbs_lengths[i]
end
# Look at the original successor
function compute_succs(i)
orig_bb = result_bbs[i]
while merged_succ[orig_bb] != 0
orig_bb = merged_succ[orig_bb]
end
map(i->bb_rename_succ[i], bbs[orig_bb].succs)
end

function compute_preds(i)
orig_bb = result_bbs[i]
preds = bbs[orig_bb].preds
map(preds) do pred
while merge_into[pred] != 0
pred = merge_into[pred]
end
bb_rename_succ[pred]
end
end
cresult_bbs = BasicBlock[BasicBlock(
StmtRange(bb_starts[i], i+1 > length(bb_starts) ? length(compact.result) : bb_starts[i+1]-1),
compute_preds(i), compute_succs(i)) for i = 1:length(result_bbs)]
compact = IncrementalCompact(ir, true)
# We're messing with the CFG. We don't want compaction to do
# so independently
compact.fold_constant_branches = false
compact.bb_rename_succ = bb_rename_succ
compact.bb_rename_pred = bb_rename_pred
compact.result_bbs = cresult_bbs
result_idx = 1
for (idx, orig_bb) in enumerate(result_bbs)
ms = orig_bb
while ms != 0
for i in bbs[ms].stmts
stmt = ir.stmts[i]
compact.result[compact.result_idx] = nothing
compact.result_types[compact.result_idx] = ir.types[i]
compact.result_lines[compact.result_idx] = ir.lines[i]
compact.result_flags[compact.result_idx] = ir.flags[i]
# If we merged a basic block, we need remove the trailing GotoNode (if any)
if isa(stmt, GotoNode) && merged_succ[ms] != 0
# Do nothing
else
process_node!(compact, compact.result_idx, stmt, i, i, ms, true)
end
# We always increase the result index to ensure a predicatable
# placement of the resulting nodes.
compact.result_idx += 1
end
ms = merged_succ[ms]
end
end

compact.active_result_bb = length(bb_starts)
return finish(compact)
end
2 changes: 1 addition & 1 deletion base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ function show_ir(io::IO, code::IRCode, expr_type_printer=default_expr_type_print
# Compute BB guard rail
if bb_idx > length(cfg.blocks)
# Even if invariants are violated, try our best to still print
bbrange = (last(cfg.blocks[end].stmts) + 1):typemax(Int)
bbrange = (length(cfg.blocks) == 0 ? 1 : last(cfg.blocks[end].stmts) + 1):typemax(Int)
bb_idx_str = "!"
bb_type = ""
else
Expand Down
1 change: 1 addition & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,7 @@ module IRShow
using Core.IR
import ..Base
import .Compiler: IRCode, ReturnNode, GotoIfNot, CFG, scan_ssa_use!, Argument, isexpr, compute_basic_blocks, block_for_inst
Base.getindex(r::Compiler.StmtRange, ind::Integer) = Compiler.getindex(r, ind)
Base.size(r::Compiler.StmtRange) = Compiler.size(r)
Base.first(r::Compiler.StmtRange) = Compiler.first(r)
Base.last(r::Compiler.StmtRange) = Compiler.last(r)
Expand Down
Loading

0 comments on commit c5be814

Please sign in to comment.