Skip to content

Commit

Permalink
Seperate LinearAlgebra.axp(b)y! and BLAS.axp(b)y!. (JuliaLang#44758)
Browse files Browse the repository at this point in the history
* Seperate `LinearAlgebra.axpy!` and `BLAS.axpy!`
* Make more `dot` based on `BLAS`.
* Doc fix.
  • Loading branch information
N5N3 authored and pull[bot] committed Aug 20, 2022
1 parent d67b230 commit 4e74e26
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 18 deletions.
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ LinearAlgebra.ZeroPivotException
LinearAlgebra.dot
LinearAlgebra.dot(::Any, ::Any, ::Any)
LinearAlgebra.cross
LinearAlgebra.axpy!
LinearAlgebra.axpby!
LinearAlgebra.factorize
LinearAlgebra.Diagonal
LinearAlgebra.Bidiagonal
Expand Down Expand Up @@ -532,8 +534,8 @@ LinearAlgebra.BLAS.dotc
LinearAlgebra.BLAS.blascopy!
LinearAlgebra.BLAS.nrm2
LinearAlgebra.BLAS.asum
LinearAlgebra.axpy!
LinearAlgebra.axpby!
LinearAlgebra.BLAS.axpy!
LinearAlgebra.BLAS.axpby!
LinearAlgebra.BLAS.scal!
LinearAlgebra.BLAS.scal
LinearAlgebra.BLAS.iamax
Expand Down
17 changes: 6 additions & 11 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Interface to BLAS subroutines.
"""
module BLAS

import ..axpy!, ..axpby!
import Base: copyto!
using Base: require_one_based_indexing, USE_BLAS64

Expand Down Expand Up @@ -456,15 +455,13 @@ Overwrite `Y` with `X*a + Y`, where `a` is a scalar. Return `Y`.
# Examples
```jldoctest
julia> x = [1; 2; 3];
julia> x = [1.; 2; 3];
julia> y = [4; 5; 6];
julia> y = [4. ;; 5 ;; 6];
julia> BLAS.axpy!(2, x, y)
3-element Vector{Int64}:
6
9
12
1×3 Matrix{Float64}:
6.0 9.0 12.0
```
"""
function axpy! end
Expand All @@ -490,8 +487,7 @@ for (fname, elty) in ((:daxpy_,:Float64),
end
end

#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpy!` and `LinearAlgebra.axpy!`
function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where T<:BlasFloat
function axpy!(alpha::Number, x::AbstractArray{T}, y::AbstractArray{T}) where T<:BlasFloat
if length(x) != length(y)
throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))"))
end
Expand Down Expand Up @@ -563,8 +559,7 @@ for (fname, elty) in ((:daxpby_,:Float64), (:saxpby_,:Float32),
end
end

#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpby!` and `LinearAlgebra.axpby!`
function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::Number, y::Union{DenseArray{T},AbstractVector{T}},) where T<:BlasFloat
function axpby!(alpha::Number, x::AbstractArray{T}, beta::Number, y::AbstractArray{T}) where T<:BlasFloat
require_one_based_indexing(x, y)
if length(x) != length(y)
throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))"))
Expand Down
61 changes: 59 additions & 2 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1397,9 +1397,27 @@ true
isdiag(A::AbstractMatrix) = isbanded(A, 0, 0)
isdiag(x::Number) = true

"""
axpy!(α, x::AbstractArray, y::AbstractArray)
Overwrite `y` with `x * α + y` and return `y`.
If `x` and `y` have the same axes, it's equivalent with `y .+= x .* a`
See also [`BLAS.axpy!`](@ref)
# Examples
```jldoctest
julia> x = [1; 2; 3];
julia> y = [4; 5; 6];
# BLAS-like in-place y = x*α+y function (see also the version in blas.jl
# for BlasFloat Arrays)
julia> axpy!(2, x, y)
3-element Vector{Int64}:
6
9
12
```
"""
function axpy!(α, x::AbstractArray, y::AbstractArray)
n = length(x)
if n != length(y)
Expand All @@ -1425,6 +1443,27 @@ function axpy!(α, x::AbstractArray, rx::AbstractArray{<:Integer}, y::AbstractAr
y
end

"""
axpby!(α, x::AbstractArray, β, y::AbstractArray)
Overwrite `y` with `x * α + y * β` and return `y`.
If `x` and `y` have the same axes, it's equivalent with `y .= x .* a .+ y .* β`
See also [`BLAS.axpby!`](@ref)
# Examples
```jldoctest
julia> x = [1; 2; 3];
julia> y = [4; 5; 6];
julia> axpby!(2, x, 2, y)
3-element Vector{Int64}:
10
14
18
```
"""
function axpby!(α, x::AbstractArray, β, y::AbstractArray)
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
Expand All @@ -1435,6 +1474,24 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray)
y
end

DenseLike{T} = Union{DenseArray{T}, Base.StridedReshapedArray{T}, Base.StridedReinterpretArray{T}}
StridedVecLike{T} = Union{DenseLike{T}, Base.FastSubArray{T,<:Any,<:DenseLike{T}}}
axpy!::Number, x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasFloat} = BLAS.axpy!(α, x, y)
axpby!::Number, x::StridedVecLike{T}, β::Number, y::StridedVecLike{T}) where {T<:BlasFloat} = BLAS.axpby!(α, x, β, y)
function axpy!::Number,
x::StridedVecLike{T}, rx::AbstractRange{<:Integer},
y::StridedVecLike{T}, ry::AbstractRange{<:Integer},
) where {T<:BlasFloat}
if Base.has_offset_axes(rx, ry)
return Base.@invoke axpy!(α,
x::AbstractArray, rx::AbstractArray{<:Integer},
y::AbstractArray, ry::AbstractArray{<:Integer},
)
end
@views BLAS.axpy!(α, x[rx], y[ry])
return y
end

"""
rotate!(x, y, c, s)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ matprod(x, y) = x*y + x*y

# dot products

dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasReal} = BLAS.dot(x, y)
dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasComplex} = BLAS.dotc(x, y)
dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasReal} = BLAS.dot(x, y)
dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasComplex} = BLAS.dotc(x, y)

function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer}
if length(rx) != length(ry)
Expand Down
13 changes: 12 additions & 1 deletion stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,22 @@ end
@testset "LinearAlgebra.axp(b)y! for non strides input" begin
a = rand(5, 5)
@test LinearAlgebra.axpby!(1, Hermitian(a), 1, zeros(size(a))) == Hermitian(a)
@test_broken LinearAlgebra.axpby!(1, 1.:5, 1, zeros(5)) == 1.:5
@test LinearAlgebra.axpby!(1, 1.:5, 1, zeros(5)) == 1.:5
@test LinearAlgebra.axpy!(1, Hermitian(a), zeros(size(a))) == Hermitian(a)
@test LinearAlgebra.axpy!(1, 1.:5, zeros(5)) == 1.:5
end

@testset "LinearAlgebra.axp(b)y! for stride-vector like input" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
a = rand(T, 5, 5)
@test LinearAlgebra.axpby!(1, view(a, :, 1:5), 1, zeros(T, size(a))) == a
@test LinearAlgebra.axpy!(1, view(a, :, 1:5), zeros(T, size(a))) == a
b = view(a, 25:-2:1)
@test LinearAlgebra.axpby!(1, b, 1, zeros(T, size(b))) == b
@test LinearAlgebra.axpy!(1, b, zeros(T, size(b))) == b
end
end

@testset "norm and normalize!" begin
vr = [3.0, 4.0]
for Tr in (Float32, Float64)
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ end
end
end

@testset "dot product of stride-vector like input" begin
for T in (Float32, Float64, ComplexF32, ComplexF64)
a = randn(T, 10)
b = view(a, 1:10)
c = reshape(b, 5, 2)
d = view(c, :, 1:2)
r = sum(abs2, a)
for x in (a,b,c,d), y in (a,b,c,d)
@test dot(x, y) r
end
end
end

@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64)
A0 = randn(complex(T), 10, 10)
B0 = randn(T, 10, 10)
Expand Down

0 comments on commit 4e74e26

Please sign in to comment.