Skip to content

Commit

Permalink
Complete size checks in BLAS.[sy/he]mm! (#45605)
Browse files Browse the repository at this point in the history
(cherry picked from commit da13d78)
  • Loading branch information
dkarrasch authored and KristofferC committed Dec 21, 2022
1 parent 1de58c4 commit ef9ad67
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
53 changes: 42 additions & 11 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Interface to BLAS subroutines.
"""
Expand Down Expand Up @@ -1509,11 +1508,27 @@ for (mfname, elty) in ((:dsymm_,:Float64),
require_one_based_indexing(A, B, C)
m, n = size(C)
j = checksquare(A)
if j != (side == 'L' ? m : n)
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
end
if size(B,2) != n
throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
M, N = size(B)
if side == 'L'
if j != m
throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m"))
end
if N != n
throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n"))
end
if j != M
throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M"))
end
else
if j != n
throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n"))
end
if N != j
throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j"))
end
if M != m
throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m"))
end
end
chkstride1(A)
chkstride1(B)
Expand Down Expand Up @@ -1582,11 +1597,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64),
require_one_based_indexing(A, B, C)
m, n = size(C)
j = checksquare(A)
if j != (side == 'L' ? m : n)
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
end
if size(B,2) != n
throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n"))
M, N = size(B)
if side == 'L'
if j != m
throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m"))
end
if N != n
throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n"))
end
if j != M
throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M"))
end
else
if j != n
throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n"))
end
if N != j
throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j"))
end
if M != m
throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m"))
end
end
chkstride1(A)
chkstride1(B)
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,19 @@ Random.seed!(100)
@test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn)
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn)
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm)
@test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn)
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn)
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm)
@test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn)
if elty <: BlasComplex
@test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn)
@test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn)
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn)
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm)
@test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn)
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn)
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm)
@test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn)
end
end
end
Expand Down
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ end
C = zeros(eltya,n,n)
@test Hermitian(aherm) * a aherm * a
@test a * Hermitian(aherm) a * aherm
# rectangular multiplication
@test [a; a] * Hermitian(aherm) [a; a] * aherm
@test Hermitian(aherm) * [a a] aherm * [a a]
@test Hermitian(aherm) * Hermitian(aherm) aherm*aherm
@test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1)
LinearAlgebra.mul!(C,a,Hermitian(aherm))
Expand All @@ -348,6 +351,9 @@ end
@test Symmetric(asym) * Symmetric(asym) asym*asym
@test Symmetric(asym) * a asym * a
@test a * Symmetric(asym) a * asym
# rectangular multiplication
@test Symmetric(asym) * [a a] asym * [a a]
@test [a; a] * Symmetric(asym) [a; a] * asym
@test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1)
LinearAlgebra.mul!(C,a,Symmetric(asym))
@test C a*asym
Expand Down

0 comments on commit ef9ad67

Please sign in to comment.