Skip to content

Commit

Permalink
Add support for external method tables (#39697)
Browse files Browse the repository at this point in the history
This PR implements a way to keep tables of methods that are
not part of the internal method table, but still participate
in the special support we have for keeping tables of methods,
in particular unification through precompilation and efficient
lookup. The intended design use case is to allow for method overlay
tables for various non-CPU backends (e.g. GPU and TPU). These
backends would like to modify basic function like `sin` to
perform better on the device in question (or in the case of TPU
to define them over non-LLVM intrinsics).

It is worth noting that this PR only gives the ability to keep
these tables of methods. It assigns no particular meaning to them
and the runtime (and regular inference) do not look at them.
They are designed as an implementation detail for external
compilers and similar tools.

 # Demo

```julia
julia> using Base.Experimental: @overlay, @MethodTable

julia> @MethodTable(mt)
 # 0 methods:

julia> @overlay mt function sin(x::Float64)
           1
       end

julia> @overlay mt function cos(x::Float64)
           1
       end

julia> mt
 # 2 methods:
[1] cos(x::Float64) in Main at REPL[5]:1
[2] sin(x::Float64) in Main at REPL[4]:1

julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, mt, 1, typemax(UInt))
1-element Vector{Any}:
 Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(), sin(x::Float64) in Main at REPL[4]:1, true)

julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, 1, typemax(UInt))
1-element Vector{Any}:
 Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(Float64), sin(x::T) where T<:Union{Float32, Float64} in Base.Math at special/trig.jl:29, true)
```

Co-authored-by: Tim Besard <[email protected]>
Co-authored-by: Julian P Samaroo <[email protected]>
Co-authored-by: Keno Fischer <[email protected]>
  • Loading branch information
4 people committed Jun 9, 2021
1 parent 0e3276c commit f442e42
Show file tree
Hide file tree
Showing 17 changed files with 338 additions and 105 deletions.
44 changes: 37 additions & 7 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ struct InternalMethodTable <: MethodTableView
world::UInt
end

"""
struct OverlayMethodTable <: MethodTableView
Overlays the internal method table such that specific queries can be redirected to an
external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
end

"""
struct CachedMethodTable <: MethodTableView
Expand All @@ -43,33 +54,52 @@ CachedMethodTable(table::T) where T =
table)

"""
findall(sig::Type{<:Tuple}, view::MethodTableView; limit=typemax(Int))
findall(sig::Type, view::MethodTableView; limit=typemax(Int))
Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
"""
function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; limit::Int=typemax(Int))
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
elseif isempty(ms)
# fall back to the internal method table
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
end
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type{<:Tuple}), table::CachedMethodTable; limit::Int=typemax(Int))
function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit=limit)
end
end

"""
findsup(sig::Type{<:Tuple}, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
findsup(sig::Type, view::MethodTableView)::Union{Tuple{MethodMatch, WorldRange}, Nothing}
Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
Expand All @@ -82,7 +112,7 @@ Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
"""
function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable)
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
result = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
Expand All @@ -92,4 +122,4 @@ function findsup(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable)
end

# This query is not cached
findsup(sig::Type{<:Tuple}, table::CachedMethodTable) = findsup(sig, table.table)
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)
37 changes: 37 additions & 0 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
module Experimental

using Base: Threads, sync_varname
using Base.Meta

"""
Const(A::Array)
Expand Down Expand Up @@ -256,4 +257,40 @@ end
# OpaqueClosure
include("opaque_closure.jl")

"""
Experimental.@overlay mt [function def]
Define a method and add it to the method table `mt` instead of to the global method table.
This can be used to implement a method override mechanism. Regular compilation will not
consider these methods, and you should customize the compilation flow to look in these
method tables (e.g., using [`Core.Compiler.OverlayMethodTable`](@ref)).
"""
macro overlay(mt, def)
def = macroexpand(__module__, def) # to expand @inline, @generated, etc
if !isexpr(def, [:function, :(=)]) || !isexpr(def.args[1], :call)
error("@overlay requires a function Expr")
end
def.args[1].args[1] = Expr(:overlay, mt, def.args[1].args[1])
esc(def)
end

let new_mt(name::Symbol, mod::Module) = begin
ccall(:jl_check_top_level_effect, Cvoid, (Any, Cstring), mod, name)
ccall(:jl_new_method_table, Any, (Any, Any), name, mod)
end
@eval macro MethodTable(name::Symbol)
esc(:(const $name = $$new_mt($(quot(name)), $(__module__))))
end
end

"""
Experimental.@MethodTable(name)
Create a new MethodTable in the current module, bound to `name`. This method table can be
used with the [`Experimental.@overlay`](@ref) macro to define methods for a function without
adding them to the global method table.
"""
:@MethodTable

end
17 changes: 10 additions & 7 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -896,13 +896,16 @@ function _methods(@nospecialize(f), @nospecialize(t), lim::Int, world::UInt)
end

function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
return _methods_by_ftype(t, nothing, lim, world)
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt)
return _methods_by_ftype(t, mt, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end

function _method_by_ftype(args...)
Expand Down Expand Up @@ -970,7 +973,7 @@ function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
world = typemax(UInt)
min = RefValue{UInt}(typemin(UInt))
max = RefValue{UInt}(typemax(UInt))
ms = _methods_by_ftype(tt, -1, world, true, min, max, Ptr{Int32}(C_NULL))
ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))
isa(ms, Bool) && return ms
return MethodList(Method[(m::Core.MethodMatch).method for m in ms], typeof(f).name.mt)
end
Expand Down Expand Up @@ -1532,7 +1535,7 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false)
min = UInt[typemin(UInt)]
max = UInt[typemax(UInt)]
has_ambig = Int32[0]
ms = _methods_by_ftype(ti, -1, typemax(UInt), true, min, max, has_ambig)::Vector
ms = _methods_by_ftype(ti, nothing, -1, typemax(UInt), true, min, max, has_ambig)::Vector
has_ambig[] == 0 && return false
if !ambiguous_bottom
filter!(ms) do m::Core.MethodMatch
Expand Down
98 changes: 51 additions & 47 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ static const auto jlinvoke_func = new JuliaFunction{
static const auto jlmethod_func = new JuliaFunction{
"jl_method_def",
[](LLVMContext &C) { return FunctionType::get(T_prjlvalue,
{T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
{T_prjlvalue, T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
nullptr,
};
static const auto jlgenericfunction_func = new JuliaFunction{
Expand Down Expand Up @@ -4530,58 +4530,62 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
return emit_sparam(ctx, jl_unbox_long(args[0]) - 1);
}
else if (head == method_sym) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));
if (jl_expr_nargs(ex) == 1) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));

Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
if (jl_expr_nargs(ex) == 1)
Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
return gf;
}
emit_error(ctx, "method: invalid declaration");
return jl_cgval_t();
}
Value *a1 = boxed(ctx, emit_expr(ctx, args[1]));
Value *a2 = boxed(ctx, emit_expr(ctx, args[2]));
Value *mdargs[3] = {
Value *mdargs[4] = {
/*argdata*/a1,
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)),
/*code*/a2,
/*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module)
};
Expand Down
Loading

2 comments on commit f442e42

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here. cc @maleadt

Please sign in to comment.