Skip to content

Commit

Permalink
Add support for per-module max_methods. (JuliaLang#43370)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Holy <[email protected]>
Co-authored-by: Shuhei Kadowaki <[email protected]>
  • Loading branch information
3 people committed Dec 17, 2021
1 parent 55ad2f0 commit a957953
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 8 deletions.
19 changes: 12 additions & 7 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ const _REF_NAME = Ref.body.name
call_result_unused(frame::InferenceState) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[frame.currpc])

function get_max_methods(mod::Module, interp::AbstractInterpreter)
max_methods = ccall(:jl_get_module_max_methods, Cint, (Any,), mod) % Int
max_methods < 0 ? InferenceParams(interp).MAX_METHODS : max_methods
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
sv::InferenceState, max_methods::Int = get_max_methods(sv.mod, interp))
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
add_remark!(interp, sv, "Skipped call in throw block")
return CallMeta(Any, false)
Expand Down Expand Up @@ -1011,7 +1016,7 @@ end

# do apply(af, fargs...), where af is a function value
function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)
max_methods::Int = get_max_methods(sv.mod, interp))
itft = argtype_by_index(argtypes, 2)
aft = argtype_by_index(argtypes, 3)
(itft === Bottom || aft === Bottom) && return CallMeta(Bottom, false)
Expand Down Expand Up @@ -1364,7 +1369,7 @@ end
# call where the function is known exactly
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)
max_methods::Int = get_max_methods(sv.mod, interp))
(; fargs, argtypes) = arginfo
la = length(argtypes)

Expand Down Expand Up @@ -1419,12 +1424,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv, max_methods) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
rty = abstract_call_known(interp, (===), arginfo, sv).rt
rty = abstract_call_known(interp, (===), arginfo, sv, max_methods).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
elseif isa(rty, Const)
Expand All @@ -1440,7 +1445,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs = nothing
end
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false)
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv, max_methods).rt, false)
elseif la == 2 &&
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
istopfunction(f, :length)
Expand Down Expand Up @@ -1500,7 +1505,7 @@ end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
sv::InferenceState, max_methods::Int = get_max_methods(sv.mod, interp))
argtypes = arginfo.argtypes
ft = argtypes[1]
f = singleton_type(ft)
Expand Down
23 changes: 22 additions & 1 deletion base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,21 @@ macro optlevel(n::Int)
end

"""
Experimental.@compiler_options optimize={0,1,2,3} compile={yes,no,all,min} infer={yes,no}
Experimental.@max_methods n::Int
Set the maximum number of potentially-matching methods considered when running inference
for methods defined in the current module. This setting affects inference of calls with
incomplete knowledge of the argument types.
Supported values are `1`, `2`, `3`, `4`, and `default` (currently equivalent to `3`).
"""
macro max_methods(n::Int)
0 < n < 5 || error("We must have that `1 <= max_methods <= 4`, but `max_methods = $n`.")
return Expr(:meta, :max_methods, n)
end

"""
Experimental.@compiler_options optimize={0,1,2,3} compile={yes,no,all,min} infer={yes,no} max_methods={default,1,2,3,...}
Set compiler options for code in the enclosing module. Options correspond directly to
command-line options with the same name, where applicable. The following options
Expand All @@ -133,6 +147,7 @@ are currently supported:
* `compile`: Toggle native code compilation. Currently only `min` is supported, which
requests the minimum possible amount of compilation.
* `infer`: Enable or disable type inference. If disabled, implies [`@nospecialize`](@ref).
* `max_methods`: Maximum number of matching methods considered when running type inference.
"""
macro compiler_options(args...)
opts = Expr(:block)
Expand All @@ -152,6 +167,12 @@ macro compiler_options(args...)
a = a === false || a === :no ? 0 :
a === true || a === :yes ? 1 : error("invalid argument to \"infer\" option")
push!(opts.args, Expr(:meta, :infer, a))
elseif ex.args[1] === :max_methods
a = ex.args[2]
a = a === :default ? 3 :
a isa Int ? ((0 < a < 5) ? a : error("We must have that `1 <= max_methods <= 4`, but `max_methods = $a`.")) :
error("invalid argument to \"max_methods\" option")
push!(opts.args, Expr(:meta, :max_methods, a))
else
error("unknown option \"$(ex.args[1])\"")
end
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ JL_DLLEXPORT jl_sym_t *jl_all_sym;
JL_DLLEXPORT jl_sym_t *jl_compile_sym;
JL_DLLEXPORT jl_sym_t *jl_force_compile_sym;
JL_DLLEXPORT jl_sym_t *jl_infer_sym;
JL_DLLEXPORT jl_sym_t *jl_max_methods_sym;
JL_DLLEXPORT jl_sym_t *jl_atomic_sym;
JL_DLLEXPORT jl_sym_t *jl_not_atomic_sym;
JL_DLLEXPORT jl_sym_t *jl_unordered_sym;
Expand Down Expand Up @@ -336,6 +337,7 @@ void jl_init_common_symbols(void)
jl_compile_sym = jl_symbol("compile");
jl_force_compile_sym = jl_symbol("force_compile");
jl_infer_sym = jl_symbol("infer");
jl_max_methods_sym = jl_symbol("max_methods");
jl_macrocall_sym = jl_symbol("macrocall");
jl_escape_sym = jl_symbol("escape");
jl_hygienicscope_sym = jl_symbol("hygienic-scope");
Expand Down
2 changes: 2 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ static void jl_serialize_module(jl_serializer_state *s, jl_module_t *m)
write_uint8(s->s, m->optlevel);
write_uint8(s->s, m->compile);
write_uint8(s->s, m->infer);
write_uint8(s->s, m->max_methods);
}

static int jl_serialize_generic(jl_serializer_state *s, jl_value_t *v) JL_GC_DISABLED
Expand Down Expand Up @@ -1678,6 +1679,7 @@ static jl_value_t *jl_deserialize_value_module(jl_serializer_state *s) JL_GC_DIS
m->optlevel = read_int8(s->s);
m->compile = read_int8(s->s);
m->infer = read_int8(s->s);
m->max_methods = read_int8(s->s);
m->primary_world = jl_atomic_load_acquire(&jl_world_counter);
return (jl_value_t*)m;
}
Expand Down
5 changes: 5 additions & 0 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip,
jl_set_module_infer(s->module, jl_unbox_long(jl_exprarg(stmt, 1)));
}
}
else if (jl_exprarg(stmt, 0) == (jl_value_t*)jl_max_methods_sym) {
if (jl_is_long(jl_exprarg(stmt, 1))) {
jl_set_module_max_methods(s->module, jl_unbox_long(jl_exprarg(stmt, 1)));
}
}
}
}
else {
Expand Down
3 changes: 3 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ typedef struct _jl_module_t {
int8_t compile;
int8_t infer;
uint8_t istopmod;
int8_t max_methods;
jl_mutex_t lock;
} jl_module_t;

Expand Down Expand Up @@ -1538,6 +1539,8 @@ JL_DLLEXPORT void jl_set_module_compile(jl_module_t *self, int value);
JL_DLLEXPORT int jl_get_module_compile(jl_module_t *m);
JL_DLLEXPORT void jl_set_module_infer(jl_module_t *self, int value);
JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m);
JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value);
JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
// get binding for reading
JL_DLLEXPORT jl_binding_t *jl_get_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var);
JL_DLLEXPORT jl_binding_t *jl_get_binding_or_error(jl_module_t *m, jl_sym_t *var);
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,7 @@ extern JL_DLLEXPORT jl_sym_t *jl_all_sym;
extern JL_DLLEXPORT jl_sym_t *jl_compile_sym;
extern JL_DLLEXPORT jl_sym_t *jl_force_compile_sym;
extern JL_DLLEXPORT jl_sym_t *jl_infer_sym;
extern JL_DLLEXPORT jl_sym_t *jl_max_methods_sym;
extern JL_DLLEXPORT jl_sym_t *jl_atomic_sym;
extern JL_DLLEXPORT jl_sym_t *jl_not_atomic_sym;
extern JL_DLLEXPORT jl_sym_t *jl_unordered_sym;
Expand Down
17 changes: 17 additions & 0 deletions src/module.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// This file is a part of Julia. License is MIT: https://julialang.org/license

/*
Expand Down Expand Up @@ -32,6 +33,7 @@ JL_DLLEXPORT jl_module_t *jl_new_module_(jl_sym_t *name, uint8_t default_names)
m->optlevel = -1;
m->compile = -1;
m->infer = -1;
m->max_methods = -1;
JL_MUTEX_INIT(&m->lock);
htable_new(&m->bindings, 0);
arraylist_new(&m->usings, 0);
Expand Down Expand Up @@ -125,6 +127,21 @@ JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m)
return value;
}

JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value)
{
self->max_methods = value;
}

JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m)
{
int value = m->max_methods;
while (value == -1 && m->parent != m && m != jl_base_module) {
m = m->parent;
value = m->max_methods;
}
return value;
}

JL_DLLEXPORT void jl_set_istopmod(jl_module_t *self, uint8_t isprimary)
{
self->istopmod = 1;
Expand Down

0 comments on commit a957953

Please sign in to comment.