diff --git a/base/boot.jl b/base/boot.jl index a88555abc9183..922f5a3098b9a 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -456,6 +456,7 @@ struct GeneratedFunctionStub spnames::Union{Nothing, Array{Any,1}} line::Int file::Symbol + expand_early::Bool end # invoke and wrap the results of @generated diff --git a/base/inference.jl b/base/inference.jl index 2a15780329e04..169d2d9a2898f 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -371,10 +371,14 @@ function _validate(linfo::MethodInstance, src::CodeInfo, kind::String) end function get_staged(li::MethodInstance) - return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo + try + # user code might throw errors – ignore them + return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo + catch + return nothing + end end - mutable struct OptimizationState linfo::MethodInstance vararg_type_container #::Type @@ -472,12 +476,7 @@ end function retrieve_code_info(linfo::MethodInstance) m = linfo.def::Method if isdefined(m, :generator) - try - # user code might throw errors – ignore them - c = get_staged(linfo) - catch - return nothing - end + return get_staged(linfo) else # TODO: post-inference see if we can swap back to the original arrays? if isa(m.source, Array{UInt8,1}) @@ -489,6 +488,35 @@ function retrieve_code_info(linfo::MethodInstance) return c end +# TODO: Use these functions instead of directly manipulating +# the "actual" method for appropriate places in inference (see #24676) +function method_for_inference_heuristics(cinfo, default) + if isa(cinfo, CodeInfo) + # appropriate format for `sig` is svec(ftype, argtypes, world) + sig = cinfo.signature_for_inference_heuristics + if isa(sig, SimpleVector) && length(sig) == 3 + methods = _methods(sig[1], sig[2], -1, sig[3]) + if length(methods) == 1 + _, _, m = methods[] + if isa(m, Method) + return m + end + end + end + end + return default +end + +function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world) + if isdefined(method, :generator) && method.generator.expand_early + method_instance = code_for_method(method, sig, sparams, world, false) + if isa(method_instance, MethodInstance) + return method_for_inference_heuristics(get_staged(method_instance), method) + end + end + return method +end + @inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : (s::TypedSlot).id # using a function to ensure we can infer this # avoid cycle due to over-specializing `any` when used by inference @@ -3396,6 +3424,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool, method = linfo.def::Method tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ] + tree.signature_for_inference_heuristics = nothing tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ] tree.slotflags = UInt8[ 0 for i = 1:method.nargs ] tree.slottypes = nothing diff --git a/src/jltypes.c b/src/jltypes.c index d8830d34c7589..b0b4b0dc73e9e 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2046,8 +2046,9 @@ void jl_init_types(void) jl_code_info_type = jl_new_datatype(jl_symbol("CodeInfo"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(9, + jl_perm_symsvec(10, "code", + "signature_for_inference_heuristics", "slottypes", "ssavaluetypes", "slotflags", @@ -2056,17 +2057,18 @@ void jl_init_types(void) "inlineable", "propagate_inbounds", "pure"), - jl_svec(9, + jl_svec(10, jl_array_any_type, jl_any_type, jl_any_type, + jl_any_type, jl_array_uint8_type, jl_array_any_type, jl_bool_type, jl_bool_type, jl_bool_type, jl_bool_type), - 0, 1, 9); + 0, 1, 10); jl_method_type = jl_new_datatype(jl_symbol("Method"), core, diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 5596d8ae0b569..2c91d37c76f1d 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -362,7 +362,8 @@ 'nothing (cons 'list (map car sparams))) ,(if (null? loc) 0 (cadr loc)) - (inert ,(if (null? loc) 'none (caddr loc)))))))) + (inert ,(if (null? loc) 'none (caddr loc))) + false))))) (list gf)) '())) (types (llist-types argl)) diff --git a/src/julia.h b/src/julia.h index 2ff7b4eb8ccbd..3a207c61f8d5d 100644 --- a/src/julia.h +++ b/src/julia.h @@ -229,6 +229,7 @@ typedef struct _jl_llvm_functions_t { // This type describes a single function body typedef struct _jl_code_info_t { jl_array_t *code; // Any array of statements + jl_value_t *signature_for_inference_heuristics; // optional method used during inference jl_value_t *slottypes; // types of variable slots (or `nothing`) jl_value_t *ssavaluetypes; // types of ssa values (or count of them) jl_array_t *slotflags; // local var bit flags diff --git a/src/method.c b/src/method.c index e956abdc76b55..77f8d497a2621 100644 --- a/src/method.c +++ b/src/method.c @@ -187,6 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast) jl_array_del_end(meta, na - ins); } } + li->signature_for_inference_heuristics = jl_nothing; jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1); jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0); size_t nslots = jl_array_len(vis); @@ -255,6 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) (jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t), jl_code_info_type); src->code = NULL; + src->signature_for_inference_heuristics = NULL; src->slotnames = NULL; src->slotflags = NULL; src->slottypes = NULL; @@ -442,8 +444,8 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) { m->generator = NULL; jl_value_t *gexpr = jl_exprarg(st, 1); - if (jl_expr_nargs(gexpr) == 6) { - // expects (new (core GeneratedFunctionStub) funcname argnames sp line file) + if (jl_expr_nargs(gexpr) == 7) { + // expects (new (core GeneratedFunctionStub) funcname argnames sp line file expandearly) jl_value_t *funcname = jl_exprarg(gexpr, 1); assert(jl_is_symbol(funcname)); if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) { diff --git a/src/toplevel.c b/src/toplevel.c index 78c2e9328332d..7dbf75f93bdf6 100644 --- a/src/toplevel.c +++ b/src/toplevel.c @@ -576,6 +576,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr) jl_gc_wb(src, src->slotflags); src->ssavaluetypes = jl_box_long(0); jl_gc_wb(src, src->ssavaluetypes); + src->signature_for_inference_heuristics = jl_nothing; JL_GC_POP(); return src; diff --git a/test/inference.jl b/test/inference.jl index cba504d7cf8ed..d407db81a77ee 100644 --- a/test/inference.jl +++ b/test/inference.jl @@ -1317,3 +1317,77 @@ bar_22708(x) = f_22708(x) @test bar_22708(1) == "x" +# mechanism for spoofing work-limiting heuristics and early generator expansion (#24852) +function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, line, file, expand_early) + stub = Expr(:new, Core.GeneratedFunctionStub, gen, args, params, line, file, expand_early) + return Expr(:meta, :generated, stub) +end + +f24852_kernel(x, y) = x * y + +function f24852_kernel_cinfo(x, y) + sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1] + code_info = Base.uncompressed_ast(method) + body = Expr(:block, code_info.code...) + Base.Core.Inference.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate) + return method, code_info +end + +function f24852_gen_cinfo_uninflated(X, Y, f, x, y) + _, code_info = f24852_kernel_cinfo(x, y) + return code_info +end + +function f24852_gen_cinfo_inflated(X, Y, f, x, y) + method, code_info = f24852_kernel_cinfo(x, y) + code_info.signature_for_inference_heuristics = Core.Inference.svec(f, (x, y), typemax(UInt)) + return code_info +end + +function f24852_gen_expr(X, Y, f, x, y) + return :(f24852_kernel(x::$X, y::$Y)) +end + +@eval begin + function f24852_late_expr(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + end + function f24852_late_inflated(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + end + function f24852_late_uninflated(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + end +end + +@eval begin + function f24852_early_expr(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + end + function f24852_early_inflated(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + end + function f24852_early_uninflated(x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y], + Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + end +end + +x, y = rand(), rand() +result = f24852_kernel(x, y) + +@test result === f24852_late_expr(x, y) +@test result === f24852_late_uninflated(x, y) +@test result === f24852_late_inflated(x, y) + +@test result === f24852_early_expr(x, y) +@test result === f24852_early_uninflated(x, y) +@test result === f24852_early_inflated(x, y) + +# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics` +# can be used to tighten up some inference result.