Skip to content

Commit

Permalink
Let muladd accept a more restricted set of arrays (JuliaLang#38250)
Browse files Browse the repository at this point in the history
This adjusts JuliaLang#37065 to be much more cautious about what arrays it acts on: it calls mul! on StridedArrays, treats a few special types like Diagonal, UpperTriangular, and UniformScaling, and sends anything else to muladd(A,y,z) = A*y .+ z.

However this broadcasting restricts the shape of z, mostly such that A*y .= z would work. That ensures you should get the same error from the mul!(::StridedMatrix, ...) method, as from the fallback broadcasting one. Both allow z of lower dimension than the existing muladd(x,y,z) = x*y+z.

But x*y+z also allows z to have trailing dimensions, as long as they are of size 1. I made the broadcasting method allow these too, which I think should make this non-breaking. (I presume this is rarely used, and thus not worth sending to the fast method.)

Structured matrices such as UpperTriangular should all go to x*y+z. Some combinations could be made more efficient but it gets complicated. Only the case of 3 diagonals is handled.
  • Loading branch information
mcabbott committed Nov 20, 2020
1 parent 4fa9e32 commit 5be3e27
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 52 deletions.
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,3 +752,7 @@ function logabsdet(A::Diagonal)
mapreduce(x -> (log(abs(x)), sign(x)), ((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2),
A.diag)
end

function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
Diagonal(A.diag .* B.diag .+ z.diag)
end
56 changes: 38 additions & 18 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,34 +201,54 @@ julia> muladd(A, B, z)
107.0 107.0
```
"""
function Base.muladd(A::AbstractMatrix{TA}, y::AbstractVector{Ty}, z) where {TA, Ty}
T = promote_type(TA, Ty, eltype(z))
function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray})
Ay = A * y
for d in 1:ndims(Ay)
# Same error as Ay .+= z would give, to match StridedMatrix method:
size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination"))
end
for d in ndims(Ay)+1:ndims(z)
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
Ay .+ z
end

function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray})
if size(z,1) > length(u) || size(z,2) > length(v)
# Same error as (u*v) .+= z:
throw(DimensionMismatch("array could not be broadcast to match destination"))
end
for d in 3:ndims(z)
# Similar error to (u*v) + z:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
(u .* v) .+ z
end

Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
muladd(A', x', z')'
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
transpose(muladd(transpose(A), transpose(x), transpose(z)))

StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector})
T = promote_type(eltype(A), eltype(y), eltype(z))
C = similar(A, T, axes(A,1))
C .= z
mul!(C, A, y, true, true)
end

function Base.muladd(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}, z) where {TA, TB}
T = promote_type(TA, TB, eltype(z))
function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat})
T = promote_type(eltype(A), eltype(B), eltype(z))
C = similar(A, T, axes(A,1), axes(B,2))
C .= z
mul!(C, A, B, true, true)
end

Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z) = muladd(A', x', z')'
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z) = transpose(muladd(transpose(A), transpose(x), transpose(z)))

function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z)
ndims(z) > 2 && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
(u .* v) .+ z
end

function Base.muladd(u::AdjOrTransAbsVec, v::AbstractVector, z)
uv = _dot_nonrecursive(u, v)
ndims(z) > ndims(uv) && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
uv .+ z
end

"""
mul!(Y, A, B) -> Y
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,12 @@ Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m)
dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y)
dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y))
dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y)

# muladd
Base.muladd(A::UniformScaling, B::UniformScaling, z::UniformScaling) =
UniformScaling(A.λ * B.λ + z.λ)
Base.muladd(A::Union{Diagonal, UniformScaling}, B::Union{Diagonal, UniformScaling}, z::Union{Diagonal, UniformScaling}) =
Diagonal(_diag_or_value(A) .* _diag_or_value(B) .+ _diag_or_value(z))

_diag_or_value(A::Diagonal) = A.diag
_diag_or_value(A::UniformScaling) = A.λ
130 changes: 96 additions & 34 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,45 +293,107 @@ end
end

@testset "muladd" begin
A23 = reshape(1:6, 2,3)
A23 = reshape(1:6, 2,3) .+ 0
B34 = reshape(1:12, 3,4) .+ im
u2 = [10,20]
v3 = [3,5,7] .+ im
w4 = [11,13,17,19im]

@test muladd(A23, B34, 100) == A23 * B34 .+ 100
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
@test_throws DimensionMismatch muladd(B34, A23, 1)
@test_throws DimensionMismatch muladd(A23, B34, ones(2,4,1))

@test muladd(A23, v3, 100) == A23 * v3 .+ 100
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))

@test muladd(v3', B34, 0) isa Adjoint
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4))

@test muladd(u2, v3', 0) isa Matrix
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
@test muladd(u2, v3', A23) == u2 * v3' .+ A23
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))

@test muladd(u2', u2, 0) isa Number
@test muladd(v3', v3, im) == dot(v3,v3) + im
@test_throws DimensionMismatch muladd(v3', v3, [1])

vofm = [rand(1:9,2,2) for _ in 1:3]
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
@test_broken muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
@testset "matrix-matrix" begin
@test muladd(A23, B34, 0) == A23 * B34
@test muladd(A23, B34, 100) == A23 * B34 .+ 100
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
@test_throws DimensionMismatch muladd(B34, A23, 1)
@test muladd(ones(1,3), ones(3,4), ones(1,4)) == fill(4.0,1,4)
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4))

# broadcasting fallback method allows trailing dims
@test muladd(A23, B34, ones(2,4,1)) == A23 * B34 + ones(2,4,1)
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4,1))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(1,4,9))
# and catches z::Array{T,0}
@test muladd(A23, B34, fill(0)) == A23 * B34
end
@testset "matrix-vector" begin
@test muladd(A23, v3, 0) == A23 * v3
@test muladd(A23, v3, 100) == A23 * v3 .+ 100
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
@test muladd(ones(1,3), ones(3), ones(1)) == [4]
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7))

# fallback
@test muladd(A23, v3, ones(2,1,1)) == A23 * v3 + ones(2,1,1)
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7,1))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(1,7))
@test muladd(A23, v3, fill(0)) == A23 * v3
end
@testset "adjoint-matrix" begin
@test muladd(v3', B34, 0) isa Adjoint
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'

# via fallback
@test muladd(v3', B34, ones(1,4)) == (B34' * v3 + ones(4,1))'
@test_throws DimensionMismatch muladd(v3', B34, ones(7,4))
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4,7))
@test muladd(v3', B34, fill(0)) == v3' * B34 # does not make an Adjoint
end
@testset "vector-adjoint" begin
@test muladd(u2, v3', 0) isa Matrix
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
@test muladd(u2, v3', A23) == u2 * v3' .+ A23

@test muladd(u2, v3', ones(2,3,1)) == u2 * v3' + ones(2,3,1)
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))
@test_throws DimensionMismatch muladd([1], v3', ones(7,3))
@test muladd(u2, v3', fill(0)) == u2 * v3'
end
@testset "dot" begin # all use muladd(::Any, ::Any, ::Any)
@test muladd(u2', u2, 0) isa Number
@test muladd(v3', v3, im) == dot(v3,v3) + im
@test muladd(u2', u2, [1]) == [dot(u2,u2) + 1]
@test_throws DimensionMismatch muladd(u2', u2, [1,1]) == [dot(u2,u2) + 1]
@test muladd(u2', u2, fill(0)) == dot(u2,u2)
end
@testset "arrays of arrays" begin
vofm = [rand(1:9,2,2) for _ in 1:3]
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
@test muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
end
end

@testset "muladd & structured matrices" begin
A33 = reshape(1:9, 3,3) .+ im
v3 = [3,5,7im]

# no special treatment
@test muladd(Symmetric(A33), Symmetric(A33), 1) == Symmetric(A33) * Symmetric(A33) .+ 1
@test muladd(Hermitian(A33), Hermitian(A33), v3) == Hermitian(A33) * Hermitian(A33) .+ v3
@test muladd(adjoint(A33), transpose(A33), A33) == A33' * transpose(A33) .+ A33

u1 = muladd(UpperTriangular(A33), UpperTriangular(A33), Diagonal(v3))
@test u1 isa UpperTriangular
@test u1 == UpperTriangular(A33) * UpperTriangular(A33) + Diagonal(v3)

# diagonal
@test muladd(Diagonal(v3), Diagonal(A33), Diagonal(v3)).diag == ([1,5,9] .+ im .+ 1) .* v3

# uniformscaling
@test muladd(Diagonal(v3), I, I).diag == v3 .+ 1
@test muladd(2*I, 3*I, I).λ == 7
@test muladd(A33, A33', I) == A33 * A33' + I

# https://github.com/JuliaLang/julia/issues/38426
@test @evalpoly(A33, 1.0*I, 1.0*I) == I + A33
@test @evalpoly(A33, 1.0*I, 1.0*I, 1.0*I) == I + A33 + A33^2
end

# issue #6450
Expand Down

0 comments on commit 5be3e27

Please sign in to comment.