diff --git a/stdlib/LinearAlgebra/docs/src/index.md b/stdlib/LinearAlgebra/docs/src/index.md index 07cdded9eae28..88e700685a0d3 100644 --- a/stdlib/LinearAlgebra/docs/src/index.md +++ b/stdlib/LinearAlgebra/docs/src/index.md @@ -322,6 +322,8 @@ LinearAlgebra.ZeroPivotException LinearAlgebra.dot LinearAlgebra.dot(::Any, ::Any, ::Any) LinearAlgebra.cross +LinearAlgebra.axpy! +LinearAlgebra.axpby! LinearAlgebra.factorize LinearAlgebra.Diagonal LinearAlgebra.Bidiagonal @@ -532,8 +534,8 @@ LinearAlgebra.BLAS.dotc LinearAlgebra.BLAS.blascopy! LinearAlgebra.BLAS.nrm2 LinearAlgebra.BLAS.asum -LinearAlgebra.axpy! -LinearAlgebra.axpby! +LinearAlgebra.BLAS.axpy! +LinearAlgebra.BLAS.axpby! LinearAlgebra.BLAS.scal! LinearAlgebra.BLAS.scal LinearAlgebra.BLAS.iamax diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 66201249eab52..2710559e57d6b 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -5,7 +5,6 @@ Interface to BLAS subroutines. """ module BLAS -import ..axpy!, ..axpby! import Base: copyto! using Base: require_one_based_indexing, USE_BLAS64 @@ -456,15 +455,13 @@ Overwrite `Y` with `X*a + Y`, where `a` is a scalar. Return `Y`. # Examples ```jldoctest -julia> x = [1; 2; 3]; +julia> x = [1.; 2; 3]; -julia> y = [4; 5; 6]; +julia> y = [4. ;; 5 ;; 6]; julia> BLAS.axpy!(2, x, y) -3-element Vector{Int64}: - 6 - 9 - 12 +1×3 Matrix{Float64}: + 6.0 9.0 12.0 ``` """ function axpy! end @@ -490,8 +487,7 @@ for (fname, elty) in ((:daxpy_,:Float64), end end -#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpy!` and `LinearAlgebra.axpy!` -function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where T<:BlasFloat +function axpy!(alpha::Number, x::AbstractArray{T}, y::AbstractArray{T}) where T<:BlasFloat if length(x) != length(y) throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) end @@ -563,8 +559,7 @@ for (fname, elty) in ((:daxpby_,:Float64), (:saxpby_,:Float32), end end -#TODO: replace with `x::AbstractArray{T}` once we separate `BLAS.axpby!` and `LinearAlgebra.axpby!` -function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::Number, y::Union{DenseArray{T},AbstractVector{T}},) where T<:BlasFloat +function axpby!(alpha::Number, x::AbstractArray{T}, beta::Number, y::AbstractArray{T}) where T<:BlasFloat require_one_based_indexing(x, y) if length(x) != length(y) throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 8a610bb528e25..2449a78fda317 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -1397,9 +1397,27 @@ true isdiag(A::AbstractMatrix) = isbanded(A, 0, 0) isdiag(x::Number) = true +""" + axpy!(α, x::AbstractArray, y::AbstractArray) + +Overwrite `y` with `x * α + y` and return `y`. +If `x` and `y` have the same axes, it's equivalent with `y .+= x .* a` + +See also [`BLAS.axpy!`](@ref) + +# Examples +```jldoctest +julia> x = [1; 2; 3]; + +julia> y = [4; 5; 6]; -# BLAS-like in-place y = x*α+y function (see also the version in blas.jl -# for BlasFloat Arrays) +julia> axpy!(2, x, y) +3-element Vector{Int64}: + 6 + 9 + 12 +``` +""" function axpy!(α, x::AbstractArray, y::AbstractArray) n = length(x) if n != length(y) @@ -1425,6 +1443,27 @@ function axpy!(α, x::AbstractArray, rx::AbstractArray{<:Integer}, y::AbstractAr y end +""" + axpby!(α, x::AbstractArray, β, y::AbstractArray) + +Overwrite `y` with `x * α + y * β` and return `y`. +If `x` and `y` have the same axes, it's equivalent with `y .= x .* a .+ y .* β` + +See also [`BLAS.axpby!`](@ref) + +# Examples +```jldoctest +julia> x = [1; 2; 3]; + +julia> y = [4; 5; 6]; + +julia> axpby!(2, x, 2, y) +3-element Vector{Int64}: + 10 + 14 + 18 +``` +""" function axpby!(α, x::AbstractArray, β, y::AbstractArray) if length(x) != length(y) throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))")) @@ -1435,6 +1474,24 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray) y end +DenseLike{T} = Union{DenseArray{T}, Base.StridedReshapedArray{T}, Base.StridedReinterpretArray{T}} +StridedVecLike{T} = Union{DenseLike{T}, Base.FastSubArray{T,<:Any,<:DenseLike{T}}} +axpy!(α::Number, x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasFloat} = BLAS.axpy!(α, x, y) +axpby!(α::Number, x::StridedVecLike{T}, β::Number, y::StridedVecLike{T}) where {T<:BlasFloat} = BLAS.axpby!(α, x, β, y) +function axpy!(α::Number, + x::StridedVecLike{T}, rx::AbstractRange{<:Integer}, + y::StridedVecLike{T}, ry::AbstractRange{<:Integer}, +) where {T<:BlasFloat} + if Base.has_offset_axes(rx, ry) + return Base.@invoke axpy!(α, + x::AbstractArray, rx::AbstractArray{<:Integer}, + y::AbstractArray, ry::AbstractArray{<:Integer}, + ) + end + @views BLAS.axpy!(α, x[rx], y[ry]) + return y +end + """ rotate!(x, y, c, s) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 0ac3a8daef7fb..7646aae29d1b9 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -11,8 +11,8 @@ matprod(x, y) = x*y + x*y # dot products -dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasReal} = BLAS.dot(x, y) -dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasComplex} = BLAS.dotc(x, y) +dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasReal} = BLAS.dot(x, y) +dot(x::StridedVecLike{T}, y::StridedVecLike{T}) where {T<:BlasComplex} = BLAS.dotc(x, y) function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer} if length(rx) != length(ry) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index e2dcc30791900..359f1fea9085c 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -327,11 +327,22 @@ end @testset "LinearAlgebra.axp(b)y! for non strides input" begin a = rand(5, 5) @test LinearAlgebra.axpby!(1, Hermitian(a), 1, zeros(size(a))) == Hermitian(a) - @test_broken LinearAlgebra.axpby!(1, 1.:5, 1, zeros(5)) == 1.:5 + @test LinearAlgebra.axpby!(1, 1.:5, 1, zeros(5)) == 1.:5 @test LinearAlgebra.axpy!(1, Hermitian(a), zeros(size(a))) == Hermitian(a) @test LinearAlgebra.axpy!(1, 1.:5, zeros(5)) == 1.:5 end +@testset "LinearAlgebra.axp(b)y! for stride-vector like input" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + a = rand(T, 5, 5) + @test LinearAlgebra.axpby!(1, view(a, :, 1:5), 1, zeros(T, size(a))) == a + @test LinearAlgebra.axpy!(1, view(a, :, 1:5), zeros(T, size(a))) == a + b = view(a, 25:-2:1) + @test LinearAlgebra.axpby!(1, b, 1, zeros(T, size(b))) == b + @test LinearAlgebra.axpy!(1, b, zeros(T, size(b))) == b + end +end + @testset "norm and normalize!" begin vr = [3.0, 4.0] for Tr in (Float32, Float64) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index b65314d5abe43..cf0295ce552b5 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -226,6 +226,19 @@ end end end +@testset "dot product of stride-vector like input" begin + for T in (Float32, Float64, ComplexF32, ComplexF64) + a = randn(T, 10) + b = view(a, 1:10) + c = reshape(b, 5, 2) + d = view(c, :, 1:2) + r = sum(abs2, a) + for x in (a,b,c,d), y in (a,b,c,d) + @test dot(x, y) ≈ r + end + end +end + @testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64) A0 = randn(complex(T), 10, 10) B0 = randn(T, 10, 10)