Skip to content

Commit

Permalink
Improve partial inference of key container operations (#37084)
Browse files Browse the repository at this point in the history
`resize!` and `sizehint!` are big targets for invalidation, and quite
a lot of code was calling these methods with poor inference about the
size. Since the size is just `Int`, asserting that eliminates numerous
vulnerabilities. Likewise, bounds-checking always works with `Bool`.
Finally, add a couple of short-circuits for `eltype(::Tuple{T,...})`,
including one that bypasses `typejoin`.
  • Loading branch information
timholy committed Aug 26, 2020
1 parent 69eadbc commit 1c9c241
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 12 deletions.
8 changes: 4 additions & 4 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
Return `true` if the indices of `A` start with something other than 1 along any axis.
If multiple arguments are passed, equivalent to `has_offset_axes(A) | has_offset_axes(B) | ...`.
"""
has_offset_axes(A) = _tuple_any(x->first(x)!=1, axes(A))
has_offset_axes(A) = _tuple_any(x->Int(first(x))::Int != 1, axes(A))
has_offset_axes(A...) = _tuple_any(has_offset_axes, A)
has_offset_axes(::Colon) = false

Expand Down Expand Up @@ -588,11 +588,11 @@ See also [`checkbounds`](@ref).
"""
function checkbounds_indices(::Type{Bool}, IA::Tuple, I::Tuple)
@_inline_meta
checkindex(Bool, IA[1], I[1]) & checkbounds_indices(Bool, tail(IA), tail(I))
checkindex(Bool, IA[1], I[1])::Bool & checkbounds_indices(Bool, tail(IA), tail(I))
end
function checkbounds_indices(::Type{Bool}, ::Tuple{}, I::Tuple)
@_inline_meta
checkindex(Bool, OneTo(1), I[1]) & checkbounds_indices(Bool, (), tail(I))
checkindex(Bool, OneTo(1), I[1])::Bool & checkbounds_indices(Bool, (), tail(I))
end
checkbounds_indices(::Type{Bool}, IA::Tuple, ::Tuple{}) = (@_inline_meta; all(x->unsafe_length(x)==1, IA))
checkbounds_indices(::Type{Bool}, ::Tuple{}, ::Tuple{}) = true
Expand Down Expand Up @@ -1065,7 +1065,7 @@ end
pointer(x::AbstractArray{T}) where {T} = unsafe_convert(Ptr{T}, x)
function pointer(x::AbstractArray{T}, i::Integer) where T
@_inline_meta
unsafe_convert(Ptr{T}, x) + _memory_offset(x, i)
unsafe_convert(Ptr{T}, x) + Int(_memory_offset(x, i))::Int
end

# The distance from pointer(x) to the element at x[I...] in bytes
Expand Down
2 changes: 1 addition & 1 deletion base/abstractset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ max_values(::Type{Bool}) = 2
max_values(::Type{Nothing}) = 1

function union!(s::AbstractSet{T}, itr) where T
haslength(itr) && sizehint!(s, length(s) + length(itr))
haslength(itr) && sizehint!(s, length(s) + Int(length(itr))::Int)
for x in itr
push!(s, x)
length(s) == max_values(T) && break
Expand Down
2 changes: 1 addition & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ push!(a::AbstractVector, iter...) = append!(a, iter)
function _append!(a, ::Union{HasLength,HasShape}, iter)
n = length(a)
i = lastindex(a)
resize!(a, n+length(iter))
resize!(a, n+Int(length(iter))::Int)
@inbounds for (i, item) in zip(i+1:lastindex(a), iter)
a[i] = item
end
Expand Down
6 changes: 3 additions & 3 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mutable struct Dict{K,V} <: AbstractDict{K,V}
end
function Dict{K,V}(kv) where V where K
h = Dict{K,V}()
haslength(kv) && sizehint!(h, length(kv))
haslength(kv) && sizehint!(h, Int(length(kv))::Int)
for (k,v) in kv
h[k] = v
end
Expand Down Expand Up @@ -166,7 +166,7 @@ end

empty(a::AbstractDict, ::Type{K}, ::Type{V}) where {K, V} = Dict{K, V}()

hashindex(key, sz) = (((hash(key)%Int) & (sz-1)) + 1)::Int
hashindex(key, sz) = (((hash(key)::UInt % Int) & (sz-1)) + 1)::Int

@propagate_inbounds isslotempty(h::Dict, i::Int) = h.slots[i] == 0x0
@propagate_inbounds isslotfilled(h::Dict, i::Int) = h.slots[i] == 0x1
Expand Down Expand Up @@ -239,7 +239,7 @@ function sizehint!(d::Dict{T}, newsz) where T
end
# grow at least 25%
newsz = min(max(newsz, (oldsz*5)>>2),
max_values(T))
max_values(T)::Int)
rehash!(d, newsz)
end

Expand Down
2 changes: 1 addition & 1 deletion base/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ function get!(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V
if val === secret_table_token
val = default()
if !isa(val, V)
val = convert(V, val)
val = convert(V, val)::V
end
setindex!(d, val, key)
return val
Expand Down
2 changes: 1 addition & 1 deletion base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int, quote_level::In
end
print(io, ")")
else
escape_string(io, x, "\"\$")
escape_string(io, String(x)::String, "\"\$")
end
end
print(io, '"')
Expand Down
5 changes: 5 additions & 0 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,16 @@ function _compute_eltype(t::Type{<:Tuple})
@_pure_meta
@nospecialize t
t isa Union && return promote_typejoin(eltype(t.a), eltype(t.b))
# Given t = Tuple{Vararg{S}} where S<:Real, the various
# unwrapping/wrapping/va-handling here will return Real
= unwrap_unionall(t)
# TODO: handle Union/UnionAll correctly here
# For Tuple{T}, short-circuit promote_typejoin
length(t´.parameters) == 1 && return rewrap_unionall(unwrapva(t´.parameters[1]), t)
r = Union{}
for ti in.parameters
r = promote_typejoin(r, rewrap_unionall(unwrapva(ti), t))
r === Any && break # if we've already reached Any, it can't widen any more
end
return r
end
Expand Down
9 changes: 8 additions & 1 deletion test/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ mutable struct T10647{T}; x::T; end
Base.show(Base.IOContext(IOBuffer(), :limit => true), a)
end

@testset "IdDict{Any,Any}" begin
@testset "IdDict{Any,Any} and partial inference" begin
a = IdDict{Any,Any}()
a[1] = a
a[a] = 2
Expand Down Expand Up @@ -478,6 +478,13 @@ end
@test isa(d, IdDict{Any,Any})
@test d == IdDict{Any,Any}(1=>1, 2=>2, 3=>3)
@test eltype(d) == Pair{Any,Any}

d = IdDict{Any,Int32}(:hi => 7)
let c = Ref{Any}(1.5)
f() = c[]
@test @inferred(get!(f, d, :hi)) === Int32(7)
@test_throws InexactError(:Int32, Int32, 1.5) get!(f, d, :hello)
end
end

@testset "IdDict" begin
Expand Down

0 comments on commit 1c9c241

Please sign in to comment.