Skip to content

Commit

Permalink
Type-stability fixes for matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Dec 9, 2015
1 parent 2e3ee26 commit 6db0a4f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
23 changes: 16 additions & 7 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ function copy!{R,S}(B::AbstractVecOrMat{R}, ir_dest::UnitRange{Int}, jr_dest::Un
Base.copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src)
tM == 'C' && conj!(B)
end
B
end

function copy_transpose!{R,S}(B::AbstractMatrix{R}, ir_dest::UnitRange{Int}, jr_dest::UnitRange{Int}, tM::Char, M::AbstractVecOrMat{S}, ir_src::UnitRange{Int}, jr_src::UnitRange{Int})
Expand Down Expand Up @@ -435,22 +436,30 @@ const Abuf = Array(UInt8, tilebufsize)
const Bbuf = Array(UInt8, tilebufsize)
const Cbuf = Array(UInt8, tilebufsize)

function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S})
function generic_matmatmul!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if mB != nA
throw(DimensionMismatch("matrix A has dimensions ($mA, $nB), matrix B has dimensions ($mB, $nB)"))
end
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch("result C has dimensions $(size(C)), needs ($mA, $nB)"))
end

if mA == nA == nB == 2
return matmul2x2!(C, tA, tB, A, B)
end
if mA == nA == nB == 3
return matmul3x3!(C, tA, tB, A, B)
end
_generic_matmatmul!(C, tA, tB, A, B)
end

generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) = _generic_matmatmul!(C, tA, tB, A, B)

function _generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if mB != nA
throw(DimensionMismatch("matrix A has dimensions ($mA, $nB), matrix B has dimensions ($mB, $nB)"))
end
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch("result C has dimensions $(size(C)), needs ($mA, $nB)"))
end

tile_size = 0
if isbits(R) && isbits(T) && isbits(S)
Expand Down
3 changes: 3 additions & 0 deletions test/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ A = rand(1:20, 5, 5) .- 10
B = rand(1:20, 5, 5) .- 10
@test At_mul_B(A, B) == A'*B
@test A_mul_Bt(A, B) == A*B'
v = [1,2]
C = Array(Int, 2, 2)
@test @inferred(A_mul_Bc!(C, v, v)) == [1 2; 2 4]

# Preallocated
C = Array(Int, size(A, 1), size(B, 2))
Expand Down

0 comments on commit 6db0a4f

Please sign in to comment.