Skip to content

Commit

Permalink
update to match 2nd PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Nov 20, 2020
1 parent 4e5dbbd commit e9421b9
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 51 deletions.
65 changes: 48 additions & 17 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,8 @@ if VERSION < v"1.6.0-DEV.292" # 6cd329c371c1db3d9876bc337e82e274e50420e8
end

# https://github.com/JuliaLang/julia/pull/37065
if VERSION < v"1.6.0-DEV.1368" # 72971c41160720d4182a6486cc155ee7645b5bb1
# https://github.com/JuliaLang/julia/pull/38250
if VERSION < v"1.6.0-DEV.1536" # 5be3e27e029835cb56dd6934d302680c26f6e21b
using LinearAlgebra: mul!, AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

"""
Expand All @@ -840,33 +841,63 @@ if VERSION < v"1.6.0-DEV.1368" # 72971c41160720d4182a6486cc155ee7645b5bb1
107.0 107.0
```
"""
function Base.muladd(A::AbstractMatrix{TA}, y::AbstractVector{Ty}, z) where {TA, Ty}
T = promote_type(TA, Ty, eltype(z))
function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray})
Ay = A * y
for d in 1:ndims(Ay)
# Same error as Ay .+= z would give, to match StridedMatrix method:
size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination"))
end
for d in ndims(Ay)+1:ndims(z)
# Similar error to what Ay + z would give, to match (Any,Any,Any) method:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
Ay .+ z
end

function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray})
if size(z,1) > length(u) || size(z,2) > length(v)
# Same error as (u*v) .+= z:
throw(DimensionMismatch("array could not be broadcast to match destination"))
end
for d in 3:ndims(z)
# Similar error to (u*v) + z:
size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ",
axes(z), ", must have singleton at dim ", d)))
end
(u .* v) .+ z
end

Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
muladd(A', x', z')'
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
transpose(muladd(transpose(A), transpose(x), transpose(z)))

StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector})
T = promote_type(eltype(A), eltype(y), eltype(z))
C = similar(A, T, axes(A,1))
C .= z
mul!(C, A, y, true, true)
end

function Base.muladd(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}, z) where {TA, TB}
T = promote_type(TA, TB, eltype(z))
function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat})
T = promote_type(eltype(A), eltype(B), eltype(z))
C = similar(A, T, axes(A,1), axes(B,2))
C .= z
mul!(C, A, B, true, true)
end

Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z) = muladd(A', x', z')'
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z) = transpose(muladd(transpose(A), transpose(x), transpose(z)))
Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal) =
Diagonal(A.diag .* B.diag .+ z.diag)
Base.muladd(A::UniformScaling, B::UniformScaling, z::UniformScaling) =
UniformScaling(A.λ * B.λ + z.λ)
Base.muladd(A::Union{Diagonal, UniformScaling}, B::Union{Diagonal, UniformScaling}, z::Union{Diagonal, UniformScaling}) =
Diagonal(_diag_or_value(A) .* _diag_or_value(B) .+ _diag_or_value(z))

function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z)
ndims(z) > 2 && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
(u .* v) .+ z
end

function Base.muladd(u::AdjOrTransAbsVec, v::AbstractVector, z)
uv = _dot_nonrecursive(u, v)
ndims(z) > ndims(uv) && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
uv .+ z
end
_diag_or_value(A::Diagonal) = A.diag
_diag_or_value(A::UniformScaling) = A.λ

function _dot_nonrecursive(u, v) # in LinearAlgebra on Julia 1.5
lu = length(u)
Expand Down
130 changes: 96 additions & 34 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,48 +774,110 @@ end
@test sincospi(0.13im) == (sinpi(0.13im), cospi(0.13im))
end

# https://github.com/JuliaLang/julia/pull/37065
# https://github.com/JuliaLang/julia/pull/38250
@testset "muladd" begin
A23 = reshape(1:6, 2,3)
A23 = reshape(1:6, 2,3) .+ 0
B34 = reshape(1:12, 3,4) .+ im
u2 = [10,20]
v3 = [3,5,7] .+ im
w4 = [11,13,17,19im]

@test muladd(A23, B34, 100) == A23 * B34 .+ 100
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
@test_throws DimensionMismatch muladd(B34, A23, 1)
@test_throws DimensionMismatch muladd(A23, B34, ones(2,4,1))

@test muladd(A23, v3, 100) == A23 * v3 .+ 100
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))

@test muladd(v3', B34, 0) isa Adjoint
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4))

@test muladd(u2, v3', 0) isa Matrix
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
@test muladd(u2, v3', A23) == u2 * v3' .+ A23
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))

@test muladd(u2', u2, 0) isa Number
@test muladd(v3', v3, im) == dot(v3,v3) + im
@test_throws DimensionMismatch muladd(v3', v3, [1])

vofm = [rand(1:9,2,2) for _ in 1:3]
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]
if VERSION >= v"1.5"
# Julia 1.4 gets vofm' * vofm wrong, gives a scalar
@testset "matrix-matrix" begin
@test muladd(A23, B34, 0) == A23 * B34
@test muladd(A23, B34, 100) == A23 * B34 .+ 100
@test muladd(A23, B34, u2) == A23 * B34 .+ u2
@test muladd(A23, B34, w4') == A23 * B34 .+ w4'
@test_throws DimensionMismatch muladd(B34, A23, 1)
@test muladd(ones(1,3), ones(3,4), ones(1,4)) == fill(4.0,1,4)
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4))

# broadcasting fallback method allows trailing dims
@test muladd(A23, B34, ones(2,4,1)) == A23 * B34 + ones(2,4,1)
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4,1))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(1,4,9))
# and catches z::Array{T,0}
@test muladd(A23, B34, fill(0)) == A23 * B34
end
@testset "matrix-vector" begin
@test muladd(A23, v3, 0) == A23 * v3
@test muladd(A23, v3, 100) == A23 * v3 .+ 100
@test muladd(A23, v3, u2) == A23 * v3 .+ u2
@test muladd(A23, v3, im) isa Vector{Complex{Int}}
@test muladd(ones(1,3), ones(3), ones(1)) == [4]
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7))

# fallback
@test muladd(A23, v3, ones(2,1,1)) == A23 * v3 + ones(2,1,1)
@test_throws DimensionMismatch muladd(A23, v3, ones(2,2))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7,1))
@test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(1,7))
@test muladd(A23, v3, fill(0)) == A23 * v3
end
@testset "adjoint-matrix" begin
@test muladd(v3', B34, 0) isa Adjoint
@test muladd(v3', B34, 2im) == v3' * B34 .+ 2im
@test muladd(v3', B34, w4') == v3' * B34 .+ w4'

# via fallback
@test muladd(v3', B34, ones(1,4)) == (B34' * v3 + ones(4,1))'
@test_throws DimensionMismatch muladd(v3', B34, ones(7,4))
@test_throws DimensionMismatch muladd(v3', B34, ones(1,4,7))
@test muladd(v3', B34, fill(0)) == v3' * B34 # does not make an Adjoint
end
@testset "vector-adjoint" begin
@test muladd(u2, v3', 0) isa Matrix
@test muladd(u2, v3', 99) == u2 * v3' .+ 99
@test muladd(u2, v3', A23) == u2 * v3' .+ A23

@test muladd(u2, v3', ones(2,3,1)) == u2 * v3' + ones(2,3,1)
@test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4))
@test_throws DimensionMismatch muladd([1], v3', ones(7,3))
@test muladd(u2, v3', fill(0)) == u2 * v3'
end
@testset "dot" begin # all use muladd(::Any, ::Any, ::Any)
@test muladd(u2', u2, 0) isa Number
@test muladd(v3', v3, im) == dot(v3,v3) + im
@test muladd(u2', u2, [1]) == [dot(u2,u2) + 1]
@test_throws DimensionMismatch muladd(u2', u2, [1,1]) == [dot(u2,u2) + 1]
@test muladd(u2', u2, fill(0)) == dot(u2,u2)
end
@testset "arrays of arrays" begin
vofm = [rand(1:9,2,2) for _ in 1:3]
Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3]

@test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat
@test muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec
end
@test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer
@test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat
@test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat

# end
# @testset "muladd & structured matrices" begin

A33 = reshape(1:9, 3,3) .+ im
v3 = [3,5,7im]

# no special treatment
@test muladd(Symmetric(A33), Symmetric(A33), 1) == Symmetric(A33) * Symmetric(A33) .+ 1
@test muladd(Hermitian(A33), Hermitian(A33), v3) == Hermitian(A33) * Hermitian(A33) .+ v3
@test muladd(adjoint(A33), transpose(A33), A33) == A33' * transpose(A33) .+ A33

u1 = muladd(UpperTriangular(A33), UpperTriangular(A33), Diagonal(v3))
@test u1 isa UpperTriangular
@test u1 == UpperTriangular(A33) * UpperTriangular(A33) + Diagonal(v3)

# diagonal
@test muladd(Diagonal(v3), Diagonal(A33), Diagonal(v3)).diag == ([1,5,9] .+ im .+ 1) .* v3

# uniformscaling
@test muladd(Diagonal(v3), I, I).diag == v3 .+ 1
@test muladd(2*I, 3*I, I).λ == 7
@test muladd(A33, A33', I) == A33 * A33' + I

# https://github.com/JuliaLang/julia/issues/38426
@test @evalpoly(A33, 1.0*I, 1.0*I) == I + A33
@test @evalpoly(A33, 1.0*I, 1.0*I, 1.0*I) == I + A33 + A33^2
end

include("iterators.jl")
Expand Down

0 comments on commit e9421b9

Please sign in to comment.