Skip to content

Commit

Permalink
Improve _apply(apply_type, ::Tuple, ::SimpleVector)
Browse files Browse the repository at this point in the history
This is another one of those patterns that doesn't generally
come up except in Cassette/Zygote code. There's two related
changes here:

1. Expand ininling's _apply rewrite to also handle constant
   svecs (under the restriction that they must be no longer than
   the splatting cutoff and the elements must be individually eligible
   for inlining into the IR)
2. Move the _apply rewrite before the special case inliner for builtins
   such that _apply(apply_type, ...) gets eliminated early.
  • Loading branch information
Keno committed Mar 6, 2019
1 parent c034b2f commit 1c8c8e7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 32 deletions.
82 changes: 50 additions & 32 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, at
def_atypes = def_type.fields
else
def_atypes = Any[]
if isa(def_type, Const) # && isa(def_type.val, Tuple) is implied
if isa(def_type, Const) # && isa(def_type.val, Union{Tuple, SimpleVector}) is implied
for p in def_type.val
push!(def_atypes, Const(p))
end
Expand All @@ -611,8 +611,12 @@ function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, at
# now push flattened types into new_atypes and getfield exprs into new_argexprs
for j in 1:length(def_atypes)
def_atype = def_atypes[j]
new_call = Expr(:call, Core.getfield, def, j)
new_argexpr = insert_node!(ir, idx, def_atype, new_call)
if isa(def_atype, Const) && is_inlineable_constant(def_atype.val)
new_argexpr = quoted(def_atype.val)
else
new_call = Expr(:call, Core.getfield, def, j)
new_argexpr = insert_node!(ir, idx, def_atype, new_call)
end
push!(new_argexprs, new_argexpr)
push!(new_atypes, def_atype)
end
Expand Down Expand Up @@ -787,6 +791,29 @@ function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(cas
nothing
end

function is_valid_type_for_apply_rewrite(@nospecialize(typ), sv)
if isa(typ, Const) && isa(typ.val, SimpleVector)
length(typ.val) > sv.params.MAX_TUPLE_SPLAT && return false
for p in typ.val
is_inlineable_constant(p) || return false
end
return true
end
typ = widenconst(typ)
if isa(typ, DataType) && typ.name === NamedTuple_typename
typ = typ.parameters[2]
while isa(typ, TypeVar)
typ = typ.ub
end
end
isa(typ, DataType) || return false
if typ.name === Tuple.name
return !isvatuple(typ) && length(typ.parameters) <= sv.params.MAX_TUPLE_SPLAT
else
return false
end
end

function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Any[]
Expand Down Expand Up @@ -838,42 +865,14 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
end
ok || continue

# Check if we match any of the early inliners
calltype = ir.types[idx]
res = early_inline_special_case(ir, f, ft, stmt, atypes, sv, calltype)
if res !== nothing
ir.stmts[idx] = res
continue
end

if f !== Core.invoke && f !== Core._apply &&
(isa(f, IntrinsicFunction) || ft IntrinsicFunction || isa(f, Builtin) || ft Builtin)
# No inlining for builtins (other than what's handled in the early inliner)
# TODO: this test is wrong if we start to handle Unions of function types later
continue
end

# Special handling for Core.invoke and Core._apply, which can follow the normal inliner
# logic with modified inlining target
isinvoke = false

# Handle _apply
ok = true
while f === Core._apply
# Try to figure out the signature of the function being called
# and if rewrite_apply_exprargs can deal with this form
for i = 3:length(atypes)
typ = atypes[i]
typ = widenconst(typ)
# TODO: We could basically run the iteration protocol here
if isa(typ, DataType) && typ.name === NamedTuple_typename
typ = typ.parameters[2]
while isa(typ, TypeVar)
typ = typ.ub
end
end
if !isa(typ, DataType) || typ.name !== Tuple.name ||
isvatuple(typ) || length(typ.parameters) > sv.params.MAX_TUPLE_SPLAT
if !is_valid_type_for_apply_rewrite(atypes[i], sv)
ok = false
break
end
Expand All @@ -895,6 +894,25 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
end
ok || continue

# Check if we match any of the early inliners
calltype = ir.types[idx]
res = early_inline_special_case(ir, f, ft, stmt, atypes, sv, calltype)
if res !== nothing
ir.stmts[idx] = res
continue
end

if f !== Core.invoke && f !== Core._apply &&
(isa(f, IntrinsicFunction) || ft IntrinsicFunction || isa(f, Builtin) || ft Builtin)
# No inlining for builtins (other than what's handled in the early inliner)
# TODO: this test is wrong if we start to handle Unions of function types later
continue
end

# Special handling for Core.invoke and Core._apply, which can follow the normal inliner
# logic with modified inlining target
isinvoke = false

if f !== Core.invoke && (isa(f, IntrinsicFunction) || ft IntrinsicFunction || isa(f, Builtin) || ft Builtin)
# TODO: this test is wrong if we start to handle Unions of function types later
continue
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,13 @@ let code = code_typed(f_pointerref, Tuple{Type{Int}})[1][1].code
end
@test !any_ptrref
end

# Test that inlining can inline _applys of builtins/_applys on SimpleVectors
function foo_apply_apply_type_svec()
A = (Tuple, Float32)
B = Tuple{Float32, Float32}
Core.apply_type(A..., B.types...)
end
let ci = code_typed(foo_apply_apply_type_svec, Tuple{})[1].first
@test length(ci.code) == 1 && ci.code[1] == Expr(:return, NTuple{3, Float32})
end

0 comments on commit 1c8c8e7

Please sign in to comment.