Skip to content

Commit

Permalink
Merge pull request #30878 from JuliaLang/jb/sptypes
Browse files Browse the repository at this point in the history
some improvements to static parameter handling in inference
  • Loading branch information
JeffBezanson committed Feb 2, 2019
2 parents e7e726b + 36d490a commit 5aa2462
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 101 deletions.
25 changes: 5 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -863,21 +863,6 @@ function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
nothing
end

# convert an inferred static parameter value to the inferred type of a static_parameter expression
function sparam_type(@nospecialize(val))
if isa(val, TypeVar)
if Any <: val.ub
# static param bound to typevar
# if the tvar is not known to refer to anything more specific than Any,
# the static param might actually be an integer, symbol, etc.
return Any
else
return UnionAll(val, Type{val})
end
end
return AbstractEvalConstant(val)
end

function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
if isa(e, QuoteNode)
return AbstractEvalConstant((e::QuoteNode).value)
Expand Down Expand Up @@ -940,8 +925,8 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
elseif e.head === :static_parameter
n = e.args[1]
t = Any
if 1 <= n <= length(sv.sp)
t = sparam_type(sv.sp[n])
if 1 <= n <= length(sv.sptypes)
t = sv.sptypes[n]
end
elseif e.head === :method
t = (length(e.args) == 1) ? Any : Nothing
Expand Down Expand Up @@ -975,9 +960,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
end
elseif isa(sym, Expr) && sym.head === :static_parameter
n = sym.args[1]
if 1 <= n <= length(sv.sp)
val = sv.sp[n]
if !isa(val, TypeVar)
if 1 <= n <= length(sv.sptypes)
spty = sv.sptypes[n]
if isa(spty, Const)
t = Const(true)
end
end
Expand Down
71 changes: 45 additions & 26 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ const LineNum = Int
mutable struct InferenceState
params::Params # describes how to compute the result
result::InferenceResult # remember where to put the result
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
sp::SimpleVector # static parameters
linfo::MethodInstance # used here for the tuple (specTypes, env, Method) and world-age validity
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
mod::Module
currpc::LineNum
Expand Down Expand Up @@ -48,7 +48,7 @@ mutable struct InferenceState
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)

sp = spvals_from_meth_instance(linfo::MethodInstance)
sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
Expand Down Expand Up @@ -120,7 +120,7 @@ function InferenceState(result::InferenceResult, cached::Bool, params::Params)
return InferenceState(result, src, cached, params)
end

function spvals_from_meth_instance(linfo::MethodInstance)
function sptypes_from_meth_instance(linfo::MethodInstance)
toplevel = !isa(linfo.def, Method)
if !toplevel && isempty(linfo.sparam_vals) && !isempty(linfo.def.sparam_syms)
# linfo is unspecialized
Expand All @@ -130,35 +130,54 @@ function spvals_from_meth_instance(linfo::MethodInstance)
push!(sp, sig.var)
sig = sig.body
end
sp = svec(sp...)
else
sp = linfo.sparam_vals
if _any(t->isa(t,TypeVar), sp)
sp = collect(Any, sp)
end
sp = collect(Any, linfo.sparam_vals)
end
if !isa(sp, SimpleVector)
for i = 1:length(sp)
v = sp[i]
if v isa TypeVar
ub = v.ub
while ub isa TypeVar
ub = ub.ub
end
if has_free_typevars(ub)
ub = Any
for i = 1:length(sp)
v = sp[i]
if v isa TypeVar
ub = v.ub
while ub isa TypeVar
ub = ub.ub
end
if has_free_typevars(ub)
ub = Any
end
lb = v.lb
while lb isa TypeVar
lb = lb.lb
end
if has_free_typevars(lb)
lb = Bottom
end
if Any <: ub && lb <: Bottom
ty = Any
# if this parameter came from arg::Type{T}, we know that T::Type
sig = linfo.def.sig
temp = sig
for j = 1:i-1
temp = temp.body
end
lb = v.lb
while lb isa TypeVar
lb = lb.lb
Pi = temp.var
while temp isa UnionAll
temp = temp.body
end
if has_free_typevars(lb)
lb = Bottom
sigtypes = temp.parameters
for j = 1:length(sigtypes)
tj = sigtypes[j]
if isType(tj) && tj.parameters[1] === Pi
ty = Type
break
end
end
sp[i] = TypeVar(v.name, lb, ub)
else
tv = TypeVar(v.name, lb, ub)
ty = UnionAll(tv, Type{tv})
end
else
ty = Const(v)
end
sp = svec(sp...)
sp[i] = ty
end
return sp
end
Expand Down
32 changes: 16 additions & 16 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mutable struct OptimizationState
min_valid::UInt
max_valid::UInt
params::Params
sp::SimpleVector # static parameters
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
function OptimizationState(frame::InferenceState)
Expand All @@ -27,7 +27,7 @@ mutable struct OptimizationState
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
frame.min_valid, frame.max_valid,
frame.params, frame.sp, frame.slottypes, false)
frame.params, frame.sptypes, frame.slottypes, false)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
params::Params)
Expand All @@ -54,7 +54,7 @@ mutable struct OptimizationState
s_edges::Vector{Any},
src, inmodule, nargs,
min_world(linfo), max_world(linfo),
params, spvals_from_meth_instance(linfo), slottypes, false)
params, sptypes_from_meth_instance(linfo), slottypes, false)
end
end

Expand Down Expand Up @@ -135,7 +135,7 @@ function isinlineable(m::Method, me::OptimizationState, bonus::Int=0)
end
end
if !inlineable
inlineable = inline_worthy(me.src.code, me.src, me.sp, me.slottypes, me.params, cost_threshold + bonus)
inlineable = inline_worthy(me.src.code, me.src, me.sptypes, me.slottypes, me.params, cost_threshold + bonus)
end
return inlineable
end
Expand All @@ -148,7 +148,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return false
end
if isa(stmt, GotoIfNot)
t = argextype(stmt.cond, ir, ir.spvals)
t = argextype(stmt.cond, ir, ir.sptypes)
return !(t Bool)
end
if isa(stmt, Expr)
Expand All @@ -175,7 +175,7 @@ function optimize(opt::OptimizationState, @nospecialize(result))
proven_pure = true
for i in 1:length(ir.stmts)
stmt = ir.stmts[i]
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.spvals)
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, ir.types[i], ir, ir.sptypes)
proven_pure = false
break
end
Expand Down Expand Up @@ -268,19 +268,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
# known return type
isknowntype(@nospecialize T) = (T == Union{}) || isconcretetype(T)

function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any}, params::Params)
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::Params)
head = ex.head
if is_meta_expr_head(head)
return 0
elseif head === :call
farg = ex.args[1]
ftyp = argextype(farg, src, spvals, slottypes)
ftyp = argextype(farg, src, sptypes, slottypes)
if ftyp === IntrinsicFunction && farg isa SSAValue
# if this comes from code that was already inlined into another function,
# Consts have been widened. try to recover in simple cases.
farg = src.code[farg.id]
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
ftyp = argextype(farg, src, spvals, slottypes)
ftyp = argextype(farg, src, sptypes, slottypes)
end
end
f = singleton_type(ftyp)
Expand All @@ -302,7 +302,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
return 0
elseif f === Main.Core.arrayref && length(ex.args) >= 3
atyp = argextype(ex.args[3], src, spvals, slottypes)
atyp = argextype(ex.args[3], src, sptypes, slottypes)
return isknowntype(atyp) ? 4 : params.inline_nonleaf_penalty
end
fidx = find_tfunc(f)
Expand All @@ -325,7 +325,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
elseif head === :return
a = ex.args[1]
if a isa Expr
return statement_cost(a, -1, src, spvals, slottypes, params)
return statement_cost(a, -1, src, sptypes, slottypes, params)
end
return 0
elseif head === :(=)
Expand All @@ -336,7 +336,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
end
a = ex.args[2]
if a isa Expr
cost = plus_saturate(cost, statement_cost(a, -1, src, spvals, slottypes, params))
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params))
end
return cost
elseif head === :copyast
Expand All @@ -357,13 +357,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, spvals::SimpleVector
return 0
end

function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector, slottypes::Vector{Any},
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
params::Params, cost_threshold::Integer=params.inline_cost_threshold)
bodycost::Int = 0
for line = 1:length(body)
stmt = body[line]
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, spvals, slottypes, params)::Int
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params)::Int
elseif stmt isa GotoNode
# loops are generally always expensive
# but assume that forward jumps are already counted for from
Expand All @@ -378,11 +378,11 @@ function inline_worthy(body::Array{Any,1}, src::CodeInfo, spvals::SimpleVector,
return true
end

function is_known_call(e::Expr, @nospecialize(func), src, spvals::SimpleVector, slottypes::Vector{Any} = empty_slottypes)
function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = empty_slottypes)
if e.head !== :call
return false
end
f = argextype(e.args[1], src, spvals, slottypes)
f = argextype(e.args[1], src, sptypes, slottypes)
return isa(f, Const) && f.val === func
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ function just_construct_ssa(ci::CodeInfo, code::Vector{Any}, nargs::Int, sv::Opt
@timeit "domtree 1" domtree = construct_domtree(cfg)
ir = let code = Any[nothing for _ = 1:length(code)]
argtypes = sv.slottypes[1:(nargs+1)]
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sp)
IRCode(code, Any[], ci.codelocs, flags, cfg, collect(LineInfoNode, ci.linetable), argtypes, meta, sv.sptypes)
end
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sp, sv.slottypes)
@timeit "construct_ssa" ir = construct_ssa!(ci, code, ir, domtree, defuse_insts, nargs, sv.sptypes, sv.slottypes)
return ir
end

Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
isempty(eargs) && continue
arg1 = eargs[1]

ft = argextype(arg1, ir, sv.sp)
ft = argextype(arg1, ir, sv.sptypes)
has_free_typevars(ft) && continue
f = singleton_type(ft)
f === Core.Intrinsics.llvmcall && continue
Expand All @@ -797,7 +797,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
atypes[1] = ft
ok = true
for i = 2:length(stmt.args)
a = argextype(stmt.args[i], ir, sv.sp)
a = argextype(stmt.args[i], ir, sv.sptypes)
(a === Bottom || isvarargtype(a)) && (ok = false; break)
atypes[i] = a
end
Expand Down
10 changes: 5 additions & 5 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,20 +213,20 @@ struct IRCode
lines::Vector{Int32}
flags::Vector{UInt8}
argtypes::Vector{Any}
spvals::SimpleVector
sptypes::Vector{Any}
linetable::Vector{LineInfoNode}
cfg::CFG
new_nodes::Vector{NewNode}
meta::Vector{Any}

function IRCode(stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any},
spvals::SimpleVector)
return new(stmts, types, lines, flags, argtypes, spvals, linetable, cfg, NewNode[], meta)
sptypes::Vector{Any})
return new(stmts, types, lines, flags, argtypes, sptypes, linetable, cfg, NewNode[], meta)
end
function IRCode(ir::IRCode, stmts::Vector{Any}, types::Vector{Any}, lines::Vector{Int32}, flags::Vector{UInt8},
cfg::CFG, new_nodes::Vector{NewNode})
return new(stmts, types, lines, flags, ir.argtypes, ir.spvals, ir.linetable, cfg, new_nodes, ir.meta)
return new(stmts, types, lines, flags, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta)
end
end
copy(code::IRCode) = IRCode(code, copy(code.stmts), copy(code.types),
Expand Down Expand Up @@ -1143,7 +1143,7 @@ function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing
if compact_exprtype(compact, SSAValue(idx)) === Bottom
effect_free = false
else
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.spvals)
effect_free = stmt_effect_free(stmt, compact.result_types[idx], compact, compact.ir.sptypes)
end
if effect_free
for ops in userefs(stmt)
Expand Down
10 changes: 5 additions & 5 deletions base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

inflate_ir(ci::CodeInfo) = inflate_ir(ci, Core.svec(), Any[ Any for i = 1:length(ci.slotnames) ])
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ Any for i = 1:length(ci.slotnames) ])

function inflate_ir(ci::CodeInfo, linfo::MethodInstance)
spvals = spvals_from_meth_instance(linfo)
sptypes = sptypes_from_meth_instance(linfo)
if ci.inferred
argtypes, _ = matching_cache_argtypes(linfo, nothing)
else
argtypes = Any[ Any for i = 1:length(ci.slotnames) ]
end
return inflate_ir(ci, spvals, argtypes)
return inflate_ir(ci, sptypes, argtypes)
end

function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
code = copy_exprargs(ci.code)
for i = 1:length(code)
if isa(code[i], Expr)
Expand Down Expand Up @@ -46,7 +46,7 @@ function inflate_ir(ci::CodeInfo, spvals::SimpleVector, argtypes::Vector{Any})
end
ssavaluetypes = ci.ssavaluetypes isa Vector{Any} ? copy(ci.ssavaluetypes) : Any[ Any for i = 1:(ci.ssavaluetypes::Int) ]
ir = IRCode(code, ssavaluetypes, copy(ci.codelocs), copy(ci.ssaflags), cfg, collect(LineInfoNode, ci.linetable),
argtypes, Any[], spvals)
argtypes, Any[], sptypes)
return ir
end

Expand Down
Loading

0 comments on commit 5aa2462

Please sign in to comment.