Skip to content

Commit

Permalink
Derive mul! from addmul! using macro
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Nov 18, 2018
1 parent 3b72e3b commit d43d34d
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 67 deletions.
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export
lu!,
lyap,
mul!,
addmul!,
lmul!,
rmul!,
norm,
Expand Down
180 changes: 113 additions & 67 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,50 @@ Short-circuiting multiplication. `x` is returned as-is when `alpha` is 1.
"""
mul1(alpha, x) = isone(alpha) ? x : alpha * x

"""
@defaddmul!(function addmul!(C::TC, A::TA, B::TB, α, β) ... end)
Define `addmul!` and a 3-argument shortcut `mul!(C, A, B)`.
"""
macro defaddmul!(expr)
if expr.head == :block
args = filter(x -> x isa Expr, expr.args)
if length(args) != 1
error("`@defaddmul!` only supports single function. Got: ", expr)
end
expr, = args
end
if expr.head (:function, :(=))
error("Not a function definition: ", expr)
end
function getsigcall(addmul)
if addmul.head == :where
sig, call = getsigcall(addmul.args[1])
return Expr(:where, sig, addmul.args[2]), call
elseif addmul.head == :call
if addmul.args[1] != :addmul!
error("Expected function name `addmul!`. Got: ", addmul.args[1])
end
if length(addmul.args) != 6
error("Expected 5 arguments (C, A, B, α, β). Got: ", addmul.args[2:end])
end
sig = Expr(:call, :mul!, addmul.args[2:4]...)
call = Expr(:call, :addmul!,
[(a isa Expr && a.head == :(::) ? a.args[1] : a)
for a in addmul.args[2:4]]...,
true, false)
return (sig, call)
else
error("Expected `where` or `call` expression. Got: ", addmul)
end
end
sig, call = getsigcall(expr.args[1])
esc(quote
$expr
$(Expr(:function, sig, call))
end)
end

matprod(x, y) = x*y + x*y

# dot products
Expand Down Expand Up @@ -87,23 +131,23 @@ function *(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix})
end
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B

mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat} =
@defaddmul! addmul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat} =
gemv!(y, 'N', A, x, alpha, beta)
# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32,Float64)
@eval begin
function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Union{$elty, Bool} = true, beta::Union{$elty, Bool} = false)
@eval @defaddmul! begin
function addmul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Union{$elty, Bool}, beta::Union{$elty, Bool})
Afl = reinterpret($elty,A)
yfl = reinterpret($elty,y)
mul!(yfl, Afl, x, alpha, beta)
addmul!(yfl, Afl, x, alpha, beta)
return y
end
end
end
mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number = true, beta::Number = false) =
@defaddmul! addmul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number = true, beta::Number = false) =
generic_matvecmul!(y, 'N', A, x, alpha, beta)

function *(transA::Transpose{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S}
Expand All @@ -116,13 +160,13 @@ function *(transA::Transpose{<:Any,<:AbstractMatrix{T}}, x::AbstractVector{S}) w
TS = promote_op(matprod, T, S)
mul!(similar(x,TS,size(A,2)), transpose(A), x)
end
function mul!(y::StridedVector{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
@defaddmul! function addmul!(y::StridedVector{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat}
A = transA.parent
return gemv!(y, 'T', A, x, alpha, beta)
end
function mul!(y::AbstractVector, transA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(y::AbstractVector, transA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number = true, beta::Number = false)
A = transA.parent
return generic_matvecmul!(y, 'T', A, x, alpha, beta)
end
Expand All @@ -138,18 +182,18 @@ function *(adjA::Adjoint{<:Any,<:AbstractMatrix{T}}, x::AbstractVector{S}) where
mul!(similar(x,TS,size(A,2)), adjoint(A), x)
end

function mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasReal}
@defaddmul! function addmul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasReal}
A = adjA.parent
return mul!(y, transpose(A), x, alpha, beta)
return addmul!(y, transpose(A), x, alpha, beta)
end
function mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasComplex}
@defaddmul! function addmul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasComplex}
A = adjA.parent
return gemv!(y, 'C', A, x, alpha, beta)
end
function mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector,
alpha::Number = true, beta::Number = false)
A = adjA.parent
return generic_matvecmul!(y, 'C', A, x, alpha, beta)
end
Expand Down Expand Up @@ -177,18 +221,18 @@ function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), A, B)
end
mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat} =
@defaddmul! addmul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat} =
gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta)
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
for elty in (Float32,Float64)
@eval begin
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
alpha::Union{$elty, Bool} = true, beta::Union{$elty, Bool} = false)
@eval @defaddmul! begin
function addmul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
alpha::Union{$elty, Bool}, beta::Union{$elty, Bool})
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, B, alpha, beta)
addmul!(Cfl, Afl, B, alpha, beta)
return C
end
end
Expand All @@ -211,8 +255,10 @@ julia> Y
7.0 7.0
```
"""
mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false) =
mul!

@defaddmul! addmul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false) =
generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta)

"""
Expand Down Expand Up @@ -257,23 +303,23 @@ julia> B
"""
lmul!(A, B)

function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
@defaddmul! function addmul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat}
A = transA.parent
if A===B
return syrk_wrapper!(C, 'T', A, alpha, beta)
else
return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta)
end
end
function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false)
A = transA.parent
return generic_matmatmul!(C, 'T', 'N', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
@defaddmul! function addmul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat}
B = transB.parent
if A===B
return syrk_wrapper!(C, 'N', A, alpha, beta)
Expand All @@ -283,104 +329,104 @@ function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, transB::Transpose{<:An
end
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
for elty in (Float32,Float64)
@eval begin
function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
alpha::Union{$elty, Bool} = true, beta::Union{$elty, Bool} = false)
@eval @defaddmul! begin
function addmul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, transB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
alpha::Union{$elty, Bool}, beta::Union{$elty, Bool})
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
mul!(Cfl, Afl, transB, alpha, beta)
addmul!(Cfl, Afl, transB, alpha, beta)
return C
end
end
end
# collapsing the following two defs with C::AbstractVecOrMat yields ambiguities
mul!(C::AbstractVector, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false) =
@defaddmul! addmul!(C::AbstractVector, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false) =
generic_matmatmul!(C, 'N', 'T', A, transB.parent, alpha, beta)
mul!(C::AbstractMatrix, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat},
@defaddmul! addmul!(C::AbstractMatrix, A::AbstractVecOrMat, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false) =
generic_matmatmul!(C, 'N', 'T', A, transB.parent, alpha, beta)

function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Transpose{<:Any,<:StridedVecOrMat{T}},
@defaddmul! function addmul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number = true, beta::Number = false) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'T', A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
A = transA.parent
B = transB.parent
return generic_matmatmul!(C, 'T', 'T', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
@defaddmul! function addmul!(C::StridedMatrix{T}, transA::Transpose{<:Any,<:StridedVecOrMat{T}}, transB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat}
A = transA.parent
B = transB.parent
return gemm_wrapper!(C, 'T', 'C', A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, transA::Transpose{<:Any,<:AbstractVecOrMat}, transB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
A = transA.parent
B = transB.parent
return generic_matmatmul!(C, 'T', 'C', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasReal}
@defaddmul! function addmul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasReal}
A = adjA.parent
return mul!(C, transpose(A), B, alpha, beta)
return addmul!(C, transpose(A), B, alpha, beta)
end
function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasComplex}
@defaddmul! function addmul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasComplex}
A = adjA.parent
if A===B
return herk_wrapper!(C, 'C', A, alpha, beta)
else
return gemm_wrapper!(C, 'C', 'N', A, B, alpha, beta)
end
end
function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat,
alpha::Number = true, beta::Number = false)
A = adjA.parent
return generic_matmatmul!(C, 'C', 'N', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{<:BlasReal}},
alpha::Number = true, beta::Number = false) where {T<:BlasFloat}
@defaddmul! function addmul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{<:BlasReal}},
alpha::Number = true, beta::Number = false) where {T<:BlasFloat}
B = adjB.parent
return mul!(C, A, transpose(B), alpha, beta)
return addmul!(C, A, transpose(B), alpha, beta)
end
function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasComplex}
@defaddmul! function addmul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasComplex}
B = adjB.parent
if A === B
return herk_wrapper!(C, 'N', A, alpha, beta)
else
return gemm_wrapper!(C, 'N', 'C', A, B, alpha, beta)
end
end
function mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
B = adjB.parent
return generic_matmatmul!(C, 'N', 'C', A, B, alpha, beta)
end

function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool} = true, beta::Union{T, Bool} = false) where {T<:BlasFloat}
@defaddmul! function addmul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat}
A = adjA.parent
B = adjB.parent
return gemm_wrapper!(C, 'C', 'C', A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
A = adjA.parent
B = adjB.parent
return generic_matmatmul!(C, 'C', 'C', A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
@defaddmul! function addmul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number = true, beta::Number = false)
A = adjA.parent
B = transB.parent
return generic_matmatmul!(C, 'C', 'T', A, B, alpha, beta)
Expand Down

0 comments on commit d43d34d

Please sign in to comment.