Skip to content

Commit

Permalink
EA: update EAUtils.jl for latest compiler changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Sep 13, 2023
1 parent 572fa50 commit 8fea7c5
Showing 1 changed file with 63 additions and 66 deletions.
129 changes: 63 additions & 66 deletions test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ const EA = CC.EscapeAnalysis
# entries
# -------

import Base: unwrap_unionall, rewrap_unionall
import InteractiveUtils: gen_call_with_extracted_types_and_kwargs
using Base: unwrap_unionall, rewrap_unionall
using InteractiveUtils: gen_call_with_extracted_types_and_kwargs

"""
@code_escapes [options...] f(args...)
Expand Down Expand Up @@ -38,18 +38,17 @@ Runs the escape analysis on optimized IR of a generic function call with the giv
"""
function code_escapes(@nospecialize(f), @nospecialize(types=Base.default_tt(f));
world::UInt = get_world_counter(),
interp::Core.Compiler.AbstractInterpreter = Core.Compiler.NativeInterpreter(world),
debuginfo::Symbol = :none,
optimize::Bool = true)
tt = Base.signature_type(f, types)
interp = EscapeAnalyzer(interp, tt, optimize)
interp = EscapeAnalyzer(world, tt, optimize)
results = Base.code_typed_by_type(tt; optimize=true, world, interp)
isone(length(results)) || throw(ArgumentError("`code_escapes` only supports single analysis result"))
return EscapeResult(interp.ir, interp.state, interp.linfo, debuginfo === :source)
return EscapeResult(interp.ir, interp.state, interp.mi, debuginfo === :source)
end

# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.)
__clear_cache!() = empty!(GLOBAL_CODE_CACHE)
__clear_cache!() = empty!(GLOBAL_EA_CODE_CACHE)

# AbstractInterpreter
# -------------------
Expand All @@ -59,13 +58,13 @@ import .CC:
AbstractInterpreter, NativeInterpreter, WorldView, WorldRange,
InferenceParams, OptimizationParams, get_world_counter, get_inference_cache, code_cache
# usings
import Core:
using Core:
CodeInstance, MethodInstance, CodeInfo
import .CC:
using .CC:
InferenceResult, OptimizationState, IRCode, copy as cccopy,
@timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, sroa_pass!,
adce_pass!, JLOptions, verify_ir, verify_linetable
import .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable
using .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable

# when working outside of Core.Compiler,
# cache entire escape state for later inspection and debugging
Expand All @@ -75,25 +74,31 @@ struct EscapeCache
ir::IRCode # preserved just for debugging purpose
end

mutable struct EscapeAnalyzer{State} <: AbstractInterpreter
native::NativeInterpreter
cache::IdDict{InferenceResult,EscapeCache}
entry_tt
optimize::Bool
mutable struct EscapeAnalyzer <: AbstractInterpreter
const world::UInt
const inf_params::InferenceParams
const opt_params::OptimizationParams
const inf_cache::Vector{InferenceResult}
const cache::IdDict{InferenceResult,EscapeCache}
const entry_tt
const optimize::Bool
ir::IRCode
state::State
linfo::MethodInstance
EscapeAnalyzer(native::NativeInterpreter, @nospecialize(tt), optimize::Bool) =
new{EscapeState}(native, IdDict{InferenceResult,EscapeCache}(), tt, optimize)
state::EscapeState
mi::MethodInstance
function EscapeAnalyzer(world::UInt, @nospecialize(tt), optimize::Bool)
inf_params = InferenceParams()
opt_params = OptimizationParams()
inf_cache = InferenceResult[]
return new(world, inf_params, opt_params, inf_cache, IdDict{InferenceResult,EscapeCache}(), tt, optimize)
end
end

CC.InferenceParams(interp::EscapeAnalyzer) = InferenceParams(interp.native)
CC.OptimizationParams(interp::EscapeAnalyzer) = OptimizationParams(interp.native)
CC.get_world_counter(interp::EscapeAnalyzer) = get_world_counter(interp.native)
CC.InferenceParams(interp::EscapeAnalyzer) = interp.inf_params
CC.OptimizationParams(interp::EscapeAnalyzer) = interp.opt_params
CC.get_world_counter(interp::EscapeAnalyzer) = interp.world
CC.get_inference_cache(interp::EscapeAnalyzer) = interp.inf_cache

CC.get_inference_cache(interp::EscapeAnalyzer) = get_inference_cache(interp.native)

const GLOBAL_CODE_CACHE = IdDict{MethodInstance,CodeInstance}()
const GLOBAL_EA_CODE_CACHE = IdDict{MethodInstance,CodeInstance}()

function CC.code_cache(interp::EscapeAnalyzer)
worlds = WorldRange(get_world_counter(interp))
Expand All @@ -102,36 +107,36 @@ end

struct GlobalCache end

CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_CODE_CACHE, mi)
CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_EA_CODE_CACHE, mi)

CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_CODE_CACHE, mi, default)
CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_EA_CODE_CACHE, mi, default)

CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_CODE_CACHE, mi)
CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_EA_CODE_CACHE, mi)

function CC.setindex!(wvc::WorldView{GlobalCache}, ci::CodeInstance, mi::MethodInstance)
GLOBAL_CODE_CACHE[mi] = ci
GLOBAL_EA_CODE_CACHE[mi] = ci
add_callback!(mi) # register the callback on invalidation
return nothing
end

function add_callback!(linfo)
if !isdefined(linfo, :callbacks)
linfo.callbacks = Any[invalidate_cache!]
function add_callback!(mi)
if !isdefined(mi, :callbacks)
mi.callbacks = Any[invalidate_cache!]
else
if !any(@nospecialize(cb)->cb===invalidate_cache!, linfo.callbacks)
push!(linfo.callbacks, invalidate_cache!)
if !any(@nospecialize(cb)->cb===invalidate_cache!, mi.callbacks)
push!(mi.callbacks, invalidate_cache!)
end
end
return nothing
end

function invalidate_cache!(replaced, max_world, depth = 0)
delete!(GLOBAL_CODE_CACHE, replaced)
delete!(GLOBAL_EA_CODE_CACHE, replaced)

if isdefined(replaced, :backedges)
for mi in replaced.backedges
mi = mi::MethodInstance
if !haskey(GLOBAL_CODE_CACHE, mi)
if !haskey(GLOBAL_EA_CODE_CACHE, mi)
continue # otherwise fall into infinite loop
end
invalidate_cache!(mi, max_world, depth+1)
Expand All @@ -140,9 +145,9 @@ function invalidate_cache!(replaced, max_world, depth = 0)
return nothing
end

function CC.optimize(interp::EscapeAnalyzer,
opt::OptimizationState, caller::InferenceResult)
ir = run_passes_with_ea(interp, opt.src, opt, caller)
function CC.optimize(interp::EscapeAnalyzer, opt::OptimizationState, caller::InferenceResult)
ir = run_passes_ipo_safe_with_ea(interp, opt.src, opt, caller)
CC.ipo_dataflow_analysis!(interp, ir, caller)
return CC.finish(interp, opt, ir, caller)
end

Expand Down Expand Up @@ -170,18 +175,18 @@ function cache_escapes!(interp::EscapeAnalyzer,
end

function get_escape_cache(interp::EscapeAnalyzer)
return function (linfo::Union{InferenceResult,MethodInstance})
if isa(linfo, InferenceResult)
ecache = get(interp.cache, linfo, nothing)
return function (mi::Union{InferenceResult,MethodInstance})
if isa(mi, InferenceResult)
ecache = get(interp.cache, mi, nothing)
else
ecache = get(GLOBAL_ESCAPE_CACHE, linfo, nothing)
ecache = get(GLOBAL_ESCAPE_CACHE, mi, nothing)
end
return ecache !== nothing ? ecache.cache : nothing
end
end

function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::OptimizationState,
caller::InferenceResult)
function run_passes_ipo_safe_with_ea(interp::EscapeAnalyzer,
ci::CodeInfo, sv::OptimizationState, caller::InferenceResult)
@timeit "convert" ir = convert_to_ircode(ci, sv)
@timeit "slot2reg" ir = slot2reg(ir, ci, sv)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
Expand All @@ -204,7 +209,7 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati
# return back the result
interp.ir = cccopy(ir)
interp.state = state
interp.linfo = sv.linfo
interp.mi = sv.linfo
end
@timeit "Inlining" ir = ssa_inlining_pass!(ir, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
Expand All @@ -220,11 +225,11 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati
# return back the result
interp.ir = cccopy(ir)
interp.state = state
interp.linfo = sv.linfo
interp.mi = sv.linfo
end
@timeit "SROA" ir = sroa_pass!(ir)
@timeit "ADCE" ir = adce_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
@timeit "SROA" ir = sroa_pass!(ir, sv.inlining)
@timeit "ADCE" ir = adce_pass!(ir, sv.inlining)
@timeit "compact 3" ir = compact!(ir, true)
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
end
Expand All @@ -234,8 +239,8 @@ end
# printing
# --------

import Core: Argument, SSAValue
import .CC: widenconst, singleton_type
using Core: Argument, SSAValue
using .CC: widenconst, singleton_type

Base.getindex(estate::EscapeState, @nospecialize(x)) = CC.getindex(estate, x)

Expand Down Expand Up @@ -274,24 +279,16 @@ function Base.show(io::IO, x::EscapeInfo)
printstyled(io, name; color)
end
end
function Base.show(io::IO, ::MIME"application/prs.juno.inline", x::EscapeInfo)
name, color = get_name_color(x)
if isnothing(name)
return x # use fancy tree-view
else
printstyled(io, name; color)
end
end

struct EscapeResult
ir::IRCode
state::EscapeState
linfo::Union{Nothing,MethodInstance}
mi::Union{Nothing,MethodInstance}
source::Bool
function EscapeResult(ir::IRCode, state::EscapeState,
linfo::Union{Nothing,MethodInstance} = nothing,
source::Bool=false)
return new(ir, state, linfo, source)
mi::Union{Nothing,MethodInstance} = nothing,
source::Bool=false)
return new(ir, state, mi, source)
end
end
Base.show(io::IO, result::EscapeResult) = print_with_info(io, result)
Expand All @@ -301,7 +298,7 @@ Base.show(io::IO, result::EscapeResult) = print_with_info(io, result)
Base.show(io::IO, cached::EscapeCache) = show(io, EscapeResult(cached.ir, cached.state, nothing))

# adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897
function print_with_info(io::IO, (; ir, state, linfo, source)::EscapeResult)
function print_with_info(io::IO, (; ir, state, mi, source)::EscapeResult)
# print escape information on SSA values
function preprint(io::IO)
ft = ir.argtypes[1]
Expand All @@ -318,8 +315,8 @@ function print_with_info(io::IO, (; ir, state, linfo, source)::EscapeResult)
i state.nargs && print(io, ", ")
end
print(io, ')')
if !isnothing(linfo)
def = linfo.def
if !isnothing(mi)
def = mi.def
printstyled(io, " in ", (isa(def, Module) ? (def,) : (def.module, " at ", def.file, ':', def.line))...; color=:bold)
end
println(io)
Expand Down

0 comments on commit 8fea7c5

Please sign in to comment.