Skip to content

Commit

Permalink
Fix JuliaLang#9504. Check second stride before calling BLAS.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Jan 5, 2015
1 parent 5233c05 commit 2101028
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
21 changes: 10 additions & 11 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ function copytri!(A::StridedMatrix, uplo::Char, conjugate::Bool=false)
end

function gemv!{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedVecOrMat{T}, x::StridedVector{T})
stride(A, 1)==1 || return generic_matvecmul!(y, tA, A, x)
mA, nA = lapack_size(tA, A)
nA==length(x) || throw(DimensionMismatch())
mA==length(y) || throw(DimensionMismatch())
nA == length(x) || throw(DimensionMismatch())
mA == length(y) || throw(DimensionMismatch())
mA == 0 && return y
nA == 0 && return fill!(y,0)
return BLAS.gemv!(tA, one(T), A, x, zero(T), y)
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) && return BLAS.gemv!(tA, one(T), A, x, zero(T), y)
return generic_matvecmul!(y, tA, A, x)
end

function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVecOrMat{T})
Expand All @@ -227,8 +227,8 @@ function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVe
if mA == 2 && nA == 2; return matmul2x2!(C,tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3!(C,tA,tAt,A,A); end

stride(A, 1) == 1 || (return generic_matmatmul!(C, tA, tAt, A, A))
copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U')
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) && return copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U')
return generic_matmatmul!(C, tA, tAt, A, A)
end

function herk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVecOrMat{T})
Expand All @@ -245,12 +245,11 @@ function herk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedVe
if mA == 2 && nA == 2; return matmul2x2!(C,tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3!(C,tA,tAt,A,A); end

stride(A, 1) == 1 || (return generic_matmatmul!(C,tA, tAt, A, A))

# Result array does not need to be initialized as long as beta==0
# C = Array(T, mA, mA)

copytri!(BLAS.herk!('U', tA, one(T), A, zero(T), C), 'U', true)
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) && return copytri!(BLAS.herk!('U', tA, one(T), A, zero(T), C), 'U', true)
return generic_matmatmul!(C,tA, tAt, A, A)
end

function gemm_wrapper{T<:BlasFloat}(tA::Char, tB::Char,
Expand All @@ -277,8 +276,8 @@ function gemm_wrapper!{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char,
if mA == 2 && nA == 2 && nB == 2; return matmul2x2!(C,tA,tB,A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3!(C,tA,tB,A,B); end

stride(A, 1)==stride(B, 1)==1 || (return generic_matmatmul!(C, tA, tB, A, B))
BLAS.gemm!(tA, tB, one(T), A, B, zero(T), C)
stride(A, 1) == stride(B, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && BLAS.gemm!(tA, tB, one(T), A, B, zero(T), C)
return generic_matmatmul!(C, tA, tB, A, B)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down
2 changes: 1 addition & 1 deletion test/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ d[5,1:2:4,8] = 19
AA = rand(4,2)
A = convert(SharedArray, AA)
B = convert(SharedArray, AA')
@test B*A == AA'*AA
@test B*A == ctranspose(AA)*AA

d=SharedArray(Int64, (10,10); init = D->fill!(D.loc_subarr_1d, myid()), pids=[id_me, id_other])
d2 = map(x->1, d)
Expand Down

0 comments on commit 2101028

Please sign in to comment.