Skip to content

Commit

Permalink
Turn on inference for OpaqueClosure (JuliaLang#39681)
Browse files Browse the repository at this point in the history
This turns on inference for `PartialOpaque` callees (but no
optimization/inlining yet and also no dynamic dispatch
to the optimized implementations). Because of the current design
and some fixes getting pulled into previous PRs, I believe this
is all that remains to be done on the inference front.

In particular, we specialize the OpaqueClosure methods on
the tuple formed by the tuple type of the environment
(at inference time) and the argument tuples. This is a bit of
an odd method specialization, but it seems like inference
is just fine with it in general. In the fullness of time,
we may want to store the specializations differently
to give more freedom to partial optimizations, but that
would require being able to re-enter inference later, which
is currently not possible.
  • Loading branch information
Keno committed Mar 5, 2021
1 parent d0d378e commit e49567f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
37 changes: 32 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ end
# where we would spend a lot of time, but are probably unliekly to get an improved
# result anyway.
function const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
# with the const-prop-ability. It is quite possible that we can't infer
# anything at all without const-propping, so the inlining check below
# isn't particularly helpful here.
return true
end
# Peek at the inferred result for the function to determine if the optimizer
# was able to cut it down to something simple (inlineable in particular).
# If so, there's a good chance we might be able to const prop all the way
Expand All @@ -371,7 +378,9 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
method = match.method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
length(argtypes) >= nargs || return Any, nothing
if length(argtypes) < nargs
return Any, nothing
end
haveconst = false
allconst = true
# see if any or all of the arguments are constant and propagating constants may be worthwhile
Expand Down Expand Up @@ -428,10 +437,14 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
end
force_inference |= allconst
mi = specialize_method(match, !force_inference)
mi === nothing && return Any, nothing
if mi === nothing
add_remark!(interp, sv, "[constprop] Failed to specialize")
return Any, nothing
end
mi = mi::MethodInstance
# decide if it's likely to be worthwhile
if !force_inference && !const_prop_heuristic(interp, method, mi)
add_remark!(interp, sv, "[constprop] Disabled by heuristic")
return Any, nothing
end
inf_cache = get_inference_cache(interp)
Expand All @@ -444,6 +457,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
cyclei = 0
while !(infstate === nothing)
if method === infstate.linfo.def && any(infstate.result.overridden_by_const)
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
return Any, nothing
end
if cyclei < length(infstate.callers_in_cycle)
Expand Down Expand Up @@ -1199,7 +1213,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
return CallMeta(Any, nothing)
pushfirst!(argtypes, closure.env)
sig = argtypes_to_type(argtypes)
rt, edgecycle, edge = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv)
info = OpaqueClosureCallInfo(edge)
if !edgecycle
const_rettype, result = abstract_call_method_with_const_args(interp, rt, closure, argtypes, MethodMatch(sig, Core.svec(), closure.source::Method, false), sv, edgecycle)
if const_rettype rt
rt = const_rettype
end
if result !== nothing
info = ConstCallInfo(info, result)
end
end
return CallMeta(rt, info)
end

function most_general_argtypes(closure::PartialOpaque)
Expand All @@ -1209,7 +1236,7 @@ function most_general_argtypes(closure::PartialOpaque)
if !isa(argt, DataType) || argt.name !== typename(Tuple)
argt = Tuple
end
return most_general_argtypes(closure.source, argt, closure.isva)
return most_general_argtypes(closure.source, argt, closure.isva, false)
end

# call where the function is any lattice element
Expand All @@ -1224,7 +1251,7 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes, sv)
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
else
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
if isa(given_argtype, Const) || isa(given_argtype, PartialStruct)
if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || isa(given_argtype, PartialOpaque)
return is_lattice_equal(given_argtype, cache_argtype)
end
return !overridden_by_const
Expand Down Expand Up @@ -46,11 +46,11 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector)
end

function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes),
isva::Bool)
isva::Bool, withfirst::Bool = true)
toplevel = method === nothing
linfo_argtypes = Any[unwrap_unionall(specTypes).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
if !toplevel && method.is_for_opaque_closure
if !withfirst
# For opaque closure, the closure environment is processed elsewhere
nargs -= 1
end
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ struct InvokeCallInfo
match::MethodMatch
end

struct OpaqueClosureCallInfo
mi::MethodInstance
end

# Stmt infos that are used by external consumers, but not by optimization.
# These are not produced by default and must be explicitly opted into by
# the AbstractInterpreter.
Expand Down
8 changes: 0 additions & 8 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,6 @@ function invoke_api(li::CodeInstance)
return ccall(:jl_invoke_api, Cint, (Any,), li)
end

function has_opaque_closure(c::CodeInfo)
for i = 1:length(c.code)
stmt = c.code[i]
(isa(stmt, Expr) && stmt.head === :new_opaque_closure) && return true
end
return false
end

function get_staged(mi::MethodInstance)
may_invoke_generator(mi) || return nothing
try
Expand Down
6 changes: 3 additions & 3 deletions test/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ let ci = @code_lowered OcClos2Int(1, 2)();
1, 2))()
end
end
@test oc_self_call_clos() == 3
@test @inferred(oc_self_call_clos()) == 3
let opt = @code_typed oc_self_call_clos()
@test_broken length(opt[1].code) == 1
@test_broken isa(opt[1].code[1], Core.ReturnNode)
Expand Down Expand Up @@ -85,8 +85,8 @@ end
function complicated_identity(x)
oc_infer_pass_id()(x)
end
@test_broken @inferred(complicated_identity(1)) == 1
@test_broken @inferred(complicated_identity("a")) == "a"
@test @inferred(complicated_identity(1)) == 1
@test @inferred(complicated_identity("a")) == "a"
let ci = (@code_typed complicated_identity(1))[1]
@test_broken length(ci.code) == 1
@test_broken isa(ci.code[1], Core.ReturnNode)
Expand Down

0 comments on commit e49567f

Please sign in to comment.