diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index a0ff86a218412..8b91b2b4a2947 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1060,6 +1060,11 @@ function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::I end end +function is_union_phi(compact::IncrementalCompact, idx::Int) + inst = compact.result[idx] + return isa(inst[:inst], PhiNode) && isa(inst[:type], Union) +end + """ adce_pass!(ir::IRCode) -> newir::IRCode @@ -1082,15 +1087,42 @@ within `sroa_pass!` which redirects references of `typeassert`ed value to the co function adce_pass!(ir::IRCode) phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) all_phis = Int[] + unionphis = Int[] # sorted + unionphi_types = Any[] compact = IncrementalCompact(ir) for ((_, idx), stmt) in compact if isa(stmt, PhiNode) push!(all_phis, idx) - elseif is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3 - # nullify safe `typeassert` calls - ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact)) - if isexact && argextype(stmt.args[2], compact) ⊑ ty - compact[idx] = nothing + if isa(compact.result[idx][:type], Union) + push!(unionphis, idx) + push!(unionphi_types, Union{}) + end + elseif isa(stmt, PiNode) + val = stmt.val + if isa(val, SSAValue) && is_union_phi(compact, val.id) + r = searchsorted(unionphis, val.id) + if !isempty(r) + unionphi_types[first(r)] = Union{unionphi_types[first(r)], widenconst(stmt.typ)} + end + end + else + if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3 + # nullify safe `typeassert` calls + ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact)) + if isexact && argextype(stmt.args[2], compact) ⊑ ty + compact[idx] = nothing + continue + end + end + for ur in userefs(stmt) + use = ur[] + if isa(use, SSAValue) && is_union_phi(compact, use.id) + r = searchsorted(unionphis, use.id) + if !isempty(r) + deleteat!(unionphis, first(r)) + deleteat!(unionphi_types, first(r)) + end + end end end end @@ -1098,6 +1130,34 @@ function adce_pass!(ir::IRCode) for phi in all_phis count_uses(compact.result[phi][:inst]::PhiNode, phi_uses) end + # Narrow any union phi nodes that have unused branches + @assert length(unionphis) == length(unionphi_types) + for i = 1:length(unionphis) + phi = unionphis[i] + t = unionphi_types[i] + if phi_uses[phi] != 0 + continue + end + if t === Union{} + compact.result[phi][:inst] = nothing + continue + end + to_drop = Int[] + stmt = compact[phi] + stmt === nothing && continue + for i = 1:length(stmt.values) + if !isassigned(stmt.values, i) + # Should be impossible to have something used only by PiNodes that's undef + push!(to_drop, i) + elseif !hasintersect(widenconst(argextype(stmt.values[i], compact)), t) + push!(to_drop, i) + end + end + isempty(to_drop) && continue + deleteat!(stmt.values, to_drop) + deleteat!(stmt.edges, to_drop) + compact.result[phi][:type] = t + end # Perform simple DCE for unused values extra_worklist = Int[] for (idx, nused) in Iterators.enumerate(compact.used_ssas) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index c1089443cd62d..f768109614657 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -838,3 +838,20 @@ let ci = code_typed1(optimize=false) do ir = Core.Compiler.compact!(ir, true) @test count(@nospecialize(stmt)->isa(stmt, Core.GotoIfNot), ir.stmts.inst) == 0 end + +# Test that adce_pass! can drop phi node uses that can be concluded unused +# from PiNode analysis. +let src = @eval Module() begin + @noinline mkfloat() = rand(Float64) + @noinline use(a::Float64) = ccall(:jl_, Cvoid, (Any,), a) + dispatch(a::Float64) = use(a) + dispatch(a::Tuple) = nothing + function foo(b) + a = mkfloat() + a = b ? (a, 2.0) : a + dispatch(a) + end + code_typed(foo, Tuple{Bool})[1][1] + end + @test count(iscall((src, Core.tuple)), src.code) == 0 +end