Skip to content

Commit

Permalink
Merge pull request #11131 from kshyatt/symmmult
Browse files Browse the repository at this point in the history
Fix #11127 Hermitian matrix multiplication
  • Loading branch information
andreasnoack committed May 5, 2015
2 parents 06965c4 + 740db66 commit c4ab817
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
35 changes: 35 additions & 0 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ export
gemm,
symm!,
symm,
hemm!,
hemm,
syrk!,
syrk,
syr2k!,
Expand Down Expand Up @@ -614,6 +616,39 @@ for (mfname, elty) in ((:dsymm_,:Float64),
end
end

## (HE) Hermitian matrix-matrix and matrix-vector multiplication
for (mfname, elty) in ((:zhemm_,:Complex128),
(:chemm_,:Complex64))
@eval begin
# SUBROUTINE DHEMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
# .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER LDA,LDB,LDC,M,N
# CHARACTER SIDE,UPLO
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function hemm!(side::Char, uplo::Char, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
j = chksquare(A)
if j != (side == 'L' ? m : n) || size(B,2) != n throw(DimensionMismatch()) end
ccall(($(blasfunc(mfname)), libblas), Void,
(Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&side, &uplo, &m, &n,
&alpha, A, &max(1,stride(A,2)), B,
&max(1,stride(B,2)), &beta, C, &max(1,stride(C,2)))
C
end
function hemm(side::Char, uplo::Char, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
hemm!(side, uplo, alpha, A, B, zero($elty), similar(B))
end
function hemm(side::Char, uplo::Char, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
hemm(side, uplo, one($elty), A, B)
end
end
end

## syrk
for (fname, elty) in ((:dsyrk_,:Float64),
(:ssyrk_,:Float32),
Expand Down
6 changes: 4 additions & 2 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ ctranspose(A::Hermitian) = A
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(y::StridedVector{T}, A::Symmetric{T,S}, x::StridedVector{T}) = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(y::StridedVector{T}, A::Hermitian{T,S}, x::StridedVector{T}) = BLAS.hemv!(A.uplo, one(T), A.data, x, zero(T), y)
##Matmat
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(C::StridedMatrix{T}, A::Symmetric{T,S}, B::StridedMatrix{T}) = BLAS.symm!(A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(y::StridedMatrix{T}, A::Hermitian{T,S}, x::StridedMatrix{T}) = BLAS.hemm!(A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(C::StridedMatrix{T}, A::Symmetric{T,S}, B::StridedMatrix{T}) = BLAS.symm!('L', A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Symmetric{T,S}) = BLAS.symm!('R', B.uplo, one(T), B.data, A, zero(T), C)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(C::StridedMatrix{T}, A::Hermitian{T,S}, B::StridedMatrix{T}) = BLAS.hemm!('L', A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,S}) = BLAS.hemm!('R', B.uplo, one(T), B.data, A, zero(T), C)

*(A::HermOrSym, B::HermOrSym) = full(A)*full(B)
*(A::StridedMatrix, B::HermOrSym) = A*full(B)
Expand Down
4 changes: 4 additions & 0 deletions test/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ let n=10

# mat * mat
if eltya <: Complex
@test_approx_eq Hermitian(asym) * a asym * a
@test_approx_eq a * Hermitian(asym) a * asym
@test_approx_eq Hermitian(asym) * Hermitian(asym) asym*asym
end
if eltya <: Real && eltya != Int
@test_approx_eq Symmetric(asym) * Symmetric(asym) asym*asym
@test_approx_eq Symmetric(asym) * a asym * a
@test_approx_eq a * Symmetric(asym) a * asym
end

# solver
Expand Down

0 comments on commit c4ab817

Please sign in to comment.