Skip to content

Commit

Permalink
Clean up juliaBLAS.jl and test it (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Aug 29, 2022
1 parent 887fbd3 commit fbbdb28
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 100 deletions.
128 changes: 28 additions & 100 deletions src/juliaBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ function rankUpdate!(A::StridedMatrix, x::StridedVector, y::StridedVector, α::N
end
end

# Deprecated 11 October 2018
Base.@deprecate rankUpdate!::Number, x::StridedVector, y::StridedVector, A::StridedMatrix) rankUpdate!(A, x, y, α)

## Hermitian
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syr!(A.uplo, α, a, A.data)
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasReal,S<:StridedMatrix} = rankUpdate!(one(T), a, A)
function rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix}
BLAS.syr!(A.uplo, α, a, A.data)
return A
end
function rankUpdate!(A::Hermitian{Complex{T},S}, a::StridedVector{Complex{T}}, α::T) where {T<:BlasReal,S<:StridedMatrix}
BLAS.her!(A.uplo, α, a, A.data)
return A
end
rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasFloat,S<:StridedMatrix} = rankUpdate!(A, a, one(real(T)))

### Generic
function rankUpdate!(A::Hermitian, a::StridedVector, α::Real)
Expand All @@ -44,15 +48,18 @@ function rankUpdate!(A::Hermitian, a::StridedVector, α::Real)
return A
end

# Deprecated 11 October 2018
Base.@deprecate rankUpdate!::Real, a::StridedVector, A::Hermitian) rankUpdate!(A, a, α)

# Rank k update
## Real
rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
function rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix}
BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
return C
end

## Complex
rankUpdate!(C::Hermitian{T,S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.herk!(C.uplo, 'N', α, A, β, C.data)
function rankUpdate!(C::Hermitian{Complex{T},S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix}
BLAS.herk!(C.uplo, 'N', α, A, β, C.data)
return C
end

### Generic
function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real)
Expand Down Expand Up @@ -80,93 +87,14 @@ function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real)
return C
end

# Deprecated 11 October 2018
Base.@deprecate rankUpdate!::Real, A::StridedVecOrMat, C::Hermitian) rankUpdate!(C, A, α)
Base.@deprecate rankUpdate!::Real, A::StridedVecOrMat, β::Real, C::Hermitian) rankUpdate!(C, A, α, β)

if VERSION < v"1.3.0-alpha.115" # Project.toml has julia = "1.6" so this block should no longer be necessary
# BLAS style mul!
## gemv
mul!(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('N', α, A, x, β, y)
mul!(y::StridedVector{T}, A::Adjoint{T,<:StridedMatrix{T}}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('C', α, parent(adjA), x, β, y)

## gemm
mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('N', 'N', α, A, B, β, C)
mul!(C::StridedMatrix{T}, adjA::Adjoint{T,<:StridedMatrix{T}}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('C', 'N', α, parent(adjA), B, β, C)
# Not optimized since it is a generic fallback. Can probably soon be removed when the signatures in base have been updated.
function mul!(C::StridedVecOrMat,
A::StridedMatrix,
B::StridedVecOrMat,
α::Number,
β::Number)

m, n = size(C, 1), size(C, 2)
k = size(A, 2)

if β != 1
if β == 0
fill!(C, 0)
else
rmul!(C, β)
end
end
for j = 1:n
for i = 1:m
for l = 1:k
C[i,j] += α*A[i,l]*B[l,j]
end
end
end
return C
end
function mul!(C::StridedVecOrMat,
adjA::Adjoint{<:Number,<:StridedMatrix},
B::StridedVecOrMat,
α::Number,
β::Number)

A = parent(adjA)
m, n = size(C, 1), size(C, 2)
k = size(A, 1)

if β != 1
if β == 0
fill!(C, 0)
else
rmul!(C, β)
end
end
for j = 1:n
for i = 1:m
for l = 1:k
C[i,j] += α*A[l,i]'*B[l,j]
end
end
end
return C
end

## trmm like
### BLAS versions
mul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'N', α, A.data, B)
mul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'N', α, A.data, B)
mul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'U', α, A.data, B)
mul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'U', α, A.data, B)
mul!(A::Adjoint{T,UpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'N', α, parent(A).data, B)
mul!(A::Adjoint{T,LowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'N', α, parent(A).data, B)
mul!(A::Adjoint{T,UnitUpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'U', α, parent(A).data, B)
mul!(A::Adjoint{T,UnitLowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'U', α, parent(A).data, B)

end # VERSION

### Generic fallbacks
function lmul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
AA = A.data
m, n = size(B)
for i = 1:m
for j = 1:n
for i 1:m
for j 1:n
B[i,j] = α*AA[i,i]*B[i,j]
for l = i + 1:m
for l (i + 1):m
B[i,j] += α*AA[i,l]*B[l,j]
end
end
Expand All @@ -176,10 +104,10 @@ end
function lmul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
AA = A.data
m, n = size(B)
for i = m:-1:1
for j = 1:n
for i m:-1:1
for j 1:n
B[i,j] = α*AA[i,i]*B[i,j]
for l = 1:i - 1
for l 1:(i - 1)
B[i,j] += α*AA[i,l]*B[l,j]
end
end
Expand All @@ -189,11 +117,11 @@ end
function lmul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S}
AA = A.data
m, n = size(B)
for i = 1:m
for j = 1:n
for i 1:m
for j 1:n
B[i,j] = α*B[i,j]
for l = i + 1:m
B[i,j] = α*AA[i,l]*B[l,j]
for l (i + 1):m
B[i,j] += α*AA[i,l]*B[l,j]
end
end
end
Expand All @@ -205,7 +133,7 @@ function lmul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T
for i = m:-1:1
for j = 1:n
B[i,j] = α*B[i,j]
for l = 1:i - 1
for l = 1:(i - 1)
B[i,j] += α*AA[i,l]*B[l,j]
end
end
Expand Down
29 changes: 29 additions & 0 deletions test/juliaBLAS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using Test, GenericLinearAlgebra, LinearAlgebra

@testset "rankUpdate!" begin
A, B, x = (Hermitian(randn(5, 5)), randn(5, 2), randn(5))
Ac, Bc, xc = (
Hermitian(complex.(randn(5, 5), randn(5, 5))),
complex.(randn(5, 2), randn(5, 2)),
complex.(randn(5), randn(5)),
)
@test rankUpdate!(copy(A), x) A .+ x.*x'
@test rankUpdate!(copy(Ac), xc) Ac .+ xc.*xc'

@test rankUpdate!(copy(A), B, 0.5, 0.5) 0.5*A + 0.5*B*B'
@test rankUpdate!(copy(Ac), Bc, 0.5, 0.5) 0.5*Ac + 0.5*Bc*Bc'

@test invoke(rankUpdate!, Tuple{Hermitian,StridedVecOrMat,Real}, copy(Ac), Bc, 1.0)
rankUpdate!(copy(Ac), Bc, 1.0, 1.0)
end

@testset "triangular multiplication: $(typeof(T))" for T (
UpperTriangular(complex.(randn(5, 5), randn(5, 5))),
UnitUpperTriangular(complex.(randn(5, 5), randn(5, 5))),
LowerTriangular(complex.(randn(5, 5), randn(5, 5))),
UnitLowerTriangular(complex.(randn(5, 5), randn(5, 5))),
)
B = complex.(randn(5, 5), randn(5, 5))
@test lmul!(T, copy(B), complex(0.5, 0.5)) T*B*complex(0.5, 0.5)
@test lmul!(T', copy(B), complex(0.5, 0.5)) T'*B*complex(0.5, 0.5)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test

# @testset "The LinearAlgebra Test Suite" begin
include("juliaBLAS.jl")
include("cholesky.jl")
include("qr.jl")
include("eigenselfadjoint.jl")
Expand Down

0 comments on commit fbbdb28

Please sign in to comment.