diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index c4ce74fff9248..9461effc933f6 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -16,31 +16,13 @@ const _REF_NAME = Ref.body.name call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) = isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc]) -function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt) - box = Core.Box(atype) - return get!(cache, atype) do - _min_val = UInt[typemin(UInt)] - _max_val = UInt[typemax(UInt)] - _ambig = Int32[0] - ms = _methods_by_ftype(box.contents, max_methods, world, false, _min_val, _max_val, _ambig) - return ms, _min_val[1], _max_val[1], _ambig[1] != 0 - end -end - -function matching_methods(@nospecialize(atype), cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}}, max_methods::Int, world::UInt, min_valid::Vector{UInt}, max_valid::Vector{UInt}) - ms, minvalid, maxvalid, ambig = matching_methods(atype, cache, max_methods, world) - min_valid[1] = max(min_valid[1], minvalid) - max_valid[1] = min(max_valid[1], maxvalid) - return ms, ambig -end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) if sv.currpc in sv.throw_blocks return CallMeta(Any, false) end - min_valid = UInt[typemin(UInt)] - max_valid = UInt[typemax(UInt)] + valid_worlds = WorldRange() atype_params = unwrap_unionall(atype).parameters splitunions = 1 < countunionsplit(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING mts = Core.MethodTable[] @@ -56,15 +38,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(Any, false) end mt = mt::Core.MethodTable - xapplicable, ambig = matching_methods(sig_n, sv.matching_methods_cache, max_methods, - get_world_counter(interp), min_valid, max_valid) - if xapplicable === false + matches = findall(sig_n, method_table(interp); limit=max_methods) + if matches === missing add_remark!(interp, sv, "For one of the union split cases, too many methods matched") return CallMeta(Any, false) end - push!(infos, MethodMatchInfo(xapplicable, ambig)) - append!(applicable, xapplicable) - thisfullmatch = _any(match->(match::MethodMatch).fully_covers, xapplicable) + push!(infos, MethodMatchInfo(matches)) + append!(applicable, matches) + valid_worlds = intersect(valid_worlds, matches.valid_worlds) + thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) found = false for (i, mt′) in enumerate(mts) if mt′ === mt @@ -86,19 +68,20 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(Any, false) end mt = mt::Core.MethodTable - applicable, ambig = matching_methods(atype, sv.matching_methods_cache, max_methods, - get_world_counter(interp), min_valid, max_valid) - if applicable === false + matches = findall(atype, method_table(interp, sv); limit=max_methods) + if matches === missing # this means too many methods matched # (assume this will always be true, so we don't compute / update valid age in this case) add_remark!(interp, sv, "Too many methods matched") return CallMeta(Any, false) end push!(mts, mt) - push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, applicable)) - info = MethodMatchInfo(applicable, ambig) + push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches)) + info = MethodMatchInfo(matches) + applicable = matches.matches + valid_worlds = matches.valid_worlds end - update_valid_age!(min_valid[1], max_valid[1], sv) + update_valid_age!(sv, valid_worlds) applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom @@ -1460,12 +1443,7 @@ function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState) typeinf_local(interp, caller) no_active_ips_in_callers = false end - if caller.min_valid < frame.min_valid - caller.min_valid = frame.min_valid - end - if caller.max_valid > frame.max_valid - caller.max_valid = frame.max_valid - end + caller.valid_worlds = intersect(caller.valid_worlds, frame.valid_worlds) end end return true diff --git a/base/compiler/cicache.jl b/base/compiler/cicache.jl index d82665b61fc9a..9adaf6ded0b0f 100644 --- a/base/compiler/cicache.jl +++ b/base/compiler/cicache.jl @@ -14,6 +14,23 @@ end const GLOBAL_CI_CACHE = InternalCodeCache() +struct WorldRange + min_world::UInt + max_world::UInt +end +WorldRange() = WorldRange(typemin(UInt), typemax(UInt)) +WorldRange(w::UInt) = WorldRange(w, w) +WorldRange(r::UnitRange) = WorldRange(first(r), last(r)) +first(wr::WorldRange) = wr.min_world +last(wr::WorldRange) = wr.max_world +in(world::UInt, wr::WorldRange) = wr.min_world <= world <= wr.max_world + +function intersect(a::WorldRange, b::WorldRange) + ret = WorldRange(max(a.min_world, b.min_world), min(a.max_world, b.max_world)) + @assert ret.min_world <= ret.max_world + return ret +end + """ struct WorldView @@ -22,20 +39,19 @@ range of world ages, rather than defaulting to the current active world age. """ struct WorldView{Cache} cache::Cache - min_world::UInt - max_world::UInt + worlds::WorldRange + WorldView(cache::Cache, range::WorldRange) where Cache = new{Cache}(cache, range) end -WorldView(cache, r::UnitRange) = WorldView(cache, first(r), last(r)) -WorldView(cache, world::UInt) = WorldView(cache, world, world) -WorldView(wvc::WorldView, min_world::UInt, max_world::UInt) = - WorldView(wvc.cache, min_world, max_world) +WorldView(cache, args...) = WorldView(cache, WorldRange(args...)) +WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr) +WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...) function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance) - ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} !== nothing + ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing end function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default) - r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, wvc.min_world, wvc.max_world)::Union{Nothing, CodeInstance} + r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} if r === nothing return default end diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 29573312d4e5a..516fb0f2cdcb9 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -100,9 +100,11 @@ include("compiler/types.jl") include("compiler/utilities.jl") include("compiler/validation.jl") +include("compiler/cicache.jl") +include("compiler/methodtable.jl") + include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") -include("compiler/cicache.jl") include("compiler/typeutils.jl") include("compiler/typelimits.jl") diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index b4d6a7ca4f98b..56b766592787e 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -14,8 +14,7 @@ mutable struct InferenceState # info on the state of inference and the linfo src::CodeInfo world::UInt - min_valid::UInt - max_valid::UInt + valid_worlds::WorldRange nargs::Int stmt_types::Vector{Any} stmt_edges::Vector{Any} @@ -44,9 +43,10 @@ mutable struct InferenceState inferred::Bool dont_work_on_me::Bool - # cached results of calling `_methods_by_ftype`, including `min_valid` and - # `max_valid`, to be used in inlining - matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}} + # The place to look up methods while working on this function. + # In particular, we cache method lookup results for the same function to + # fast path repeated queries. + method_table::CachedMethodTable{InternalMethodTable} # The interpreter that created this inference state. Not looked at by # NativeInterpreter. But other interpreters may use this to detect cycles @@ -100,13 +100,12 @@ mutable struct InferenceState inmodule = linfo.def::Module end - min_valid = src.min_world - max_valid = src.max_world == typemax(UInt) ? - get_world_counter() : src.max_world + valid_worlds = WorldRange(src.min_world, + src.max_world == typemax(UInt) ? get_world_counter() : src.max_world) frame = new( InferenceParams(interp), result, linfo, sp, slottypes, inmodule, 0, - src, get_world_counter(interp), min_valid, max_valid, + src, get_world_counter(interp), valid_worlds, nargs, s_types, s_edges, stmt_info, Union{}, W, 1, n, cur_hand, handler_at, n_handlers, @@ -115,7 +114,7 @@ mutable struct InferenceState Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, cached, false, false, false, - IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), + CachedMethodTable(method_table(interp)), interp) result.result = frame cached && push!(get_inference_cache(interp), result) @@ -123,6 +122,8 @@ mutable struct InferenceState end end +method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table + function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter) # prepare an InferenceState object for inferring lambda src = retrieve_code_info(result.linfo) @@ -202,14 +203,13 @@ end _topmod(sv::InferenceState) = _topmod(sv.mod) # work towards converging the valid age range for sv -function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState) - sv.min_valid = max(sv.min_valid, min_valid) - sv.max_valid = min(sv.max_valid, max_valid) - @assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update") +function update_valid_age!(sv::InferenceState, worlds::WorldRange) + sv.valid_worlds = intersect(worlds, sv.valid_worlds) + @assert(sv.world in sv.valid_worlds, "invalid age range update") nothing end -update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv) +update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds) function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState) old = frame.src.ssavaluetypes[ssa_id] diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl new file mode 100644 index 0000000000000..9d2d8c9fa351a --- /dev/null +++ b/base/compiler/methodtable.jl @@ -0,0 +1,93 @@ +abstract type MethodTableView; end + +struct MethodLookupResult + # Really Vector{Core.MethodMatch}, but it's easier to represent this as + # and work with Vector{Any} on the C side. + matches::Vector{Any} + valid_worlds::WorldRange + ambig::Bool +end +length(result::MethodLookupResult) = length(result.matches) +function iterate(result::MethodLookupResult, args...) + r = iterate(result.matches, args...) + r === nothing && return nothing + match, state = r + return (match::MethodMatch, state) +end +getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch + +""" + struct InternalMethodTable <: MethodTableView + +A singleton struct representing the state of the internal method table at a +particular world age. +""" +struct InternalMethodTable <: MethodTableView + world::UInt +end + +""" + struct InternalMethodTable <: MethodTableView + +Overlays another method table view with an additional local fast path cache that +can respond to repeated, identical queries faster than the original method table. +""" +struct CachedMethodTable{T} <: MethodTableView + cache::IdDict{Any, Union{Missing, MethodLookupResult}} + table::T +end +CachedMethodTable(table::T) where T = + CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodLookupResult}}(), + table) + +""" + findall(sig::Type{<:Tuple}, view::MethodTableView; limit=typemax(Int)) + +Find all methods in the given method table `view` that are applicable to the +given signature `sig`. If no applicable methods are found, an empty result is +returned. If the number of applicable methods exeeded the specified limit, +`missing` is returned. +""" +function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; limit::Int=typemax(Int)) + _min_val = RefValue{UInt}(typemin(UInt)) + _max_val = RefValue{UInt}(typemax(UInt)) + _ambig = RefValue{Int32}(0) + ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig) + if ms === false + return missing + end + return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) +end + +function findall(@nospecialize(sig::Type{<:Tuple}), table::CachedMethodTable; limit::Int=typemax(Int)) + box = Core.Box(sig) + return get!(table.cache, sig) do + findall(box.contents, table.table; limit=limit) + end +end + +""" + findsup(sig::Type{<:Tuple}, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing} + +Find the (unique) method `m` such that `sig <: m.sig`, while being more +specific than any other method with the same property. In other words, find +the method which is the least upper bound (supremum) under the specificity/subtype +relation of the queried `signature`. If `sig` is concrete, this is equivalent to +asking for the method that will be called given arguments whose types match the +given signature. This query is also used to implement `invoke`. + +Such a method `m` need not exist. It is possible that no method is an +upper bound of `sig`, or it is possible that among the upper bounds, there +is no least element. In both cases `nothing` is returned. +""" +function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable) + min_valid = RefValue{UInt}(typemin(UInt)) + max_valid = RefValue{UInt}(typemax(UInt)) + result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}), + sig, table.world, min_valid, max_valid)::Union{Method, Nothing} + result === nothing && return nothing + (result, WorldRange(min_valid[], max_valid[])) +end + +# This query is not cached +findsup(sig::Type{<:Tuple}, table::CachedMethodTable) = findsup(sig, table.table) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 2bfb15ddb0fc4..4e3b6e0723b19 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -13,14 +13,10 @@ mutable struct OptimizationState mod::Module nargs::Int world::UInt - min_valid::UInt - max_valid::UInt + valid_worlds::WorldRange sptypes::Vector{Any} # static parameters slottypes::Vector{Any} const_api::Bool - # cached results of calling `_methods_by_ftype` from inference, including - # `min_valid` and `max_valid` - matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt, Bool}} # TODO: This will be eliminated once optimization no longer needs to do method lookups interp::AbstractInterpreter function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter) @@ -33,9 +29,9 @@ mutable struct OptimizationState return new(params, frame.linfo, s_edges::Vector{Any}, src, frame.stmt_info, frame.mod, frame.nargs, - frame.world, frame.min_valid, frame.max_valid, + frame.world, frame.valid_worlds, frame.sptypes, frame.slottypes, false, - frame.matching_methods_cache, interp) + interp) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter) # prepare src for running optimization passes @@ -64,9 +60,9 @@ mutable struct OptimizationState return new(params, linfo, s_edges::Vector{Any}, src, stmt_info, inmodule, nargs, - get_world_counter(), UInt(1), get_world_counter(), + get_world_counter(), WorldRange(UInt(1), get_world_counter()), sptypes_from_meth_instance(linfo), slottypes, false, - IdDict{Any, Tuple{Any, UInt, UInt, Bool}}(), interp) + interp) end end @@ -110,11 +106,9 @@ const TOP_TUPLE = GlobalRef(Core, :tuple) _topmod(sv::OptimizationState) = _topmod(sv.mod) -function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState) - sv.min_valid = max(sv.min_valid, min_valid) - sv.max_valid = min(sv.max_valid, max_valid) - @assert(sv.min_valid <= sv.world <= sv.max_valid, - "invalid age range update") +function update_valid_age!(sv::OptimizationState, valid_worlds::WorldRange) + sv.valid_worlds = intersect(sv.valid_worlds, valid_worlds) + @assert(sv.world in sv.valid_worlds, "invalid age range update") nothing end @@ -126,7 +120,7 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState) end function add_backedge!(li::CodeInstance, caller::OptimizationState) - update_valid_age!(min_world(li), max_world(li), caller) + update_valid_age!(caller, WorldRange(min_world(li), max_world(li))) add_backedge!(li.def, caller) nothing end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 8553baa2922f7..926491223795c 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -579,19 +579,19 @@ function batch_inline!(todo::Vector{Any}, ir::IRCode, linetable::Vector{LineInfo end function spec_lambda(@nospecialize(atype), sv::OptimizationState, @nospecialize(invoke_data)) - min_valid = UInt[typemin(UInt)] - max_valid = UInt[typemax(UInt)] + min_valid = RefValue{UInt}(typemin(UInt)) + max_valid = RefValue{UInt}(typemax(UInt)) if invoke_data === nothing mi = ccall(:jl_get_spec_lambda, Any, (Any, UInt, Ptr{UInt}, Ptr{UInt}), atype, sv.world, min_valid, max_valid) else invoke_data = invoke_data::InvokeData atype <: invoke_data.types0 || return nothing mi = ccall(:jl_get_invoke_lambda, Any, (Any, Any), invoke_data.entry, atype) - min_valid[1] = invoke_data.min_valid - max_valid[1] = invoke_data.max_valid + min_valid[] = invoke_data.min_valid + max_valid[] = invoke_data.max_valid end mi !== nothing && add_backedge!(mi::MethodInstance, sv) - update_valid_age!(min_valid[1], max_valid[1], sv) + update_valid_age!(sv, WorldRange(min_valid[], max_valid[])) return mi end @@ -987,7 +987,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok methsp = methsp::SimpleVector result = analyze_method!(idx, sig, metharg, methsp, method, stmt, sv, true, invoke_data, calltype) handle_single_case!(ir, stmt, idx, result, true, todo) - update_valid_age!(invoke_data.min_valid, invoke_data.max_valid, sv) + update_valid_age!(sv, WorldRange(invoke_data.min_valid, invoke_data.max_valid)) return nothing end @@ -1053,10 +1053,9 @@ function recompute_method_matches(@nospecialize(atype), sv::OptimizationState) # in the case that the cache is nonempty, so it should be unchanged # The max number of methods should be the same as in inference most # of the time, and should not affect correctness otherwise. - (meth, min_valid, max_valid, ambig) = - matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world) - update_valid_age!(min_valid, max_valid, sv) - MethodMatchInfo(meth, ambig) + results = findall(atype, InternalMethodTable(sv.world); limit=sv.params.MAX_METHODS) + results !== missing && update_valid_age!(sv, results.valid_worlds) + MethodMatchInfo(results) end function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecialize(stmt), @@ -1069,8 +1068,8 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecia local fully_covered = true for i in 1:length(infos) info = infos[i] - meth = info.applicable - if meth === false || info.ambig + meth = info.results + if meth === missing || meth.ambig # Too many applicable methods # Or there is a (partial?) ambiguity too_many = true @@ -1087,7 +1086,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Any}, idx::Int, @nospecia else only_method = false end - for match in meth::Vector{Any} + for match in meth (metharg, methsp, method) = (match.spec_types, match.sparams, match.method) signature_union = Union{signature_union, metharg} if !isdispatchtuple(metharg) @@ -1179,7 +1178,7 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState) if !isa(info, UnionSplitInfo) infos = MethodMatchInfo[] for union_sig in UnionSplitSignature(sig.atypes) - push!(infos, recompute_method_matches(union_sig, sv)) + push!(infos, recompute_method_matches(argtypes_to_type(union_sig), sv)) end else infos = info.matches @@ -1221,12 +1220,10 @@ function compute_invoke_data(@nospecialize(atypes), world::UInt) return nothing end invoke_types = rewrap_unionall(Tuple{ft, unwrap_unionall(invoke_tt).parameters...}, invoke_tt) - min_valid = UInt[typemin(UInt)] - max_valid = UInt[typemax(UInt)] - invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), - invoke_types, world) # XXX: min_valid, max_valid + invoke_entry = findsup(invoke_types, InternalMethodTable(world)) invoke_entry === nothing && return nothing - invoke_data = InvokeData(invoke_entry::Method, invoke_types, min_valid[1], max_valid[1]) + method, valid_worlds = invoke_entry + invoke_data = InvokeData(method, invoke_types, first(valid_worlds), last(valid_worlds)) atype0 = atypes[2] atypes = atypes[4:end] pushfirst!(atypes, atype0) diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 39952dff75c1b..8f785e1336fca 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -7,8 +7,7 @@ to re-consult the method table. This info is illegal on any statement that is not a call to a generic function. """ struct MethodMatchInfo - applicable::Any - ambig::Bool + results::Union{Missing, MethodLookupResult} end """ diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index ec41e46ffa96f..da22fe78fd41f 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1234,14 +1234,14 @@ function invoke_tfunc(interp::AbstractInterpreter, @nospecialize(ft), @nospecial isdispatchelem(ft) || return Any # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types) argtype = Tuple{ft, argtype.parameters...} - meth = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, get_world_counter(interp)) - if meth === nothing + result = findsup(types, method_table(interp)) + if result === nothing return Any end - # XXX: update_valid_age!(min_valid[1], max_valid[1], sv) - meth = meth::Method - (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), argtype, meth.sig)::SimpleVector - rt, edge = typeinf_edge(interp, meth, ti, env, sv) + method, valid_worlds = result + update_valid_age!(sv, valid_worlds) + (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), argtype, method.sig)::SimpleVector + rt, edge = typeinf_edge(interp, method, ti, env, sv) edge !== nothing && add_backedge!(edge::MethodInstance, sv) return rt end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index bf61e574e723f..49125f19aedf6 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -23,8 +23,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) # collect results for the new expanded frame results = InferenceResult[ frames[i].result for i in 1:length(frames) ] # empty!(frames) - min_valid = frame.min_valid - max_valid = frame.max_valid + valid_worlds = frame.valid_worlds cached = frame.cached if cached || frame.parent !== nothing for caller in results @@ -46,27 +45,21 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) else caller.src = nothing end - if min_valid < opt.min_valid - min_valid = opt.min_valid - end - if max_valid > opt.max_valid - max_valid = opt.max_valid - end + valid_worlds = intersect(valid_worlds, opt.valid_worlds) end end end - if max_valid == get_world_counter() - max_valid = typemax(UInt) + if last(valid_worlds) == get_world_counter() + valid_worlds = WorldRange(first(valid_worlds), typemax(UInt)) end for caller in frames - caller.min_valid = min_valid - caller.max_valid = max_valid - caller.src.min_world = min_valid - caller.src.max_world = max_valid + caller.valid_worlds = valid_worlds + caller.src.min_world = first(valid_worlds) + caller.src.max_world = last(valid_worlds) if cached - cache_result!(interp, caller.result, min_valid, max_valid) + cache_result!(interp, caller.result, valid_worlds) end - if max_valid == typemax(UInt) + if last(valid_worlds) == typemax(UInt) # if we aren't cached, we don't need this edge # but our caller might, so let's just make it anyways for caller in frames @@ -79,7 +72,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) return true end -function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt, +function CodeInstance(result::InferenceResult, valid_worlds::WorldRange, may_compress=true, allow_discard_tree=true) inferred_result = result.src local const_flags::Int32 @@ -126,7 +119,7 @@ function CodeInstance(result::InferenceResult, min_valid::UInt, max_valid::UInt, end return CodeInstance(result.linfo, widenconst(result.result), rettype_const, inferred_result, - const_flags, min_valid, max_valid) + const_flags, first(valid_worlds), last(valid_worlds)) end # For the NativeInterpreter, we don't need to do an actual cache query to know @@ -140,17 +133,17 @@ already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) = # inference completed on `me` # update the MethodInstance -function cache_result!(interp::AbstractInterpreter, result::InferenceResult, min_valid::UInt, max_valid::UInt) +function cache_result!(interp::AbstractInterpreter, result::InferenceResult, valid_worlds::WorldRange) # check if the existing linfo metadata is also sufficient to describe the current inference result # to decide if it is worth caching this already_inferred = already_inferred_quick_test(interp, result.linfo) - if !already_inferred && haskey(WorldView(code_cache(interp), min_valid, max_valid), result.linfo) + if !already_inferred && haskey(WorldView(code_cache(interp), valid_worlds), result.linfo) already_inferred = true end # TODO: also don't store inferred code if we've previously decided to interpret this function if !already_inferred - code_cache(interp)[result.linfo] = CodeInstance(result, min_valid, max_valid, + code_cache(interp)[result.linfo] = CodeInstance(result, valid_worlds, may_compress(interp), may_discard_trees(interp)) end unlock_mi_inference(interp, result.linfo) @@ -495,7 +488,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize mi = specialize_method(method, atypes, sparams)::MethodInstance code = get(code_cache(interp), mi, nothing) if code isa CodeInstance # return existing rettype if the code is already inferred - update_valid_age!(min_world(code), max_world(code), caller) + update_valid_age!(caller, WorldRange(min_world(code), max_world(code))) if isdefined(code, :rettype_const) if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype) return PartialStruct(code.rettype, code.rettype_const), mi diff --git a/base/compiler/types.jl b/base/compiler/types.jl index a5177afb70543..fce0b5b30ba49 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -202,3 +202,5 @@ add_remark!(ni::NativeInterpreter, sv, s) = nothing may_optimize(ni::NativeInterpreter) = true may_compress(ni::NativeInterpreter) = true may_discard_trees(ni::NativeInterpreter) = true + +method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai)) diff --git a/base/reflection.jl b/base/reflection.jl index 927a7d229cf55..660324cca72e8 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -866,6 +866,9 @@ end function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1}) return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} end +function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32}) + return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} +end # high-level, more convenient method lookup functions diff --git a/src/gf.c b/src/gf.c index 64b04db1386a0..837ac70e1d6e7 100644 --- a/src/gf.c +++ b/src/gf.c @@ -2335,7 +2335,7 @@ static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world) { - // XXX: return min/max world + // Deprecated: Use jl_gf_invoke_lookup_worlds for future development size_t min_valid = 0; size_t max_valid = ~(size_t)0; jl_method_match_t *matc = _gf_invoke_lookup(types, world, &min_valid, &max_valid); @@ -2344,6 +2344,13 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world) return (jl_value_t*)matc->method; } + +JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, size_t world, size_t *min_world, size_t *max_world) +{ + jl_method_match_t *matc = _gf_invoke_lookup(types, world, min_world, max_world); + return (jl_value_t*)matc->method; +} + static jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs); // invoke() diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 285f944d4bd8c..ca57396330478 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1486,8 +1486,8 @@ let linfo = get_linfo(Base.convert, Tuple{Type{Int64}, Int32}), @test opt.src.ssavaluetypes isa Vector{Any} @test !opt.src.inferred @test opt.mod === Base - @test opt.max_valid === Core.Compiler.get_world_counter() - @test opt.min_valid === Core.Compiler.min_world(opt.src) === UInt(1) + @test opt.valid_worlds.max_world === Core.Compiler.get_world_counter() + @test opt.valid_worlds.min_world === Core.Compiler.min_world(opt.src) === UInt(1) @test opt.nargs == 3 end