Skip to content

Commit

Permalink
make IRInterpretationState mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 24, 2023
1 parent 99464af commit 2d79359
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 59 deletions.
3 changes: 2 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -990,8 +990,9 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
mi_cache = WorldView(code_cache(interp), world)
code = get(mi_cache, mi, nothing)
if code !== nothing
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world, sv)
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world)
if irsv !== nothing
irsv.parent = sv
irinterp = switch_to_irinterp(interp)
rt, nothrow = ir_abstract_constant_propagation(irinterp, irsv)
@assert !(rt isa Conditional || rt isa MustAlias) "invalid lattice element returned from irinterp"
Expand Down
97 changes: 46 additions & 51 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,52 +644,47 @@ end
# =====================

# TODO add `result::InferenceResult` and put the irinterp result into the inference cache?
struct IRInterpretationState
method_info::MethodInfo
ir::IRCode
mi::MethodInstance
world::UInt
curridx::RefValue{Int}
argtypes_refined::Vector{Bool}
sptypes::Vector{VarState}
tpdum::TwoPhaseDefUseMap
ssa_refined::BitSet
lazydomtree::LazyDomtree
valid_worlds::RefValue{WorldRange}
edges::Vector{Any}
parent # ::AbsIntState
end

# AbsIntState
# ===========

const AbsIntState = Union{InferenceState,IRInterpretationState}
mutable struct IRInterpretationState
const method_info::MethodInfo
const ir::IRCode
const mi::MethodInstance
const world::UInt
curridx::Int
const argtypes_refined::Vector{Bool}
const sptypes::Vector{VarState}
const tpdum::TwoPhaseDefUseMap
const ssa_refined::BitSet
const lazydomtree::LazyDomtree
valid_worlds::WorldRange
const edges::Vector{Any}
parent # ::Union{Nothing,AbsIntState}

function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
world::UInt, min_world::UInt, max_world::UInt, parent::AbsIntState)
curridx = RefValue(1)
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(given_argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
world::UInt, min_world::UInt, max_world::UInt)
curridx = 1
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(given_argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi)
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
for i = 1:length(given_argtypes)]
empty!(ir.argtypes)
append!(ir.argtypes, given_argtypes)
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
ssa_refined = BitSet()
lazydomtree = LazyDomtree(ir)
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
edges = Any[]
parent = nothing
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazydomtree, valid_worlds, edges, parent)
end
given_argtypes = va_process_argtypes(optimizer_lattice(interp), given_argtypes, mi)
argtypes_refined = Bool[!(optimizer_lattice(interp), ir.argtypes[i], given_argtypes[i])
for i = 1:length(given_argtypes)]
empty!(ir.argtypes)
append!(ir.argtypes, given_argtypes)
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
ssa_refined = BitSet()
lazydomtree = LazyDomtree(ir)
valid_worlds = RefValue(WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world))
edges = Any[]
return IRInterpretationState(method_info, ir, mi, world, curridx, argtypes_refined,
ir.sptypes, tpdum, ssa_refined, lazydomtree,
valid_worlds, edges, parent)
end

function IRInterpretationState(interp::AbstractInterpreter,
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt, parent::AbsIntState)
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt)
@assert code.def === mi
src = @atomic :monotonic code.inferred
if isa(src, Vector{UInt8})
Expand All @@ -700,9 +695,14 @@ function IRInterpretationState(interp::AbstractInterpreter,
method_info = MethodInfo(src)
ir = inflate_ir(src, mi)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
src.min_world, src.max_world, parent)
src.min_world, src.max_world)
end

# AbsIntState
# ===========

const AbsIntState = Union{InferenceState,IRInterpretationState}

frame_instance(sv::InferenceState) = sv.linfo
frame_instance(sv::IRInterpretationState) = sv.mi

Expand Down Expand Up @@ -746,16 +746,11 @@ has_conditional(𝕃::AbstractLattice, ::InferenceState) = has_conditional(𝕃)
has_conditional(::AbstractLattice, ::IRInterpretationState) = false

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert sv.world in valid_worlds "invalid age range update"
return valid_worlds
end
function update_valid_age!(irsv::IRInterpretationState, valid_worlds::WorldRange)
valid_worlds = irsv.valid_worlds[] = intersect(valid_worlds, irsv.valid_worlds[])
@assert irsv.world in valid_worlds "invalid age range update"
return valid_worlds
end

"""
AbsIntStackUnwind(sv::AbsIntState)
Expand Down Expand Up @@ -809,13 +804,13 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci
end

get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx[]][:flag]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

add_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] |= flag
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx[]][:flag] |= flag
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] |= flag

sub_curr_ssaflag!(sv::InferenceState, flag::UInt8) = sv.src.ssaflags[sv.currpc] &= ~flag
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx[]][:flag] &= ~flag
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt8) = sv.ir.stmts[sv.curridx][:flag] &= ~flag

merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects) =
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
Expand Down
15 changes: 8 additions & 7 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
if is_constprop_edge_recursed(mi, irsv)
return Pair{Any,Bool}(nothing, is_nothrow(effects))
end
newirsv = IRInterpretationState(interp, code, mi, argtypes, world, irsv)
newirsv = IRInterpretationState(interp, code, mi, argtypes, world)
if newirsv !== nothing
newirsv.parent = irsv
return _ir_abstract_constant_propagation(interp, newirsv)
end
return Pair{Any,Bool}(nothing, is_nothrow(effects))
Expand All @@ -68,7 +69,7 @@ end
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
si = StmtInfo(true) # TODO better job here?
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
irsv.ir.stmts[irsv.curridx[]][:info] = info
irsv.ir.stmts[irsv.curridx][:info] = info
return RTEffects(rt, effects)
end

Expand Down Expand Up @@ -216,7 +217,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
stmts = bbs[bb].stmts
lstmt = last(stmts)
for idx = stmts
irsv.curridx[] = idx
irsv.curridx = idx
inst = ir.stmts[idx][:inst]
typ = ir.stmts[idx][:type]
any_refined = false
Expand Down Expand Up @@ -264,7 +265,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
stmts = bbs[bb].stmts
lstmt = last(stmts)
for idx = stmts
irsv.curridx[] = idx
irsv.curridx = idx
inst = ir.stmts[idx][:inst]
for ur in userefs(inst)
val = ur[]
Expand All @@ -288,7 +289,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
stmts = bbs[bb].stmts
lstmt = last(stmts)
for idx = stmts
irsv.curridx[] = idx
irsv.curridx = idx
inst = ir.stmts[idx][:inst]
for ur in userefs(inst)
val = ur[]
Expand All @@ -308,7 +309,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
end
while !isempty(stmt_ip)
idx = popfirst!(stmt_ip)
irsv.curridx[] = idx
irsv.curridx = idx
inst = ir.stmts[idx][:inst]
typ = ir.stmts[idx][:type]
if reprocess_instruction!(interp,
Expand Down Expand Up @@ -340,7 +341,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
end
end

if last(irsv.valid_worlds[]) >= get_world_counter()
if last(irsv.valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(frame_instance(irsv), irsv.edges)
Expand Down

0 comments on commit 2d79359

Please sign in to comment.