Skip to content

Commit

Permalink
Merge 6afd0e1 into 403e4e2
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Mar 17, 2023
2 parents 403e4e2 + 6afd0e1 commit 22c94a3
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 27 deletions.
13 changes: 7 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# as we may want to concrete-evaluate this frame in cases when there are
# no overlayed calls, try an additional effort now to check if this call
# isn't overlayed rather than just handling it conservatively
matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if !isa(matches, FailedMethodMatch)
nonoverlayed = matches.nonoverlayed
Expand All @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
Expand Down Expand Up @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
function find_matching_methods(𝕃::AbstractLattice,
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
max_union_splitting::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(argtypes)
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(𝕃, argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
Expand Down Expand Up @@ -1495,7 +1496,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum
splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum
ctypes = [Any[aft]]
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
effects = EFFECTS_TOTAL
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
has_mustalias(::AnyMustAliasesLattice) = true
has_mustalias(::JLTypeLattice) = false

has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃))
has_extended_unionsplit(::AnyMustAliasesLattice) = true
has_extended_unionsplit(::JLTypeLattice) = false

# Curried versions
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ end
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)

_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts)

"""
alias::InterMustAlias
Expand Down
28 changes: 15 additions & 13 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters)
if 1 < unionsplitcost(a.parameters) <= max_union_splitting
if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting
ta = switchtupleunion(a)
return typesubtract(Union{ta...}, b, 0)
elseif b isa DataType
Expand Down Expand Up @@ -227,12 +227,11 @@ end
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
# outside than inside because the representation is larger (because and it
# informs the callee whether any splitting is possible).
function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}})
function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}})
nu = 1
max = 2
for ti in argtypes
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
if has_extended_unionsplit(𝕃) && !isvarargtype(ti)
ti = widenconst(ti)
end
if isa(ti, Union)
Expand All @@ -252,12 +251,12 @@ end
# and `Union{return...} == ty`
function switchtupleunion(@nospecialize(ty))
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty)
end

switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing)
switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing)

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
if i == 0
if origt === nothing
push!(tunion, copy(t))
Expand All @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
else
origti = ti = t[i]
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
ti = widenconst(ti)
end
if isa(ti, Union)
for ty in uniontypes(ti::Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
else
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
end
return tunion
Expand Down
14 changes: 7 additions & 7 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2944,11 +2944,11 @@ end
# issue #28356
# unit test to make sure countunionsplit overflows gracefully
# we don't care what number is returned as long as it's large
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6

# make sure compiler doesn't hang in union splitting

Expand Down Expand Up @@ -3949,13 +3949,13 @@ end

# argtypes
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)])
@test length(tunion) == 2
@test Any[Int32, Core.Const(nothing)] in tunion
@test Any[Int64, Core.Const(nothing)] in tunion
end
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
@test length(tunion) == 4
@test Any[Int32, Float32, Core.Const(nothing)] in tunion
@test Any[Int32, Float64, Core.Const(nothing)] in tunion
Expand Down

0 comments on commit 22c94a3

Please sign in to comment.