Skip to content

Commit

Permalink
Add support for external method tables (JuliaLang#39697)
Browse files Browse the repository at this point in the history
This PR implements a way to keep tables of methods that are
not part of the internal method table, but still participate
in the special support we have for keeping tables of methods,
in particular unification through precompilation and efficient
lookup. The intended design use case is to allow for method overlay
tables for various non-CPU backends (e.g. GPU and TPU). These
backends would like to modify basic function like `sin` to
perform better on the device in question (or in the case of TPU
to define them over non-LLVM intrinsics). To date, the best
available mechanism of achieving this result was to use a
Cassette-like pass rewriting every method and injecting
an overlay if necessary. However, this approach is somewhat
unsatisfying for two reasons:

1. It requires rewriting every function, which has non-trivial
   performance cost.
2. It is (currently) expensive because of the repeated calls to
   generated functions.
3. It confuses inference, because suddenly everything is one method.
   We have hooks to work around this, but support is incomplete.

It is also not clear that Cassette it is the best conceptual model,
because these methods *are* methods of the same generic function,
they just happen to only be applicable for a particular backend.

It is worth noting that this PR only gives the ability to keep
these tables of methods. It assigns no particular meaning to them
and the runtime (and regular inference) do not look at them.
They are designed as an implementation detail for external
compilers and similar tools.

This feature does not replace Cassette for the method-interception
use case in the absence of such a compiler, though it could in
the future form part of a solution (I'm hoping the AD work will
in due course lead to abstractions that would enable a "faster
Cassette" which may use part of these fetaures). As such,
I'm not sure we'll ever move this out of Experimental, but
until such a time that we have a better solution, I think this'll
be a useful feature for the GPU stack.

With all those disclaimers out of the way, here is a description
of the various parts of the current design that deserve
discussion:

 # Demo

```julia
julia> using Base.Experimental: @overlay, @MethodTable

julia> @MethodTable(mt)
 # 0 methods:

julia> mt
 # 0 methods:

julia> @overlay mt function sin(x::Float64)
           1
       end

julia> @overlay mt function cos(x::Float64)
           1
       end

julia> mt
 # 2 methods:
[1] cos(x::Float64) in Main at REPL[5]:1
[2] sin(x::Float64) in Main at REPL[4]:1

julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, mt, 1, typemax(UInt))
1-element Vector{Any}:
 Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(), sin(x::Float64) in Main at REPL[4]:1, true)

julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, 1, typemax(UInt))
1-element Vector{Any}:
 Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(Float64), sin(x::T) where T<:Union{Float32, Float64} in Base.Math at special/trig.jl:29, true)
```

 # The `@overlay` macro

The macro replaces the function name by an `Expr(:overlay, mt, name)`,
which then gets piped through to Method def. One particular design
aspect here is that I've stopped letting the 4-argument :method
Expr introduce new generic functions, reserving this functionality
entirely to the 2-argument :method Expr. We already started going
this way when we began omitting method names from the 4-argument
version. This PR re-uses that name field of the 4-argument version
to specify a method table instead.

 # Identity of methods

I think one of the biggest questions of this design is what happens
to the identity of methods. Until OpaqueClosure, all methods were uniquely
identified by their signatures, with the applicable method table
computed from the first argument of the signature. This is important
so that incremental compilation can properly merge method tables coming
from different .ji files. For these methods, that is of course not the
correct method table to use for these methods, so methods
that are not part of the internal method table will instead have a
backreference to the applicable method table.

 # Identity of method tables

Method tables are identified by the name of their binding in the
containing module. To ensure consistency of this mapping, these
MethodTables may only be constructed using the `@MethodTable(name)`
macro, which simultaneously establishes a const binding in
the declaring module.

Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
Keno and maleadt committed May 18, 2021
1 parent b054be5 commit 39caf28
Show file tree
Hide file tree
Showing 16 changed files with 319 additions and 97 deletions.
32 changes: 31 additions & 1 deletion base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ struct InternalMethodTable <: MethodTableView
world::UInt
end

"""
struct OverlayMethodTable <: MethodTableView
Overlays the internal method table such that specific queries can be redirected to an
external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
end

"""
struct CachedMethodTable <: MethodTableView
Expand All @@ -54,7 +65,26 @@ function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable;
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
end
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type{<:Tuple}), table::OverlayMethodTable; limit::Int=typemax(Int))
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig)
if ms === false
return missing
elseif isempty(ms)
# fall back to the internal method table
_min_val[] = typemin(UInt)
_max_val[] = typemax(UInt)
ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig)
end
if ms === false
return missing
end
Expand Down
33 changes: 33 additions & 0 deletions base/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
module Experimental

using Base: Threads, sync_varname
using Base.Meta

"""
Const(A::Array)
Expand Down Expand Up @@ -256,4 +257,36 @@ end
# OpaqueClosure
include("opaque_closure.jl")

"""
Experimental.@overlay mt [function def]
Define a method and add it to the method table `mt` instead of to the global method table.
This can be used to implement a method override mechanism. Regular compilation will not
consider these methods, and you should customize the compilation flow to look in these
method tables (e.g., using [`Core.Compiler.OverlayMethodTable`](@ref)).
"""
macro overlay(mt, def)
def = macroexpand(__module__, def) # to expand @inline, @generated, etc
if !isexpr(def, [:function, :(=)]) || !isexpr(def.args[1], :call)
error("@overlay requires a function Expr")
end
def.args[1].args[1] = Expr(:overlay, mt, def.args[1].args[1])
esc(def)
end

"""
Experimental.@MethodTable(name)
Create a new MethodTable in the current module, bound to `name`. This method table can be
used with the [`Experimental.@overlay`](@ref) macro to define methods for a function without
adding them to the global method table.
"""
macro MethodTable(name)
isa(name, Symbol) || error("name must be a symbol")
esc(quote
const $name = ccall(:jl_new_method_table, Any, (Any, Any), $(quot(name)), $(__module__))
end)
end

end
17 changes: 10 additions & 7 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -895,13 +895,16 @@ function _methods(@nospecialize(f), @nospecialize(t), lim::Int, world::UInt)
end

function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
return _methods_by_ftype(t, nothing, lim, world)
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt)
return _methods_by_ftype(t, mt, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL))
end
function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end
function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32})
return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool}
end

function _method_by_ftype(args...)
Expand Down Expand Up @@ -971,7 +974,7 @@ function methods_including_ambiguous(@nospecialize(f), @nospecialize(t))
world = typemax(UInt)
min = RefValue{UInt}(typemin(UInt))
max = RefValue{UInt}(typemax(UInt))
ms = _methods_by_ftype(tt, -1, world, true, min, max, Ptr{Int32}(C_NULL))
ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))
isa(ms, Bool) && return ms
return MethodList(Method[(m::Core.MethodMatch).method for m in ms], typeof(f).name.mt)
end
Expand Down Expand Up @@ -1533,7 +1536,7 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false)
min = UInt[typemin(UInt)]
max = UInt[typemax(UInt)]
has_ambig = Int32[0]
ms = _methods_by_ftype(ti, -1, typemax(UInt), true, min, max, has_ambig)::Vector
ms = _methods_by_ftype(ti, nothing, -1, typemax(UInt), true, min, max, has_ambig)::Vector
has_ambig[] == 0 && return false
if !ambiguous_bottom
filter!(ms) do m::Core.MethodMatch
Expand Down
98 changes: 51 additions & 47 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ static const auto jlinvoke_func = new JuliaFunction{
static const auto jlmethod_func = new JuliaFunction{
"jl_method_def",
[](LLVMContext &C) { return FunctionType::get(T_prjlvalue,
{T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
{T_prjlvalue, T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); },
nullptr,
};
static const auto jlgenericfunction_func = new JuliaFunction{
Expand Down Expand Up @@ -4398,58 +4398,62 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
return emit_sparam(ctx, jl_unbox_long(args[0]) - 1);
}
else if (head == method_sym) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));
if (jl_expr_nargs(ex) == 1) {
jl_value_t *mn = args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn));

Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
if (jl_expr_nargs(ex) == 1)
Value *bp = NULL, *name, *bp_owner = V_null;
jl_binding_t *bnd = NULL;
bool issym = jl_is_symbol(mn);
bool isglobalref = !issym && jl_is_globalref(mn);
jl_module_t *mod = ctx.module;
if (issym || isglobalref) {
if (isglobalref) {
mod = jl_globalref_mod(mn);
mn = (jl_value_t*)jl_globalref_name(mn);
}
JL_TRY {
if (jl_symbol_name((jl_sym_t*)mn)[0] == '@')
jl_errorf("macro definition not allowed inside a local scope");
name = literal_pointer_val(ctx, mn);
bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn);
}
JL_CATCH {
jl_value_t *e = jl_current_exception();
// errors. boo. root it somehow :(
bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1);
bnd->value = e;
bnd->constp = 1;
raise_exception(ctx, literal_pointer_val(ctx, e));
return ghostValue(jl_nothing_type);
}
bp = julia_binding_gv(ctx, bnd);
bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod);
}
else if (jl_is_slot(mn) || jl_is_argument(mn)) {
int sl = jl_slot_number(mn)-1;
jl_varinfo_t &vi = ctx.slots[sl];
bp = vi.boxroot;
name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl));
}
if (bp) {
Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp,
bp_owner, literal_pointer_val(ctx, bnd) };
jl_cgval_t gf = mark_julia_type(
ctx,
ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)),
true,
jl_function_type);
return gf;
}
emit_error(ctx, "method: invalid declaration");
return jl_cgval_t();
}
Value *a1 = boxed(ctx, emit_expr(ctx, args[1]));
Value *a2 = boxed(ctx, emit_expr(ctx, args[2]));
Value *mdargs[3] = {
Value *mdargs[4] = {
/*argdata*/a1,
ConstantPointerNull::get(cast<PointerType>(T_prjlvalue)),
/*code*/a2,
/*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module)
};
Expand Down
57 changes: 46 additions & 11 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_
jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque);
}

enum METHOD_SERIALIZATION_MODE {
METHOD_INTERNAL = 1,
METHOD_EXTERNAL_MT = 2,
};

static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED
{
if (jl_serialize_generic(s, v)) {
Expand Down Expand Up @@ -627,18 +632,34 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
else if (jl_is_method(v)) {
write_uint8(s->s, TAG_METHOD);
jl_method_t *m = (jl_method_t*)v;
int internal = 1;
internal = m->is_for_opaque_closure || module_in_worklist(m->module);
if (!internal) {
int serialization_mode = 0;
if (m->is_for_opaque_closure || module_in_worklist(m->module))
serialization_mode |= METHOD_INTERNAL;
if (!(serialization_mode & METHOD_INTERNAL)) {
// flag this in the backref table as special
uintptr_t *bp = (uintptr_t*)ptrhash_bp(&backref_table, v);
assert(*bp != (uintptr_t)HT_NOTFOUND);
*bp |= 1;
}
jl_serialize_value(s, (jl_value_t*)m->sig);
jl_serialize_value(s, (jl_value_t*)m->module);
write_uint8(s->s, internal);
if (!internal)
if (m->external_mt != NULL) {
assert(jl_typeis(m->external_mt, jl_methtable_type));
jl_methtable_t *mt = (jl_methtable_t*)m->external_mt;
if (!module_in_worklist(mt->module)) {
serialization_mode |= METHOD_EXTERNAL_MT;
}
}
write_uint8(s->s, serialization_mode);
if (serialization_mode & METHOD_EXTERNAL_MT) {
// We reference this method table by module and binding
jl_methtable_t *mt = (jl_methtable_t*)m->external_mt;
jl_serialize_value(s, mt->module);
jl_serialize_value(s, mt->name);
} else {
jl_serialize_value(s, (jl_value_t*)m->external_mt);
}
if (!(serialization_mode & METHOD_INTERNAL))
return;
jl_serialize_value(s, m->specializations);
jl_serialize_value(s, m->speckeyset);
Expand Down Expand Up @@ -952,6 +973,10 @@ static void jl_collect_lambdas_from_mod(jl_array_t *s, jl_module_t *m) JL_GC_DIS
jl_collect_lambdas_from_mod(s, (jl_module_t*)b->value);
}
}
else if (jl_is_mtable(bv)) {
// a module containing an external method table
jl_collect_methtable_from_mod(s, (jl_methtable_t*)bv);
}
}
}
}
Expand Down Expand Up @@ -1015,7 +1040,7 @@ static void jl_collect_backedges(jl_array_t *s, jl_array_t *t)
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
int ambig = 0;
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false) {
valid = 0;
break;
Expand Down Expand Up @@ -1458,8 +1483,18 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
jl_gc_wb(m, m->sig);
m->module = (jl_module_t*)jl_deserialize_value(s, (jl_value_t**)&m->module);
jl_gc_wb(m, m->module);
int internal = read_uint8(s->s);
if (!internal) {
int serialization_mode = read_uint8(s->s);
if (serialization_mode & METHOD_EXTERNAL_MT) {
jl_module_t *mt_mod = (jl_module_t*)jl_deserialize_value(s, NULL);
jl_sym_t *mt_name = (jl_sym_t*)jl_deserialize_value(s, NULL);
m->external_mt = jl_get_global(mt_mod, mt_name);
jl_gc_wb(m, m->external_mt);
assert(jl_typeis(m->external_mt, jl_methtable_type));
} else {
m->external_mt = jl_deserialize_value(s, &m->external_mt);
jl_gc_wb(m, m->external_mt);
}
if (!(serialization_mode & METHOD_INTERNAL)) {
assert(loc != NULL && loc != HT_NOTFOUND);
arraylist_push(&flagref_list, loc);
arraylist_push(&flagref_list, (void*)pos);
Expand Down Expand Up @@ -1897,7 +1932,7 @@ static void jl_insert_methods(jl_array_t *list)
assert(!meth->is_for_opaque_closure);
jl_tupletype_t *simpletype = (jl_tupletype_t*)jl_array_ptr_ref(list, i + 1);
assert(jl_is_method(meth));
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)meth->sig);
jl_methtable_t *mt = jl_method_get_table(meth);
assert((jl_value_t*)mt != jl_nothing);
jl_method_table_insert(mt, meth, simpletype);
}
Expand Down Expand Up @@ -1927,7 +1962,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
size_t max_valid = ~(size_t)0;
int ambig = 0;
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig);
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
valid = 0;
}
Expand Down Expand Up @@ -2465,7 +2500,7 @@ static jl_method_t *jl_recache_method(jl_method_t *m)
{
assert(!m->is_for_opaque_closure);
jl_datatype_t *sig = (jl_datatype_t*)m->sig;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
jl_methtable_t *mt = jl_method_get_table(m);
assert((jl_value_t*)mt != jl_nothing);
jl_set_typeof(m, (void*)(intptr_t)0x30); // invalidate the old value to help catch errors
jl_method_t *_new = jl_lookup_method(mt, sig, m->module->primary_world);
Expand Down
Loading

0 comments on commit 39caf28

Please sign in to comment.