Skip to content

Commit

Permalink
Merge pull request JuliaLang#39221 from JuliaLang/jn/conditional-bugs
Browse files Browse the repository at this point in the history
fix some issues with inference Conditionals
  • Loading branch information
vtjnash committed Jan 13, 2021
2 parents 7e85405 + 95d03f9 commit 7896e77
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 61 deletions.
42 changes: 25 additions & 17 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,20 +783,28 @@ end
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
la = length(argtypes)
if f === ifelse && fargs isa Vector{Any} && la == 4 && argtypes[2] isa Conditional
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
cnd = argtypes[2]::Conditional
tx = argtypes[3]
ty = argtypes[4]
a = ssa_def_slot(fargs[3], sv)
b = ssa_def_slot(fargs[4], sv)
if isa(a, Slot) && slot_id(cnd.var) == slot_id(a)
tx = typeintersect(tx, cnd.vtype)
end
if isa(b, Slot) && slot_id(cnd.var) == slot_id(b)
ty = typeintersect(ty, cnd.elsetype)
end
return tmerge(tx, ty)
if f === ifelse && fargs isa Vector{Any} && la == 4
cnd = argtypes[2]
if isa(cnd, Conditional)
newcnd = widenconditional(cnd)
if isa(newcnd, Const)
# if `cnd` is constant, we should just respect its constantness to keep inference accuracy
return newcnd.val ? tx : ty
else
# try to simulate this as a real conditional (`cnd ? x : y`), so that the penalty for using `ifelse` instead isn't too high
tx = argtypes[3]
ty = argtypes[4]
a = ssa_def_slot(fargs[3], sv)
b = ssa_def_slot(fargs[4], sv)
if isa(a, Slot) && slot_id(cnd.var) == slot_id(a)
tx = (cnd.vtype tx ? cnd.vtype : tmeet(tx, widenconst(cnd.vtype)))
end
if isa(b, Slot) && slot_id(cnd.var) == slot_id(b)
ty = (cnd.elsetype ty ? cnd.elsetype : tmeet(ty, widenconst(cnd.elsetype)))
end
return tmerge(tx, ty)
end
end
end
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
Expand Down Expand Up @@ -1179,9 +1187,9 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if length(e.args) == 2 && isconcretetype(t) && !t.mutable
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
n = fieldcount(t)
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val) &&
let t = t, at = at; _all(i->at.val[i] isa fieldtype(t, i), 1:n); end
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
if isa(at, Const) && (val = at.val; isa(val, Tuple)) && n == length(val) &&
let t = t, val = val; _all(i->val[i] isa fieldtype(t, i), 1:n); end
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, val))
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields) &&
let t = t, at = at; _all(i->at.fields[i] fieldtype(t, i), 1:n); end
t = PartialStruct(t, at.fields)
Expand Down
41 changes: 2 additions & 39 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,47 +543,10 @@ function typeof_tfunc(@nospecialize(t))
end
add_tfunc(typeof, 1, 1, typeof_tfunc, 0)

function typeassert_type_instance(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
if ti === Bottom
return Bottom
end
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
if isa(v.fields[i], Core.TypeofVararg)
new_fields[i] = v.fields[i]
else
new_fields[i] = typeassert_type_instance(v.fields[i], getfield_tfunc(t, Const(i)))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
return typeintersect(widenconst(v), t)
end
function typeassert_tfunc(@nospecialize(v), @nospecialize(t))
t = instanceof_tfunc(t)[1]
t === Any && return v
return typeassert_type_instance(v, t)
return tmeet(v, t)
end
add_tfunc(typeassert, 2, 2, typeassert_tfunc, 4)

Expand Down Expand Up @@ -1652,7 +1615,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
if contains_is(argtypes_vec, Union{})
return Const(Union{})
end
rt = abstract_call(interp, nothing, argtypes_vec, sv, -1).rt
rt = widenconditional(abstract_call(interp, nothing, argtypes_vec, sv, -1).rt)
if isa(rt, Const)
# output was computed to be constant
return Const(typeof(rt.val))
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
if invoke_api(code) == 2
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ ReturnNode(quoted(code.rettype_const)) ]
rettype_const = code.rettype_const
tree.code = Any[ ReturnNode(quoted(rettype_const)) ]
nargs = Int(method.nargs)
tree.slotnames = ccall(:jl_uncompress_argnames, Vector{Symbol}, (Any,), method.slot_syms)
tree.slotflags = fill(0x00, nargs)
Expand All @@ -814,7 +815,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
tree.pure = true
tree.inlineable = true
tree.parent = mi
tree.rettype = Core.Typeof(code.rettype_const)
tree.rettype = Core.Typeof(rettype_const)
tree.min_world = code.min_world
tree.max_world = code.max_world
return tree
Expand Down
11 changes: 8 additions & 3 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,20 @@ function ⊑(@nospecialize(a), @nospecialize(b))
end
isa(a, MaybeUndef) && (a = a.typ)
isa(b, MaybeUndef) && (b = b.typ)
(a === NOT_FOUND || b === Any) && return true
(a === Any || b === NOT_FOUND) && return false
b === Any && return true
a === Any && return false
a === Union{} && return true
b === Union{} && return false
@assert !isa(a, TypeVar) "invalid lattice item"
@assert !isa(b, TypeVar) "invalid lattice item"
if isa(a, Conditional)
if isa(b, Conditional)
return issubconditional(a, b)
issubconditional(a, b) && return true
b = maybe_extract_const_bool(b)
if b isa Bool && maybe_extract_const_bool(a) === b
return true
end
return false
elseif isa(b, Const) && isa(b.val, Bool)
return maybe_extract_const_bool(a) === b.val
end
Expand Down
44 changes: 44 additions & 0 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
isa(typeb, MaybeUndef) ? typeb.typ : typeb))
end
# type-lattice for Conditional wrapper
if isa(typea, Conditional) && isa(typeb, Conditional) && typea.var !== typeb.var
widenconditional(typea) isa Const && (typea = widenconditional(typea))
widenconditional(typeb) isa Const && (typeb = widenconditional(typeb))
end
if isa(typea, Conditional) && isa(typeb, Const)
if typeb.val === true
typeb = Conditional(typea.var, Any, Union{})
Expand Down Expand Up @@ -520,3 +524,43 @@ function tuplemerge(a::DataType, b::DataType)
end
return Tuple{p...}
end

# compute typeintersect over the extended inference lattice
# where v is in the extended lattice, and t is a Type
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
if ti === Bottom
return Bottom
end
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
if isa(v.fields[i], Core.TypeofVararg)
new_fields[i] = v.fields[i]
else
new_fields[i] = tmeet(v.fields[i], widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
return typeintersect(widenconst(v), t)
end

0 comments on commit 7896e77

Please sign in to comment.