From 39caf28d64ec17e4957f0c437e7970b0b11d98cc Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 18 May 2021 11:59:40 -0400 Subject: [PATCH] Add support for external method tables (#39697) This PR implements a way to keep tables of methods that are not part of the internal method table, but still participate in the special support we have for keeping tables of methods, in particular unification through precompilation and efficient lookup. The intended design use case is to allow for method overlay tables for various non-CPU backends (e.g. GPU and TPU). These backends would like to modify basic function like `sin` to perform better on the device in question (or in the case of TPU to define them over non-LLVM intrinsics). To date, the best available mechanism of achieving this result was to use a Cassette-like pass rewriting every method and injecting an overlay if necessary. However, this approach is somewhat unsatisfying for two reasons: 1. It requires rewriting every function, which has non-trivial performance cost. 2. It is (currently) expensive because of the repeated calls to generated functions. 3. It confuses inference, because suddenly everything is one method. We have hooks to work around this, but support is incomplete. It is also not clear that Cassette it is the best conceptual model, because these methods *are* methods of the same generic function, they just happen to only be applicable for a particular backend. It is worth noting that this PR only gives the ability to keep these tables of methods. It assigns no particular meaning to them and the runtime (and regular inference) do not look at them. They are designed as an implementation detail for external compilers and similar tools. This feature does not replace Cassette for the method-interception use case in the absence of such a compiler, though it could in the future form part of a solution (I'm hoping the AD work will in due course lead to abstractions that would enable a "faster Cassette" which may use part of these fetaures). As such, I'm not sure we'll ever move this out of Experimental, but until such a time that we have a better solution, I think this'll be a useful feature for the GPU stack. With all those disclaimers out of the way, here is a description of the various parts of the current design that deserve discussion: # Demo ```julia julia> using Base.Experimental: @overlay, @MethodTable julia> @MethodTable(mt) # 0 methods: julia> mt # 0 methods: julia> @overlay mt function sin(x::Float64) 1 end julia> @overlay mt function cos(x::Float64) 1 end julia> mt # 2 methods: [1] cos(x::Float64) in Main at REPL[5]:1 [2] sin(x::Float64) in Main at REPL[4]:1 julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, mt, 1, typemax(UInt)) 1-element Vector{Any}: Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(), sin(x::Float64) in Main at REPL[4]:1, true) julia> Base._methods_by_ftype(Tuple{typeof(sin), Float64}, 1, typemax(UInt)) 1-element Vector{Any}: Core.MethodMatch(Tuple{typeof(sin), Float64}, svec(Float64), sin(x::T) where T<:Union{Float32, Float64} in Base.Math at special/trig.jl:29, true) ``` # The `@overlay` macro The macro replaces the function name by an `Expr(:overlay, mt, name)`, which then gets piped through to Method def. One particular design aspect here is that I've stopped letting the 4-argument :method Expr introduce new generic functions, reserving this functionality entirely to the 2-argument :method Expr. We already started going this way when we began omitting method names from the 4-argument version. This PR re-uses that name field of the 4-argument version to specify a method table instead. # Identity of methods I think one of the biggest questions of this design is what happens to the identity of methods. Until OpaqueClosure, all methods were uniquely identified by their signatures, with the applicable method table computed from the first argument of the signature. This is important so that incremental compilation can properly merge method tables coming from different .ji files. For these methods, that is of course not the correct method table to use for these methods, so methods that are not part of the internal method table will instead have a backreference to the applicable method table. # Identity of method tables Method tables are identified by the name of their binding in the containing module. To ensure consistency of this mapping, these MethodTables may only be constructed using the `@MethodTable(name)` macro, which simultaneously establishes a const binding in the declaring module. Co-authored-by: Tim Besard --- base/compiler/methodtable.jl | 32 +++++++- base/experimental.jl | 33 ++++++++ base/reflection.jl | 17 ++-- src/codegen.cpp | 98 ++++++++++++----------- src/dump.c | 57 ++++++++++--- src/gf.c | 11 ++- src/interpreter.c | 34 +++++--- src/jltypes.c | 4 +- src/julia-syntax.scm | 22 +++-- src/julia.h | 3 +- src/julia_internal.h | 4 +- src/method.c | 14 +++- stdlib/Serialization/src/Serialization.jl | 3 + stdlib/Test/src/Test.jl | 2 +- test/ambiguous.jl | 6 +- test/compiler/contextual.jl | 76 ++++++++++++++++++ 16 files changed, 319 insertions(+), 97 deletions(-) diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index cff9e21fccb41..9b2010afcb05b 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -28,6 +28,17 @@ struct InternalMethodTable <: MethodTableView world::UInt end +""" + struct OverlayMethodTable <: MethodTableView + +Overlays the internal method table such that specific queries can be redirected to an +external table, e.g., to override existing method. +""" +struct OverlayMethodTable <: MethodTableView + world::UInt + mt::Core.MethodTable +end + """ struct CachedMethodTable <: MethodTableView @@ -54,7 +65,26 @@ function findall(@nospecialize(sig::Type{<:Tuple}), table::InternalMethodTable; _min_val = RefValue{UInt}(typemin(UInt)) _max_val = RefValue{UInt}(typemax(UInt)) _ambig = RefValue{Int32}(0) - ms = _methods_by_ftype(sig, limit, table.world, false, _min_val, _max_val, _ambig) + ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig) + if ms === false + return missing + end + return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) +end + +function findall(@nospecialize(sig::Type{<:Tuple}), table::OverlayMethodTable; limit::Int=typemax(Int)) + _min_val = RefValue{UInt}(typemin(UInt)) + _max_val = RefValue{UInt}(typemax(UInt)) + _ambig = RefValue{Int32}(0) + ms = _methods_by_ftype(sig, table.mt, limit, table.world, false, _min_val, _max_val, _ambig) + if ms === false + return missing + elseif isempty(ms) + # fall back to the internal method table + _min_val[] = typemin(UInt) + _max_val[] = typemax(UInt) + ms = _methods_by_ftype(sig, nothing, limit, table.world, false, _min_val, _max_val, _ambig) + end if ms === false return missing end diff --git a/base/experimental.jl b/base/experimental.jl index 7e30792dda4e8..a2a40294bfb14 100644 --- a/base/experimental.jl +++ b/base/experimental.jl @@ -10,6 +10,7 @@ module Experimental using Base: Threads, sync_varname +using Base.Meta """ Const(A::Array) @@ -256,4 +257,36 @@ end # OpaqueClosure include("opaque_closure.jl") +""" + Experimental.@overlay mt [function def] + +Define a method and add it to the method table `mt` instead of to the global method table. +This can be used to implement a method override mechanism. Regular compilation will not +consider these methods, and you should customize the compilation flow to look in these +method tables (e.g., using [`Core.Compiler.OverlayMethodTable`](@ref)). + +""" +macro overlay(mt, def) + def = macroexpand(__module__, def) # to expand @inline, @generated, etc + if !isexpr(def, [:function, :(=)]) || !isexpr(def.args[1], :call) + error("@overlay requires a function Expr") + end + def.args[1].args[1] = Expr(:overlay, mt, def.args[1].args[1]) + esc(def) +end + +""" + Experimental.@MethodTable(name) + +Create a new MethodTable in the current module, bound to `name`. This method table can be +used with the [`Experimental.@overlay`](@ref) macro to define methods for a function without +adding them to the global method table. +""" +macro MethodTable(name) + isa(name, Symbol) || error("name must be a symbol") + esc(quote + const $name = ccall(:jl_new_method_table, Any, (Any, Any), $(quot(name)), $(__module__)) + end) +end + end diff --git a/base/reflection.jl b/base/reflection.jl index 8fe04defe1d3f..7d7318fadce9e 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -895,13 +895,16 @@ function _methods(@nospecialize(f), @nospecialize(t), lim::Int, world::UInt) end function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt) - return _methods_by_ftype(t, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL)) + return _methods_by_ftype(t, nothing, lim, world) end -function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1}) - return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} +function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt) + return _methods_by_ftype(t, mt, lim, world, false, RefValue{UInt}(typemin(UInt)), RefValue{UInt}(typemax(UInt)), Ptr{Int32}(C_NULL)) end -function _methods_by_ftype(@nospecialize(t), lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32}) - return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} +function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Array{UInt,1}, max::Array{UInt,1}, has_ambig::Array{Int32,1}) + return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} +end +function _methods_by_ftype(@nospecialize(t), mt::Union{Core.MethodTable, Nothing}, lim::Int, world::UInt, ambig::Bool, min::Ref{UInt}, max::Ref{UInt}, has_ambig::Ref{Int32}) + return ccall(:jl_matching_methods, Any, (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ptr{Int32}), t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} end function _method_by_ftype(args...) @@ -971,7 +974,7 @@ function methods_including_ambiguous(@nospecialize(f), @nospecialize(t)) world = typemax(UInt) min = RefValue{UInt}(typemin(UInt)) max = RefValue{UInt}(typemax(UInt)) - ms = _methods_by_ftype(tt, -1, world, true, min, max, Ptr{Int32}(C_NULL)) + ms = _methods_by_ftype(tt, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL)) isa(ms, Bool) && return ms return MethodList(Method[(m::Core.MethodMatch).method for m in ms], typeof(f).name.mt) end @@ -1533,7 +1536,7 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false) min = UInt[typemin(UInt)] max = UInt[typemax(UInt)] has_ambig = Int32[0] - ms = _methods_by_ftype(ti, -1, typemax(UInt), true, min, max, has_ambig)::Vector + ms = _methods_by_ftype(ti, nothing, -1, typemax(UInt), true, min, max, has_ambig)::Vector has_ambig[] == 0 && return false if !ambiguous_bottom filter!(ms) do m::Core.MethodMatch diff --git a/src/codegen.cpp b/src/codegen.cpp index bc60798220d52..b8f86e001e985 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -566,7 +566,7 @@ static const auto jlinvoke_func = new JuliaFunction{ static const auto jlmethod_func = new JuliaFunction{ "jl_method_def", [](LLVMContext &C) { return FunctionType::get(T_prjlvalue, - {T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); }, + {T_prjlvalue, T_prjlvalue, T_prjlvalue, T_pjlvalue}, false); }, nullptr, }; static const auto jlgenericfunction_func = new JuliaFunction{ @@ -4398,58 +4398,62 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval) return emit_sparam(ctx, jl_unbox_long(args[0]) - 1); } else if (head == method_sym) { - jl_value_t *mn = args[0]; - assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn)); + if (jl_expr_nargs(ex) == 1) { + jl_value_t *mn = args[0]; + assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(mn) || jl_is_slot(mn)); - Value *bp = NULL, *name, *bp_owner = V_null; - jl_binding_t *bnd = NULL; - bool issym = jl_is_symbol(mn); - bool isglobalref = !issym && jl_is_globalref(mn); - jl_module_t *mod = ctx.module; - if (issym || isglobalref) { - if (isglobalref) { - mod = jl_globalref_mod(mn); - mn = (jl_value_t*)jl_globalref_name(mn); - } - JL_TRY { - if (jl_symbol_name((jl_sym_t*)mn)[0] == '@') - jl_errorf("macro definition not allowed inside a local scope"); - name = literal_pointer_val(ctx, mn); - bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn); - } - JL_CATCH { - jl_value_t *e = jl_current_exception(); - // errors. boo. root it somehow :( - bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1); - bnd->value = e; - bnd->constp = 1; - raise_exception(ctx, literal_pointer_val(ctx, e)); - return ghostValue(jl_nothing_type); - } - bp = julia_binding_gv(ctx, bnd); - bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod); - } - else if (jl_is_slot(mn) || jl_is_argument(mn)) { - int sl = jl_slot_number(mn)-1; - jl_varinfo_t &vi = ctx.slots[sl]; - bp = vi.boxroot; - name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl)); - } - if (bp) { - Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp, - bp_owner, literal_pointer_val(ctx, bnd) }; - jl_cgval_t gf = mark_julia_type( - ctx, - ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)), - true, - jl_function_type); - if (jl_expr_nargs(ex) == 1) + Value *bp = NULL, *name, *bp_owner = V_null; + jl_binding_t *bnd = NULL; + bool issym = jl_is_symbol(mn); + bool isglobalref = !issym && jl_is_globalref(mn); + jl_module_t *mod = ctx.module; + if (issym || isglobalref) { + if (isglobalref) { + mod = jl_globalref_mod(mn); + mn = (jl_value_t*)jl_globalref_name(mn); + } + JL_TRY { + if (jl_symbol_name((jl_sym_t*)mn)[0] == '@') + jl_errorf("macro definition not allowed inside a local scope"); + name = literal_pointer_val(ctx, mn); + bnd = jl_get_binding_for_method_def(mod, (jl_sym_t*)mn); + } + JL_CATCH { + jl_value_t *e = jl_current_exception(); + // errors. boo. root it somehow :( + bnd = jl_get_binding_wr(ctx.module, (jl_sym_t*)jl_gensym(), 1); + bnd->value = e; + bnd->constp = 1; + raise_exception(ctx, literal_pointer_val(ctx, e)); + return ghostValue(jl_nothing_type); + } + bp = julia_binding_gv(ctx, bnd); + bp_owner = literal_pointer_val(ctx, (jl_value_t*)mod); + } + else if (jl_is_slot(mn) || jl_is_argument(mn)) { + int sl = jl_slot_number(mn)-1; + jl_varinfo_t &vi = ctx.slots[sl]; + bp = vi.boxroot; + name = literal_pointer_val(ctx, (jl_value_t*)slot_symbol(ctx, sl)); + } + if (bp) { + Value *mdargs[5] = { name, literal_pointer_val(ctx, (jl_value_t*)mod), bp, + bp_owner, literal_pointer_val(ctx, bnd) }; + jl_cgval_t gf = mark_julia_type( + ctx, + ctx.builder.CreateCall(prepare_call(jlgenericfunction_func), makeArrayRef(mdargs)), + true, + jl_function_type); return gf; + } + emit_error(ctx, "method: invalid declaration"); + return jl_cgval_t(); } Value *a1 = boxed(ctx, emit_expr(ctx, args[1])); Value *a2 = boxed(ctx, emit_expr(ctx, args[2])); - Value *mdargs[3] = { + Value *mdargs[4] = { /*argdata*/a1, + ConstantPointerNull::get(cast(T_prjlvalue)), /*code*/a2, /*module*/literal_pointer_val(ctx, (jl_value_t*)ctx.module) }; diff --git a/src/dump.c b/src/dump.c index 2a6c8318c2095..5771776a8658d 100644 --- a/src/dump.c +++ b/src/dump.c @@ -504,6 +504,11 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_ jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque); } +enum METHOD_SERIALIZATION_MODE { + METHOD_INTERNAL = 1, + METHOD_EXTERNAL_MT = 2, +}; + static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED { if (jl_serialize_generic(s, v)) { @@ -627,9 +632,10 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li else if (jl_is_method(v)) { write_uint8(s->s, TAG_METHOD); jl_method_t *m = (jl_method_t*)v; - int internal = 1; - internal = m->is_for_opaque_closure || module_in_worklist(m->module); - if (!internal) { + int serialization_mode = 0; + if (m->is_for_opaque_closure || module_in_worklist(m->module)) + serialization_mode |= METHOD_INTERNAL; + if (!(serialization_mode & METHOD_INTERNAL)) { // flag this in the backref table as special uintptr_t *bp = (uintptr_t*)ptrhash_bp(&backref_table, v); assert(*bp != (uintptr_t)HT_NOTFOUND); @@ -637,8 +643,23 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li } jl_serialize_value(s, (jl_value_t*)m->sig); jl_serialize_value(s, (jl_value_t*)m->module); - write_uint8(s->s, internal); - if (!internal) + if (m->external_mt != NULL) { + assert(jl_typeis(m->external_mt, jl_methtable_type)); + jl_methtable_t *mt = (jl_methtable_t*)m->external_mt; + if (!module_in_worklist(mt->module)) { + serialization_mode |= METHOD_EXTERNAL_MT; + } + } + write_uint8(s->s, serialization_mode); + if (serialization_mode & METHOD_EXTERNAL_MT) { + // We reference this method table by module and binding + jl_methtable_t *mt = (jl_methtable_t*)m->external_mt; + jl_serialize_value(s, mt->module); + jl_serialize_value(s, mt->name); + } else { + jl_serialize_value(s, (jl_value_t*)m->external_mt); + } + if (!(serialization_mode & METHOD_INTERNAL)) return; jl_serialize_value(s, m->specializations); jl_serialize_value(s, m->speckeyset); @@ -952,6 +973,10 @@ static void jl_collect_lambdas_from_mod(jl_array_t *s, jl_module_t *m) JL_GC_DIS jl_collect_lambdas_from_mod(s, (jl_module_t*)b->value); } } + else if (jl_is_mtable(bv)) { + // a module containing an external method table + jl_collect_methtable_from_mod(s, (jl_methtable_t*)bv); + } } } } @@ -1015,7 +1040,7 @@ static void jl_collect_backedges(jl_array_t *s, jl_array_t *t) size_t min_valid = 0; size_t max_valid = ~(size_t)0; int ambig = 0; - jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig); + jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig); if (matches == jl_false) { valid = 0; break; @@ -1458,8 +1483,18 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_ jl_gc_wb(m, m->sig); m->module = (jl_module_t*)jl_deserialize_value(s, (jl_value_t**)&m->module); jl_gc_wb(m, m->module); - int internal = read_uint8(s->s); - if (!internal) { + int serialization_mode = read_uint8(s->s); + if (serialization_mode & METHOD_EXTERNAL_MT) { + jl_module_t *mt_mod = (jl_module_t*)jl_deserialize_value(s, NULL); + jl_sym_t *mt_name = (jl_sym_t*)jl_deserialize_value(s, NULL); + m->external_mt = jl_get_global(mt_mod, mt_name); + jl_gc_wb(m, m->external_mt); + assert(jl_typeis(m->external_mt, jl_methtable_type)); + } else { + m->external_mt = jl_deserialize_value(s, &m->external_mt); + jl_gc_wb(m, m->external_mt); + } + if (!(serialization_mode & METHOD_INTERNAL)) { assert(loc != NULL && loc != HT_NOTFOUND); arraylist_push(&flagref_list, loc); arraylist_push(&flagref_list, (void*)pos); @@ -1897,7 +1932,7 @@ static void jl_insert_methods(jl_array_t *list) assert(!meth->is_for_opaque_closure); jl_tupletype_t *simpletype = (jl_tupletype_t*)jl_array_ptr_ref(list, i + 1); assert(jl_is_method(meth)); - jl_methtable_t *mt = jl_method_table_for((jl_value_t*)meth->sig); + jl_methtable_t *mt = jl_method_get_table(meth); assert((jl_value_t*)mt != jl_nothing); jl_method_table_insert(mt, meth, simpletype); } @@ -1927,7 +1962,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) size_t max_valid = ~(size_t)0; int ambig = 0; // TODO: possibly need to included ambiguities too (for the optimizer correctness)? - jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig); + jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_world_counter, &min_valid, &max_valid, &ambig); if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) { valid = 0; } @@ -2465,7 +2500,7 @@ static jl_method_t *jl_recache_method(jl_method_t *m) { assert(!m->is_for_opaque_closure); jl_datatype_t *sig = (jl_datatype_t*)m->sig; - jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig); + jl_methtable_t *mt = jl_method_get_table(m); assert((jl_value_t*)mt != jl_nothing); jl_set_typeof(m, (void*)(intptr_t)0x30); // invalidate the old value to help catch errors jl_method_t *_new = jl_lookup_method(mt, sig, m->module->primary_world); diff --git a/src/gf.c b/src/gf.c index 5bdf7c8ec29b0..3cda8ae2a733c 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1876,7 +1876,7 @@ jl_method_instance_t *jl_method_lookup(jl_value_t **args, size_t nargs, size_t w // // lim is the max # of methods to return. if there are more, returns jl_false. // -1 for no limit. -JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, int lim, int include_ambiguous, +JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *mt, int lim, int include_ambiguous, size_t world, size_t *min_valid, size_t *max_valid, int *ambig) { JL_TIMING(METHOD_MATCH); @@ -1885,10 +1885,13 @@ JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, int lim, int jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types); if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type) return (jl_value_t*)jl_an_empty_vec_any; - jl_methtable_t *mt = jl_method_table_for(unw); + if (mt == jl_nothing) + mt = (jl_value_t*)jl_method_table_for(unw); + else if (!jl_typeis(mt, jl_methtable_type)) + jl_error("matching_method: `mt` is not a method table"); if ((jl_value_t*)mt == jl_nothing) return jl_false; // indeterminate - ml_matches can't deal with this case - return ml_matches(mt, 0, types, lim, include_ambiguous, 1, world, 1, min_valid, max_valid, ambig); + return ml_matches((jl_methtable_t*)mt, 0, types, lim, include_ambiguous, 1, world, 1, min_valid, max_valid, ambig); } jl_method_instance_t *jl_get_unspecialized(jl_method_instance_t *method JL_PROPAGATES_ROOT) @@ -2067,7 +2070,7 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES size_t min_valid2 = 1; size_t max_valid2 = ~(size_t)0; int ambig = 0; - jl_value_t *matches = jl_matching_methods(types, 1, 1, world, &min_valid2, &max_valid2, &ambig); + jl_value_t *matches = jl_matching_methods(types, jl_nothing, 1, 1, world, &min_valid2, &max_valid2, &ambig); if (*min_valid < min_valid2) *min_valid = min_valid2; if (*max_valid > max_valid2) diff --git a/src/interpreter.c b/src/interpreter.c index 008886f1c99c9..f130c311c1a45 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -78,28 +78,36 @@ static jl_value_t *eval_body(jl_array_t *stmts, interpreter_state *s, size_t ip, static jl_value_t *eval_methoddef(jl_expr_t *ex, interpreter_state *s) { jl_value_t **args = jl_array_ptr_data(ex->args); - jl_sym_t *fname = (jl_sym_t*)args[0]; - jl_module_t *modu = s->module; - if (jl_is_globalref(fname)) { - modu = jl_globalref_mod(fname); - fname = jl_globalref_name(fname); - } - assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(fname)); - if (jl_is_symbol(fname)) { + if (jl_expr_nargs(ex) == 1) { + jl_value_t **args = jl_array_ptr_data(ex->args); + jl_sym_t *fname = (jl_sym_t*)args[0]; + jl_module_t *modu = s->module; + if (jl_is_globalref(fname)) { + modu = jl_globalref_mod(fname); + fname = jl_globalref_name(fname); + } + if (!jl_is_symbol(fname)) { + jl_error("method: invalid declaration"); + } jl_value_t *bp_owner = (jl_value_t*)modu; jl_binding_t *b = jl_get_binding_for_method_def(modu, fname); jl_value_t **bp = &b->value; jl_value_t *gf = jl_generic_function_def(b->name, b->owner, bp, bp_owner, b); - if (jl_expr_nargs(ex) == 1) - return gf; + return gf; } - jl_value_t *atypes = NULL, *meth = NULL; - JL_GC_PUSH2(&atypes, &meth); + jl_value_t *atypes = NULL, *meth = NULL, *fname = NULL; + JL_GC_PUSH3(&atypes, &meth, &fname); + + fname = eval_value(args[0], s); + jl_methtable_t *mt = NULL; + if (jl_typeis(fname, jl_methtable_type)) { + mt = (jl_methtable_t*)fname; + } atypes = eval_value(args[1], s); meth = eval_value(args[2], s); - jl_method_def((jl_svec_t*)atypes, (jl_code_info_t*)meth, s->module); + jl_method_def((jl_svec_t*)atypes, mt, (jl_code_info_t*)meth, s->module); JL_GC_POP(); return jl_nothing; } diff --git a/src/jltypes.c b/src/jltypes.c index 2ba69caf1991e..73015d2cfe6c1 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2302,6 +2302,7 @@ void jl_init_types(void) JL_GC_DISABLED "specializations", "speckeyset", "slot_syms", + "external_mt", "source", "unspecialized", "generator", @@ -2329,6 +2330,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_array_type, jl_string_type, jl_any_type, + jl_any_type, jl_any_type, // jl_method_instance_type jl_any_type, jl_array_any_type, @@ -2535,7 +2537,7 @@ void jl_init_types(void) JL_GC_DISABLED #endif jl_svecset(jl_methtable_type->types, 10, jl_uint8_type); jl_svecset(jl_methtable_type->types, 11, jl_uint8_type); - jl_svecset(jl_method_type->types, 11, jl_method_instance_type); + jl_svecset(jl_method_type->types, 12, jl_method_instance_type); jl_svecset(jl_method_instance_type->types, 6, jl_code_instance_type); jl_svecset(jl_code_instance_type->types, 9, jl_voidpointer_type); jl_svecset(jl_code_instance_type->types, 10, jl_voidpointer_type); diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index ca274ef552f5b..dfec6ca980020 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -221,10 +221,11 @@ (define (method-expr-name m) (let ((name (cadr m))) + (let ((name (if (or (length= m 2) (not (pair? name)) (not (quoted? name))) name (cadr name)))) (cond ((not (pair? name)) name) ((eq? (car name) 'outerref) (cadr name)) ;((eq? (car name) 'globalref) (caddr name)) - (else name)))) + (else name))))) ;; extract static parameter names from a (method ...) expression (define (method-expr-static-parameters m) @@ -252,6 +253,13 @@ (pair? (caddr e)) (memq (car (caddr e)) '(quote inert)) (symbol? (cadr (caddr e)))))) +(define (overlay? e) + (and (pair? e) (eq? (car e) 'overlay))) + +(define (sym-ref-or-overlay? e) + (or (overlay? e) + (sym-ref? e))) + ;; convert final (... x) to (curly Vararg x) (define (dots->vararg a) (if (null? a) a @@ -341,14 +349,15 @@ (let* ((names (map car sparams)) (anames (map (lambda (x) (if (underscore-symbol? x) UNUSED x)) (llist-vars argl))) (unused_anames (filter (lambda (x) (not (eq? x UNUSED))) anames)) - (ename (if (nodot-sym-ref? name) name `(null)))) + (ename (if (nodot-sym-ref? name) name + (if (overlay? name) (cadr name) `(null))))) (if (has-dups unused_anames) (error (string "function argument name not unique: \"" (car (has-dups unused_anames)) "\""))) (if (has-dups names) (error "function static parameter names not unique")) (if (any (lambda (x) (and (not (eq? x UNUSED)) (memq x names))) anames) (error "function argument and static parameter names must be distinct")) - (if (or (and name (not (sym-ref? name))) (not (valid-name? name))) + (if (or (and name (not (sym-ref-or-overlay? name))) (not (valid-name? name))) (error (string "invalid function name \"" (deparse name) "\""))) (let* ((loc (maybe-remove-functionloc! body)) (generator (if (expr-contains-p if-generated? body (lambda (x) (not (function-def? x)))) @@ -1130,13 +1139,14 @@ (argl-stmts (lower-destructuring-args argl)) (argl (car argl-stmts)) (name (check-dotop (car argl))) + (argname (if (overlay? name) (caddr name) name)) ;; fill in first (closure) argument (adj-decl (lambda (n) (if (and (decl? n) (length= n 2)) `(|::| |#self#| ,(cadr n)) n))) - (farg (if (decl? name) - (adj-decl name) - `(|::| |#self#| (call (core Typeof) ,name)))) + (farg (if (decl? argname) + (adj-decl argname) + `(|::| |#self#| (call (core Typeof) ,argname)))) (body (insert-after-meta body (cdr argl-stmts))) (argl (cdr argl)) (argl (fix-arglist diff --git a/src/julia.h b/src/julia.h index 8ec461e77126b..ed1cb5156285a 100644 --- a/src/julia.h +++ b/src/julia.h @@ -311,6 +311,7 @@ typedef struct _jl_method_t { jl_array_t *speckeyset; // index lookup by hash into specializations jl_value_t *slot_syms; // compacted list of slot names (String) + jl_value_t *external_mt; // reference to the method table this method is part of, null if part of the internal table jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null jl_value_t *generator; // executable code-generating function if available @@ -1400,7 +1401,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, jl_module_t *module, jl_value_t **bp, jl_value_t *bp_owner, jl_binding_t *bnd); -JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata, jl_code_info_t *f, jl_module_t *module); +JL_DLLEXPORT jl_method_t *jl_method_def(jl_svec_t *argdata, jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module); JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo); JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src); JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT; diff --git a/src/julia_internal.h b/src/julia_internal.h index d6170a9d56719..ec45080bcf908 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -548,13 +548,15 @@ jl_method_instance_t *jl_method_lookup(jl_value_t **args, size_t nargs, size_t w jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value_t **args, size_t nargs); jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t *f, jl_value_t **args, size_t nargs); -JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, int lim, int include_ambiguous, +JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *mt, int lim, int include_ambiguous, size_t world, size_t *min_valid, size_t *max_valid, int *ambig); JL_DLLEXPORT jl_datatype_t *jl_first_argument_datatype(jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_argument_datatype(jl_value_t *argt JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_methtable_t *jl_method_table_for( jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; +JL_DLLEXPORT jl_methtable_t *jl_method_get_table( + jl_method_t *method) JL_NOTSAFEPOINT; jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAGATES_ROOT); int jl_pointer_egal(jl_value_t *t); diff --git a/src/method.c b/src/method.c index 1d3a593e638ed..f771af9f79603 100644 --- a/src/method.c +++ b/src/method.c @@ -652,6 +652,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module) m->roots = NULL; m->ccallable = NULL; m->module = module; + m->external_mt = NULL; m->source = NULL; m->unspecialized = NULL; m->generator = NULL; @@ -764,6 +765,11 @@ JL_DLLEXPORT jl_methtable_t *jl_method_table_for(jl_value_t *argtypes JL_PROPAGA return first_methtable(argtypes, 0); } +JL_DLLEXPORT jl_methtable_t *jl_method_get_table(jl_method_t *method JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT +{ + return method->external_mt ? (jl_methtable_t*)method->external_mt : jl_method_table_for(method->sig); +} + // get the MethodTable implied by a single given type, or `nothing` JL_DLLEXPORT jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT { @@ -773,6 +779,7 @@ JL_DLLEXPORT jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAG jl_array_t *jl_all_methods JL_GLOBALLY_ROOTED; JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata, + jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module) { @@ -801,7 +808,9 @@ JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata, argtype = jl_new_struct(jl_unionall_type, tv, argtype); } - jl_methtable_t *mt = jl_method_table_for(argtype); + jl_methtable_t *external_mt = mt; + if (!mt) + mt = jl_method_table_for(argtype); if ((jl_value_t*)mt == jl_nothing) jl_error("Method dispatch is unimplemented currently for this method signature"); if (mt->frozen) @@ -830,6 +839,9 @@ JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata, f = jl_new_code_info_from_ir((jl_expr_t*)f); } m = jl_new_method_uninit(module); + m->external_mt = (jl_value_t*)external_mt; + if (external_mt) + jl_gc_wb(m, external_mt); m->sig = argtype; m->name = name; m->isva = isva; diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 060dbb2bc011f..14d3336b34caf 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -434,6 +434,9 @@ function serialize(s::AbstractSerializer, meth::Method) else serialize(s, nothing) end + if isdefined(meth, :external_mt) + error("cannot serialize Method objects with external method tables") + end nothing end diff --git a/stdlib/Test/src/Test.jl b/stdlib/Test/src/Test.jl index 9a16d3d25c9be..674893051fdac 100644 --- a/stdlib/Test/src/Test.jl +++ b/stdlib/Test/src/Test.jl @@ -1594,7 +1594,7 @@ function detect_ambiguities(mods::Module...; for m in Base.MethodList(mt) is_in_mods(m.module, recursive, mods) || continue ambig = Int32[0] - ms = Base._methods_by_ftype(m.sig, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) + ms = Base._methods_by_ftype(m.sig, nothing, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) ambig[1] == 0 && continue isa(ms, Bool) && continue for match2 in ms diff --git a/test/ambiguous.jl b/test/ambiguous.jl index 8d0c2092f21c2..265d97776c053 100644 --- a/test/ambiguous.jl +++ b/test/ambiguous.jl @@ -348,7 +348,7 @@ f35983(::Type, ::Type) = 2 @test length(Base.methods(f35983, (Any, Any))) == 2 @test first(Base.methods(f35983, (Any, Any))).sig == Tuple{typeof(f35983), Type, Type} let ambig = Int32[0] - ms = Base._methods_by_ftype(Tuple{typeof(f35983), Type, Type}, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) + ms = Base._methods_by_ftype(Tuple{typeof(f35983), Type, Type}, nothing, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) @test length(ms) == 1 @test ambig[1] == 0 end @@ -356,14 +356,14 @@ f35983(::Type{Int16}, ::Any) = 3 @test length(Base.methods_including_ambiguous(f35983, (Type, Type))) == 2 @test length(Base.methods(f35983, (Type, Type))) == 2 let ambig = Int32[0] - ms = Base._methods_by_ftype(Tuple{typeof(f35983), Type, Type}, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) + ms = Base._methods_by_ftype(Tuple{typeof(f35983), Type, Type}, nothing, -1, typemax(UInt), true, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) @test length(ms) == 2 @test ambig[1] == 1 end struct B38280 <: Real; val; end let ambig = Int32[0] - ms = Base._methods_by_ftype(Tuple{Type{B38280}, Any}, 1, typemax(UInt), false, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) + ms = Base._methods_by_ftype(Tuple{Type{B38280}, Any}, nothing, 1, typemax(UInt), false, UInt[typemin(UInt)], UInt[typemax(UInt)], ambig) @test ms isa Vector @test length(ms) == 1 @test ambig[1] == 1 diff --git a/test/compiler/contextual.jl b/test/compiler/contextual.jl index bce782d73df87..afbfa1db601f6 100644 --- a/test/compiler/contextual.jl +++ b/test/compiler/contextual.jl @@ -135,3 +135,79 @@ let method = which(func2, ()) end func3() = func2() @test_throws UndefVarError func3() + + + +## overlay method tables + +module OverlayModule + +using Base.Experimental: @MethodTable, @overlay + +@MethodTable(mt) + +@overlay mt function sin(x::Float64) + 1 +end + +# short function def +@overlay mt cos(x::Float64) = 2 + +end + +methods = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, nothing, 1, typemax(UInt)) +@test only(methods).method.module === Base.Math + +methods = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, OverlayModule.mt, 1, typemax(UInt)) +@test only(methods).method.module === OverlayModule + +methods = Base._methods_by_ftype(Tuple{typeof(sin), Int}, OverlayModule.mt, 1, typemax(UInt)) +@test isempty(methods) + +# precompilation + +load_path = mktempdir() +depot_path = mktempdir() +try + pushfirst!(LOAD_PATH, load_path) + pushfirst!(DEPOT_PATH, depot_path) + + write(joinpath(load_path, "Foo.jl"), + """ + module Foo + Base.Experimental.@MethodTable(mt) + Base.Experimental.@overlay mt sin(x::Int) = 1 + end + """) + + # precompiling Foo serializes the overlay method through the `mt` binding in the module + Foo = Base.require(Main, :Foo) + @test length(Foo.mt) == 1 + + write(joinpath(load_path, "Bar.jl"), + """ + module Bar + Base.Experimental.@MethodTable(mt) + end + """) + + write(joinpath(load_path, "Baz.jl"), + """ + module Baz + using Bar + Base.Experimental.@overlay Bar.mt sin(x::Int) = 1 + end + """) + + # when referring an method table in another module, + # the overlay method needs to be discovered explicitly + Bar = Base.require(Main, :Bar) + @test length(Bar.mt) == 0 + Baz = Base.require(Main, :Baz) + @test length(Bar.mt) == 1 +finally + rm(load_path, recursive=true, force=true) + rm(depot_path, recursive=true, force=true) + filter!((≠)(load_path), LOAD_PATH) + filter!((≠)(depot_path), DEPOT_PATH) +end