diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 3b5d01d8b1556..f9e3522eba8f6 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -479,7 +479,12 @@ function conditional_argtype(๐•ƒแตข::AbstractLattice, @nospecialize(rt), @nospe if isa(rt, InterConditional) && rt.slot == i return rt else - thentype = elsetype = tmeet(๐•ƒแตข, widenslotwrapper(argtypes[i]), fieldtype(sig, i)) + argt = widenslotwrapper(argtypes[i]) + if isvarargtype(argt) + @assert fieldcount(sig) == i + argt = unwrapva(argt) + end + thentype = elsetype = tmeet(๐•ƒแตข, argt, fieldtype(sig, i)) condval = maybe_extract_const_bool(rt) condval === true && (elsetype = Bottom) condval === false && (thentype = Bottom) @@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, # N.B. remarks are emitted within `const_prop_entry_heuristic` return nothing end - nargs::Int = method.nargs - method.isva && (nargs -= 1) - length(arginfo.argtypes) < nargs && return nothing if !const_prop_argument_heuristic(interp, arginfo, sv) add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics") return nothing end all_overridden = is_all_overridden(interp, arginfo, sv) - if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv) + if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv) add_remark!(interp, sv, "[constprop] Disabled by function heuristic") return nothing end @@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: end function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), - arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState) + arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState) argtypes = arginfo.argtypes - if nargs > 1 + if length(argtypes) > 1 ๐•ƒแตข = typeinf_lattice(interp) if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] @@ -1274,7 +1276,7 @@ function const_prop_call(interp::AbstractInterpreter, end overridden_by_const = falses(length(argtypes)) for i = 1:length(argtypes) - if argtypes[i] !== cache_argtypes[i] + if argtypes[i] !== argtype_by_index(cache_argtypes, i) overridden_by_const[i] = true end end @@ -1349,20 +1351,6 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, end given_argtypes[i] = widenslotwrapper(argtype) end - if condargs !== nothing - given_argtypes = let condargs=condargs - va_process_argtypes(๐•ƒ, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int - # invalidate `Conditional` imposed on varargs - for (slotid, i) in condargs - if slotid โ‰ฅ last && (1 โ‰ค i โ‰ค length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise - isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i]) - end - end - end - end - else - given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) - end return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end @@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: return CallMeta(res, exct, effects, retinfo) end -function argtype_by_index(argtypes::Vector{Any}, i::Int) +function argtype_by_index(argtypes::Vector{Any}, i::Integer) n = length(argtypes) na = argtypes[n] if isvarargtype(na) @@ -2890,12 +2878,12 @@ end struct BestguessInfo{Interp<:AbstractInterpreter} interp::Interp bestguess - nargs::Int + nargs::UInt slottypes::Vector{Any} changes::VarTable - function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int, + function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt, slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter - new{Interp}(interp, bestguess, nargs, slottypes, changes) + new{Interp}(interp, bestguess, Int(nargs), slottypes, changes) end end @@ -2970,7 +2958,7 @@ end # pick up the first "interesting" slot, convert `rt` to its `Conditional` # TODO: ideally we want `Conditional` and `InterConditional` to convey # constraints on multiple slots - for slot_id = 1:info.nargs + for slot_id = 1:Int(info.nargs) rt = bool_rt_to_conditional(rt, slot_id, info) rt isa InterConditional && break end @@ -2981,6 +2969,9 @@ end โŠ‘แตข = โŠ‘(typeinf_lattice(info.interp)) old = info.slottypes[slot_id] new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional + if isvarargtype(old) || isvarargtype(new) + return rt + end if new โŠ‘แตข old && !(old โŠ‘แตข new) if isa(rt, Const) val = rt.val diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index 2575429fbf924..8d8aec0462853 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -24,27 +24,56 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, for i = 1:length(argtypes) given_argtypes[i] = widenslotwrapper(argtypes[i]) end - given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end +function pick_const_arg(๐•ƒ::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype)) + if !is_argtype_match(๐•ƒ, given_argtype, cache_argtype, false) + # prefer the argtype we were given over the one computed from `mi` + if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) && + !โŠ(๐•ƒ, given_argtype, cache_argtype)) + # if the type information of this `PartialStruct` is less strict than + # declared method signature, narrow it down using `tmeet` + given_argtype = tmeet(๐•ƒ, given_argtype, cache_argtype) + end + else + given_argtype = cache_argtype + end + return given_argtype +end + function pick_const_args!(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any}) - nargtypes = length(given_argtypes) - @assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`" - for i = 1:nargtypes - given_argtype = given_argtypes[i] - cache_argtype = cache_argtypes[i] - if !is_argtype_match(๐•ƒ, given_argtype, cache_argtype, false) - # prefer the argtype we were given over the one computed from `mi` - if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) && - !โŠ(๐•ƒ, given_argtype, cache_argtype)) - # if the type information of this `PartialStruct` is less strict than - # declared method signature, narrow it down using `tmeet` - given_argtypes[i] = tmeet(๐•ƒ, given_argtype, cache_argtype) - end + if length(given_argtypes) == 0 || length(cache_argtypes) == 0 + return Any[] + end + given_va = given_argtypes[end] + cache_va = cache_argtypes[end] + if isvarargtype(given_va) + ngiven = length(given_argtypes) + va = unwrapva(given_va) + if isvarargtype(cache_va) + # Process the common prefix, then join + nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1) + resize!(given_argtypes, nprocessargs+1) + given_argtypes[end] = Vararg{pick_const_arg(๐•ƒ, unwrapva(given_va), unwrapva(cache_va))} else - given_argtypes[i] = cache_argtype + nprocessargs = length(cache_argtypes) + resize!(given_argtypes, nprocessargs) + end + for i = ngiven:nprocessargs + given_argtypes[i] = va end + elseif isvarargtype(cache_va) + nprocessargs = length(given_argtypes) + else + @assert length(given_argtypes) == length(cache_argtypes) + nprocessargs = length(given_argtypes) + end + for i = 1:nprocessargs + given_argtype = given_argtypes[i] + cache_argtype = argtype_by_index(cache_argtypes, i) + given_argtype = pick_const_arg(๐•ƒ, given_argtype, cache_argtype) + given_argtypes[i] = given_argtype end return given_argtypes end @@ -60,25 +89,33 @@ function is_argtype_match(๐•ƒ::AbstractLattice, end end -va_process_argtypes(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) = - va_process_argtypes(Returns(nothing), ๐•ƒ, given_argtypes, mi) -function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) - def = mi.def::Method - isva = def.isva - nargs = Int(def.nargs) - if isva || isvarargtype(given_argtypes[end]) - isva_given_argtypes = Vector{Any}(undef, nargs) +function va_process_argtypes(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool) + if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end])) + isva_given_argtypes = Vector{Any}(undef, Int(nargs)) for i = 1:(nargs-isva) - isva_given_argtypes[i] = argtype_by_index(given_argtypes, i) + newarg = argtype_by_index(given_argtypes, i) + if isva && has_conditional(๐•ƒ) && isa(newarg, Conditional) + if newarg.slot > (nargs-isva) + newarg = widenconditional(newarg) + end + end + isva_given_argtypes[i] = newarg end if isva if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end]) last = length(given_argtypes) else last = nargs + if has_conditional(๐•ƒ) + for i = last:length(given_argtypes) + newarg = given_argtypes[i] + if isa(newarg, Conditional) && newarg.slot > (nargs-isva) + given_argtypes[i] = widenconditional(newarg) + end + end + end end isva_given_argtypes[nargs] = tuple_tfunc(๐•ƒ, given_argtypes[last:end]) - va_handler!(isva_given_argtypes, last) end return isva_given_argtypes end @@ -87,84 +124,44 @@ function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, gi end function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes)) - toplevel = method === nothing - isva = !toplevel && method.isva mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...] - nargs::Int = toplevel ? 0 : method.nargs - cache_argtypes = Vector{Any}(undef, nargs) - # First, if we're dealing with a varargs method, then we set the last element of `args` - # to the appropriate `Tuple` type or `PartialStruct` instance. - mi_argtypes_length = length(mi_argtypes) - if !toplevel && isva - if specTypes::Type == Tuple - mi_argtypes = Any[Any for i = 1:nargs] - if nargs > 1 - mi_argtypes[end] = Tuple - end - vargtype = Tuple - else - if nargs > mi_argtypes_length - va = mi_argtypes[mi_argtypes_length] - if isvarargtype(va) - new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes) - vargtype = Tuple{new_va} - else - vargtype = Tuple{} - end - else - vargtype_elements = Any[] - for i in nargs:mi_argtypes_length - p = mi_argtypes[i] - p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p) - push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes))) - end - for i in 1:length(vargtype_elements) - atyp = vargtype_elements[i] - if issingletontype(atyp) - # replace singleton types with their equivalent Const object - vargtype_elements[i] = Const(atyp.instance) - elseif isconstType(atyp) - vargtype_elements[i] = Const(atyp.parameters[1]) - end - end - vargtype = tuple_tfunc(fallback_lattice, vargtype_elements) - end - end - cache_argtypes[nargs] = vargtype - nargs -= 1 + nargtypes = length(mi_argtypes) + nargs = isa(method, Method) ? method.nargs : 0 + if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end]) + resize!(mi_argtypes, nargs) end # Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some # type info as we go (where possible). Note that if we're dealing with a varargs method, # we already handled the last element of `cache_argtypes` (and decremented `nargs` so that # we don't overwrite the result of that work here). - if mi_argtypes_length > 0 - tail_index = nargtypes = min(mi_argtypes_length, nargs) - local lastatype - for i = 1:nargtypes - atyp = mi_argtypes[i] - if i == nargtypes && isvarargtype(atyp) - atyp = unwrapva(atyp) - tail_index -= 1 - end - atyp = unwraptv(atyp) - if issingletontype(atyp) - # replace singleton types with their equivalent Const object - atyp = Const(atyp.instance) - elseif isconstType(atyp) - atyp = Const(atyp.parameters[1]) - else - atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes)) - end - i == nargtypes && (lastatype = atyp) - cache_argtypes[i] = atyp + tail_index = min(nargtypes, nargs) + local lastatype + for i = 1:nargtypes + atyp = mi_argtypes[i] + wasva = false + if i == nargtypes && isvarargtype(atyp) + wasva = true + atyp = unwrapva(atyp) end - for i = (tail_index+1):nargs - cache_argtypes[i] = lastatype + atyp = unwraptv(atyp) + if issingletontype(atyp) + # replace singleton types with their equivalent Const object + atyp = Const(atyp.instance) + elseif isconstType(atyp) + atyp = Const(atyp.parameters[1]) + else + atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes)) end - else - @assert nargs == 0 "invalid specialization of method" # wrong number of arguments + mi_argtypes[i] = atyp + if wasva + lastatype = atyp + mi_argtypes[end] = Vararg{widenconst(atyp)} + end + end + for i = (tail_index+1):(nargs-1) + mi_argtypes[i] = lastatype end - return cache_argtypes + return mi_argtypes end # eliminate free `TypeVar`s in order to make the life much easier down the road: @@ -184,7 +181,6 @@ function cache_lookup(๐•ƒ::AbstractLattice, mi::MethodInstance, given_argtypes: cache::Vector{InferenceResult}) method = mi.def::Method nargtypes = length(given_argtypes) - @assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`" for cached_result in cache cached_result.linfo === mi || @goto next_cache cache_argtypes = cached_result.argtypes diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 169d543f3249c..4e72d5036464e 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -302,6 +302,9 @@ mutable struct InferenceState bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ] bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots) argtypes = result.argtypes + + argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva) + nargtypes = length(argtypes) for i = 1:nslots argtyp = (i > nargtypes) ? Bottom : argtypes[i] @@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState) end function narguments(sv::InferenceState, include_va::Bool=true) - def = sv.linfo.def - nargs = length(sv.result.argtypes) + nargs = sv.src.nargs if !include_va - nargs -= isa(def, Method) && def.isva + nargs -= sv.src.isva end return nargs end @@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter, end method_info = MethodInfo(src) ir = inflate_ir(src, mi) - argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi) + argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva) return IRInterpretationState(interp, method_info, ir, mi, argtypes, world, codeinst.min_world, codeinst.max_world) end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 85942d3ca83b3..6c956e56b89ad 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1264,14 +1264,13 @@ end function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState) # need `ci` for the slot metadata, IR for the code svdef = sv.linfo.def - nargs = isa(svdef, Method) ? Int(svdef.nargs) : 0 @timeit "domtree 1" domtree = construct_domtree(ir) - defuse_insts = scan_slot_def_use(nargs, ci, ir.stmts.stmt) + defuse_insts = scan_slot_def_use(ci.nargs, ci, ir.stmts.stmt) ๐•ƒโ‚’ = optimizer_lattice(sv.inlining.interp) @timeit "construct_ssa" ir = construct_ssa!(ci, ir, sv, domtree, defuse_insts, ๐•ƒโ‚’) # consumes `ir` # NOTE now we have converted `ir` to the SSA form and eliminated slots # let's resize `argtypes` now and remove unnecessary types for the eliminated slots - resize!(ir.argtypes, nargs) + resize!(ir.argtypes, ci.nargs) return ir end diff --git a/base/compiler/ssair/legacy.jl b/base/compiler/ssair/legacy.jl index b45db03875801..2b0721b8d2408 100644 --- a/base/compiler/ssair/legacy.jl +++ b/base/compiler/ssair/legacy.jl @@ -10,7 +10,13 @@ the original `ci::CodeInfo` are modified. """ function inflate_ir!(ci::CodeInfo, mi::MethodInstance) sptypes = sptypes_from_meth_instance(mi) - argtypes = matching_cache_argtypes(fallback_lattice, mi) + if ci.slottypes === nothing + argtypes = va_process_argtypes(fallback_lattice, + matching_cache_argtypes(fallback_lattice, mi), + ci.nargs, ci.isva) + else + argtypes = ci.slottypes[1:ci.nargs] + end return inflate_ir!(ci, sptypes, argtypes) end function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any}) diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index f2bfa0e4c5476..90616c3d3cff1 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -33,7 +33,7 @@ function scan_entry!(result::Vector{SlotInfo}, idx::Int, @nospecialize(stmt)) end end -function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any}) +function scan_slot_def_use(nargs::Integer, ci::CodeInfo, code::Vector{Any}) nslots = length(ci.slotflags) result = SlotInfo[SlotInfo() for i = 1:nslots] # Set defs for arguments diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index d091fb8d2f5f8..ee3e93806f853 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -468,7 +468,7 @@ function adjust_effects(sv::InferenceState) # this frame is known to be safe ipo_effects = Effects(ipo_effects; nothrow=true) end - if is_inaccessiblemem_or_argmemonly(ipo_effects) && all(1:narguments(sv, #=include_va=#true)) do i::Int + if is_inaccessiblemem_or_argmemonly(ipo_effects) && all(1:narguments(sv, #=include_va=#true)) do i::UInt return is_mutation_free_argtype(sv.slottypes[i]) end ipo_effects = Effects(ipo_effects; inaccessiblememonly=ALWAYS_TRUE) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 30cb0fb0f39c5..a6f5488ef6703 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -91,9 +91,6 @@ mutable struct InferenceResult is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively ci::CodeInstance # CodeInstance if this result has been added to the cache function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector}) - def = mi.def - nargs = def isa Method ? Int(def.nargs) : 0 - @assert length(argtypes) == nargs "invalid `argtypes` for `mi`" return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing, WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false) end diff --git a/base/deprecated.jl b/base/deprecated.jl index 28382ade26161..4de675028f6cc 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -24,6 +24,7 @@ const __internal_changes_list = ( :invertedlinetables, :codeinforefactor, :miuninferredrm, + :codeinfonargs # #54341 # Add new change names above this line ) diff --git a/base/opaque_closure.jl b/base/opaque_closure.jl index b00955e7a0ca0..bd1cf4d5ae3fd 100644 --- a/base/opaque_closure.jl +++ b/base/opaque_closure.jl @@ -77,6 +77,8 @@ function Core.OpaqueClosure(ir::IRCode, @nospecialize env...; end src.slotflags = fill(zero(UInt8), nargtypes) src.slottypes = copy(ir.argtypes) + src.isva = isva + src.nargs = nargtypes src = Core.Compiler.ir_to_codeinf!(src, ir) return generate_opaque_closure(sig, Union{}, rt, src, nargs, isva, env...; kwargs...) end diff --git a/src/codegen.cpp b/src/codegen.cpp index 12594ea851829..01f147d280c1d 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2759,7 +2759,7 @@ static void jl_init_function(Function *F, const Triple &TT) F->addFnAttrs(attr); } -static bool uses_specsig(jl_value_t *sig, bool needsparams, bool va, jl_value_t *rettype, bool prefer_specsig) +static bool uses_specsig(jl_value_t *sig, bool needsparams, jl_value_t *rettype, bool prefer_specsig) { if (needsparams) return false; @@ -2769,10 +2769,8 @@ static bool uses_specsig(jl_value_t *sig, bool needsparams, bool va, jl_value_t return false; if (jl_nparams(sig) == 0) return false; - if (va) { - if (jl_is_vararg(jl_tparam(sig, jl_nparams(sig) - 1))) - return false; - } + if (jl_vararg_kind(jl_tparam(sig, jl_nparams(sig) - 1)) == JL_VARARG_UNBOUND) + return false; // not invalid, consider if specialized signature is worthwhile if (prefer_specsig) return true; @@ -2803,7 +2801,6 @@ static bool uses_specsig(jl_value_t *sig, bool needsparams, bool va, jl_value_t static std::pair uses_specsig(jl_method_instance_t *lam, jl_value_t *rettype, bool prefer_specsig) { - int va = lam->def.method->isva; jl_value_t *sig = lam->specTypes; bool needsparams = false; if (jl_is_method(lam->def.method)) { @@ -2814,7 +2811,7 @@ static std::pair uses_specsig(jl_method_instance_t *lam, jl_value_t needsparams = true; } } - return std::make_pair(uses_specsig(sig, needsparams, va, rettype, prefer_specsig), needsparams); + return std::make_pair(uses_specsig(sig, needsparams, rettype, prefer_specsig), needsparams); } @@ -5407,7 +5404,7 @@ static jl_cgval_t emit_call(jl_codectx_t &ctx, jl_expr_t *ex, jl_value_t *rt, bo jl_value_t *oc_rett = jl_tparam1(f.typ); if (jl_is_datatype(oc_argt) && jl_tupletype_length_compat(oc_argt, nargs-1)) { jl_value_t *sigtype = jl_argtype_with_function_type((jl_value_t*)f.typ, (jl_value_t*)oc_argt); - if (uses_specsig(sigtype, false, true, oc_rett, true)) { + if (uses_specsig(sigtype, false, oc_rett, true)) { JL_GC_PUSH1(&sigtype); jl_cgval_t r = emit_specsig_oc_call(ctx, f.typ, sigtype, argv, nargs); JL_GC_POP(); @@ -8094,19 +8091,15 @@ static jl_llvm_functions_t ctx.module = jl_is_method(lam->def.method) ? lam->def.method->module : lam->def.module; ctx.linfo = lam; ctx.name = TSM.getModuleUnlocked()->getModuleIdentifier().data(); - size_t nreq = 0; - int va = 0; - if (jl_is_method(lam->def.method)) { - ctx.nargs = nreq = lam->def.method->nargs; - ctx.is_opaque_closure = lam->def.method->is_for_opaque_closure; - if ((nreq > 0 && jl_is_method(lam->def.value) && lam->def.method->isva)) { - assert(nreq > 0); - nreq--; - va = 1; - } + size_t nreq = src->nargs; + int va = src->isva; + ctx.nargs = nreq; + if (va) { + assert(nreq > 0); + nreq--; } - else { - ctx.nargs = 0; + if (jl_is_method(lam->def.value)) { + ctx.is_opaque_closure = lam->def.method->is_for_opaque_closure; } ctx.nReqArgs = nreq; if (va) { @@ -8170,7 +8163,7 @@ static jl_llvm_functions_t // step 3. some variable analysis size_t i; - for (i = 0; i < nreq; i++) { + for (i = 0; i < nreq && i < vinfoslen; i++) { jl_varinfo_t &varinfo = ctx.slots[i]; varinfo.isArgument = true; jl_sym_t *argname = slot_symbol(ctx, i); @@ -8684,7 +8677,7 @@ static jl_llvm_functions_t AttrBuilder param(ctx.builder.getContext()); attrs[Arg->getArgNo()] = AttributeSet::get(Arg->getContext(), param); } - for (i = 0; i < nreq; i++) { + for (i = 0; i < nreq && i < vinfoslen; i++) { jl_sym_t *s = slot_symbol(ctx, i); jl_value_t *argType = jl_nth_slot_type(lam->specTypes, i); // TODO: jl_nth_slot_type should call jl_rewrap_unionall? @@ -9770,7 +9763,7 @@ static jl_llvm_functions_t jl_emit_oc_wrapper(orc::ThreadSafeModule &m, jl_codeg std::string funcName = get_function_name(true, false, ctx.name, ctx.emission_context.TargetTriple); jl_llvm_functions_t declarations; declarations.functionObject = "jl_f_opaque_closure_call"; - if (uses_specsig(mi->specTypes, false, true, rettype, true)) { + if (uses_specsig(mi->specTypes, false, rettype, true)) { jl_returninfo_t returninfo = get_specsig_function(ctx, M, NULL, funcName, mi->specTypes, rettype, true, JL_FEAT_TEST(ctx,gcstack_arg)); Function *gf_thunk = cast(returninfo.decl.getCallee()); jl_init_function(gf_thunk, ctx.emission_context.TargetTriple); diff --git a/src/interpreter.c b/src/interpreter.c index 87dbc0061ee23..5760386324527 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -766,8 +766,8 @@ jl_value_t *NOINLINE jl_fptr_interpret_call(jl_value_t *f, jl_value_t **args, ui } else { s->module = mi->def.method->module; - size_t defargs = mi->def.method->nargs; - int isva = mi->def.method->isva ? 1 : 0; + size_t defargs = src->nargs; + int isva = src->isva; size_t i; s->locals[0] = f; assert(isva ? nargs + 2 >= defargs : nargs + 1 == defargs); diff --git a/src/ircode.c b/src/ircode.c index 2e16d1b5b2420..50ab603ebe94b 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -458,14 +458,17 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal) } static jl_code_info_flags_t code_info_flags(uint8_t propagate_inbounds, uint8_t has_fcall, - uint8_t nospecializeinfer, uint8_t inlining, uint8_t constprop) + uint8_t nospecializeinfer, uint8_t isva, + uint8_t inlining, uint8_t constprop, uint8_t nargsmatchesmethod) { jl_code_info_flags_t flags; flags.bits.propagate_inbounds = propagate_inbounds; flags.bits.has_fcall = has_fcall; flags.bits.nospecializeinfer = nospecializeinfer; + flags.bits.isva = isva; flags.bits.inlining = inlining; flags.bits.constprop = constprop; + flags.bits.nargsmatchesmethod = nargsmatchesmethod; return flags; } @@ -821,7 +824,7 @@ static int codelocs_nstmts(jl_string_t *cl) JL_NOTSAFEPOINT } #endif -#define IR_DATASIZE_FLAGS sizeof(uint8_t) +#define IR_DATASIZE_FLAGS sizeof(uint16_t) #define IR_DATASIZE_PURITY sizeof(uint16_t) #define IR_DATASIZE_INLINING_COST sizeof(uint16_t) #define IR_DATASIZE_NSLOTS sizeof(int32_t) @@ -833,6 +836,16 @@ typedef enum { ir_offset_slotflags = 0 + IR_DATASIZE_FLAGS + IR_DATASIZE_PURITY + IR_DATASIZE_INLINING_COST + IR_DATASIZE_NSLOTS } ir_offset; +// static_assert is technically a declaration, so shenanigans are required to +// open an inline declaration context. `sizeof` is the traditional way to do this, +// but this pattern is illegal in C++, which some compilers warn about, so use +// `offsetof` instead. +#define declaration_context(what) (void)offsetof(struct{what; int dummy_;}, dummy_) + +// Checks (at compile time) that sizeof(data) == macro_size +#define checked_size(data, macro_size) \ + (declaration_context(static_assert(sizeof(data) == macro_size, #macro_size " does not match written size")), data) + JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) { JL_TIMING(AST_COMPRESS, AST_COMPRESS); @@ -859,24 +872,28 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) 1 }; + uint8_t nargsmatchesmethod = code->nargs == m->nargs; jl_code_info_flags_t flags = code_info_flags(code->propagate_inbounds, code->has_fcall, - code->nospecializeinfer, code->inlining, code->constprop); - write_uint8(s.s, flags.packed); - static_assert(sizeof(flags.packed) == IR_DATASIZE_FLAGS, "ir_datasize_flags is mismatched with the actual size"); - write_uint16(s.s, code->purity.bits); - static_assert(sizeof(code->purity.bits) == IR_DATASIZE_PURITY, "ir_datasize_purity is mismatched with the actual size"); - write_uint16(s.s, code->inlining_cost); - static_assert(sizeof(code->inlining_cost) == IR_DATASIZE_INLINING_COST, "ir_datasize_inlining_cost is mismatched with the actual size"); - - int32_t nslots = jl_array_nrows(code->slotflags); + code->nospecializeinfer, code->isva, + code->inlining, code->constprop, + nargsmatchesmethod); + write_uint16(s.s, checked_size(flags.packed, IR_DATASIZE_FLAGS)); + write_uint16(s.s, checked_size(code->purity.bits, IR_DATASIZE_PURITY)); + write_uint16(s.s, checked_size(code->inlining_cost, IR_DATASIZE_INLINING_COST)); + + size_t nslots = jl_array_nrows(code->slotflags); assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions - write_int32(s.s, nslots); - static_assert(sizeof(nslots) == IR_DATASIZE_NSLOTS, "ir_datasize_nslots is mismatched with the actual size"); + write_int32(s.s, checked_size((int32_t)nslots, IR_DATASIZE_NSLOTS)); ios_write(s.s, jl_array_data(code->slotflags, const char), nslots); // N.B.: The layout of everything before this point is explicitly referenced // by the various jl_ir_ accessors. Make sure to adjust those if you change // the data layout. + if (!nargsmatchesmethod) { + size_t nargs = code->nargs; + assert(nargs < INT32_MAX); + write_int32(s.s, (int32_t)nargs); + } for (i = 0; i < 5; i++) { int copy = 1; @@ -892,6 +909,8 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) if (m->is_for_opaque_closure) jl_encode_value_(&s, code->slottypes, 1); + // Slotnames. For regular methods, we require that m->slot_syms matches the + // CodeInfo's slotnames, so we do not need to save it here. if (m->generator) // can't optimize generated functions jl_encode_value_(&s, (jl_value_t*)jl_compress_argnames(code->slotnames), 1); @@ -937,19 +956,27 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t jl_code_info_t *code = jl_new_code_info_uninit(); jl_code_info_flags_t flags; - flags.packed = read_uint8(s.s); + flags.packed = read_uint16(s.s); code->inlining = flags.bits.inlining; code->constprop = flags.bits.constprop; code->propagate_inbounds = flags.bits.propagate_inbounds; code->has_fcall = flags.bits.has_fcall; code->nospecializeinfer = flags.bits.nospecializeinfer; + code->isva = flags.bits.isva; code->purity.bits = read_uint16(s.s); code->inlining_cost = read_uint16(s.s); - size_t nslots = read_int32(&src); + + size_t nslots = read_int32(s.s); code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); ios_readall(s.s, jl_array_data(code->slotflags, char), nslots); + if (flags.bits.nargsmatchesmethod) { + code->nargs = m->nargs; + } else { + code->nargs = read_int32(s.s); + } + for (i = 0; i < 5; i++) { if (i == 1) // skip debuginfo continue; diff --git a/src/jltypes.c b/src/jltypes.c index a3369c42be834..da69686f60695 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3163,7 +3163,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_info_type = jl_new_datatype(jl_symbol("CodeInfo"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(19, + jl_perm_symsvec(21, "code", "debuginfo", "ssavaluetypes", @@ -3176,14 +3176,16 @@ void jl_init_types(void) JL_GC_DISABLED "edges", "min_world", "max_world", + "nargs", "propagate_inbounds", "has_fcall", "nospecializeinfer", + "isva", "inlining", "constprop", "purity", "inlining_cost"), - jl_svec(19, + jl_svec(21, jl_array_any_type, jl_debuginfo_type, jl_any_type, @@ -3196,6 +3198,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_any_type, jl_ulong_type, jl_ulong_type, + jl_ulong_type, + jl_bool_type, jl_bool_type, jl_bool_type, jl_bool_type, @@ -3204,7 +3208,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_uint16_type, jl_uint16_type), jl_emptysvec, - 0, 1, 19); + 0, 1, 21); jl_method_type = jl_new_datatype(jl_symbol("Method"), core, diff --git a/src/julia.h b/src/julia.h index fa05c12d77d5a..0d46f15776610 100644 --- a/src/julia.h +++ b/src/julia.h @@ -320,11 +320,13 @@ typedef struct _jl_code_info_t { jl_value_t *edges; // forward edges to method instances that must be invalidated (for copying to debuginfo) size_t min_world; size_t max_world; + size_t nargs; // various boolean properties: uint8_t propagate_inbounds; uint8_t has_fcall; uint8_t nospecializeinfer; + uint8_t isva; // uint8 settings uint8_t inlining; // 0 = default; 1 = @inline; 2 = @noinline uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none diff --git a/src/julia_internal.h b/src/julia_internal.h index 1b1ac4aea2b7b..0e87eb5f07fe7 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -631,16 +631,18 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO // -- helper types -- // typedef struct { - uint8_t propagate_inbounds:1; - uint8_t has_fcall:1; - uint8_t nospecializeinfer:1; - uint8_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none - uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none + uint16_t propagate_inbounds:1; + uint16_t has_fcall:1; + uint16_t nospecializeinfer:1; + uint16_t isva:1; + uint16_t nargsmatchesmethod:1; + uint16_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none + uint16_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none } jl_code_info_flags_bitfield_t; typedef union { jl_code_info_flags_bitfield_t bits; - uint8_t packed; + uint16_t packed; } jl_code_info_flags_t; // -- functions -- // diff --git a/src/method.c b/src/method.c index b74a82e1aa505..59c24671f46f3 100644 --- a/src/method.c +++ b/src/method.c @@ -420,6 +420,10 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ir) jl_code_info_t *li = NULL; JL_GC_PUSH1(&li); li = jl_new_code_info_uninit(); + + jl_expr_t *arglist = (jl_expr_t*)jl_exprarg(ir, 0); + li->nargs = jl_array_len(arglist); + assert(jl_is_expr(ir)); jl_expr_t *bodyex = (jl_expr_t*)jl_exprarg(ir, 2); @@ -642,6 +646,8 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) src->constprop = 0; src->inlining = 0; src->purity.bits = 0; + src->nargs = 0; + src->isva = 0; src->inlining_cost = UINT16_MAX; return src; } @@ -770,6 +776,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t } jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator."); } + // TODO: This should ideally be in the lambda expression, + // but currently our isva determination is non-syntactic + func->isva = def->isva; } // If this generated function has an opaque closure, cache it for @@ -953,6 +962,8 @@ JL_DLLEXPORT void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) jl_array_ptr_set(copy, i, st); } src = jl_copy_code_info(src); + src->isva = m->isva; // TODO: It would be nice to reverse this + assert(m->nargs == src->nargs); src->code = copy; jl_gc_wb(src, copy); m->slot_syms = jl_compress_argnames(src->slotnames); diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index ef8048d925cf8..c8b5314fe719d 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -80,7 +80,7 @@ const TAGS = Any[ const NTAGS = length(TAGS) @assert NTAGS == 255 -const ser_version = 28 # do not make changes without bumping the version #! +const ser_version = 29 # do not make changes without bumping the version #! format_version(::AbstractSerializer) = ser_version format_version(s::Serializer) = s.version @@ -1094,6 +1094,10 @@ function deserialize(s::AbstractSerializer, ::Type{Method}) if template !== nothing # TODO: compress template template = template::CodeInfo + if format_version(s) < 29 + template.nargs = nargs + template.isva = isva + end meth.source = template meth.debuginfo = template.debuginfo if !@isdefined(slot_syms) @@ -1259,6 +1263,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) ci.inlining_cost = inlining_cost end end + if format_version(s) >= 29 + ci.nargs = deserialize(s) + end ci.propagate_inbounds = deserialize(s) if format_version(s) < 23 deserialize(s) # `pure` field has been removed @@ -1269,6 +1276,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) if format_version(s) >= 24 ci.nospecializeinfer = deserialize(s)::Bool end + if format_version(s) >= 29 + ci.isva = deserialize(s)::Bool + end if format_version(s) >= 21 ci.inlining = deserialize(s)::UInt8 end diff --git a/stdlib/Test/src/precompile.jl b/stdlib/Test/src/precompile.jl index 1e53033a09143..04907f8425440 100644 --- a/stdlib/Test/src/precompile.jl +++ b/stdlib/Test/src/precompile.jl @@ -1,5 +1,6 @@ if Base.generating_output() - redirect_stdout(devnull) do +let + function example_payload() @testset "example" begin @test 1 == 1 @test_throws ErrorException error() @@ -8,4 +9,7 @@ if Base.generating_output() @test 1 โ‰ˆ 1.0000000000000001 end end + + redirect_stdout(example_payload, devnull) +end end diff --git a/test/compiler/contextual.jl b/test/compiler/contextual.jl index e7b1d9a9355f7..fca09558fd09d 100644 --- a/test/compiler/contextual.jl +++ b/test/compiler/contextual.jl @@ -96,6 +96,9 @@ module MiniCassette transform!(mi, code_info, length(args), match.sparams) # TODO: this is mandatory: code_info.min_world = max(code_info.min_world, min_world[]) # TODO: this is mandatory: code_info.max_world = min(code_info.max_world, max_world[]) + # Match the generator, since that's what our transform! does + code_info.nargs = 4 + code_info.isva = true return code_info end @@ -225,3 +228,25 @@ end end @test_throws "oh no" doit49715(sin, Tuple{Int}) + +# Test that the CodeInfo returned from generated function need not match the +# generator. +function overdubbee54341(a, b) + a + b +end +const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1] + +function overdub_generator54341(world::UInt, source::LineNumberNode, args...) + if length(args) != 2 + :(error("Wrong number of arguments")) + else + return copy(overdubee_codeinfo54341) + end +end + +@eval function overdub54341(args...) + $(Expr(:meta, :generated, overdub_generator54341)) + $(Expr(:meta, :generated_only)) +end + +@test overdub54341(1, 2) == 3 diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 870eeaaf0687b..3b443801af5d6 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1937,6 +1937,8 @@ function f24852_kernel_cinfo(world::UInt, source, fsig::Type) end pushfirst!(code_info.slotnames, Symbol("#self#")) pushfirst!(code_info.slotflags, 0x00) + code_info.nargs = 4 + code_info.isva = false # TODO: this is mandatory: code_info.min_world = max(code_info.min_world, min_world[]) # TODO: this is mandatory: code_info.max_world = min(code_info.max_world, max_world[]) return match.method, code_info @@ -4487,6 +4489,8 @@ let } where Bound<:Integer argtypes = Core.Compiler.most_general_argtypes(method, specTypes) popfirst!(argtypes) + # N.B.: `argtypes` do not have va processing applied yet + @test length(argtypes) == 12 @test argtypes[1] == Integer @test argtypes[2] == Integer @test argtypes[3] == Type{Bound} where Bound<:Integer @@ -4497,7 +4501,8 @@ let @test argtypes[8] == Any @test argtypes[9] == Union{Nothing,Bound} where Bound<:Integer @test argtypes[10] == Any - @test argtypes[11] == Tuple{Integer,Integer} + @test argtypes[11] == Integer + @test argtypes[12] == Integer end # make sure not to call `widenconst` on `TypeofVararg` objects @@ -5660,6 +5665,8 @@ function gen_tuin_from_arg(world::UInt, source, _, _) ReturnNode(true), ]; slottypes=Any[Any, Bool]) ci.slotnames = Symbol[:var"#self#", :def] + ci.nargs = 2 + ci.isva = false ci end @@ -5691,6 +5698,8 @@ function gen_infinite_loop_ssa_generator(world::UInt, source, _) ReturnNode(SSAValue(2)) ]; slottypes=Any[Any]) ci.slotnames = Symbol[:var"#self#"] + ci.nargs = 1 + ci.isva = false ci end @@ -5728,3 +5737,14 @@ end # fieldcount on `Tuple` should constant fold, even though `.fields` not const @test fully_eliminated(Base.fieldcount, Tuple{Type{Tuple{Nothing, Int, Int}}}) + +# Vararg-constprop regression from MutableArithmetics (#54341) +global SIDE_EFFECT54341::Int +function foo54341(a, b, c, d, args...) + # Side effect to force constprop rather than semi-concrete + global SIDE_EFFECT54341 = a + b + c + d + return SIDE_EFFECT54341 +end +bar54341(args...) = foo54341(4, args...) + +@test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int diff --git a/test/compiler/ssair.jl b/test/compiler/ssair.jl index 9134abd7d08d9..4e789da7815c9 100644 --- a/test/compiler/ssair.jl +++ b/test/compiler/ssair.jl @@ -747,6 +747,8 @@ function gen_unreachable_phinode_edge1(world::UInt, source, args...) ReturnNode(SSAValue(4)) ]; slottypes=Any[Any,Int,Int]) ci.slotnames = Symbol[:var"#self#", :x, :y] + ci.nargs = 3 + ci.isva = false return ci end @eval function f_unreachable_phinode_edge1(x, y) @@ -769,6 +771,8 @@ function gen_unreachable_phinode_edge2(world::UInt, source, args...) ReturnNode(SSAValue(4)) ]; slottypes=Any[Any,Int,Int]) ci.slotnames = Symbol[:var"#self#", :x, :y] + ci.nargs = 3 + ci.isva = false return ci end @eval function f_unreachable_phinode_edge2(x, y) @@ -791,6 +795,8 @@ function gen_must_throw_phinode_edge(world::UInt, source, _) ReturnNode(SSAValue(4)) ]; slottypes=Any[Any]) ci.slotnames = Symbol[:var"#self#"] + ci.nargs = 1 + ci.isva = false return ci end @eval function f_must_throw_phinode_edge() diff --git a/test/opaque_closure.jl b/test/opaque_closure.jl index 490ff282b7a53..59f5796504d24 100644 --- a/test/opaque_closure.jl +++ b/test/opaque_closure.jl @@ -189,6 +189,8 @@ let ci = @code_lowered const_int() cig.slotnames = Symbol[Symbol("#self#")] cig.slottypes = Any[Any] cig.slotflags = UInt8[0x00] + cig.nargs = 1 + cig.isva = false return cig end end