Skip to content

Commit

Permalink
Make Broadcast.flatten(bc).f more complier frendly. (better inferre…
Browse files Browse the repository at this point in the history
…d and inlined) (JuliaLang#43322)

A follow up attemp to fix JuliaLang#27988. (close JuliaLang#47493 close JuliaLang#50554)
Examples:
```julia
julia> using LazyArrays
julia> bc = @~ @. 1*(1 + 1) + 1*1;
julia> bc2 = @~ 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3;
```
On master:
<details><summary> click for details </summary>
<p>

```julia
julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 4)::Int64
│   %5  = Core.getfield(args, 5)::Int64
│   %6  = invoke Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#15#17", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}, typeof(+)}}(Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#15#17", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}, typeof(+)}(Base.Broadcast.var"JuliaLang#15#17"(), Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}(Base.Broadcast.var"JuliaLang#15#17"())), Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), +))(%1::Int64, %2::Int64, %3::Vararg{Int64}, %4, %5)::Tuple{Int64, Int64, Vararg{Int64}}
│   %7  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), %6)::Tuple{Int64, Int64}
│   %8  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), %6)::Tuple{Vararg{Int64}}
│   %9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#9#11", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}, typeof(*)}(Base.Broadcast.var"JuliaLang#9#11"(), Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}(Base.Broadcast.var"JuliaLang#15#17"())), Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), *), %8)::Tuple{Int64}
│   %10 = Core.getfield(%7, 1)::Int64
│   %11 = Core.getfield(%7, 2)::Int64
│   %12 = Base.mul_int(%10, %11)::Int64
│   %13 = Core.getfield(%9, 1)::Int64
│   %14 = Base.add_int(%12, %13)::Int64
└──       return %14
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 5)::Int64
│   %5  = Core.getfield(args, 7)::Int64
│   %6  = Core.getfield(args, 8)::Int64
│   %7  = Core.getfield(args, 10)::Int64
│   %8  = invoke Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#15#17", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}}, typeof(Base.literal_pow)}}(Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#15#17", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"JuliaLang#15#17"(), Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}(Base.Broadcast.var"JuliaLang#15#17"()))), Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"()))), Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"()))), Base.literal_pow))(%3::Int64, ^::Function, %4::Vararg{Any}, $(QuoteNode(Val{2}())), %5, %6, ^, %7, $(QuoteNode(Val{3}())))::Tuple{Int64, Any, Vararg{Any}}
│   %9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), %8)::Tuple{Int64, Any}
│   %10 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), %8)::Tuple
│   %11 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#15#17"(), %10)::Tuple
│   %12 = Core.getfield(%9, 1)::Int64
│   %13 = Core.getfield(%9, 2)::Any
│   %14 = (*)(%12, %13)::Any
│   %15 = Core.tuple(%14)::Tuple{Any}
│   %16 = Core._apply_iterate(Base.iterate, Core.tuple, %15, %11)::Tuple{Any, Vararg{Any}}
│   %17 = Base.mul_int(%1, %2)::Int64
│   %18 = Core.tuple(%17)::Tuple{Int64}
│   %19 = Core._apply_iterate(Base.iterate, Core.tuple, %18, %16)::Tuple{Int64, Any, Vararg{Any}}
│   %20 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), %19)::Tuple{Int64, Any}
│   %21 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), %19)::Tuple
│   %22 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#15#17", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}, typeof(*)}(Base.Broadcast.var"JuliaLang#15#17"(), Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}(Base.Broadcast.var"JuliaLang#15#17"())), Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), *), %21)::Tuple{Any, Vararg{Any}}
│   %23 = Core.getfield(%20, 1)::Int64
│   %24 = Core.getfield(%20, 2)::Any
│   %25 = (-)(%23, %24)::Any
│   %26 = Core.tuple(%25)::Tuple{Any}
│   %27 = Core._apply_iterate(Base.iterate, Core.tuple, %26, %22)::Tuple{Any, Any, Vararg{Any}}
│   %28 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"())), %27)::Tuple{Any, Any}
│   %29 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"())), %27)::Tuple
│   %30 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"JuliaLang#16#18"{Base.Broadcast.var"JuliaLang#9#11", Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}}, Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}}, Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"JuliaLang#9#11"(), Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}}(Base.Broadcast.var"JuliaLang#13#14"{Base.Broadcast.var"JuliaLang#15#17"}(Base.Broadcast.var"JuliaLang#15#17"()))), Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}}(Base.Broadcast.var"JuliaLang#23#24"{Base.Broadcast.var"JuliaLang#25#26"}(Base.Broadcast.var"JuliaLang#25#26"()))), Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}}(Base.Broadcast.var"JuliaLang#19#20"{Base.Broadcast.var"JuliaLang#21#22"}(Base.Broadcast.var"JuliaLang#21#22"()))), Base.literal_pow), %29)::Tuple{Any}
│   %31 = Core.getfield(%28, 1)::Any
│   %32 = Core.getfield(%28, 2)::Any
│   %33 = (+)(%31, %32)::Any
│   %34 = Core.getfield(%30, 1)::Any
│   %35 = (+)(%33, %34)::Any
└──       return %35
) => Any
```
</p>

</details>

On this PR
```julia
julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)
CodeInfo(
1 ─ %1 = Core.getfield(args, 1)::Int64
│   %2 = Core.getfield(args, 2)::Int64
│   %3 = Core.getfield(args, 3)::Int64
│   %4 = Core.getfield(args, 4)::Int64
│   %5 = Core.getfield(args, 5)::Int64
│   %6 = Base.add_int(%2, %3)::Int64
│   %7 = Base.mul_int(%1, %6)::Int64
│   %8 = Base.mul_int(%4, %5)::Int64
│   %9 = Base.add_int(%7, %8)::Int64
└──      return %9
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 5)::Int64
│   %5  = Core.getfield(args, 7)::Int64
│   %6  = Core.getfield(args, 8)::Int64
│   %7  = Core.getfield(args, 10)::Int64
│   %8  = Base.mul_int(%1, %2)::Int64
│   %9  = Base.mul_int(%4, %4)::Int64
│   %10 = Base.mul_int(%3, %9)::Int64
│   %11 = Base.sub_int(%8, %10)::Int64
│   %12 = Base.mul_int(%5, %6)::Int64
│   %13 = Base.add_int(%11, %12)::Int64
│   %14 = Base.mul_int(%7, %7)::Int64
│   %15 = Base.mul_int(%14, %7)::Int64
│   %16 = Base.add_int(%13, %15)::Int64
└──       return %16
) => Int64
```
  • Loading branch information
N5N3 committed Jul 15, 2023
2 parents 22ac24a + d406c7e commit f15eb4e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 81 deletions.
121 changes: 43 additions & 78 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,16 @@ function flatten(bc::Broadcasted)
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(bc)
# build a function `makeargs` that takes a "flat" argument list and
# and creates the appropriate input arguments for `f`, e.g.,
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
#
# `makeargs` is built recursively and looks a bit like this:
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
# = (w, g(x, y), makeargs2(z)...)
# = (w, g(x, y), z)
let makeargs = make_makeargs(()->(), bc.args), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted(bc.style, newf, args, bc.axes)
end
# build a tuple of functions `makeargs`. Its elements take
# the whole "flat" argument list and and generate the appropriate
# input arguments for the broadcasted function `f`, e.g.,
# makeargs[1] = ((w, x, y, z)) -> w
# makeargs[2] = ((w, x, y, z)) -> g(x, y)
# makeargs[3] = ((w, x, y, z)) -> z
makeargs = make_makeargs(bc.args)
f = Base.maybeconstructor(bc.f)
newf = (args...) -> (@inline; f(prepare_args(makeargs, args)...))
return Broadcasted(bc.style, newf, args, bc.axes)
end

const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}}
Expand All @@ -363,78 +359,47 @@ _isflat(args::NestedTuple) = false
_isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
cat_nested() = ()
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
cat_nested_args(::Tuple{}) = ()
cat_nested_args(t::Tuple{Any}) = cat_nested(t[1])
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
cat_nested(a) = (a,)

"""
make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
make_makeargs(t::Tuple) -> Tuple{Vararg{Function}}
Each element of `t` is one (consecutive) node in a broadcast tree.
Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
to return a function that takes in flattened argument list and returns a
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree). As an additional
complication, the passed in tuple may be longer than the number of leaves
in the subtree described by `t`. The `makeargs_tail` function should
be called on such additional arguments (but not the arguments consumed
by `t`).
The returned `Tuple` are functions which take in the (whole) flattened
list and generate the inputs for the corresponding broadcasted function.
"""
@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail
@inline function make_makeargs(makeargs_tail, t::Tuple)
makeargs = make_makeargs(makeargs_tail, tail(t))
(head, tail...)->(head, makeargs(tail...)...)
make_makeargs(args::Tuple) = _make_makeargs(args, 1)[1]

# We build `makeargs` by traversing the broadcast nodes recursively.
# note: `n` indicates the flattened index of the next unused argument.
@inline function _make_makeargs(args::Tuple, n::Int)
head, n = _make_makeargs1(args[1], n)
rest, n = _make_makeargs(tail(args), n)
(head, rest...), n
end
function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}})
bc = t[1]
# c.f. the same expression in the function on leaf nodes above. Here
# we recurse into siblings in the broadcast tree.
let makeargs_tail = make_makeargs(makeargs_tail, tail(t)),
# Here we recurse into children. It would be valid to pass in makeargs_tail
# here, and not use it below. However, in that case, our recursion is no
# longer purely structural because we're building up one argument (the closure)
# while destructuing another.
makeargs_head = make_makeargs((args...)->args, bc.args),
f = bc.f
# Create two functions, one that splits of the first length(bc.args)
# elements from the tuple and one that yields the remaining arguments.
# N.B. We can't call headargs on `args...` directly because
# args is flattened (i.e. our children have not been evaluated
# yet).
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs_head(args...)
a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...)
(f(a...), b...)
end
end
_make_makeargs(::Tuple{}, n::Int) = (), n

# A help struct to store the flattened index staticly
struct Pick{N} <: Function end
(::Pick{N})(@nospecialize(args::Tuple)) where {N} = args[N]

# For flat nodes, we just consume one argument (n += 1), and return the "Pick" function
@inline _make_makeargs1(_, n::Int) = Pick{n}(), n + 1
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
@inline function _make_makeargs1(bc::Broadcasted, n::Int)
makeargs, n = _make_makeargs(bc.args, n)
f = Base.maybeconstructor(bc.f)
makeargs1 = (args::Tuple) -> (@inline; f(prepare_args(makeargs, args)...))
makeargs1, n
end

@inline function make_headargs(t::Tuple)
let headargs = make_headargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
(head, headargs(tail...)...)
end
end
end
@inline function make_headargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
()
end
end

@inline function make_tailargs(t::Tuple)
let tailargs = make_tailargs(tail(t))
return @inline function(head, tail::Vararg{Any,N}) where N
tailargs(tail...)
end
end
end
@inline function make_tailargs(::Tuple{})
return @inline function(tail::Vararg{Any,N}) where N
tail
end
end
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
@inline prepare_args(makeargs::Tuple{Any}, @nospecialize(x::Tuple)) = (makeargs[1](x),)
prepare_args(::Tuple{}, ::Tuple) = ()

## Broadcasting utilities ##

Expand Down
19 changes: 16 additions & 3 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,14 +774,27 @@ let X = zeros(2, 3)
end

# issue #27988: inference of Broadcast.flatten
using .Broadcast: Broadcasted
using .Broadcast: Broadcasted, cat_nested
let
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
# @. 1 + 1 * (1 + 1 + 1 + 1)
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1) # `cat_nested` failed to infer this
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
# @. 1 + (1 + 1) + 1 + (1 + 1) + 1 + (1 + 1) + 1
bc = Broadcasted(+, (1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1))
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
bc = Broadcasted(Float32, (Broadcasted(+, (1, 1)),))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
end

let
Expand Down

0 comments on commit f15eb4e

Please sign in to comment.