Skip to content

Commit

Permalink
Merge pull request JuliaLang#26306 from JuliaLang/kf/domsort
Browse files Browse the repository at this point in the history
[NewOptimizer] Domsort basic blocks
  • Loading branch information
Keno committed Mar 5, 2018
2 parents 15f84b3 + 2dd211c commit 3e94518
Show file tree
Hide file tree
Showing 17 changed files with 310 additions and 32 deletions.
8 changes: 6 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4194,8 +4194,7 @@ function copy_duplicated_expr_pass!(sv::OptimizationState)
end

# fix label numbers to always equal the statement index of the label
function reindex_labels!(sv::OptimizationState)
body = sv.src.code
function reindex_labels!(body::Vector{Any})
mapping = get_label_map(body)
for i = 1:length(body)
el = body[i]
Expand Down Expand Up @@ -4235,6 +4234,11 @@ function reindex_labels!(sv::OptimizationState)
end
end


function reindex_labels!(sv::OptimizationState)
reindex_labels!(sv.src.code)
end

function return_type(@nospecialize(f), @nospecialize(t))
params = Params(ccall(:jl_get_tls_world_age, UInt, ()))
rt = Union{}
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/domtree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ function construct_domtree(cfg)
# Recursively set level
update_level!(domtree, 1, 1)
DomTree(idoms, domtree)
end
end
22 changes: 21 additions & 1 deletion base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ function normalize(@nospecialize(stmt), meta::Vector{Any}, table::Vector{LineInf
elseif stmt.head === :gotoifnot
return GotoIfNot(stmt.args...)
elseif stmt.head === :return
return ReturnNode{Any}(stmt.args...)
return ReturnNode((length(stmt.args) == 0 ? (nothing,) : stmt.args)...)
elseif stmt.head === :unreachable
return ReturnNode()
end
elseif isa(stmt, LabelNode)
return nothing
Expand All @@ -89,6 +91,23 @@ end
function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
mod = linetable[1].mod
ci.code = copy(ci.code)
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels.
idx = 1
while idx <= length(ci.code)
stmt = ci.code[idx]
if isexpr(stmt, :(=))
stmt = stmt.args[2]
end
if isa(stmt, Expr) && stmt.typ === Union{}
if !(idx < length(ci.code) && isexpr(ci.code[idx+1], :unreachable))
insert!(ci.code, idx + 1, ReturnNode())
idx += 1
end
end
idx += 1
end
reindex_labels!(ci.code)
meta = Any[]
lines = fill(0, length(ci.code))
let loc = RefValue(1)
Expand All @@ -110,6 +129,7 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
IRCode(code, lines, cfg, argtypes, mod, meta)
end
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs)
domtree = construct_domtree(ir.cfg)
ir = compact!(ir)
verify_ir(ir)
ir = type_lift_pass!(ir)
Expand Down
29 changes: 19 additions & 10 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ struct GotoIfNot
GotoIfNot(@nospecialize(cond), dest::Int) = new(cond, dest)
end

struct ReturnNode{T}
val::T
ReturnNode{T}(@nospecialize(val)) where {T} = new{T}(val::T)
ReturnNode{T}() where {T} = new{T}()
struct ReturnNode
val
ReturnNode(@nospecialize(val)) = new(val)
# unassigned val indicates unreachable
ReturnNode() = new()
end

"""
Expand All @@ -31,6 +32,8 @@ start(r::StmtRange) = 0
done(r::StmtRange, state) = r.last - r.first < state
next(r::StmtRange, state) = (r.first + state, state + 1)

StmtRange(range::UnitRange{Int}) = StmtRange(first(range), last(range))

struct BasicBlock
stmts::StmtRange
preds::Vector{Int}
Expand Down Expand Up @@ -264,7 +267,7 @@ function done(it::UseRefIterator, use)
false
end

function scan_ssa_use!(used::IdSet{Int64}, @nospecialize(stmt))
function scan_ssa_use!(used, @nospecialize(stmt))
if isa(stmt, SSAValue)
push!(used, stmt.id)
end
Expand Down Expand Up @@ -340,9 +343,9 @@ mutable struct IncrementalCompact
end

struct TypesView
compact::IncrementalCompact
ir::Union{IRCode, IncrementalCompact}
end
types(compact::IncrementalCompact) = TypesView(compact)
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)

function getindex(compact::IncrementalCompact, idx)
if idx < compact.result_idx
Expand All @@ -368,10 +371,16 @@ function setindex!(compact::IncrementalCompact, v, idx)
end

function getindex(view::TypesView, idx)
if idx < view.compact.result_idx
isa(idx, SSAValue) && (idx = idx.id)
if isa(view.ir, IncrementalCompact) && idx < view.compact.result_idx
return view.compact.result_types[idx]
else
return view.compact.ir.types[idx]
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
if idx <= length(ir.types)
return ir.types[idx]
else
return ir.new_nodes[idx - length(ir.types)][2]
end
end
end

Expand Down Expand Up @@ -457,7 +466,7 @@ function next(compact::IncrementalCompact, (idx, active_bb, old_result_idx)::Tup
compact.result_types[old_result_idx] = typ
compact.result_lines[old_result_idx] = new_line
result_idx = process_node!(compact, old_result_idx, new_node, new_idx, idx)
(old_result_idx == result_idx) && return next(compact, (idx, result_idx))
(old_result_idx == result_idx) && return next(compact, (idx, active_bb, result_idx))
compact.result_idx = result_idx
return (old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb, compact.result_idx)
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function replace_code!(ci::CodeInfo, code::IRCode, nargs::Int, linetable::Vector
new_stmt = Expr(:return, rename(stmt.val))
else
# Unreachable, so no issue with this
new_stmt = nothing
new_stmt = Expr(:unreachable)
end
elseif isa(stmt, SSAValue)
new_stmt = rename(stmt)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function type_lift_pass!(ir::IRCode)
def = ir.stmts[item]
edges = copy(def.edges)
values = Vector{Any}(uninitialized, length(edges))
new_phi = insert_node!(ir, item, Bool, PhiNode(edges, values))
new_phi = length(values) == 0 ? false : insert_node!(ir, item, Bool, PhiNode(edges, values))
processed[item] = new_phi
if first
lifted_undef[stmt_id] = new_phi
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ function Base.show(io::IO, code::IRCode)
bbrange = cfg.blocks[bb_idx].stmts
bbrange = bbrange.first:bbrange.last
bb_pad = max_bb_idx_size - length(string(bb_idx))
bb_start_str = string("$(bb_idx) ",length(cfg.blocks[bb_idx].preds) <= 1 ? "" : "", ""^(bb_pad)," ")
if idx != last(bbrange)
if idx == first(bbrange)
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
print(io, bb_start_str)
else
print(io, ""," "^max_bb_idx_size)
end
Expand All @@ -98,7 +99,7 @@ function Base.show(io::IO, code::IRCode)
node_idx += length(code.stmts)
if print_sep
if floop
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
print(io, bb_start_str)
else
print(io, ""," "^max_bb_idx_size)
end
Expand All @@ -117,7 +118,7 @@ function Base.show(io::IO, code::IRCode)
end
if print_sep
if idx == first(bbrange) && floop
print(io, "$(bb_idx) ",""^(1+bb_pad)," ")
print(io, bb_start_str)
else
print(io, idx == last(bbrange) ? string("", ""^(1+max_bb_idx_size), " ") :
string("", " "^max_bb_idx_size))
Expand Down
Loading

0 comments on commit 3e94518

Please sign in to comment.