From 3c4198d31e09eb6778a8df42758a78b12dbf1d6e Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Mon, 29 Aug 2022 13:56:25 +0200 Subject: [PATCH] Clean up juliaBLAS.jl and test it --- src/juliaBLAS.jl | 128 ++++++++++------------------------------------ test/juliaBLAS.jl | 29 +++++++++++ test/runtests.jl | 1 + 3 files changed, 58 insertions(+), 100 deletions(-) create mode 100644 test/juliaBLAS.jl diff --git a/src/juliaBLAS.jl b/src/juliaBLAS.jl index 8c04bef..2ab734d 100644 --- a/src/juliaBLAS.jl +++ b/src/juliaBLAS.jl @@ -24,12 +24,16 @@ function rankUpdate!(A::StridedMatrix, x::StridedVector, y::StridedVector, α::N end end -# Deprecated 11 October 2018 -Base.@deprecate rankUpdate!(α::Number, x::StridedVector, y::StridedVector, A::StridedMatrix) rankUpdate!(A, x, y, α) - ## Hermitian -rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syr!(A.uplo, α, a, A.data) -rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasReal,S<:StridedMatrix} = rankUpdate!(one(T), a, A) +function rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}, α::T) where {T<:BlasReal,S<:StridedMatrix} + BLAS.syr!(A.uplo, α, a, A.data) + return A +end +function rankUpdate!(A::Hermitian{T,S}, a::StridedVector{Complex{T}}, α::T) where {T<:BlasReal,S<:StridedMatrix} + BLAS.syr!(A.uplo, α, a, A.data) + return A +end +rankUpdate!(A::HermOrSym{T,S}, a::StridedVector{T}) where {T<:BlasFloat,S<:StridedMatrix} = rankUpdate!(A, a, one(real(T))) ### Generic function rankUpdate!(A::Hermitian, a::StridedVector, α::Real) @@ -44,15 +48,18 @@ function rankUpdate!(A::Hermitian, a::StridedVector, α::Real) return A end -# Deprecated 11 October 2018 -Base.@deprecate rankUpdate!(α::Real, a::StridedVector, A::Hermitian) rankUpdate!(A, a, α) - # Rank k update ## Real -rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syrk!(C.uplo, 'N', α, A, β, C.data) +function rankUpdate!(C::HermOrSym{T,S}, A::StridedMatrix{T}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} + BLAS.syrk!(C.uplo, 'N', α, A, β, C.data) + return C +end ## Complex -rankUpdate!(C::Hermitian{T,S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} = BLAS.herk!(C.uplo, 'N', α, A, β, C.data) +function rankUpdate!(C::Hermitian{Complex{T},S}, A::StridedMatrix{Complex{T}}, α::T, β::T) where {T<:BlasReal,S<:StridedMatrix} + BLAS.herk!(C.uplo, 'N', α, A, β, C.data) + return C +end ### Generic function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real) @@ -80,93 +87,14 @@ function rankUpdate!(C::Hermitian, A::StridedVecOrMat, α::Real) return C end -# Deprecated 11 October 2018 -Base.@deprecate rankUpdate!(α::Real, A::StridedVecOrMat, C::Hermitian) rankUpdate!(C, A, α) -Base.@deprecate rankUpdate!(α::Real, A::StridedVecOrMat, β::Real, C::Hermitian) rankUpdate!(C, A, α, β) - -if VERSION < v"1.3.0-alpha.115" # Project.toml has julia = "1.6" so this block should no longer be necessary -# BLAS style mul! -## gemv -mul!(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('N', α, A, x, β, y) -mul!(y::StridedVector{T}, A::Adjoint{T,<:StridedMatrix{T}}, x::StridedVector{T}, α::T, β::T) where {T<:BlasFloat} = gemv!('C', α, parent(adjA), x, β, y) - -## gemm -mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('N', 'N', α, A, B, β, C) -mul!(C::StridedMatrix{T}, adjA::Adjoint{T,<:StridedMatrix{T}}, B::StridedMatrix{T}, α::T, β::T) where {T<:BlasFloat} = BLAS.gemm!('C', 'N', α, parent(adjA), B, β, C) -# Not optimized since it is a generic fallback. Can probably soon be removed when the signatures in base have been updated. -function mul!(C::StridedVecOrMat, - A::StridedMatrix, - B::StridedVecOrMat, - α::Number, - β::Number) - - m, n = size(C, 1), size(C, 2) - k = size(A, 2) - - if β != 1 - if β == 0 - fill!(C, 0) - else - rmul!(C, β) - end - end - for j = 1:n - for i = 1:m - for l = 1:k - C[i,j] += α*A[i,l]*B[l,j] - end - end - end - return C -end -function mul!(C::StridedVecOrMat, - adjA::Adjoint{<:Number,<:StridedMatrix}, - B::StridedVecOrMat, - α::Number, - β::Number) - - A = parent(adjA) - m, n = size(C, 1), size(C, 2) - k = size(A, 1) - - if β != 1 - if β == 0 - fill!(C, 0) - else - rmul!(C, β) - end - end - for j = 1:n - for i = 1:m - for l = 1:k - C[i,j] += α*A[l,i]'*B[l,j] - end - end - end - return C -end - -## trmm like -### BLAS versions -mul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'N', α, A.data, B) -mul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'N', α, A.data, B) -mul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'N', 'U', α, A.data, B) -mul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'N', 'U', α, A.data, B) -mul!(A::Adjoint{T,UpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'N', α, parent(A).data, B) -mul!(A::Adjoint{T,LowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'N', α, parent(A).data, B) -mul!(A::Adjoint{T,UnitUpperTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'U', 'C', 'U', α, parent(A).data, B) -mul!(A::Adjoint{T,UnitLowerTriangular{T,S}}, B::StridedMatrix{T}, α::T) where {T<:BlasFloat,S} = trmm!('L', 'L', 'C', 'U', α, parent(A).data, B) - -end # VERSION - ### Generic fallbacks function lmul!(A::UpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S} AA = A.data m, n = size(B) - for i = 1:m - for j = 1:n + for i ∈ 1:m + for j ∈ 1:n B[i,j] = α*AA[i,i]*B[i,j] - for l = i + 1:m + for l ∈ (i + 1):m B[i,j] += α*AA[i,l]*B[l,j] end end @@ -176,10 +104,10 @@ end function lmul!(A::LowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S} AA = A.data m, n = size(B) - for i = m:-1:1 - for j = 1:n + for i ∈ m:-1:1 + for j ∈ 1:n B[i,j] = α*AA[i,i]*B[i,j] - for l = 1:i - 1 + for l ∈ 1:(i - 1) B[i,j] += α*AA[i,l]*B[l,j] end end @@ -189,11 +117,11 @@ end function lmul!(A::UnitUpperTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T<:Number,S} AA = A.data m, n = size(B) - for i = 1:m - for j = 1:n + for i ∈ 1:m + for j ∈ 1:n B[i,j] = α*B[i,j] - for l = i + 1:m - B[i,j] = α*AA[i,l]*B[l,j] + for l ∈ (i + 1):m + B[i,j] += α*AA[i,l]*B[l,j] end end end @@ -205,7 +133,7 @@ function lmul!(A::UnitLowerTriangular{T,S}, B::StridedMatrix{T}, α::T) where {T for i = m:-1:1 for j = 1:n B[i,j] = α*B[i,j] - for l = 1:i - 1 + for l = 1:(i - 1) B[i,j] += α*AA[i,l]*B[l,j] end end diff --git a/test/juliaBLAS.jl b/test/juliaBLAS.jl new file mode 100644 index 0000000..8ee4157 --- /dev/null +++ b/test/juliaBLAS.jl @@ -0,0 +1,29 @@ +using Test, GenericLinearAlgebra, LinearAlgebra + +@testset "rankUpdate!" begin + A, B, x = (Hermitian(randn(5, 5)), randn(5, 2), randn(5)) + Ac, Bc, xc = ( + Hermitian(complex.(randn(5, 5), randn(5, 5))), + complex.(randn(5, 2), randn(5, 2)), + complex.(randn(5), randn(5)), + ) + @test rankUpdate!(copy(A), x) ≈ A .+ x.*x' + @test rankUpdate!(copy(Ac), xc) ≈ Ac .+ xc.*xc' + + @test rankUpdate!(copy(A), B, 0.5, 0.5) ≈ 0.5*A + 0.5*B*B' + @test rankUpdate!(copy(Ac), Bc, 0.5, 0.5) ≈ 0.5*Ac + 0.5*Bc*Bc' + + @test invoke(rankUpdate!, Tuple{Hermitian,StridedVecOrMat,Real}, copy(Ac), Bc, 1.0) ≈ + rankUpdate!(copy(Ac), Bc, 1.0, 1.0) +end + +@testset "triangular multiplication: $(typeof(T))" for T ∈ ( + UpperTriangular(complex.(randn(5, 5), randn(5, 5))), + UnitUpperTriangular(complex.(randn(5, 5), randn(5, 5))), + LowerTriangular(complex.(randn(5, 5), randn(5, 5))), + UnitLowerTriangular(complex.(randn(5, 5), randn(5, 5))), +) + B = complex.(randn(5, 5), randn(5, 5)) + @test lmul!(T, copy(B), complex(0.5, 0.5)) ≈ T*B*complex(0.5, 0.5) + @test lmul!(T', copy(B), complex(0.5, 0.5)) ≈ T'*B*complex(0.5, 0.5) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3f10df1..a81e88a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test # @testset "The LinearAlgebra Test Suite" begin + include("juliaBLAS.jl") include("cholesky.jl") include("qr.jl") include("eigenselfadjoint.jl")