Skip to content

Commit

Permalink
show: reduce code duplication in printing IR code
Browse files Browse the repository at this point in the history
fix JuliaLang#27644
fix JuliaLang#27965
(test coverage provided now by using existing methods for showing IR)
  • Loading branch information
vtjnash committed Jul 21, 2018
1 parent 1d3c9c8 commit 6cbb52b
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 265 deletions.
139 changes: 81 additions & 58 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,105 +37,125 @@ StmtRange(range::UnitRange{Int}) = StmtRange(first(range), last(range))

struct BasicBlock
stmts::StmtRange
#error_handler::Bool
preds::Vector{Int}
succs::Vector{Int}
end
function BasicBlock(stmts::StmtRange)
BasicBlock(stmts, Int[], Int[])
return BasicBlock(stmts, Int[], Int[])
end
function BasicBlock(old_bb, stmts)
BasicBlock(stmts, #= old_bb.error_handler, =# old_bb.preds, old_bb.succs)
return BasicBlock(stmts, old_bb.preds, old_bb.succs)
end
copy(bb::BasicBlock) = BasicBlock(bb.stmts, #= bb.error_handler, =# copy(bb.preds), copy(bb.succs))
copy(bb::BasicBlock) = BasicBlock(bb.stmts, copy(bb.preds), copy(bb.succs))

struct CFG
blocks::Vector{BasicBlock}
index::Vector{Int}
index::Vector{Int} # map from instruction => basic-block number
# TODO: make this O(1) instead of O(log(n_blocks))?
end
copy(c::CFG) = CFG(copy(c.blocks), copy(c.index))

function block_for_inst(index, inst)
searchsortedfirst(index, inst, lt=(<=))
return searchsortedfirst(index, inst, lt=(<=))
end
block_for_inst(cfg::CFG, inst) = block_for_inst(cfg.index, inst)

function compute_basic_blocks(stmts::Vector{Any})
jump_dests = BitSet(1)
function basic_blocks_starts(stmts::Vector{Any})
jump_dests = BitSet()
push!(jump_dests, 1) # function entry point
# First go through and compute jump destinations
for (idx, stmt) in pairs(stmts)
for idx in 1:length(stmts)
stmt = stmts[idx]
# Terminators
if isa(stmt, GotoIfNot) || isa(stmt, GotoNode) || isa(stmt, ReturnNode)
if isa(stmt, GotoIfNot)
if isa(stmt, GotoIfNot)
push!(jump_dests, idx+1)
push!(jump_dests, stmt.dest)
elseif isa(stmt, ReturnNode)
idx < length(stmts) && push!(jump_dests, idx+1)
elseif isa(stmt, GotoNode)
# This is a fake dest to force the next stmt to start a bb
idx < length(stmts) && push!(jump_dests, idx+1)
push!(jump_dests, stmt.label)
elseif isa(stmt, Expr)
if stmt.head === :leave
# :leave terminates a BB
push!(jump_dests, idx+1)
push!(jump_dests, stmt.dest)
else
elseif stmt.head == :enter
# :enter starts/ends a BB
push!(jump_dests, idx)
push!(jump_dests, idx+1)
# The catch block is a jump dest
push!(jump_dests, stmt.args[1])
elseif stmt.head === :gotoifnot
# also tolerate expr form of IR
push!(jump_dests, idx+1)
push!(jump_dests, stmt.args[2])
elseif stmt.head === :return
# also tolerate expr form of IR
# This is a fake dest to force the next stmt to start a bb
idx < length(stmts) && push!(jump_dests, idx+1)
if isa(stmt, GotoNode)
push!(jump_dests, stmt.label)
end
end
elseif isa(stmt, Expr) && stmt.head === :leave
# :leave terminates a BB
push!(jump_dests, idx+1)
elseif isa(stmt, Expr) && stmt.head == :enter
# :enter starts/ends a BB
push!(jump_dests, idx)
push!(jump_dests, idx+1)
# The catch block is a jump dest
push!(jump_dests, stmt.args[1])
end
end
bb_starts = collect(jump_dests)
# and add add one more basic block start after the last statement
for i = length(stmts):-1:1
if stmts[i] != nothing
push!(bb_starts, i+1)
push!(jump_dests, i+1)
break
end
end
return jump_dests
end

function compute_basic_blocks(stmts::Vector{Any})
bb_starts = basic_blocks_starts(stmts)
# Compute ranges
basic_block_index = Int[]
pop!(bb_starts, 1)
basic_block_index = collect(bb_starts)
blocks = BasicBlock[]
sizehint!(blocks, length(bb_starts)-1)
for (first, last) in Iterators.zip(bb_starts, Iterators.drop(bb_starts, 1))
push!(basic_block_index, first)
push!(blocks, BasicBlock(StmtRange(first, last-1)))
sizehint!(blocks, length(basic_block_index))
let first = 1
for last in basic_block_index
push!(blocks, BasicBlock(StmtRange(first, last - 1)))
first = last
end
end
popfirst!(basic_block_index)
# Compute successors/predecessors
for (num, b) in pairs(blocks)
for (num, b) in enumerate(blocks)
terminator = stmts[last(b.stmts)]
if isa(terminator, ReturnNode)
continue
end
if isa(terminator, GotoNode)
block′ = block_for_inst(basic_block_index, terminator.label)
push!(blocks[block′].preds, num)
push!(b.succs, block′)
continue
end
# Conditional Branch
if isa(terminator, GotoIfNot)
block′ = block_for_inst(basic_block_index, terminator.dest)
push!(blocks[block′].preds, num)
push!(b.succs, block′)
end
if isa(terminator, GotoNode)
block′ = block_for_inst(basic_block_index, terminator.label)
elseif isa(terminator, Expr) && terminator.head == :enter
# :enter gets a virtual edge to the exception handler and
# the exception handler gets a virtual edge from outside
# the function.
# See the devdocs on exception handling in SSA form (or
# bug Keno to write them, if you're reading this and they
# don't exist)
block′ = block_for_inst(basic_block_index, terminator.args[1])
push!(blocks[block′].preds, num)
push!(blocks[block′].preds, 0)
push!(b.succs, block′)
elseif !isa(terminator, ReturnNode)
if isa(terminator, Expr) && terminator.head == :enter
# :enter gets a virtual edge to the exception handler and
# the exception handler gets a virtual edge from outside
# the function.
# See the devdocs on exception handling in SSA form (or
# bug Keno to write them, if you're reading this and they
# don't exist)
block′ = block_for_inst(basic_block_index, terminator.args[1])
push!(blocks[block′].preds, num)
push!(blocks[block′].preds, 0)
push!(b.succs, block′)
end
if num + 1 <= length(blocks)
push!(blocks[num+1].preds, num)
push!(b.succs, num+1)
end
end
# statement fall-through
if num + 1 <= length(blocks)
push!(blocks[num + 1].preds, num)
push!(b.succs, num + 1)
end
end
CFG(blocks, basic_block_index)
return CFG(blocks, basic_block_index)
end

function first_insert_for_bb(code, cfg::CFG, block::Int)
Expand All @@ -157,9 +177,11 @@ struct NewNode
# The node itself
node::Any
# The index into the line number table of this entry
line::Int
line::Int32

NewNode(pos::Int, attach_after::Bool, @nospecialize(typ), @nospecialize(node), line::Int32) =
new(pos, attach_after, typ, node, line)
end
copy(n::NewNode) = copy(n.pos, n.attach_after, n.typ, copy(n.node), n.line)

struct IRCode
stmts::Vector{Any}
Expand Down Expand Up @@ -283,7 +305,8 @@ function is_relevant_expr(e::Expr)
:gc_preserve_begin, :gc_preserve_end,
:foreigncall, :isdefined, :copyast,
:undefcheck, :throw_undef_if_not,
:cfunction, :method)
:cfunction, :method,
#=legacy IR format support=# :gotoifnot, :return)
end

function setindex!(x::UseRef, @nospecialize(v))
Expand Down
Loading

0 comments on commit 6cbb52b

Please sign in to comment.