Skip to content

Commit

Permalink
add if @generated ... else ... end inside functions to provide opti…
Browse files Browse the repository at this point in the history
…onal optimizers (#23168)

use meta nodes instead of `stagedfunction` expression head
  • Loading branch information
JeffBezanson committed Oct 25, 2017
1 parent 20c416d commit 80a2c2f
Show file tree
Hide file tree
Showing 24 changed files with 426 additions and 254 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ New language features
* The macro call syntax `@macroname[args]` is now available and is parsed
as `@macroname([args])` ([#23519]).

* The construct `if @generated ...; else ...; end` can be used to provide both
`@generated` and normal implementations of part of a function. Surrounding code
will be common to both versions ([#23168]).

Language changes
----------------

Expand Down
28 changes: 28 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,32 @@ show(@nospecialize a) = show(STDOUT, a)
print(@nospecialize a...) = print(STDOUT, a...)
println(@nospecialize a...) = println(STDOUT, a...)

struct GeneratedFunctionStub
gen
argnames::Array{Any,1}
spnames::Union{Void, Array{Any,1}}
line::Int
file::Symbol
end

# invoke and wrap the results of @generated
function (g::GeneratedFunctionStub)(@nospecialize args...)
body = g.gen(args...)
if body isa CodeInfo
return body
end
lam = Expr(:lambda, g.argnames,
Expr(Symbol("scope-block"),
Expr(:block,
LineNumberNode(g.line, g.file),
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
Expr(:return, body),
Expr(:meta, :pop_loc))))
if g.spnames === nothing
return lam
else
return Expr(Symbol("with-static-parameters"), lam, g.spnames...)
end
end

ccall(:jl_set_istopmod, Void, (Any, Bool), Core, true)
2 changes: 1 addition & 1 deletion base/docs/Docs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ finddoc(λ, def) = false

# Predicates and helpers for `docm` expression selection:

const FUNC_HEADS = [:function, :stagedfunction, :macro, :(=)]
const FUNC_HEADS = [:function, :macro, :(=)]
const BINDING_HEADS = [:typealias, :const, :global, :(=)] # deprecation: remove `typealias` post-0.6
# For the special `:@mac` / `:(Base.@mac)` syntax for documenting a macro after definition.
isquotedmacrocall(x) =
Expand Down
19 changes: 16 additions & 3 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,23 @@ function remove_linenums!(ex::Expr)
return ex
end

macro generated()
return Expr(:generated)
end

macro generated(f)
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
f.head = :stagedfunction
return Expr(:escape, f)
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
body = f.args[2]
lno = body.args[1]
return Expr(:escape,
Expr(f.head, f.args[1],
Expr(:block,
lno,
Expr(:if, Expr(:generated),
body,
Expr(:block,
Expr(:meta, :generated_only),
Expr(:return, nothing))))))
else
error("invalid syntax; @generated must be used with a function definition")
end
Expand Down
16 changes: 11 additions & 5 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,12 +577,18 @@ _valuefields(::Type{<:AbstractTriangular}) = [:data]

const SpecialArrays = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,AbstractTriangular}

@generated function fillslots!(A::SpecialArrays, x)
ex = :(xT = convert(eltype(A), x))
for field in _valuefields(A)
ex = :($ex; fill!(A.$field, xT))
function fillslots!(A::SpecialArrays, x)
xT = convert(eltype(A), x)
if @generated
quote
$([ :(fill!(A.$field, xT)) for field in _valuefields(A) ]...)
end
else
for field in _valuefields(A)
fill!(getfield(A, field), xT)
end
end
:($ex;return A)
return A
end

# for historical reasons:
Expand Down
12 changes: 10 additions & 2 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar
return s, string_with_env(env, t)
end

function method_argnames(m::Method)
if !isdefined(m, :source) && isdefined(m, :generator)
return m.generator.argnames
end
argnames = Vector{Any}(m.nargs)
ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames)
return argnames
end

function arg_decl_parts(m::Method)
tv = Any[]
sig = m.sig
Expand All @@ -52,8 +61,7 @@ function arg_decl_parts(m::Method)
file = m.file
line = m.line
if isdefined(m, :source) || isdefined(m, :generator)
argnames = Vector{Any}(m.nargs)
ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames)
argnames = method_argnames(m)
show_env = ImmutableDict{Symbol, Any}()
for t in tv
show_env = ImmutableDict(show_env, :unionall_env => t)
Expand Down
52 changes: 25 additions & 27 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,11 @@ end
@noinline throw_checksize_error(A, sz) = throw(DimensionMismatch("output array is the wrong size; expected $sz, got $(size(A))"))

## setindex! ##
@generated function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...)
N = length(I)
quote
@_inline_meta
@boundscheck checkbounds(A, I...)
_unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...)
A
end
function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...)
@_inline_meta
@boundscheck checkbounds(A, I...)
_unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...)
A
end

_iterable(v::AbstractArray) = v
Expand Down Expand Up @@ -916,28 +913,29 @@ function copy!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N}
dest
end

@generated function copy!(dest::AbstractArray{T1,N},
Rdest::CartesianRange{N},
src::AbstractArray{T2,N},
Rsrc::CartesianRange{N}) where {T1,T2,N}
quote
isempty(Rdest) && return dest
if size(Rdest) != size(Rsrc)
throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
function copy!(dest::AbstractArray{T1,N}, Rdest::CartesianRange{N},
src::AbstractArray{T2,N}, Rsrc::CartesianRange{N}) where {T1,T2,N}
isempty(Rdest) && return dest
if size(Rdest) != size(Rsrc)
throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
end
@boundscheck checkbounds(dest, first(Rdest))
@boundscheck checkbounds(dest, last(Rdest))
@boundscheck checkbounds(src, first(Rsrc))
@boundscheck checkbounds(src, last(Rsrc))
ΔI = first(Rdest) - first(Rsrc)
if @generated
quote
@nloops $N i (n->Rsrc.indices[n]) begin
@inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i)
end
end
@boundscheck checkbounds(dest, first(Rdest))
@boundscheck checkbounds(dest, last(Rdest))
@boundscheck checkbounds(src, first(Rsrc))
@boundscheck checkbounds(src, last(Rsrc))
ΔI = first(Rdest) - first(Rsrc)
# TODO: restore when #9080 is fixed
# for I in Rsrc
# @inbounds dest[I+ΔI] = src[I]
@nloops $N i (n->Rsrc.indices[n]) begin
@inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i)
else
for I in Rsrc
@inbounds dest[I + ΔI] = src[I]
end
dest
end
dest
end

"""
Expand Down
19 changes: 11 additions & 8 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,13 @@ end
"""
code_lowered(f, types, expand_generated = true)
Return an array of lowered ASTs for the methods matching the given generic function and type signature.
Return an array of the lowered forms (IR) for the methods matching the given generic function
and type signature.
If `expand_generated` is `false`, then the `CodeInfo` instances returned for `@generated`
methods will correspond to the generators' lowered ASTs. If `expand_generated` is `true`,
these `CodeInfo` instances will correspond to the lowered ASTs of the method bodies yielded
by expanding the generators.
If `expand_generated` is `false`, the returned `CodeInfo` instances will correspond to fallback
implementations. An error is thrown if no fallback implementation exists.
If `expand_generated` is `true`, these `CodeInfo` instances will correspond to the method bodies
yielded by expanding the generators.
Note that an error will be thrown if `types` are not leaf types when `expand_generated` is
`true` and the corresponding method is a `@generated` method.
Expand Down Expand Up @@ -737,7 +738,9 @@ function length(mt::MethodTable)
end
isempty(mt::MethodTable) = (mt.defs === nothing)

uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m, :source) ? m.source : m.generator.inferred)
uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) :
isdefined(m,:generator) ? error("Method is @generated; try `code_lowered` instead.") :
error("Code for this Method is not available.")
uncompressed_ast(m::Method, s::CodeInfo) = s
uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo
uncompressed_ast(m::Core.MethodInstance) = uncompressed_ast(m.def)
Expand Down Expand Up @@ -851,7 +854,7 @@ code_native(::IO, ::Any, ::Symbol) = error("illegal code_native call") # resolve

# give a decent error message if we try to instantiate a staged function on non-leaf types
function func_for_method_checked(m::Method, @nospecialize types)
if isdefined(m,:generator) && !isdefined(m,:source) && !_isleaftype(types)
if isdefined(m,:generator) && !_isleaftype(types)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
Expand All @@ -861,7 +864,7 @@ end
"""
code_typed(f, types; optimize=true)
Returns an array of lowered and type-inferred ASTs for the methods matching the given
Returns an array of type-inferred lowered form (IR) for the methods matching the given
generic function and type signature. The keyword argument `optimize` controls whether
additional optimizations, such as inlining, are also applied.
"""
Expand Down
32 changes: 19 additions & 13 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,21 +236,27 @@ include("broadcast.jl")
using .Broadcast

# define the real ntuple functions
@generated function ntuple(f::F, ::Val{N}) where {F,N}
Core.typeassert(N, Int)
(N >= 0) || return :(throw($(ArgumentError(string("tuple length should be ≥0, got ", N)))))
return quote
$(Expr(:meta, :inline))
@nexprs $N i -> t_i = f(i)
@ncall $N tuple t
@inline function ntuple(f::F, ::Val{N}) where {F,N}
N::Int
(N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N)))
if @generated
quote
@nexprs $N i -> t_i = f(i)
@ncall $N tuple t
end
else
Tuple(f(i) for i = 1:N)
end
end
@generated function fill_to_length(t::Tuple, val, ::Val{N}) where {N}
M = length(t.parameters)
M > N && return :(throw($(ArgumentError("input tuple of length $M, requested $N"))))
return quote
$(Expr(:meta, :inline))
(t..., $(Any[ :val for i = (M + 1):N ]...))
@inline function fill_to_length(t::Tuple, val, ::Val{N}) where {N}
M = length(t)
M > N && throw(ArgumentError("input tuple of length $M, requested $N"))
if @generated
quote
(t..., $(fill(:val, N-length(t.parameters))...))
end
else
(t..., fill(val, N-M)...)
end
end

Expand Down
Loading

0 comments on commit 80a2c2f

Please sign in to comment.