Skip to content

Commit

Permalink
infer/optimize splatting of NamedTuples, similar to Tuples (JuliaLang…
Browse files Browse the repository at this point in the history
…#30561)

This allows removing one of the generated functions for NamedTuples,
and a couple others could be removed as well.
  • Loading branch information
JeffBezanson committed Jan 12, 2019
1 parent 0ecaffb commit 7772486
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 40 deletions.
6 changes: 6 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ function precise_container_type(@nospecialize(typ), vtypes::VarTable, sv::Infere

tti0 = widenconst(typ)
tti = unwrap_unionall(tti0)
if isa(tti, DataType) && tti.name === NamedTuple_typename
tti0 = tti.parameters[2]
while isa(tti0, TypeVar)
tti0 = tti0.ub
end
end
if isa(tti, Union)
utis = uniontypes(tti)
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
Expand Down
12 changes: 11 additions & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,11 @@ function rewrite_apply_exprargs!(ir::IRCode, idx::Int, argexprs::Vector{Any}, at
push!(def_atypes, Const(p))
end
else
for p in widenconst(def_type).parameters
ti = widenconst(def_type)
if ti.name === NamedTuple_typename
ti = ti.parameters[2]
end
for p in ti.parameters
if isa(p, DataType) && isdefined(p, :instance)
# replace singleton types with their equivalent Const object
p = Const(p.instance)
Expand Down Expand Up @@ -827,6 +831,12 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
typ = atypes[i]
typ = widenconst(typ)
# TODO: We could basically run the iteration protocol here
if isa(typ, DataType) && typ.name === NamedTuple_typename
typ = typ.parameters[2]
while isa(typ, TypeVar)
typ = typ.ub
end
end
if !isa(typ, DataType) || typ.name !== Tuple.name ||
isvatuple(typ) || length(typ.parameters) > sv.params.MAX_TUPLE_SPLAT
ok = false
Expand Down
26 changes: 19 additions & 7 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -643,19 +643,31 @@ function append_any(xs...)
end
for j in 1:lx
y = @inbounds x[j]
arrayset(true, out, y, i)
arrayset(false, out, y, i)
i += 1
end
elseif x isa Tuple
lx = length(x)
lx = nfields(x)
if i + lx - 1 > l
ladd = lx > 16 ? lx : 16
_growend!(out, ladd)
l += ladd
end
for j in 1:lx
y = @inbounds x[j]
arrayset(true, out, y, i)
y = getfield(x, j, false)
arrayset(false, out, y, i)
i += 1
end
elseif x isa NamedTuple
lx = nfields(x)
if i + lx - 1 > l
ladd = lx > 16 ? lx : 16
_growend!(out, ladd)
l += ladd
end
for j in 1:lx
y = getfield(x, j, false)
arrayset(false, out, y, i)
i += 1
end
elseif x isa Array
Expand All @@ -666,8 +678,8 @@ function append_any(xs...)
l += ladd
end
for j in 1:lx
y = arrayref(true, x, j)
arrayset(true, out, y, i)
y = arrayref(false, x, j)
arrayset(false, out, y, i)
i += 1
end
else
Expand All @@ -676,7 +688,7 @@ function append_any(xs...)
_growend!(out, 16)
l += 16
end
arrayset(true, out, y, i)
arrayset(false, out, y, i)
i += 1
end
end
Expand Down
35 changes: 5 additions & 30 deletions base/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,10 @@ Construct a named tuple with the given `names` (a tuple of Symbols) and field ty
(a `Tuple` type) from a tuple of values.
"""
function NamedTuple{names,T}(args::Tuple) where {names, T <: Tuple}
if length(args) == length(names)
if @generated
N = length(names)
types = T.parameters
Expr(:new, :(NamedTuple{names,T}), Any[ :(convert($(types[i]), args[$i])) for i in 1:N ]...)
else
N = length(names)
NT = NamedTuple{names,T}
types = T.parameters
fields = Any[ convert(types[i], args[i]) for i = 1:N ]
ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), NT, fields, N)::NT
end
else
if length(args) != length(names)
throw(ArgumentError("Wrong number of arguments to named tuple constructor."))
end
NamedTuple{names,T}(T(args))
end

"""
Expand Down Expand Up @@ -112,13 +101,7 @@ function convert(::Type{NamedTuple{names,T}}, nt::NamedTuple{names}) where {name
end

if nameof(@__MODULE__) === :Base
function Tuple(nt::NamedTuple{names}) where {names}
if @generated
return Expr(:tuple, Any[:(getfield(nt, $(QuoteNode(n)))) for n in names]...)
else
return tuple(nt...)
end
end
Tuple(nt::NamedTuple) = (nt...,)
(::Type{T})(nt::NamedTuple) where {T <: Tuple} = convert(T, Tuple(nt))
end

Expand Down Expand Up @@ -173,20 +156,12 @@ isless(a::NamedTuple{n}, b::NamedTuple{n}) where {n} = isless(Tuple(a), Tuple(b)
same_names(::NamedTuple{names}...) where {names} = true
same_names(::NamedTuple...) = false

# NOTE: this method signature makes sure we don't define map(f)
function map(f, nt::NamedTuple{names}, nts::NamedTuple...) where names
if !same_names(nt, nts...)
throw(ArgumentError("Named tuple names do not match."))
end
# this method makes sure we don't define a map(f) method
NT = NamedTuple{names}
if @generated
N = length(names)
M = length(nts)
args = Expr[:(f($(Expr[:(getfield(nt, $j)), (:(getfield(nts[$i], $j)) for i = 1:M)...]...))) for j = 1:N]
:( NT(($(args...),)) )
else
NT(map(f, map(Tuple, (nt, nts...))...))
end
NamedTuple{names}(map(f, map(Tuple, (nt, nts...))...))
end

# a version of `in` for the older world these generated functions run in
Expand Down
4 changes: 2 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ JL_CALLABLE(jl_f__apply)
if (jl_is_svec(args[i])) {
n += jl_svec_len(args[i]);
}
else if (jl_is_tuple(args[i])) {
else if (jl_is_tuple(args[i]) || jl_is_namedtuple(args[i])) {
n += jl_nfields(args[i]);
}
else if (jl_is_array(args[i])) {
Expand Down Expand Up @@ -524,7 +524,7 @@ JL_CALLABLE(jl_f__apply)
for(j=0; j < al; j++)
newargs[n++] = jl_svecref(t, j);
}
else if (jl_is_tuple(ai)) {
else if (jl_is_tuple(ai) || jl_is_namedtuple(ai)) {
size_t al = jl_nfields(ai);
for(j=0; j < al; j++) {
// jl_fieldref may allocate.
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ static inline int jl_is_layout_opaque(const jl_datatype_layout_t *l) JL_NOTSAFEP
// basic predicates -----------------------------------------------------------
#define jl_is_nothing(v) (((jl_value_t*)(v)) == ((jl_value_t*)jl_nothing))
#define jl_is_tuple(v) (((jl_datatype_t*)jl_typeof(v))->name == jl_tuple_typename)
#define jl_is_namedtuple(v) (((jl_datatype_t*)jl_typeof(v))->name == jl_namedtuple_typename)
#define jl_is_svec(v) jl_typeis(v,jl_simplevector_type)
#define jl_is_simplevector(v) jl_is_svec(v)
#define jl_is_datatype(v) jl_typeis(v,jl_datatype_type)
Expand Down
3 changes: 3 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2193,3 +2193,6 @@ j30385(T, y) = k30385(f30385(T, y))

@test @inferred(j30385(AbstractFloat, 1)) == 1
@test @inferred(j30385(:dummy, 1)) == "dummy"

@test Base.return_types(Tuple, (NamedTuple{<:Any,Tuple{Any,Int}},)) == Any[Tuple{Any,Int}]
@test Base.return_types(Base.splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}]

0 comments on commit 7772486

Please sign in to comment.