Skip to content

Commit

Permalink
refactoring on SROA passes (JuliaLang#55262)
Browse files Browse the repository at this point in the history
All changes are cosmetic and do not change the basic functionality:
- Added the interface type to the callbacks received by `simple_walker`
to clarify which objects are passed as callbacks to `simple_walker`.
- Replaced ambiguous names like `idx` with more descriptive ones like
`defidx` to make the algorithm easier to understand.
  • Loading branch information
aviatesk authored Jul 29, 2024
1 parent 4dfce5d commit 427824e
Showing 1 changed file with 88 additions and 69 deletions.
157 changes: 88 additions & 69 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
function find_curblock(domtree::DomTree, allblocks::BitSet, curblock::Int)
# TODO: This can be much faster by looking at current level and only
# searching for those blocks in a sorted order
while !(curblock in allblocks) && curblock !== 0
while curblock allblocks && curblock 0
curblock = domtree.idoms_bb[curblock]
end
return curblock
Expand Down Expand Up @@ -190,18 +190,21 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
return walk_to_defs(compact, val, typeconstraint, predecessors, 𝕃ₒ)
end

function trivial_walker(@nospecialize(pi), @nospecialize(idx))
return nothing
end
abstract type WalkerCallback end

function pi_walker(@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
return LiftedValue(pi.val)
struct TrivialWalker <: WalkerCallback end
(::TrivialWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue)) = nothing

struct PiWalker <: WalkerCallback end
function (::PiWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if isa(def, PiNode)
return LiftedValue(def.val)
end
return nothing
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), callback=trivial_walker)
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa::AnySSAValue),
walker_callback::WalkerCallback=TrivialWalker())
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand All @@ -218,15 +221,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
def = compact[defssa][:stmt]
if isa(def, AnySSAValue)
callback(def, defssa)
walker_callback(def, defssa)
if isa(def, SSAValue)
is_old(compact, defssa) && (def = OldSSAValue(def.id))
end
defssa = def
elseif isa(def, Union{PhiNode, PhiCNode, GlobalRef})
return defssa
else
new_def = callback(def, defssa)
new_def = walker_callback(def, defssa)
if new_def === nothing
return defssa
end
Expand All @@ -241,16 +244,21 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
return LiftedValue(pi.val)
end
return nothing
mutable struct TypeConstrainingWalker <: WalkerCallback
typeconstraint::Any
TypeConstrainingWalker(@nospecialize(typeconstraint::Any)) = new(typeconstraint)
end
function (walker_callback::TypeConstrainingWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if isa(def, PiNode)
walker_callback.typeconstraint =
typeintersect(walker_callback.typeconstraint, widenconst(def.typ))
return LiftedValue(def.val)
end
def = simple_walk(compact, defssa, callback)
return nothing
end
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(val::AnySSAValue),
@nospecialize(typeconstraint))
def = simple_walk(compact, val, TypeConstrainingWalker(typeconstraint))
return Pair{Any, Any}(def, typeconstraint)
end

Expand Down Expand Up @@ -638,15 +646,17 @@ end

struct SkipToken end; const SKIP_TOKEN = SkipToken()

function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=::AnySSAValue=#), @nospecialize(old_value),
lifted_philikes::Vector{LiftedPhilike}, lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int},
walker_callback)
function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa::AnySSAValue),
@nospecialize(old_value), lifted_philikes::Vector{LiftedPhilike},
lifted_leaves::Union{LiftedLeaves, LiftedDefs},
reverse_mapping::IdDict{AnySSAValue, Int},
walker_callback::WalkerCallback)
val = old_value
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
val = simple_walk(compact, val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
val = simple_walk(compact, val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
end
if val in keys(lifted_leaves)
lifted_val = lifted_leaves[val]
Expand All @@ -656,7 +666,7 @@ function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=:
lifted_val === nothing && return UNDEF_TOKEN
val = lifted_val.val
if isa(val, AnySSAValue)
val = simple_walk(compact, val, pi_walker)
val = simple_walk(compact, val, PiWalker())
end
return val
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
Expand All @@ -673,7 +683,7 @@ function is_old(compact, @nospecialize(old_node_ssa))
return true
end

struct PhiNest{C}
struct PhiNest{C<:WalkerCallback}
visited_philikes::Vector{AnySSAValue}
lifted_philikes::Vector{LiftedPhilike}
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
Expand Down Expand Up @@ -743,20 +753,29 @@ function finish_phi_nest!(compact::IncrementalCompact, nest::PhiNest)
end
end

function def_walker(lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int}, walker_callback)
function (@nospecialize(walk_def), @nospecialize(defssa))
if (defssa in keys(lifted_leaves)) || (isa(defssa, AnySSAValue) && defssa in keys(reverse_mapping))
return nothing
end
isa(walk_def, PiNode) && return LiftedValue(walk_def.val)
return walker_callback(walk_def, defssa)
struct LiftedLeaveWalker{C<:WalkerCallback} <: WalkerCallback
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
reverse_mapping::IdDict{AnySSAValue, Int}
inner_walker_callback::C
function LiftedLeaveWalker(@nospecialize(lifted_leaves::Union{LiftedLeaves, LiftedDefs}),
@nospecialize(reverse_mapping::IdDict{AnySSAValue, Int}),
inner_walker_callback::C) where C<:WalkerCallback
return new{C}(lifted_leaves, reverse_mapping, inner_walker_callback)
end
end
function (walker_callback::LiftedLeaveWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
(; lifted_leaves, reverse_mapping, inner_walker_callback) = walker_callback
if defssa in keys(lifted_leaves) || defssa in keys(reverse_mapping)
return nothing
end
isa(def, PiNode) && return LiftedValue(def.val)
return inner_walker_callback(def, defssa)
end

function perform_lifting!(compact::IncrementalCompact,
visited_philikes::Vector{AnySSAValue}, @nospecialize(cache_key),
@nospecialize(result_t), lifted_leaves::Union{LiftedLeaves, LiftedDefs}, @nospecialize(stmt_val),
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback = trivial_walker)
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback::WalkerCallback = TrivialWalker())
reverse_mapping = IdDict{AnySSAValue, Int}()
for id in 1:length(visited_philikes)
reverse_mapping[visited_philikes[id]] = id
Expand Down Expand Up @@ -839,7 +858,7 @@ function perform_lifting!(compact::IncrementalCompact,

# Fixup the stmt itself
if isa(stmt_val, Union{SSAValue, OldSSAValue})
stmt_val = simple_walk(compact, stmt_val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
stmt_val = simple_walk(compact, stmt_val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
end

if stmt_val in keys(lifted_leaves)
Expand Down Expand Up @@ -948,6 +967,17 @@ function keyvalue_predecessors(@nospecialize(key), 𝕃ₒ::AbstractLattice)
end
end

struct KeyValueWalker <: WalkerCallback
compact::IncrementalCompact
end
function (walker_callback::KeyValueWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, walker_callback.compact)
@assert length(def.args) in (5, 6)
return LiftedValue(def.args[end-2])
end
return nothing
end

function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
collection = stmt.args[end-1]
key = stmt.args[end]
Expand All @@ -964,16 +994,9 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
end

function keyvalue_walker(@nospecialize(def), _)
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
@assert length(def.args) in (5, 6)
return LiftedValue(def.args[end-2])
end
return nothing
end
(lifted_val, nest) = perform_lifting!(compact,
visited_philikes, key, result_t, lifted_leaves, collection, nothing,
keyvalue_walker)
KeyValueWalker(compact))

compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
finish_phi_nest!(compact, nest)
Expand Down Expand Up @@ -1139,13 +1162,11 @@ end
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

struct IntermediaryCollector
struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (this::IntermediaryCollector)(@nospecialize(pi), @nospecialize(ssa))
if !isa(pi, Expr)
push!(this.intermediaries, ssa.id)
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
return nothing
end

Expand Down Expand Up @@ -1242,7 +1263,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
update_scope_mapping!(scope_mapping, bb+1, bbs)
end
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false
is_setfield = is_isdefined = is_finalizer = false
field_ordering = :unspecified
if is_known_call(stmt, setfield!, compact)
4 <= length(stmt.args) <= 5 || continue
Expand Down Expand Up @@ -1371,8 +1392,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
if ismutabletypename(struct_typ_name)
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = IntermediaryCollector(intermediaries)
def = simple_walk(compact, val, callback)
def = simple_walk(compact, val, IntermediaryCollector(intermediaries))
# Mutable stuff here
isa(def, SSAValue) || continue
if defuses === nothing
Expand Down Expand Up @@ -1680,24 +1700,23 @@ end
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, lazydomtree::LazyDomtree, inlining::Union{Nothing, InliningState})
𝕃ₒ = inlining === nothing ? SimpleInferenceLattice.instance : optimizer_lattice(inlining.interp)
lazypostdomtree = LazyPostDomtree(ir)
for (idx, (intermediaries, defuse)) in defuses
for (defidx, (intermediaries, defuse)) in defuses
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
# show up in the nuses_total count.
nleaves = length(defuse.uses) + length(defuse.defs)
nuses = 0
for idx in intermediaries
nuses += used_ssas[idx]
for iidx in intermediaries
nuses += used_ssas[iidx]
end
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
nuses_total = used_ssas[defidx] + nuses - length(intermediaries)
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)][:stmt]
defexpr = ir[SSAValue(defidx)][:stmt]
isexpr(defexpr, :new) || continue
newidx = idx
typ = unwrap_unionall(ir.stmts[newidx][:type])
typ = unwrap_unionall(ir.stmts[defidx][:type])
# Could still end up here if we tried to setfield! on an immutable, which would
# error at runtime, but is not illegal to have in the IR.
typ = widenconst(typ)
Expand All @@ -1713,7 +1732,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
end
end
if finalizer_idx !== nothing && inlining !== nothing
try_resolve_finalizer!(ir, idx, finalizer_idx, defuse, inlining,
try_resolve_finalizer!(ir, defidx, finalizer_idx, defuse, inlining,
lazydomtree, lazypostdomtree, ir[SSAValue(finalizer_idx)][:info])
continue
end
Expand Down Expand Up @@ -1752,11 +1771,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# but we should come up with semantics for well defined semantics
# for uninitialized fields first.
ndefuse = length(fielddefuse)
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# BitSet}}(undef, ndefuse)
blocks = Vector{Tuple{#=phiblocks=#Vector{Int},#=allblocks=#BitSet}}(undef, ndefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, newidx)
push!(du.defs, defidx)
ldu = compute_live_ins(ir.cfg, du)
if isempty(ldu.live_in_bbs)
phiblocks = Int[]
Expand All @@ -1769,7 +1788,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
for i = 1:length(du.uses)
use = du.uses[i]
if use.kind === :isdefined
if has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx)
if has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx)
ir[SSAValue(use.idx)][:stmt] = true
else
all_eliminated = false
Expand All @@ -1782,7 +1801,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
continue
end
end
has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx) || @goto skip
has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx) || @goto skip
end
else # always have some definition at the allocation site
for i = 1:length(du.uses)
Expand Down Expand Up @@ -1849,19 +1868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# all "usages" (i.e. `getfield` and `isdefined` calls) are eliminated,
# now eliminate "definitions" (i.e. `setfield!`) calls
# (NOTE the allocation itself will be eliminated by DCE pass later)
for idx in du.defs
idx == newidx && continue # this is allocation
for didx in du.defs
didx == defidx && continue # this is allocation
# verify this statement won't throw, otherwise it can't be eliminated safely
ssa = SSAValue(idx)
if is_nothrow(ir, ssa)
ir[ssa][:stmt] = nothing
setfield_ssa = SSAValue(didx)
if is_nothrow(ir, setfield_ssa)
ir[setfield_ssa][:stmt] = nothing
else
# We can't eliminate this statement, because it might still
# throw an error, but we can mark it as effect-free since we
# know we have removed all uses of the mutable allocation.
# As a result, if we ever do prove nothrow, we can delete
# this statement then.
add_flag!(ir[ssa], IR_FLAG_EFFECT_FREE)
add_flag!(ir[setfield_ssa], IR_FLAG_EFFECT_FREE)
end
end
end
Expand All @@ -1870,7 +1889,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# this means all ccall preserves have been replaced with forwarded loads
# so we can potentially eliminate the allocation, otherwise we must preserve
# the whole allocation.
push!(intermediaries, newidx)
push!(intermediaries, defidx)
end
# Insert the new preserves
for (useidx, new_preserves) in preserve_uses
Expand Down

0 comments on commit 427824e

Please sign in to comment.