Skip to content

Commit

Permalink
generators: expose caller world to GeneratedFunctionStub
Browse files Browse the repository at this point in the history
Expose the demanded world to the GeneratedFunctionStub caller, for users
such as Cassette. If this argument is used, the uesr must return a
CodeInfo with the min/max world field set correctly.

Make the internal representation a tiny bit more compact also, removing
a little bit of unnecessary metadata.

Remove support for returning `body isa CodeInfo` via this wrapper, since
it is impossible to return a correct object via the
GeneratedFunctionStub since it strips off the world argument, which is
required for it to do so. This also removes support for not inferring
these fully (expand_early=false).

Also answer method lookup queries about the future correctly, by
refusing to answer them. This helps keeps execution correct as methods
get added to the system asynchronously.

This reverts "fix #25678: return matters for generated functions
(#40778)" (commit 92c84bf), since this
is no longer sensible to return here anyways, so it is no longer
permitted or supported by this macro.

Fixes various issues where we failed to specify the correct world.
  • Loading branch information
vtjnash authored and maleadt committed Feb 21, 2023
1 parent 0975906 commit a918e93
Show file tree
Hide file tree
Showing 31 changed files with 203 additions and 203 deletions.
2 changes: 1 addition & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules
for match = _methods(+, (Int, Int), -1, get_world_counter())
m = match.method
delete!(push!(Set{Method}(), m), m)
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match)))
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match), typemax(UInt)))

empty!(Set())
push!(push!(Set{Union{GlobalRef,Symbol}}(), :two), GlobalRef(Base, :two))
Expand Down
27 changes: 12 additions & 15 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,28 +590,25 @@ println(@nospecialize a...) = println(stdout, a...)

struct GeneratedFunctionStub
gen
argnames::Array{Any,1}
spnames::Union{Nothing, Array{Any,1}}
line::Int
file::Symbol
expand_early::Bool
argnames::SimpleVector
spnames::SimpleVector
end

# invoke and wrap the results of @generated
function (g::GeneratedFunctionStub)(@nospecialize args...)
# invoke and wrap the results of @generated expression
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
# args is (spvals..., argtypes...)
body = g.gen(args...)
if body isa CodeInfo
return body
end
lam = Expr(:lambda, g.argnames,
Expr(Symbol("scope-block"),
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
Expr(:block,
LineNumberNode(g.line, g.file),
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
source,
Expr(:meta, :push_loc, file, :var"@generated body"),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
if spnames === nothing
if spnames === svec()
return lam
else
return Expr(Symbol("with-static-parameters"), lam, spnames...)
Expand Down
13 changes: 7 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
break
end
topmost === nothing || continue
if edge_matches_sv(infstate, method, sig, sparams, hardlimit, sv)
if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv)
topmost = infstate
edgecycle = true
end
Expand Down Expand Up @@ -677,12 +677,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
end

function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
# The `method_for_inference_heuristics` will expand the given method's generator if
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
# access it directly instead (to avoid regeneration).
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
world = get_world_counter(interp)
callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing}

inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing)
Expand Down Expand Up @@ -719,11 +720,11 @@ function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(si
end

# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
if isdefined(method, :generator) && !(method.generator isa Core.GeneratedFunctionStub) && may_invoke_generator(method, sig, sparams)
method_instance = specialize_method(method, sig, sparams)
if isa(method_instance, MethodInstance)
cinfo = get_staged(method_instance)
cinfo = get_staged(method_instance, world)
if isa(cinfo, CodeInfo)
method2 = cinfo.method_for_inference_limit_heuristics
if method2 isa Method
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let interp = NativeInterpreter()
else
tt = Tuple{typeof(f), Vararg{Any}}
end
for m in _methods_by_ftype(tt, 10, typemax(UInt))::Vector
for m in _methods_by_ftype(tt, 10, get_world_counter())::Vector
# remove any TypeVars from the intersection
m = m::MethodMatch
typ = Any[m.spec_types.parameters...]
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
src = retrieve_code_info(result.linfo)
world = get_world_counter(interp)
src = retrieve_code_info(result.linfo, world)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
return InferenceState(result, src, cache, interp)
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
end
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
src = retrieve_code_info(linfo)
world = get_world_counter(interp)
src = retrieve_code_info(linfo, world)
src === nothing && return nothing
return OptimizationState(linfo, src, params, interp)
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
end
end
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_sysimg()
return retrieve_code_info(mi)
return retrieve_code_info(mi, get_world_counter(interp))
end
lock_mi_inference(interp, mi)
result = InferenceResult(mi, typeinf_lattice(interp))
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ struct NativeInterpreter <: AbstractInterpreter
cache = Vector{InferenceResult}() # Initially empty cache

# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age
# we cap it to the current world age for correctness
if world == typemax(UInt)
world = get_world_counter()
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ end
invoke_api(li::CodeInstance) = ccall(:jl_invoke_api, Cint, (Any,), li)
use_const_api(li::CodeInstance) = invoke_api(li) == 2

function get_staged(mi::MethodInstance)
function get_staged(mi::MethodInstance, world::UInt)
may_invoke_generator(mi) || return nothing
try
# user code might throw errors – ignore them
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
return ci
catch
return nothing
end
end

function retrieve_code_info(linfo::MethodInstance)
function retrieve_code_info(linfo::MethodInstance, world::UInt)
m = linfo.def::Method
c = nothing
if isdefined(m, :generator)
# user code might throw errors – ignore them
c = get_staged(linfo)
c = get_staged(linfo, world)
end
if c === nothing && isdefined(m, :source)
src = m.source
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,14 @@ end

"""
validate_code!(errors::Vector{InvalidCodeError}, mi::MethodInstance,
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
c::Union{Nothing,CodeInfo})
Validate `mi`, logging any violation by pushing an `InvalidCodeError` into `errors`.
If `isa(c, CodeInfo)`, also call `validate_code!(errors, c)`. It is assumed that `c` is
the `CodeInfo` instance associated with `mi`.
a `CodeInfo` instance associated with `mi`.
"""
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance,
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance, c::Union{Nothing,CodeInfo})
is_top_level = mi.def isa Module
if is_top_level
mnargs = 0
Expand Down
5 changes: 1 addition & 4 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -962,10 +962,7 @@ macro generated(f)
Expr(:block,
lno,
Expr(:if, Expr(:generated),
# https://github.com/JuliaLang/julia/issues/25678
Expr(:block,
:(local $tmp = $body),
:(if $tmp isa $(GlobalRef(Core, :CodeInfo)); return $tmp; else $tmp; end)),
body,
Expr(:block,
Expr(:meta, :generated_only),
Expr(:return, nothing))))))
Expand Down
38 changes: 24 additions & 14 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,11 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
if debuginfo !== :source && debuginfo !== :none
throw(ArgumentError("'debuginfo' must be either :source or :none"))
end
return map(method_instances(f, t)) do m
world = get_world_counter()
return map(method_instances(f, t, world)) do m
if generated && hasgenerator(m)
if may_invoke_generator(m)
return ccall(:jl_code_for_staged, Any, (Any,), m)::CodeInfo
return ccall(:jl_code_for_staged, Any, (Any, UInt), m, world)::CodeInfo
else
error("Could not expand generator for `@generated` method ", m, ". ",
"This can happen if the provided argument types (", t, ") are ",
Expand Down Expand Up @@ -1053,6 +1054,8 @@ methods(@nospecialize(f), @nospecialize(t), mod::Module) = methods(f, t, (mod,))
function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
tt = signature_type(f, t)
world = get_world_counter()
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
min = RefValue{UInt}(typemin(UInt))
max = RefValue{UInt}(typemax(UInt))
ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector
Expand Down Expand Up @@ -1125,9 +1128,11 @@ _uncompressed_ir(ci::Core.CodeInstance, s::Array{UInt8,1}) = ccall(:jl_uncompres
const uncompressed_ast = uncompressed_ir
const _uncompressed_ast = _uncompressed_ir

function method_instances(@nospecialize(f), @nospecialize(t), world::UInt=get_world_counter())
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
tt = signature_type(f, t)
results = Core.MethodInstance[]
# this make a better error message than the typeassert that follows
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
for match in _methods_by_ftype(tt, -1, world)::Vector
instance = Core.Compiler.specialize_method(match)
push!(results, instance)
Expand Down Expand Up @@ -1198,20 +1203,22 @@ function may_invoke_generator(method::Method, @nospecialize(atype), sparams::Sim
# generator only has one method
generator = method.generator
isa(generator, Core.GeneratedFunctionStub) || return false
gen_mthds = methods(generator.gen)::MethodList
length(gen_mthds) == 1 || return false
gen_mthds = _methods_by_ftype(Tuple{typeof(generator.gen), Vararg{Any}}, 1, method.primary_world)
(gen_mthds isa Vector && length(gen_mthds) == 1) || return false

generator_method = first(gen_mthds)
generator_method = first(gen_mthds).method
nsparams = length(sparams)
isdefined(generator_method, :source) || return false
code = generator_method.source
nslots = ccall(:jl_ir_nslots, Int, (Any,), code)
at = unwrap_unionall(atype)::DataType
at = unwrap_unionall(atype)
at isa DataType || return false
(nslots >= 1 + length(sparams) + length(at.parameters)) || return false

firstarg = 1
for i = 1:nsparams
if isa(sparams[i], TypeVar)
if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0
if (ast_slotflag(code, firstarg + i) & SLOT_USED) != 0
return false
end
end
Expand All @@ -1220,15 +1227,15 @@ function may_invoke_generator(method::Method, @nospecialize(atype), sparams::Sim
non_va_args = method.isva ? nargs - 1 : nargs
for i = 1:non_va_args
if !isdispatchelem(at.parameters[i])
if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0
if (ast_slotflag(code, firstarg + i + nsparams) & SLOT_USED) != 0
return false
end
end
end
if method.isva
# If the va argument is used, we need to ensure that all arguments that
# contribute to the va tuple are dispatchelemes
if (ast_slotflag(code, 1 + nargs + nsparams) & SLOT_USED) != 0
if (ast_slotflag(code, firstarg + nargs + nsparams) & SLOT_USED) != 0
for i = (non_va_args+1):length(at.parameters)
if !isdispatchelem(at.parameters[i])
return false
Expand Down Expand Up @@ -1318,7 +1325,8 @@ function code_typed_by_type(@nospecialize(tt::Type);
debuginfo::Symbol=:default,
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
if @isdefined(IRShow)
debuginfo = IRShow.debuginfo(debuginfo)
elseif debuginfo === :default
Expand Down Expand Up @@ -1427,7 +1435,7 @@ function code_ircode_by_type(
interp = Core.Compiler.NativeInterpreter(world),
optimize_until::Union{Integer,AbstractString,Nothing} = nothing,
)
ccall(:jl_is_in_pure_context, Bool, ()) &&
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
tt = to_tuple_type(tt)
matches = _methods_by_ftype(tt, -1, world)::Vector
Expand All @@ -1454,7 +1462,8 @@ end
function return_types(@nospecialize(f), @nospecialize(types=default_tt(f));
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
if isa(f, Core.OpaqueClosure)
_, rt = only(code_typed_opaque_closure(f))
return Any[rt]
Expand All @@ -1478,7 +1487,8 @@ end
function infer_effects(@nospecialize(f), @nospecialize(types=default_tt(f));
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")
if isa(f, Core.Builtin)
types = to_tuple_type(types)
argtypes = Any[Core.Compiler.Const(f), types.parameters...]
Expand Down
8 changes: 4 additions & 4 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,10 @@ A (usually temporary) container for holding lowered source code.

A `UInt8` array of slot properties, represented as bit flags:

* 2 - assigned (only false if there are *no* assignment statements with this var on the left)
* 8 - const (currently unused for local variables)
* 16 - statically assigned once
* 32 - might be used before assigned. This flag is only valid after type inference.
* 0x02 - assigned (only false if there are *no* assignment statements with this var on the left)
* 0x08 - used (if there is any read or write of the slot)
* 0x10 - statically assigned once
* 0x20 - might be used before assigned. This flag is only valid after type inference.

* `ssavaluetypes`

Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
if (src)
jlrettype = src->rettype;
else if (jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
src = mi->def.method->generator ? jl_code_for_staged(mi, world) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, NULL, (jl_array_t*)src);
}
Expand Down
4 changes: 2 additions & 2 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,10 @@ static jl_value_t *jl_invoke_julia_macro(jl_array_t *args, jl_module_t *inmodule
jl_value_t *result;
JL_TRY {
margs[0] = jl_toplevel_eval(*ctx, margs[0]);
jl_method_instance_t *mfunc = jl_method_lookup(margs, nargs, world);
jl_method_instance_t *mfunc = jl_method_lookup(margs, nargs, ct->world_age);
JL_GC_PROMISE_ROOTED(mfunc);
if (mfunc == NULL) {
jl_method_error(margs[0], &margs[1], nargs, world);
jl_method_error(margs[0], &margs[1], nargs, ct->world_age);
// unreachable
}
*ctx = mfunc->def.method->module;
Expand Down
7 changes: 6 additions & 1 deletion src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ extern "C" {
JL_DLLEXPORT _Atomic(size_t) jl_world_counter = 1; // uses atomic acquire/release
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT
{
jl_task_t *ct = jl_current_task;
if (ct->ptls->in_pure_callback)
return ~(size_t)0;
return jl_atomic_load_acquire(&jl_world_counter);
}

Expand Down Expand Up @@ -2267,7 +2270,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
// if that didn't work and compilation is off, try running in the interpreter
if (compile_option == JL_OPTIONS_COMPILE_OFF ||
compile_option == JL_OPTIONS_COMPILE_MIN) {
jl_code_info_t *src = jl_code_for_interpreter(mi);
jl_code_info_t *src = jl_code_for_interpreter(mi, world);
if (!jl_code_requires_compiler(src, 0)) {
jl_code_instance_t *codeinst = jl_new_codeinst(mi,
(jl_value_t*)jl_any_type, NULL, NULL,
Expand Down Expand Up @@ -3105,6 +3108,8 @@ static jl_value_t *ml_matches(jl_methtable_t *mt,
int intersections, size_t world, int cache_result,
size_t *min_valid, size_t *max_valid, int *ambig)
{
if (world > jl_atomic_load_acquire(&jl_world_counter))
return jl_nothing; // the future is not enumerable
int has_ambiguity = 0;
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)type);
assert(jl_is_datatype(unw));
Expand Down
Loading

0 comments on commit a918e93

Please sign in to comment.