Skip to content

Commit

Permalink
keep undefinedness of static parameters as a separate field
Browse files Browse the repository at this point in the history
This is an alternative to #46791.
Will be filed as a PR to check the performance difference.
  • Loading branch information
aviatesk committed Feb 7, 2023
1 parent 790e667 commit 99a272a
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 54 deletions.
8 changes: 2 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2192,11 +2192,7 @@ function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::
nothrow = false
if 1 <= n <= length(sv.sptypes)
rt = sv.sptypes[n]
if is_maybeundefsp(rt)
rt = unwrap_maybeundefsp(rt)
else
nothrow = true
end
nothrow = !sv.spundefs[n]
end
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; nothrow))
return rt
Expand Down Expand Up @@ -2460,7 +2456,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
elseif isexpr(sym, :static_parameter)
n = sym.args[1]::Int
if 1 <= n <= length(sv.sptypes)
if !is_maybeundefsp(sv.sptypes, n)
if !sv.spundefs[n]
t = Const(true)
end
end
Expand Down
31 changes: 7 additions & 24 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ mutable struct InferenceState
world::UInt
mod::Module
sptypes::Vector{Any}
spundefs::BitVector
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG
Expand Down Expand Up @@ -141,7 +142,7 @@ mutable struct InferenceState
world = get_world_counter(interp)
def = linfo.def
mod = isa(def, Method) ? def.module : def
sptypes = sptypes_from_meth_instance(linfo)
sptypes, spundefs = sptypes_from_meth_instance(linfo)
code = src.code::Vector{Any}
cfg = compute_basic_blocks(code)

Expand Down Expand Up @@ -185,7 +186,7 @@ mutable struct InferenceState
cached = cache === :global

frame = new(
linfo, world, mod, sptypes, slottypes, src, cfg,
linfo, world, mod, sptypes, spundefs, slottypes, src, cfg,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
Expand Down Expand Up @@ -401,25 +402,7 @@ function constrains_param(var::TypeVar, @nospecialize(typ), covariant::Bool)
return false
end

"""
MaybeUndefSP(typ)
is_maybeundefsp(typ) -> Bool
unwrap_maybeundefsp(typ) -> Any
A special wrapper that represents a static parameter that could be undefined at runtime.
This does not participate in the native type system nor the inference lattice,
and it thus should be always unwrapped when performing any type or lattice operations on it.
"""
struct MaybeUndefSP
typ
MaybeUndefSP(@nospecialize typ) = new(typ)
end
is_maybeundefsp(@nospecialize typ) = isa(typ, MaybeUndefSP)
unwrap_maybeundefsp(@nospecialize typ) = isa(typ, MaybeUndefSP) ? typ.typ : typ
is_maybeundefsp(sptypes::Vector{Any}, idx::Int) = is_maybeundefsp(sptypes[idx])
unwrap_maybeundefsp(sptypes::Vector{Any}, idx::Int) = unwrap_maybeundefsp(sptypes[idx])

const EMPTY_SPTYPES = Any[]
const EMPTY_SPTYPES = (Any[], falses(0))

function sptypes_from_meth_instance(linfo::MethodInstance)
def = linfo.def
Expand All @@ -437,10 +420,11 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
else
sp = collect(Any, linfo.sparam_vals)
end
spundefs = falses(length(sp))
for i = 1:length(sp)
v = sp[i]
if v isa TypeVar
maybe_undef = !constrains_param(v, linfo.specTypes, true)
spundefs[i] = !constrains_param(v, linfo.specTypes, true)
temp = sig
for j = 1:i-1
temp = temp.body
Expand Down Expand Up @@ -480,15 +464,14 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
ty = UnionAll(tv, Type{tv})
end
@label ty_computed
maybe_undef && (ty = MaybeUndefSP(ty))
elseif isvarargtype(v)
ty = Int
else
ty = Const(v)
end
sp[i] = ty
end
return sp
return sp, spundefs
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
Expand Down
15 changes: 8 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
stmt_info::Vector{CallInfo}
mod::Module
sptypes::Vector{Any}
spundefs::BitVector
slottypes::Vector{Any}
inlining::InliningState{Interp}
cfg::Union{Nothing,CFG}
Expand All @@ -155,7 +156,7 @@ function OptimizationState(frame::InferenceState, params::OptimizationParams,
inlining = InliningState(frame, params, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg)
frame.sptypes, frame.spundefs, frame.slottypes, inlining, cfg)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
Expand All @@ -166,7 +167,7 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
else
nssavalues = length(src.ssavaluetypes::Vector{Any})
end
sptypes = sptypes_from_meth_instance(linfo)
sptypes, spundefs = sptypes_from_meth_instance(linfo)
nslots = length(src.slotflags)
slottypes = src.slottypes
if slottypes === nothing
Expand All @@ -179,7 +180,7 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(params, interp)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, spundefs, slottypes, inlining, nothing)
end
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
src = retrieve_code_info(linfo)
Expand Down Expand Up @@ -268,8 +269,8 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
(; head, args) = stmt
if head === :static_parameter
# if we aren't certain enough about the type, it might be an UndefVarError at runtime
sptypes = isa(src, IRCode) ? src.sptypes : src.ir.sptypes
nothrow = !is_maybeundefsp(sptypes, args[1]::Int)
spundefs = isa(src, IRCode) ? src.spundefs : src.ir.spundefs
nothrow = !spundefs[args[1]::Int]
return (true, nothrow, nothrow)
end
if head === :call
Expand Down Expand Up @@ -377,7 +378,7 @@ function argextype(
sptypes::Vector{Any}, slottypes::Vector{Any})
if isa(x, Expr)
if x.head === :static_parameter
return unwrap_maybeundefsp(sptypes, x.args[1]::Int)
return sptypes[x.args[1]::Int]
elseif x.head === :boundscheck
return Bool
elseif x.head === :copyast
Expand Down Expand Up @@ -699,7 +700,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
if cfg === nothing
cfg = compute_basic_blocks(code)
end
return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes, sv.spundefs)
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
Expand Down
12 changes: 7 additions & 5 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,20 @@ struct IRCode
stmts::InstructionStream
argtypes::Vector{Any}
sptypes::Vector{Any}
spundefs::BitVector
linetable::Vector{LineInfoNode}
cfg::CFG
new_nodes::NewNodeStream
meta::Vector{Expr}

function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{Any})
return new(stmts, argtypes, sptypes, linetable, cfg, NewNodeStream(), meta)
function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode},
argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{Any}, spundefs::BitVector)
return new(stmts, argtypes, sptypes, spundefs, linetable, cfg, NewNodeStream(), meta)
end
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
return new(stmts, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta)
return new(stmts, ir.argtypes, ir.sptypes, ir.spundefs, ir.linetable, cfg, new_nodes, ir.meta)
end
global copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes),
global copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes), copy(ir.spundefs),
copy(ir.linetable), copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta))
end

Expand All @@ -358,7 +360,7 @@ for debugging and unit testing of IRCode APIs. The compiler itself should genera
from the frontend or one of the caches.
"""
function IRCode()
ir = IRCode(InstructionStream(1), CFG([BasicBlock(1:1, Int[], Int[])], Int[1]), LineInfoNode[], Any[], Expr[], Any[])
ir = IRCode(InstructionStream(1), CFG([BasicBlock(1:1, Int[], Int[])], Int[1]), LineInfoNode[], Any[], Expr[], Any[], falses(0))
ir[SSAValue(1)][:inst] = ReturnNode(nothing)
ir[SSAValue(1)][:type] = Nothing
ir[SSAValue(1)][:flag] = 0x00
Expand Down
12 changes: 6 additions & 6 deletions base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ This should be used with caution as it is a in-place transformation where the fi
the original `ci::CodeInfo` are modified.
"""
function inflate_ir!(ci::CodeInfo, linfo::MethodInstance)
sptypes = sptypes_from_meth_instance(linfo)
sptypes, spundefs = sptypes_from_meth_instance(linfo)
if ci.inferred
argtypes, _ = matching_cache_argtypes(fallback_lattice, linfo)
else
argtypes = Any[ Any for i = 1:length(ci.slotflags) ]
end
return inflate_ir!(ci, sptypes, argtypes)
return inflate_ir!(ci, sptypes, spundefs, argtypes)
end
function inflate_ir!(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
function inflate_ir!(ci::CodeInfo, sptypes::Vector{Any}, spundefs::BitVector, argtypes::Vector{Any})
code = ci.code
cfg = compute_basic_blocks(code)
for i = 1:length(code)
Expand Down Expand Up @@ -46,7 +46,7 @@ function inflate_ir!(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
linetable = collect(LineInfoNode, linetable::Vector{Any})::Vector{LineInfoNode}
end
meta = Expr[]
return IRCode(stmts, cfg, linetable, argtypes, meta, sptypes)
return IRCode(stmts, cfg, linetable, argtypes, meta, sptypes, spundefs)
end

"""
Expand All @@ -58,8 +58,8 @@ Non-destructive version of `inflate_ir!`.
Mainly used for testing or interactive use.
"""
inflate_ir(ci::CodeInfo, linfo::MethodInstance) = inflate_ir!(copy(ci), linfo)
inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) = inflate_ir!(copy(ci), sptypes, argtypes)
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ ci.slottypes === nothing ? Any : (ci.slottypes::Vector{Any})[i] for i = 1:length(ci.slotflags) ])
inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, spundefs::BitVector, argtypes::Vector{Any}) = inflate_ir!(copy(ci), sptypes, spundefs, argtypes)
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], falses(0), Any[ ci.slottypes === nothing ? Any : (ci.slottypes::Vector{Any})[i] for i = 1:length(ci.slotflags) ])

function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
@assert isempty(ir.new_nodes)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ end
function typ_for_val(@nospecialize(x), ci::CodeInfo, sptypes::Vector{Any}, idx::Int, slottypes::Vector{Any})
if isa(x, Expr)
if x.head === :static_parameter
return unwrap_maybeundefsp(sptypes, x.args[1]::Int)
return sptypes[x.args[1]::Int]
elseif x.head === :boundscheck
return Bool
elseif x.head === :copyast
Expand Down
2 changes: 1 addition & 1 deletion stdlib/InteractiveUtils/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ end

@testset "code_llvm on opaque_closure" begin
let ci = code_typed(+, (Int, Int))[1][1]
ir = Core.Compiler.inflate_ir(ci, Any[], Any[Tuple{}, Int, Int])
ir = Core.Compiler.inflate_ir(ci, Any[], Core.Compiler.falses(0), Any[Tuple{}, Int, Int])
oc = Core.OpaqueClosure(ir)
@test (code_llvm(devnull, oc, Tuple{Int, Int}); true)
let io = IOBuffer()
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ let m = Meta.@lower 1 + 1
nstmts = length(src.code)
src.codelocs = fill(Int32(1), nstmts)
src.ssaflags = fill(Int32(0), nstmts)
ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any])
ir = Core.Compiler.inflate_ir(src, Any[], Core.Compiler.falses(0), Any[Any, Any])
@test Core.Compiler.verify_ir(ir) === nothing
ir = @test_nowarn Core.Compiler.sroa_pass!(ir)
@test Core.Compiler.verify_ir(ir) === nothing
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ let cfg = CFG(BasicBlock[
make_bb([2, 3] , [] ),
], Int[])
insts = Compiler.InstructionStream([], [], Any[], Int32[], UInt8[])
code = Compiler.IRCode(insts, cfg, LineInfoNode[], [], Expr[], [])
code = Compiler.IRCode(insts, cfg, LineInfoNode[], [], Expr[], [], Core.Compiler.falses(0))
compact = Compiler.IncrementalCompact(code, true)
@test length(compact.result_bbs) == 4 && 0 in compact.result_bbs[3].preds
end
Expand Down
4 changes: 2 additions & 2 deletions test/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ let ci = code_typed(+, (Int, Int))[1][1]
@test OpaqueClosure(ir; nargs=2, isva=false)(40, 2) == 42
@test OpaqueClosure(ci)(40, 2) == 42

ir = Core.Compiler.inflate_ir(ci, Any[], Any[Tuple{}, Int, Int])
ir = Core.Compiler.inflate_ir(ci, Any[], Core.Compiler.falses(0), Any[Tuple{}, Int, Int])
@test OpaqueClosure(ir; nargs=2, isva=false)(40, 2) == 42
@test isa(OpaqueClosure(ir; nargs=2, isva=false), Core.OpaqueClosure{Tuple{Int, Int}, Int})
@test_throws TypeError OpaqueClosure(ir; nargs=2, isva=false)(40.0, 2)
Expand All @@ -264,7 +264,7 @@ let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
@test_throws MethodError oc(1,2,3)
end

ir = Core.Compiler.inflate_ir(ci, Any[], Any[Tuple{}, Int, Tuple{Int}])
ir = Core.Compiler.inflate_ir(ci, Any[], Core.Compiler.falses(0), Any[Tuple{}, Int, Tuple{Int}])
let oc = OpaqueClosure(ir; nargs=2, isva=true)
@test oc(40, 2) === (40, (2,))
@test_throws MethodError oc(1,2,3)
Expand Down

0 comments on commit 99a272a

Please sign in to comment.