From b449ea5e2ce5b90171849ec546453aecb721f3a0 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Fri, 25 Mar 2022 11:26:41 +0900 Subject: [PATCH] optimizer: switch to more general data structure of `SSADefUse` (#44730) Now `(du::SSADefUse).uses` field can contain arbitrary "usages" to be eliminated. This structure might be helpful for implementing array SROA, for example. Also slightly tweaks the implementation of `ccall` preserve elimination. --- base/compiler/ssair/passes.jl | 110 ++++++++++++++++++++++------------ test/compiler/irpasses.jl | 30 ++++++++++ 2 files changed, 102 insertions(+), 38 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c6760d01a61e8..bf2e6d3aae17f 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,30 +6,43 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end +struct SSAUse + kind::Symbol + idx::Int +end +GetfieldUse(idx::Int) = SSAUse(:getfield, idx) +PreserveUse(idx::Int) = SSAUse(:preserve, idx) +NoPreserve() = SSAUse(:nopreserve, 0) +IsdefinedUse(idx::Int) = SSAUse(:isdefined, idx) + """ du::SSADefUse This struct keeps track of all uses of some mutable struct allocated in the current function: -- `du.uses::Vector{Int}` are all instances of `getfield` / `isdefined` on the struct +- `du.uses::Vector{SSAUse}` are some "usages" (like `getfield`) of the struct - `du.defs::Vector{Int}` are all instances of `setfield!` on the struct The terminology refers to the uses/defs of the "slot bundle" that the mutable struct represents. -In addition we keep track of all instances of a `:foreigncall` that preserves of this mutable -struct in `du.ccall_preserve_uses`. Somewhat counterintuitively, we don't actually need to -make sure that the struct itself is live (or even allocated) at a `ccall` site. -If there are no other places where the struct escapes (and thus e.g. where its address is taken), -it need not be allocated. We do however, need to make sure to preserve any elements of this struct. +`du.uses` tracks all instances of `getfield` and `isdefined` calls on the struct. +Additionally it also tracks all instances of a `:foreigncall` that preserves of this mutable +struct. Somewhat counterintuitively, we don't actually need to make sure that the struct +itself is live (or even allocated) at a `ccall` site. If there are no other places where +the struct escapes (and thus e.g. where its address is taken), it need not be allocated. +We do however, need to make sure to preserve any elements of this struct. """ struct SSADefUse - uses::Vector{Int} + uses::Vector{SSAUse} defs::Vector{Int} - ccall_preserve_uses::Vector{Int} end -SSADefUse() = SSADefUse(Int[], Int[], Int[]) +SSADefUse() = SSADefUse(SSAUse[], Int[]) function compute_live_ins(cfg::CFG, du::SSADefUse) - # filter out `isdefined` usages - return compute_live_ins(cfg, du.defs, filter(>(0), du.uses)) + uses = Int[] + for use in du.uses + use.kind === :isdefined && continue # filter out `isdefined` usages + push!(uses, use.idx) + end + compute_live_ins(cfg, du.defs, uses) end # assume `stmt == getfield(obj, field, ...)` or `stmt == setfield!(obj, field, val, ...)` @@ -89,7 +102,8 @@ function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) end -function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) +function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) def, useblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use) if def == 0 if !haskey(phinodes, curblock) @@ -787,8 +801,8 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) - push!(defuse.ccall_preserve_uses, idx) + mid, defuse = get!(()->(SPCSet(),SSADefUse()), defuses, defidx) + push!(defuse.uses, PreserveUse(idx)) union!(mid, intermediaries) end continue @@ -846,13 +860,13 @@ function sroa_pass!(ir::IRCode) if defuses === nothing defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) + mid, defuse = get!(()->(SPCSet(),SSADefUse()), defuses, def.id) if is_setfield push!(defuse.defs, idx) elseif is_isdefined - push!(defuse.uses, -idx) + push!(defuse.uses, IsdefinedUse(idx)) else - push!(defuse.uses, idx) + push!(defuse.uses, GetfieldUse(idx)) end union!(mid, intermediaries) end @@ -923,7 +937,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # escapes and we cannot eliminate the allocation. This works, because we're guaranteed # not to include any intermediaries that have dead uses. As a result, missing uses will only ever # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + nleaves = length(defuse.uses) + length(defuse.defs) nuses = 0 for idx in intermediaries nuses += used_ssas[idx] @@ -946,7 +960,13 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] all_eliminated = all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(abs(use))][:inst] # == `getfield`/`isdefined` call + if use.kind === :preserve + for du in fielddefuse + push!(du.uses, use) + end + continue + end + stmt = ir[SSAValue(use.idx)][:inst] # == `getfield`/`isdefined` call # We may have discovered above that this use is dead # after the getfield elim of immutables. In that case, # it would have been deleted. That's fine, just ignore @@ -985,16 +1005,23 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks if fidx + 1 > length(defexpr.args) - for use in du.uses - if use > 0 # == `getfield` use - has_safe_def(ir, get_domtree(), allblocks, du, newidx, use) || @goto skip - else # == `isdefined` use - if has_safe_def(ir, get_domtree(), allblocks, du, newidx, -use) - ir[SSAValue(-use)][:inst] = true + for i = 1:length(du.uses) + use = du.uses[i] + if use.kind === :isdefined + if has_safe_def(ir, get_domtree(), allblocks, du, newidx, use.idx) + ir[SSAValue(use.idx)][:inst] = true else all_eliminated = false end + continue + elseif use.kind === :preserve + if length(du.defs) == 1 # allocation with this field unintialized + # there is nothing to preserve, just ignore this use + du.uses[i] = NoPreserve() + continue + end end + has_safe_def(ir, get_domtree(), allblocks, du, newidx, use.idx) || @goto skip end end end @@ -1003,8 +1030,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse # This needs to be after we iterate through the IR with `IncrementalCompact` # because removing dead blocks can invalidate the domtree. domtree = get_domtree() - preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : - IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) + local preserve_uses = nothing for fidx in 1:ndefuse du = fielddefuse[fidx] ftyp = fieldtype(typ, fidx) @@ -1017,17 +1043,24 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end # Now go through all uses and rewrite them for use in du.uses - if use > 0 # == `getfield` use - ir[SSAValue(use)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use) - else # == `isdefined` use + if use.kind === :getfield + ir[SSAValue(use.idx)][:inst] = compute_value_for_use(ir, domtree, allblocks, + du, phinodes, fidx, use.idx) + elseif use.kind === :isdefined continue # already rewritten if possible - end - end - if !isbitstype(ftyp) - if preserve_uses !== nothing - for (use, list) in preserve_uses - push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) + elseif use.kind === :nopreserve + continue # nothing to preserve (may happen when there are unintialized fields) + elseif use.kind === :preserve + newval = compute_value_for_use(ir, domtree, allblocks, + du, phinodes, fidx, use.idx) + if !isbitstype(widenconst(argextype(newval, ir))) + if preserve_uses === nothing + preserve_uses = IdDict{Int, Vector{Any}}() + end + push!(get!(()->Any[], preserve_uses, use.idx), newval) end + else + @assert false "sroa_mutables!: unexpected use" end end for b in phiblocks @@ -1056,8 +1089,9 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse push!(intermediaries, newidx) end # Insert the new preserves - for (use, new_preserves) in preserve_uses - ir[SSAValue(use)][:inst] = form_new_preserves(ir[SSAValue(use)][:inst]::Expr, intermediaries, new_preserves) + for (useidx, new_preserves) in preserve_uses + ir[SSAValue(useidx)][:inst] = form_new_preserves(ir[SSAValue(useidx)][:inst]::Expr, + intermediaries, new_preserves) end @label skip diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index d441c7ebc4889..fedc622b91148 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -409,10 +409,25 @@ end # preserve elimination # -------------------- +function ispreserved(@nospecialize(x)) + return function (@nospecialize(stmt),) + if Meta.isexpr(stmt, :foreigncall) + nccallargs = length(stmt.args[3]::Core.SimpleVector) + for pidx = (6+nccallargs):length(stmt.args) + if stmt.args[pidx] === x + return true + end + end + end + return false + end +end + let src = code_typed1((String,)) do s ccall(:some_ccall, Cint, (Ptr{String},), Ref(s)) end @test count(isnew, src.code) == 0 + @test any(ispreserved(#=s=#Core.Argument(2)), src.code) end # if the mutable struct is directly used, we shouldn't eliminate it @@ -425,6 +440,21 @@ let src = code_typed1() do @test count(isnew, src.code) == 1 end +# should eliminate allocation whose address isn't taked even if it has unintialized field(s) +mutable struct BadRef + x::String + y::String + BadRef(x) = new(x) +end +Base.cconvert(::Type{Ptr{BadRef}}, a::String) = BadRef(a) +Base.unsafe_convert(::Type{Ptr{BadRef}}, ar::BadRef) = Ptr{BadRef}(pointer_from_objref(ar.x)) +let src = code_typed1((String,)) do s + ccall(:jl_breakpoint, Cvoid, (Ptr{BadRef},), s) + end + @test count(isnew, src.code) == 0 + @test any(ispreserved(#=s=#Core.Argument(2)), src.code) +end + # isdefined elimination # ---------------------