Skip to content

Commit

Permalink
Merge pull request #49123 from JuliaLang/jn/limit-converts
Browse files Browse the repository at this point in the history
Avoid calling no-op convert most of the time, syntactically.

This is not quite as elegant, but dispatch can have a notable cost for
latency, so doing something which doesn't generate edges can improve
the compiler (since it is then impossible for the user to introduce
dispatch changes into it). This is enough of an apparent latency /
memory improvement due to the simplification of the abstract model that
I think it is worth doing.
  • Loading branch information
vtjnash committed Apr 10, 2023
2 parents 35e4a1f + ffca15a commit f48194c
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 97 deletions.
40 changes: 30 additions & 10 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ getproperty(x::Tuple, f::Int) = (@inline; getfield(x, f))
setproperty!(x::Tuple, f::Int, v) = setfield!(x, f, v) # to get a decent error

getproperty(x, f::Symbol) = (@inline; getfield(x, f))
setproperty!(x, f::Symbol, v) = setfield!(x, f, convert(fieldtype(typeof(x), f), v))
function setproperty!(x, f::Symbol, v)
ty = fieldtype(typeof(x), f)
val = v isa ty ? v : convert(ty, v)
return setfield!(x, f, val)
end

dotgetproperty(x, f) = getproperty(x, f)

getproperty(x::Module, f::Symbol, order::Symbol) = (@inline; getglobal(x, f, order))
function setproperty!(x::Module, f::Symbol, v, order::Symbol=:monotonic)
@inline
val::Core.get_binding_type(x, f) = v
ty = Core.get_binding_type(x, f)
val = v isa ty ? v : convert(ty, v)
return setglobal!(x, f, val, order)
end
getproperty(x::Type, f::Symbol, order::Symbol) = (@inline; getfield(x, f, order))
Expand All @@ -51,14 +56,29 @@ getproperty(x::Tuple, f::Int, order::Symbol) = (@inline; getfield(x, f, order))
setproperty!(x::Tuple, f::Int, v, order::Symbol) = setfield!(x, f, v, order) # to get a decent error

getproperty(x, f::Symbol, order::Symbol) = (@inline; getfield(x, f, order))
setproperty!(x, f::Symbol, v, order::Symbol) = (@inline; setfield!(x, f, convert(fieldtype(typeof(x), f), v), order))
function setproperty!(x, f::Symbol, v, order::Symbol)
@inline
ty = fieldtype(typeof(x), f)
val = v isa ty ? v : convert(ty, v)
return setfield!(x, f, val, order)
end

swapproperty!(x, f::Symbol, v, order::Symbol=:not_atomic) =
(@inline; Core.swapfield!(x, f, convert(fieldtype(typeof(x), f), v), order))
modifyproperty!(x, f::Symbol, op, v, order::Symbol=:not_atomic) =
(@inline; Core.modifyfield!(x, f, op, v, order))
replaceproperty!(x, f::Symbol, expected, desired, success_order::Symbol=:not_atomic, fail_order::Symbol=success_order) =
(@inline; Core.replacefield!(x, f, expected, convert(fieldtype(typeof(x), f), desired), success_order, fail_order))
function swapproperty!(x, f::Symbol, v, order::Symbol=:not_atomic)
@inline
ty = fieldtype(typeof(x), f)
val = v isa ty ? v : convert(ty, v)
return Core.swapfield!(x, f, val, order)
end
function modifyproperty!(x, f::Symbol, op, v, order::Symbol=:not_atomic)
@inline
return Core.modifyfield!(x, f, op, v, order)
end
function replaceproperty!(x, f::Symbol, expected, desired, success_order::Symbol=:not_atomic, fail_order::Symbol=success_order)
@inline
ty = fieldtype(typeof(x), f)
val = desired isa ty ? desired : convert(ty, desired)
return Core.replacefield!(x, f, expected, val, success_order, fail_order)
end

convert(::Type{Any}, Core.@nospecialize x) = x
convert(::Type{T}, x::T) where {T} = x
Expand Down Expand Up @@ -149,7 +169,7 @@ include("refpointer.jl")
delete_method(which(Pair{Any,Any}, (Any, Any)))
@eval function (P::Type{Pair{A, B}})(@nospecialize(a), @nospecialize(b)) where {A, B}
@inline
return $(Expr(:new, :P, :(convert(A, a)), :(convert(B, b))))
return $(Expr(:new, :P, :(a isa A ? a : convert(A, a)), :(b isa B ? b : convert(B, b))))
end

# The REPL stdlib hooks into Base using this Ref
Expand Down
18 changes: 9 additions & 9 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ function vect(X...)
return T[X...]
end

size(a::Array, d::Integer) = arraysize(a, convert(Int, d))
size(a::Array, d::Integer) = arraysize(a, d isa Int ? d : convert(Int, d))
size(a::Vector) = (arraysize(a,1),)
size(a::Matrix) = (arraysize(a,1), arraysize(a,2))
size(a::Array{<:Any,N}) where {N} = (@inline; ntuple(M -> size(a, M), Val(N))::Dims)
Expand Down Expand Up @@ -383,7 +383,7 @@ copyto!(dest::Array{T}, src::Array{T}) where {T} = copyto!(dest, 1, src, 1, leng
# N.B: The generic definition in multidimensional.jl covers, this, this is just here
# for bootstrapping purposes.
function fill!(dest::Array{T}, x) where T
xT = convert(T, x)
xT = x isa T ? x : convert(T, x)::T
for i in eachindex(dest)
@inbounds dest[i] = xT
end
Expand Down Expand Up @@ -475,7 +475,7 @@ end
getindex(::Type{Any}) = Vector{Any}()

function fill!(a::Union{Array{UInt8}, Array{Int8}}, x::Integer)
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), a, convert(eltype(a), x), length(a))
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), a, x isa eltype(a) ? x : convert(eltype(a), x), length(a))
return a
end

Expand Down Expand Up @@ -1020,9 +1020,9 @@ Dict{String, Int64} with 2 entries:
function setindex! end

@eval setindex!(A::Array{T}, x, i1::Int) where {T} =
arrayset($(Expr(:boundscheck)), A, convert(T,x)::T, i1)
arrayset($(Expr(:boundscheck)), A, x isa T ? x : convert(T,x)::T, i1)
@eval setindex!(A::Array{T}, x, i1::Int, i2::Int, I::Int...) where {T} =
(@inline; arrayset($(Expr(:boundscheck)), A, convert(T,x)::T, i1, i2, I...))
(@inline; arrayset($(Expr(:boundscheck)), A, x isa T ? x : convert(T,x)::T, i1, i2, I...))

__inbounds_setindex!(A::Array{T}, x, i1::Int) where {T} =
arrayset(false, A, convert(T,x)::T, i1)
Expand Down Expand Up @@ -1116,7 +1116,7 @@ function push! end

function push!(a::Vector{T}, item) where T
# convert first so we don't grow the array if the assignment won't work
itemT = convert(T, item)
itemT = item isa T ? item : convert(T, item)::T
_growend!(a, 1)
@_safeindex a[length(a)] = itemT
return a
Expand Down Expand Up @@ -1466,7 +1466,7 @@ julia> pushfirst!([1, 2, 3, 4], 5, 6)
```
"""
function pushfirst!(a::Vector{T}, item) where T
item = convert(T, item)
item = item isa T ? item : convert(T, item)::T
_growbeg!(a, 1)
@_safeindex a[1] = item
return a
Expand Down Expand Up @@ -1553,7 +1553,7 @@ julia> insert!(Any[1:6;], 3, "here")
"""
function insert!(a::Array{T,1}, i::Integer, item) where T
# Throw convert error before changing the shape of the array
_item = convert(T, item)
_item = item isa T ? item : convert(T, item)::T
_growat!(a, i, 1)
# _growat! already did bound check
@inbounds a[i] = _item
Expand Down Expand Up @@ -2194,7 +2194,7 @@ findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::AbstractUnitR
function findfirst(p::Union{Fix2{typeof(isequal),T},Fix2{typeof(==),T}}, r::StepRange{T,S}) where {T,S}
isempty(r) && return nothing
minimum(r) <= p.x <= maximum(r) || return nothing
d = convert(S, p.x - first(r))
d = convert(S, p.x - first(r))::S
iszero(d % step(r)) || return nothing
return d ÷ step(r) + 1
end
Expand Down
16 changes: 14 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1736,22 +1736,34 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
elseif has_conditional(𝕃ᵢ, sv) && (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any})
# perform very limited back-propagation of type information for `is` and `isa`
if f === isa
# try splitting value argument, based on types
a = ssa_def_slot(fargs[2], sv)
a2 = argtypes[2]
a3 = argtypes[3]
if isa(a, SlotNumber)
cndt = isa_condition(a2, argtypes[3], InferenceParams(interp).max_union_splitting, rt)
cndt = isa_condition(a2, a3, InferenceParams(interp).max_union_splitting, rt)
if cndt !== nothing
return Conditional(a, cndt.thentype, cndt.elsetype)
end
end
if isa(a2, MustAlias)
if !isa(rt, Const) # skip refinement when the field is known precisely (just optimization)
cndt = isa_condition(a2, argtypes[3], InferenceParams(interp).max_union_splitting)
cndt = isa_condition(a2, a3, InferenceParams(interp).max_union_splitting)
if cndt !== nothing
return form_mustalias_conditional(a2, cndt.thentype, cndt.elsetype)
end
end
end
# try splitting type argument, based on value
if isdispatchelem(widenconst(a2)) && a3 isa Union && !has_free_typevars(a3) && !isa(rt, Const)
b = ssa_def_slot(fargs[3], sv)
if isa(b, SlotNumber)
# !(x isa T) implies !(Type{a2} <: T)
# TODO: complete splitting, based on which portions of the Union a3 for which isa_tfunc returns Const(true) or Const(false) instead of Bool
elsetype = typesubtract(a3, Type{widenconst(a2)}, InferenceParams(interp).max_union_splitting)
return Conditional(b, a3, elsetype)
end
end
elseif f === (===)
a = ssa_def_slot(fargs[2], sv)
b = ssa_def_slot(fargs[3], sv)
Expand Down
40 changes: 29 additions & 11 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,19 @@ ht_keyindex2!(h::Dict, key) = ht_keyindex2_shorthash!(h, key)[1]
end

function setindex!(h::Dict{K,V}, v0, key0) where V where K
key = convert(K, key0)
if !(isequal(key, key0)::Bool)
throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K"))
if key0 isa K
key = key0
else
key = convert(K, key0)::K
if !(isequal(key, key0)::Bool)
throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K"))
end
end
setindex!(h, v0, key)
end

function setindex!(h::Dict{K,V}, v0, key::K) where V where K
v = convert(V, v0)
v = v0 isa V ? v0 : convert(V, v0)::V
index, sh = ht_keyindex2_shorthash!(h, key)

if index > 0
Expand Down Expand Up @@ -453,9 +457,13 @@ Dict{Int64, Int64} with 1 entry:
get!(f::Callable, collection, key)

function get!(default::Callable, h::Dict{K,V}, key0) where V where K
key = convert(K, key0)
if !isequal(key, key0)
throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K"))
if key0 isa K
key = key0
else
key = convert(K, key0)::K
if !isequal(key, key0)
throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K"))
end
end
return get!(default, h, key)
end
Expand All @@ -466,7 +474,10 @@ function get!(default::Callable, h::Dict{K,V}, key::K) where V where K
index > 0 && return h.vals[index]

age0 = h.age
v = convert(V, default())
v = default()
if !isa(v, V)
v = convert(V, v)::V
end
if h.age != age0
index, sh = ht_keyindex2_shorthash!(h, key)
end
Expand Down Expand Up @@ -756,10 +767,17 @@ function mergewith!(combine, d1::Dict{K, V}, d2::AbstractDict) where {K, V}
if i > 0
d1.vals[i] = combine(d1.vals[i], v)
else
if !isequal(k, convert(K, k))
throw(ArgumentError("$(limitrepr(k)) is not a valid key for type $K"))
if !(k isa K)
k1 = convert(K, k)::K
if !isequal(k, k1)
throw(ArgumentError("$(limitrepr(k)) is not a valid key for type $K"))
end
k = k1
end
if !isa(v, V)
v = convert(V, v)::V
end
@inbounds _setindex!(d1, convert(V, v), k, -i, sh)
@inbounds _setindex!(d1, v, k, -i, sh)
end
end
return d1
Expand Down
14 changes: 10 additions & 4 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ end
@eval struct Pairs{K, V, I, A} <: AbstractDict{K, V}
data::A
itr::I
Pairs{K, V, I, A}(data, itr) where {K, V, I, A} = $(Expr(:new, :(Pairs{K, V, I, A}), :(convert(A, data)), :(convert(I, itr))))
Pairs{K, V, I, A}(data, itr) where {K, V, I, A} = $(Expr(:new, :(Pairs{K, V, I, A}), :(data isa A ? data : convert(A, data)), :(itr isa I ? itr : convert(I, itr))))
Pairs{K, V}(data::A, itr::I) where {K, V, I, A} = $(Expr(:new, :(Pairs{K, V, I, A}), :data, :itr))
Pairs{K}(data::A, itr::I) where {K, I, A} = $(Expr(:new, :(Pairs{K, eltype(A), I, A}), :data, :itr))
Pairs(data::A, itr::I) where {I, A} = $(Expr(:new, :(Pairs{eltype(I), eltype(A), I, A}), :data, :itr))
Expand Down Expand Up @@ -459,7 +459,13 @@ function convert(::Type{T}, x::NTuple{N,Any}) where {N, T<:Tuple}
if typeintersect(NTuple{N,Any}, T) === Union{}
_tuple_error(T, x)
end
cvt1(n) = (@inline; convert(fieldtype(T, n), getfield(x, n, #=boundscheck=#false)))
function cvt1(n)
@inline
Tn = fieldtype(T, n)
xn = getfield(x, n, #=boundscheck=#false)
xn isa Tn && return xn
return convert(Tn, xn)
end
return ntuple(cvt1, Val(N))::NTuple{N,Any}
end

Expand Down Expand Up @@ -512,7 +518,7 @@ julia> oftype(y, x)
4.0
```
"""
oftype(x, y) = convert(typeof(x), y)
oftype(x, y) = y isa typeof(x) ? y : convert(typeof(x), y)::typeof(x)

unsigned(x::Int) = reinterpret(UInt, x)
signed(x::UInt) = reinterpret(Int, x)
Expand All @@ -533,7 +539,7 @@ Neither `convert` nor `cconvert` should take a Julia object and turn it into a `
"""
function cconvert end

cconvert(T::Type, x) = convert(T, x) # do the conversion eagerly in most cases
cconvert(T::Type, x) = x isa T ? x : convert(T, x) # do the conversion eagerly in most cases
cconvert(::Type{<:Ptr}, x) = x # but defer the conversion to Ptr to unsafe_convert
unsafe_convert(::Type{T}, x::T) where {T} = x # unsafe_convert (like convert) defaults to assuming the convert occurred
unsafe_convert(::Type{T}, x::T) where {T<:Ptr} = x # to resolve ambiguity with the next method
Expand Down
4 changes: 2 additions & 2 deletions base/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ end
function setindex!(d::IdDict{K,V}, @nospecialize(val), @nospecialize(key)) where {K, V}
!isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K"))
if !(val isa V) # avoid a dynamic call
val = convert(V, val)
val = convert(V, val)::V
end
if d.ndel >= ((3*length(d.ht))>>2)
rehash!(d, max((length(d.ht)%UInt)>>1, 32))
Expand Down Expand Up @@ -155,7 +155,7 @@ copy(d::IdDict) = typeof(d)(d)
function get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V}
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
if val === secret_table_token
val = isa(default, V) ? default : convert(V, default)
val = isa(default, V) ? default : convert(V, default)::V
setindex!(d, val, key)
return val
else
Expand Down
3 changes: 2 additions & 1 deletion base/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ function logmsg_code(_module, file, line, level, message, exs...)
return quote
let
level = $level
std_level = convert(LogLevel, level)
# simplify std_level code emitted, if we know it is one of our global constants
std_level = $(level isa Symbol ? :level : :(level isa LogLevel ? level : convert(LogLevel, level)::LogLevel))
if std_level >= _min_enabled_level[]
group = $(log_data._group)
_module = $(log_data._module)
Expand Down
7 changes: 5 additions & 2 deletions base/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ end
function NamedTuple{names, T}(nt::NamedTuple) where {names, T <: Tuple}
if @generated
Expr(:new, :(NamedTuple{names, T}),
Any[ :(convert(fieldtype(T, $n), getfield(nt, $(QuoteNode(names[n]))))) for n in 1:length(names) ]...)
Any[ :(let Tn = fieldtype(T, $n),
ntn = getfield(nt, $(QuoteNode(names[n])))
ntn isa Tn ? ntn : convert(Tn, ntn)
end) for n in 1:length(names) ]...)
else
NamedTuple{names, T}(map(Fix1(getfield, nt), names))
end
Expand Down Expand Up @@ -195,7 +198,7 @@ end

if nameof(@__MODULE__) === :Base
Tuple(nt::NamedTuple) = (nt...,)
(::Type{T})(nt::NamedTuple) where {T <: Tuple} = convert(T, Tuple(nt))
(::Type{T})(nt::NamedTuple) where {T <: Tuple} = (t = Tuple(nt); t isa T ? t : convert(T, t)::T)
end

function show(io::IO, t::NamedTuple)
Expand Down
6 changes: 5 additions & 1 deletion base/pair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ last(p::Pair) = p.second

convert(::Type{Pair{A,B}}, x::Pair{A,B}) where {A,B} = x
function convert(::Type{Pair{A,B}}, x::Pair) where {A,B}
Pair{A,B}(convert(A, x[1]), convert(B, x[2]))::Pair{A,B}
a = getfield(x, :first)
a isa A || (a = convert(A, a))
b = getfield(x, :second)
b isa B || (b = convert(B, b))
return Pair{A,B}(a, b)::Pair{A,B}
end

promote_rule(::Type{Pair{A1,B1}}, ::Type{Pair{A2,B2}}) where {A1,B1,A2,B2} =
Expand Down
4 changes: 2 additions & 2 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ reducedim_init(f, op::typeof(|), A::AbstractArrayOrBroadcasted, region) = reduce
let
BitIntFloat = Union{BitInteger, IEEEFloat}
T = Union{
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}
Any[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
Any[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}

global function reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region)
z = zero(f(zero(eltype(A))))
Expand Down
6 changes: 4 additions & 2 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ function tuple_type_tail(T::Type)
end
end

(::Type{T})(x::Tuple) where {T<:Tuple} = convert(T, x) # still use `convert` for tuples
(::Type{T})(x::Tuple) where {T<:Tuple} = x isa T ? x : convert(T, x) # still use `convert` for tuples

Tuple(x::Ref) = tuple(getindex(x)) # faster than iterator for one element
Tuple(x::Array{T,0}) where {T} = tuple(getindex(x))
Expand All @@ -395,7 +395,9 @@ function _totuple(::Type{T}, itr, s::Vararg{Any,N}) where {T,N}
@inline
y = iterate(itr, s...)
y === nothing && _totuple_err(T)
t1 = convert(fieldtype(T, 1), y[1])
T1 = fieldtype(T, 1)
y1 = y[1]
t1 = y1 isa T1 ? y1 : convert(T1, y1)::T1
# inference may give up in recursive calls, so annotate here to force accurate return type to be propagated
rT = tuple_type_tail(T)
ts = _totuple(rT, itr, y[2])::rT
Expand Down
Loading

2 comments on commit f48194c

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected.
A full report can be found here.

Please sign in to comment.