Skip to content

Commit

Permalink
introduce runtime representation of broadcast fusion
Browse files Browse the repository at this point in the history
fix #21094
fix #22060
fix #22053
replaces #22063
  • Loading branch information
vtjnash committed Sep 12, 2017
1 parent fce0a3c commit a13aacd
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 126 deletions.
225 changes: 225 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,229 @@ macro __dot__(x)
esc(__dot__(x))
end

############################################################
## The parser turns dotted calls into the equivalent Fusion expression.
## Effectively, this turns the Expr tree into a runtime AST,
## for a limited subset of expression types.
#
## For example, in the expression:
# d = sin.((a .+ (b .* c))...)
## The kernel becomes
# d' = Fusion{3}(
# FusionApply(
# sin,
# ( FusionCall(
# +,
# ( FusionArg{1}(),
# FusionCall(
# *,
# ( FusionArg{2}(),
# FusionArg{3}() )), )), )),
# (:a, :b, :c))
## and then the final expansion becomes:
# d = broadcast(d', a, b, c)

struct Fusion{N, vararg#=::Bool=#, T}
f::T
# Debugging Metadata:
# names::NTuple{N, Symbol}
# source::LineNumberNode
function Fusion{N, vararg}(f) where {N, vararg}
return new{N, vararg::Bool, typeof(f)}(f)
end
end

struct FusionArg{N}
end

struct FusionConstant{T}
c::T
function FusionConstant(c) where {}
return new{typeof(c)}(c)
end
end

struct FusionCall{F, Args<:Tuple}
f::F
args::Args
function FusionCall(f, args::Tuple) where {}
return new{typeof(f), typeof(args)}(f, args)
end
end

struct FusionApply{N, F, Args<:NTuple{N, Any}}
f::F
args::Args
function FusionApply(f, args::NTuple{N, Any}) where {N}
return new{N, typeof(f), typeof(args)}(f, args)
end
end

function kw_to_vec(kws::Vector{Any})
kwargs = Vector{Any}(2 * length(kws))
for i in 1:2:length(kws)
kw = kws[i]::Tuple{Any, Any}
kwargs[i] = getfield(kw, 1)
kwargs[i + 1] = getfield(kw, 2)
end
return kwargs
end

struct FusionKWCall{F, Args<:Tuple}
f::F
args::Args
kwargs::Vector{Any}
function FusionKWCall(f, args::Tuple; kwargs...) where {}
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
end
end

struct FusionKWApply{F, Args<:Tuple}
f::F
args::Args
kwargs::Vector{Any}
function FusionKWApply(f, args::Tuple; kwargs...) where {}
return new{typeof(f), typeof(args)}(f, args, kw_to_vec(kwargs))
end
end

function tuplehead(t::Tuple, N::Val)
return ntuple(i -> t[i], N)
end
@generated function tupletail(t::NTuple{M, Any}, ::Val{N}) where {N, M}
# alternative, non-generated versions,
# enable when inference is improved:
#tupletail(t, Nreq) = ntuple(i -> t[i + Nreq], length(t) - Nreq)
#tupletail(t, Nreq) = t[(Nreq + 1):end]
args = Any[ :(getfield(t, $i)) for i in (N + 1):M ]
tpl = Expr(:tuple)
tpl.args = args
return tpl
end

@inline (f::Fusion{N, false})(args::Vararg{Any, N}) where {N} = f.f(args...)
function (f::Fusion{Nreq, true})(args::Vararg{Any, M}) where {Nreq, M}
M >= Nreq || throw(MethodError(f, args))
fargs = tuplehead(args, Val(Nreq))
vararg = tupletail(args, Val(Nreq))
return f.f((fargs..., vararg)...)
end
@inline (f::FusionArg{N})(args...) where {N} = args[N]
@inline (f::FusionConstant)(args...) = f.c
@inline (f::FusionCall)(args...) = f.f(map(a -> a(args...), f.args)...)
# TODO: calling _apply on map _apply is not handled by inference
# for now, we unroll some cases and generate others, to help it out
#@inline (f::FusionApply)(args...) = Core._apply(f.f, map(a -> a(args...), f.args)...)
@inline (f::FusionApply{0})(args...) = f.f()
@inline (f::FusionApply{1})(args...) = f.f(f.args[1](args...)...)
@inline (f::FusionApply{2})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)...)
@inline (f::FusionApply{3})(args...) = f.f(f.args[1](args...)..., f.args[2](args...)..., f.args[3](args...)...)
@generated function (f::FusionApply{N})(args...) where {N}
fargs = Any[ :(getfield(f.args, $i)(args...)) for i in 1:N ]
return Expr(:call, GlobalRef(Core, :_apply), :(f.f), fargs...)
end
@inline function (f::FusionKWCall)(args...)
fargs = map(a -> a(args...), f.args)
# return f.f(args...; kwargs...)
if isempty(f.kwargs)
return f.f(fargs...)
else
return Core.kwfunc(f.f)(f.kwargs, f.f, fargs...)
end
end
@inline function (f::FusionKWApply)(args...)
fargs = map(a -> a(args...), f.args)
# return Core._apply(f.f, args...; kwargs...)
if isempty(f.kwargs)
return Core._apply(f.f, fargs...)
else
return Core._apply(Core.kwfunc(f.f), (f.kwargs,), (f.f,), fargs...)
end
end

function Base.show(io::IO, f::Fusion{N, vararg}) where {N, vararg}
nargs = (vararg ? N + 1 : N)
names = String[ "a_$i" for i in 1:nargs ] # f.names
print(io, "(")
join(io, names, ", ")
vararg && print(io, "...")
print(io, ") -> ")
show_fusion(io, f.f, names)
end

function show_fusion(io::IO, f::FusionArg{N}, names) where N
print(io, names[N])
nothing
end

function show_fusion(io::IO, f::FusionConstant{N}, names) where N
print(io, f.c)
nothing
end

function show_fusion(io::IO, f::FusionCall, names)
Base.show(io, f.f)
print(io, '(')
first = true
for i in f.args
first || print(io, ", ")
first = false
show_fusion(io, i, names)
end
print(io, ')')
nothing
end

function show_fusion(io::IO, f::FusionApply, names)
print(io, "Core._apply(")
Base.show(io, f.f)
for i in f.args
print(io, ", ")
show_fusion(io, i, names)
end
print(io, ')')
nothing
end

function show_fusion(io::IO, f::FusionKWCall, names)
Base.show(io, f.f)
print(io, '(')
first = true
for i in f.args
first || print(io, ", ")
first = false
show_fusion(io, i, names)
end
print(io, "; ")
first = true
for i in 1:2:length(f.kwargs)
first || print(io, ", ")
first = false
print(io, f.kwargs[i])
print(io, "=")
end
print(io, ')')
nothing
end


function show_fusion(io::IO, f::FusionKWApply, names)
print(io, "Core._apply(")
Base.show(io, f.f)
for i in f.args
print(io, ", ")
show_fusion(io, i, names)
end
print(io, "; #=kwargs=#...)")
nothing
end


function show_fusion(io::IO, @nospecialize(f), names)
print(io, "#= unexpected expression ")
show(io, f)
print(io, " =#")
nothing
end

end # module
Loading

0 comments on commit a13aacd

Please sign in to comment.