Skip to content

Commit

Permalink
Add AbstractInterpreter to parameterize compilation pipeline
Browse files Browse the repository at this point in the history
This allows selective overriding of the compilation pipeline through
multiple dispatch, enabling projects like `XLA.jl` to maintain separate
inference caches, inference algorithms or heuristic algorithms while
inferring and lowering code.  In particular, it defines a new type,
`AbstractInterpreter`, that represents an abstract interpretation
pipeline.  This `AbstractInterpreter` has a single defined concrete
subtype, `NativeInterpreter`, that represents the native Julia
compilation pipeline.  The `NativeInterpreter` contains within it all
the compiler parameters previously contained within `Params`, split into
two pieces: `InferenceParams` and `OptimizationParams`, used within type
inference and optimization, respectively.  The interpreter object is
then threaded throughout most of the type inference pipeline, and allows
for straightforward prototyping and replacement of the compiler
internals.

As a simple example of the kind of workflow this enables, I include here
a simple testing script showing how to use this to easily get a list
of the number of times a function is inferred during type inference by
overriding just two functions within the compiler.  First, I will define
here some simple methods to make working with inference a bit easier:

```julia
using Core.Compiler
import Core.Compiler: InferenceParams, OptimizationParams, get_world_counter, get_inference_cache

"""
    @infer_function interp foo(1, 2) [show_steps=true] [show_ir=false]

Infer a function call using the given interpreter object, return
the inference object.  Set keyword arguments to modify verbosity:

* Set `show_steps` to `true` to see the `InferenceResult` step by step.
* Set `show_ir` to `true` to see the final type-inferred Julia IR.
"""
macro infer_function(interp, func_call, kwarg_exs...)
    if !isa(func_call, Expr) || func_call.head != :call
        error("@infer_function requires a function call")
    end

    local func = func_call.args[1]
    local args = func_call.args[2:end]
    kwargs = []
    for ex in kwarg_exs
        if ex isa Expr && ex.head === :(=) && ex.args[1] isa Symbol
            push!(kwargs, first(ex.args) => last(ex.args))
        else
            error("Invalid @infer_function kwarg $(ex)")
        end
    end
    return quote
        infer_function($(esc(interp)), $(esc(func)), typeof.(($(args)...,)); $(esc(kwargs))...)
    end
end

function infer_function(interp, f, tt; show_steps::Bool=false, show_ir::Bool=false)
    # Find all methods that are applicable to these types
    fms = methods(f, tt)
    if length(fms) != 1
        error("Unable to find single applicable method for $f with types $tt")
    end

    # Take the first applicable method
    method = first(fms)

    # Build argument tuple
    method_args = Tuple{typeof(f), tt...}

    # Grab the appropriate method instance for these types
    mi = Core.Compiler.specialize_method(method, method_args, Core.svec())

    # Construct InferenceResult to hold the result,
    result = Core.Compiler.InferenceResult(mi)
    if show_steps
        @info("Initial result, before inference: ", result)
    end

    # Create an InferenceState to begin inference, give it a world that is always newest
    world = Core.Compiler.get_world_counter()
    frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)

    # Run type inference on this frame.  Because the interpreter is embedded
    # within this InferenceResult, we don't need to pass the interpreter in.
    Core.Compiler.typeinf_local(interp, frame)
    if show_steps
        @info("Ending result, post-inference: ", result)
    end
    if show_ir
        @info("Inferred source: ", result.result.src)
    end

    # Give the result back
    return result
end
```

Next, we define a simple function and pass it through:
```julia
function foo(x, y)
    return x + y * x
end

native_interpreter = Core.Compiler.NativeInterpreter()
inferred = @infer_function native_interpreter foo(1.0, 2.0) show_steps=true show_ir=true
```

This gives a nice output such as the following:
```julia-repl
┌ Info: Initial result, before inference:
└   result = foo(::Float64, ::Float64) => Any
┌ Info: Ending result, post-inference:
└   result = foo(::Float64, ::Float64) => Float64
┌ Info: Inferred source:
│   result.result.src =
│    CodeInfo(
│        @ REPL[1]:3 within `foo'
│    1 ─ %1 = (y * x)::Float64
│    │   %2 = (x + %1)::Float64
│    └──      return %2
└    )
```

We can then define a custom `AbstractInterpreter` subtype that will
override two specific pieces of the compilation process; managing the
runtime inference cache.  While it will transparently pass all information
through to a bundled `NativeInterpreter`, it has the ability to force cache
misses in order to re-infer things so that we can easily see how many
methods (and which) would be inferred to compile a certain method:

```julia
struct CountingInterpreter <: Compiler.AbstractInterpreter
    visited_methods::Set{Core.Compiler.MethodInstance}
    methods_inferred::Ref{UInt64}

    # Keep around a native interpreter so that we can sub off to "super" functions
    native_interpreter::Core.Compiler.NativeInterpreter
end
CountingInterpreter() = CountingInterpreter(
    Set{Core.Compiler.MethodInstance}(),
    Ref(UInt64(0)),
    Core.Compiler.NativeInterpreter(),
)

InferenceParams(ci::CountingInterpreter) = InferenceParams(ci.native_interpreter)
OptimizationParams(ci::CountingInterpreter) = OptimizationParams(ci.native_interpreter)
get_world_counter(ci::CountingInterpreter) = get_world_counter(ci.native_interpreter)
get_inference_cache(ci::CountingInterpreter) = get_inference_cache(ci.native_interpreter)

function Core.Compiler.inf_for_methodinstance(interp::CountingInterpreter, mi::Core.Compiler.MethodInstance, min_world::UInt, max_world::UInt=min_world)
    # Hit our own cache; if it exists, pass on to the main runtime
    if mi in interp.visited_methods
        return Core.Compiler.inf_for_methodinstance(interp.native_interpreter, mi, min_world, max_world)
    end

    # Otherwise, we return `nothing`, forcing a cache miss
    return nothing
end

function Core.Compiler.cache_result(interp::CountingInterpreter, result::Core.Compiler.InferenceResult, min_valid::UInt, max_valid::UInt)
    push!(interp.visited_methods, result.linfo)
    interp.methods_inferred[] += 1
    return Core.Compiler.cache_result(interp.native_interpreter, result, min_valid, max_valid)
end

function reset!(interp::CountingInterpreter)
    empty!(interp.visited_methods)
    interp.methods_inferred[] = 0
    return nothing
end
```

Running it on our testing function:
```julia
counting_interpreter = CountingInterpreter()
inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
inferred = @infer_function counting_interpreter foo(1, 2) show_ir=true
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")

inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
reset!(counting_interpreter)

@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
inferred = @infer_function counting_interpreter foo(1.0, 2.0)
@info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
```

Also gives us a nice result:
```
[ Info: Cumulative number of methods inferred: 2
┌ Info: Inferred source:
│   result.result.src =
│    CodeInfo(
│        @ /Users/sabae/src/julia-compilerhack/AbstractInterpreterTest.jl:81 within `foo'
│    1 ─ %1 = (y * x)::Int64
│    │   %2 = (x + %1)::Int64
│    └──      return %2
└    )
[ Info: Cumulative number of methods inferred: 4
[ Info: Cumulative number of methods inferred: 4
[ Info: Cumulative number of methods inferred: 0
[ Info: Cumulative number of methods inferred: 2
```
  • Loading branch information
staticfloat authored and Keno committed May 10, 2020
1 parent 8f512f3 commit c5fcd73
Show file tree
Hide file tree
Showing 18 changed files with 392 additions and 269 deletions.
147 changes: 75 additions & 72 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions base/compiler/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# since we won't be able to specialize & infer them at runtime

let fs = Any[typeinf_ext, typeinf, typeinf_edge, pure_eval_call, run_passes],
world = get_world_counter()
world = get_world_counter(),
interp = NativeInterpreter(world)

for x in T_FFUNC_VAL
push!(fs, x[3])
end
Expand All @@ -27,7 +29,7 @@ let fs = Any[typeinf_ext, typeinf, typeinf_edge, pure_eval_call, run_passes],
typ[i] = typ[i].ub
end
end
typeinf_type(m[3], Tuple{typ...}, m[2], Params(world))
typeinf_type(interp, m[3], Tuple{typ...}, m[2])
end
end
end
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ using .Sort
# compiler #
############

include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")

include("compiler/inferenceresult.jl")
include("compiler/params.jl")
include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
Expand Down
12 changes: 0 additions & 12 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,6 @@

const EMPTY_VECTOR = Vector{Any}()

mutable struct InferenceResult
linfo::MethodInstance
argtypes::Vector{Any}
overridden_by_const::BitVector
result # ::Type, or InferenceState if WIP
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes)
return new(linfo, argtypes, overridden_by_const, Any, nothing)
end
end

function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
overridden_by_const::Bool)
Expand Down
18 changes: 9 additions & 9 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const LineNum = Int

mutable struct InferenceState
params::Params # describes how to compute the result
params::InferenceParams
result::InferenceResult # remember where to put the result
linfo::MethodInstance
sptypes::Vector{Any} # types of static parameter
Expand All @@ -13,6 +13,7 @@ mutable struct InferenceState

# info on the state of inference and the linfo
src::CodeInfo
world::UInt
min_valid::UInt
max_valid::UInt
nargs::Int
Expand Down Expand Up @@ -47,7 +48,7 @@ mutable struct InferenceState

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult, src::CodeInfo,
cached::Bool, params::Params)
cached::Bool, interp::AbstractInterpreter)
linfo = result.linfo
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)
Expand Down Expand Up @@ -95,9 +96,9 @@ mutable struct InferenceState
max_valid = src.max_world == typemax(UInt) ?
get_world_counter() : src.max_world
frame = new(
params, result, linfo,
InferenceParams(interp), result, linfo,
sp, slottypes, inmodule, 0,
src, min_valid, max_valid,
src, get_world_counter(interp), min_valid, max_valid,
nargs, s_types, s_edges,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Expand All @@ -108,17 +109,17 @@ mutable struct InferenceState
cached, false, false, false,
IdDict{Any, Tuple{Any, UInt, UInt}}())
result.result = frame
cached && push!(params.cache, result)
cached && push!(get_inference_cache(interp), result)
return frame
end
end

function InferenceState(result::InferenceResult, cached::Bool, params::Params)
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
# prepare an InferenceState object for inferring lambda
src = retrieve_code_info(result.linfo)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
return InferenceState(result, src, cached, params)
return InferenceState(result, src, cached, interp)
end

function sptypes_from_meth_instance(linfo::MethodInstance)
Expand Down Expand Up @@ -195,8 +196,7 @@ _topmod(sv::InferenceState) = _topmod(sv.mod)
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.params.world <= sv.max_valid,
"invalid age range update")
@assert(sv.min_valid <= sv.world <= sv.max_valid, "invalid age range update")
nothing
end

Expand Down
50 changes: 26 additions & 24 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,38 @@
#####################

mutable struct OptimizationState
params::OptimizationParams
linfo::MethodInstance
calledges::Vector{Any}
src::CodeInfo
mod::Module
nargs::Int
world::UInt
min_valid::UInt
max_valid::UInt
params::Params
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
# cached results of calling `_methods_by_ftype` from inference, including
# `min_valid` and `max_valid`
matching_methods_cache::IdDict{Any, Tuple{Any, UInt, UInt}}
function OptimizationState(frame::InferenceState)
# TODO: This will be eliminated once optimization no longer needs to do method lookups
interp::AbstractInterpreter
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
s_edges = frame.stmt_edges[1]
if s_edges === nothing
s_edges = []
frame.stmt_edges[1] = s_edges
end
src = frame.src
return new(frame.linfo,
return new(params, frame.linfo,
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
frame.min_valid, frame.max_valid,
frame.params, frame.sptypes, frame.slottypes, false,
frame.matching_methods_cache)
frame.world, frame.min_valid, frame.max_valid,
frame.sptypes, frame.slottypes, false,
frame.matching_methods_cache, interp)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
params::Params)
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
# prepare src for running optimization passes
# if it isn't already
nssavalues = src.ssavaluetypes
Expand All @@ -57,19 +59,19 @@ mutable struct OptimizationState
inmodule = linfo.def::Module
nargs = 0
end
return new(linfo,
return new(params, linfo,
s_edges::Vector{Any},
src, inmodule, nargs,
UInt(1), get_world_counter(),
params, sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt}}())
get_world_counter(), UInt(1), get_world_counter(),
sptypes_from_meth_instance(linfo), slottypes, false,
IdDict{Any, Tuple{Any, UInt, UInt}}(), interp)
end
end

function OptimizationState(linfo::MethodInstance, params::Params)
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
src = retrieve_code_info(linfo)
src === nothing && return nothing
return OptimizationState(linfo, src, params)
return OptimizationState(linfo, src, params, interp)
end


Expand Down Expand Up @@ -109,7 +111,7 @@ _topmod(sv::OptimizationState) = _topmod(sv.mod)
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
@assert(sv.min_valid <= sv.params.world <= sv.max_valid,
@assert(sv.min_valid <= sv.world <= sv.max_valid,
"invalid age range update")
nothing
end
Expand All @@ -127,10 +129,10 @@ function add_backedge!(li::CodeInstance, caller::OptimizationState)
nothing
end

function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
function isinlineable(m::Method, me::OptimizationState, params::OptimizationParams, bonus::Int=0)
# compute the cost (size) of inlining this code
inlineable = false
cost_threshold = me.params.inline_cost_threshold
cost_threshold = params.inline_cost_threshold
if m.module === _topmod(m.module)
# a few functions get special treatment
name = m.name
Expand All @@ -145,7 +147,7 @@ function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
end
end
if !inlineable
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, me.params, cost_threshold + bonus)
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, params, cost_threshold + bonus)
end
return inlineable
end
Expand All @@ -168,7 +170,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
end

# run the optimization work
function optimize(opt::OptimizationState, @nospecialize(result))
function optimize(opt::OptimizationState, params::OptimizationParams, @nospecialize(result))
def = opt.linfo.def
nargs = Int(opt.nargs) - 1
@timeit "optimizer" ir = run_passes(opt.src, nargs, opt)
Expand Down Expand Up @@ -247,13 +249,13 @@ function optimize(opt::OptimizationState, @nospecialize(result))
else
bonus = 0
if result Tuple && !isbitstype(widenconst(result))
bonus = opt.params.inline_tupleret_bonus
bonus = params.inline_tupleret_bonus
end
if opt.src.inlineable
# For functions declared @inline, increase the cost threshold 20x
bonus += opt.params.inline_cost_threshold*19
bonus += params.inline_cost_threshold*19
end
opt.src.inlineable = isinlineable(def, opt, bonus)
opt.src.inlineable = isinlineable(def, opt, params, bonus)
end
end
nothing
Expand Down Expand Up @@ -282,7 +284,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
# known return type
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))

function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::Params)
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams)
head = ex.head
if is_meta_expr_head(head)
return 0
Expand Down Expand Up @@ -372,7 +374,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
end

function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
params::Params, cost_threshold::Integer=params.inline_cost_threshold)
params::OptimizationParams, cost_threshold::Integer=params.inline_cost_threshold)
bodycost::Int = 0
for line = 1:length(body)
stmt = body[line]
Expand Down
72 changes: 0 additions & 72 deletions base/compiler/params.jl

This file was deleted.

Loading

0 comments on commit c5fcd73

Please sign in to comment.