Skip to content

Commit

Permalink
lattice: Thread through lattice argument for getfield_tfunc (#47097)
Browse files Browse the repository at this point in the history
Like `tuple`, `getfield` needs some lattice awareness to give
the correct answer in the presence of extended lattices. Refactor
to split and thread through the lattice argument through
_getfield_tfunc so external lattices can provide `getfield` tfuncs
for their custom elements.
  • Loading branch information
Keno committed Oct 10, 2022
1 parent 25e3809 commit d0b15c2
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 35 deletions.
4 changes: 2 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1326,13 +1326,13 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
if !isa(stateordonet_widened, DataType) || !(stateordonet_widened <: Tuple) || isvatuple(stateordonet_widened) || length(stateordonet_widened.parameters) != 2
break
end
nstatetype = getfield_tfunc(stateordonet, Const(2))
nstatetype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(2))
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if (typeinf_lattice(interp), nstatetype, statetype)
return Any[Bottom], nothing
end
valtype = getfield_tfunc(stateordonet, Const(1))
valtype = getfield_tfunc(typeinf_lattice(interp), stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), StmtInfo(true), sv)
Expand Down
99 changes: 70 additions & 29 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -854,25 +854,33 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), boundscheck::
return false
end

function getfield_tfunc(s00, name, boundscheck_or_order)
@nospecialize
function getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00),
@nospecialize(name), @nospecialize(boundscheck_or_order))
t = isvarargtype(boundscheck_or_order) ? unwrapva(boundscheck_or_order) :
widenconst(boundscheck_or_order)
hasintersect(t, Symbol) || hasintersect(t, Bool) || return Bottom
return getfield_tfunc(s00, name)
return getfield_tfunc(lattice, s00, name)
end
function getfield_tfunc(s00, name, order, boundscheck)
@nospecialize
function getfield_tfunc(@nospecialize(s00), name, boundscheck_or_order)
return getfield_tfunc(fallback_lattice, s00, name, boundscheck_or_order)
end
function getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00),
@nospecialize(name), @nospecialize(order), @nospecialize(boundscheck))
hasintersect(widenconst(order), Symbol) || return Bottom
if isvarargtype(boundscheck)
t = unwrapva(boundscheck)
hasintersect(t, Symbol) || hasintersect(t, Bool) || return Bottom
else
hasintersect(widenconst(boundscheck), Bool) || return Bottom
end
return getfield_tfunc(s00, name)
return getfield_tfunc(lattice, s00, name)
end
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(s00, name, false)
function getfield_tfunc(@nospecialize(s00), @nospecialize(name), @nospecialize(order), @nospecialize(boundscheck))
return getfield_tfunc(fallback_lattice, s00, name, order, boundscheck)
end
getfield_tfunc(@nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(fallback_lattice, s00, name, false)
getfield_tfunc(@specialize(lattice::AbstractLattice), @nospecialize(s00), @nospecialize(name)) = _getfield_tfunc(lattice, s00, name, false)


function _getfield_fieldindex(@nospecialize(s), name::Const)
nv = name.val
Expand Down Expand Up @@ -902,10 +910,46 @@ function _getfield_tfunc_const(@nospecialize(sv), name::Const, setfield::Bool)
return nothing
end

function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, Conditional)
function _getfield_tfunc(@specialize(lattice::InferenceLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, LimitedAccuracy)
# This will error, but it's better than duplicating the error here
s00 = widenconst(s00)
end
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(@specialize(lattice::OptimizerLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
# If undef, that's a Union, but that doesn't affect the rt when tmerged
# into the unwrapped result.
isa(s00, MaybeUndef) && (s00 = s00.typ)
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(@specialize(lattice::AnyConditionalsLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, AnyConditional)
return Bottom # Bool has no fields
elseif isa(s00, Const)
end
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(@specialize(lattice::PartialsLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, PartialStruct)
s = widenconst(s00)
sty = unwrap_unionall(s)::DataType
if isa(name, Const)
nv = _getfield_fieldindex(sty, name)
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
return unwrapva(s00.fields[nv])
end
end
s00 = s
end

return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(lattice::ConstsLattice, @nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, Const)
sv = s00.val
if isa(name, Const)
nv = name.val
Expand All @@ -919,30 +963,24 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
r = _getfield_tfunc_const(sv, name, setfield)
r !== nothing && return r
end
s = typeof(sv)
elseif isa(s00, PartialStruct)
s = widenconst(s00)
sty = unwrap_unionall(s)::DataType
if isa(name, Const)
nv = _getfield_fieldindex(sty, name)
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
return unwrapva(s00.fields[nv])
end
end
else
s = unwrap_unionall(s00)
s00 = widenconst(s00)
end
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(lattice::JLTypeLattice, @nospecialize(s00), @nospecialize(name), setfield::Bool)
s = unwrap_unionall(s00)
if isa(s, Union)
return tmerge(_getfield_tfunc(rewrap_unionall(s.a, s00), name, setfield),
_getfield_tfunc(rewrap_unionall(s.b, s00), name, setfield))
return tmerge(_getfield_tfunc(lattice, rewrap_unionall(s.a, s00), name, setfield),
_getfield_tfunc(lattice, rewrap_unionall(s.b, s00), name, setfield))
end
if isType(s)
if isconstType(s)
sv = s00.parameters[1]
if isa(name, Const)
if isa(name, Const)
r = _getfield_tfunc_const(sv, name, setfield)
r !== nothing && return r
end
end
s = typeof(sv)
else
sv = s.parameters[1]
Expand Down Expand Up @@ -982,7 +1020,7 @@ function _getfield_tfunc(@nospecialize(s00), @nospecialize(name), setfield::Bool
if !(_ts <: Tuple)
return Any
end
return _getfield_tfunc(_ts, name, setfield)
return _getfield_tfunc(lattice, _ts, name, setfield)
end
ftypes = datatype_fieldtypes(s)
nf = length(ftypes)
Expand Down Expand Up @@ -1090,7 +1128,7 @@ end
function setfield!_tfunc(o, f, v)
@nospecialize
mutability_errorcheck(o) || return Bottom
ft = _getfield_tfunc(o, f, true)
ft = _getfield_tfunc(fallback_lattice, o, f, true)
ft === Bottom && return Bottom
hasintersect(widenconst(v), widenconst(ft)) || return Bottom
return v
Expand Down Expand Up @@ -1168,7 +1206,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any
# as well as compute the info for the method matches
op = unwrapva(argtypes[4])
v = unwrapva(argtypes[5])
TF = getfield_tfunc(o, f)
TF = getfield_tfunc(typeinf_lattice(interp), o, f)
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), sv, #=max_methods=# 1)
TF2 = tmeet(callinfo.rt, widenconst(TF))
if TF2 === Bottom
Expand Down Expand Up @@ -2118,6 +2156,9 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
# wrong # of args
return Bottom
end
if f === getfield
return getfield_tfunc(typeinf_lattice(interp), argtypes...)
end
return tf[3](argtypes...)
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ function tmeet(lattice::PartialsLattice, @nospecialize(v), @nospecialize(t::Type
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(lattice, vfi, widenconst(getfield_tfunc(t, Const(i))))
new_fields[i] = tmeet(lattice, vfi, widenconst(getfield_tfunc(lattice, t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function issimplertype(lattice::AbstractLattice, @nospecialize(typea), @nospecia
bi = (tni.val::Core.TypeName).wrapper
is_lattice_equal(lattice, ai, bi) && continue
end
bi = getfield_tfunc(typeb, Const(i))
bi = getfield_tfunc(lattice, typeb, Const(i))
is_lattice_equal(lattice, ai, bi) && continue
# It is not enough for ai to be simpler than bi: it must exactly equal
# (for this, an invariant struct field, by contrast to
Expand Down Expand Up @@ -490,8 +490,8 @@ function tmerge(lattice::PartialsLattice, @nospecialize(typea), @nospecialize(ty
fields = Vector{Any}(undef, type_nfields)
anyrefine = false
for i = 1:type_nfields
ai = getfield_tfunc(typea, Const(i))
bi = getfield_tfunc(typeb, Const(i))
ai = getfield_tfunc(lattice, typea, Const(i))
bi = getfield_tfunc(lattice, typeb, Const(i))
ft = fieldtype(aty, i)
if is_lattice_equal(lattice, ai, bi) || is_lattice_equal(lattice, ai, ft)
# Since ai===bi, the given type has no restrictions on complexity.
Expand Down

0 comments on commit d0b15c2

Please sign in to comment.