Skip to content

Commit

Permalink
some more improvements
Browse files Browse the repository at this point in the history
- avoid assuming the tuples have the same eltype
- coerce the inputs to simpler representations before printing, if
  unambiguous
- update a couple more error messages with the same text
- use string builder, instead of repeated concatenation
  • Loading branch information
vtjnash committed Feb 3, 2024
1 parent 74a626c commit ce7a16e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
33 changes: 22 additions & 11 deletions base/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,33 @@ IndexStyle(::IndexStyle, ::IndexStyle) = IndexCartesian()

promote_shape(::Tuple{}, ::Tuple{}) = ()

# Consistent error message for promote_shape mismatch, hiding implementation details like
# Consistent error message for promote_shape mismatch, hiding type details like
# OneTo. When b ≡ nothing, it is omitted; i can be supplied for an index.
function throw_promote_shape_mismatch(a::Tuple{T,Vararg{T}},
b::Union{Nothing,Tuple{T,Vararg{T}}},
i = nothing) where {T}
_has_axes = T <: AbstractUnitRange
_normalize(d) = map(x -> _has_axes ? (firstindex(x):lastindex(x)) : x, d)
_things = _has_axes ? "axes" : "size"
msg = "a has $(_things) $(_normalize(a))"
function throw_promote_shape_mismatch(a::Tuple, b::Union{Nothing,Tuple}, i = nothing)
if a isa Tuple{Vararg{Base.OneTo}} && (b === nothing || b isa Tuple{Vararg{Base.OneTo}})
a = map(lastindex, a)::Dims
b === nothing || (b = map(lastindex, b)::Dims)
end
_has_axes = !(a isa Dims && (b === nothing || b isa Dims))
if _has_axes
_normalize(d) = map(x -> firstindex(x):lastindex(x), d)
a = _normalize(a)
b === nothing || (b = _normalize(b))
_things = "axes "
else
_things = "size "
end
msg = IOBuffer()
print(msg, "a has ", _things)
print(msg, a)
if b nothing
msg *= ", b has $(_things) $(_normalize(b))"
print(msg, ", b has ", _things)
print(msg, b)
end
if i nothing
msg *= ", mismatch at $(i)"
print(msg, ", mismatch at dim ", i)
end
throw(DimensionMismatch(msg))
throw(DimensionMismatch(String(take!(msg))))
end

function promote_shape(a::Tuple{Int,}, b::Tuple{Int,})
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, Ab
end
for d in ndims(Ay)+1:ndims(z)
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
Ay .+ z
Expand All @@ -197,7 +197,7 @@ function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, Ab
end
for d in 3:ndims(z)
# Similar error to (u*v) + z:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
size(z,d) > 1 && throw(DimensionMismatch(string("z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
(u .* v) .+ z
Expand Down
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ end
# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
function map(f, A::StructuredMatrix, Bs::StructuredMatrix...)
sz = size(A)
all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match"))
for B in Bs
size(B) == sz || Base.throw_promote_shape_mismatch(sz, size(B))
end
return f.(A, Bs...)
end
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ end
@test map!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end
# these would be valid for broadcast, but not for map
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)))
@test_throws DimensionMismatch map(+, D, Diagonal(rand(1)), D)
@test_throws DimensionMismatch map(+, D, D, Diagonal(rand(1)))
@test_throws DimensionMismatch map(+, Diagonal(rand(1)), D, D)
end

@testset "Issue #33397" begin
Expand Down

0 comments on commit ce7a16e

Please sign in to comment.