Skip to content

Commit

Permalink
Let muladd accept arrays (#37065)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 29, 2020
1 parent 5c47690 commit 72971c4
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
47 changes: 47 additions & 0 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,53 @@ for elty in (Float32,Float64)
end
end

"""
muladd(A, y, z)
Combined multiply-add, `A*y .+ z`, for matrix-matrix or matrix-vector multiplication.
The result is always the same size as `A*y`, but `z` may be smaller, or a scalar.
!!! compat "Julia 1.6"
These methods require Julia 1.6 or later.
# Examples
```jldoctest
julia> A=[1.0 2.0; 3.0 4.0]; B=[1.0 1.0; 1.0 1.0]; z=[0, 100];
julia> muladd(A, B, z)
2×2 Matrix{Float64}:
3.0 3.0
107.0 107.0
```
"""
function Base.muladd(A::AbstractMatrix{TA}, y::AbstractVector{Ty}, z) where {TA, Ty}
T = promote_type(TA, Ty, eltype(z))
C = similar(A, T, axes(A,1))
C .= z
mul!(C, A, y, true, true)
end

function Base.muladd(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}, z) where {TA, TB}
T = promote_type(TA, TB, eltype(z))
C = similar(A, T, axes(A,1), axes(B,2))
C .= z
mul!(C, A, B, true, true)
end

Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z) = muladd(A', x', z')'
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z) = transpose(muladd(transpose(A), transpose(x), transpose(z)))

function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z)
ndims(z) > 2 && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
(u .* v) .+ z
end

function Base.muladd(u::AdjOrTransAbsVec, v::AbstractVector, z)
uv = _dot_nonrecursive(u, v)
ndims(z) > ndims(uv) && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
uv .+ z
end

"""
mul!(Y, A, B) -> Y
Expand Down
42 changes: 42 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,48 @@ end
end
end

@testset "muladd" begin
A23 = reshape(1:6, 2,3)
B34 = reshape(1:12, 3,4) .+ im
u2 = [10,20]
v3 = [3,5,7] .+ im
w4 = [11,13,17,19im]

@test muladd(A23, B34, 100) == A23 * B34 .+ 100
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
@test_throws DimensionMismatch muladd(B34, A23, 1)
@test_throws DimensionMismatch muladd(A23, B34, ones(2,4,1))

@test muladd(A23, v3, 100) == A23 * v3 .+ 100
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))

@test muladd(v3', B34, 0) isa Adjoint
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4))

@test muladd(u2, v3', 0) isa Matrix
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
@test muladd(u2, v3', A23) == u2 * v3' .+ A23
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))

@test muladd(u2', u2, 0) isa Number
@test muladd(v3', v3, im) == dot(v3,v3) + im
@test_throws DimensionMismatch muladd(v3', v3, [1])

vofm = [rand(1:9,2,2) for _ in 1:3]
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
@test_broken muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
end

# issue #6450
@test dot(Any[1.0,2.0], Any[3.5,4.5]) === 12.5

Expand Down

0 comments on commit 72971c4

Please sign in to comment.