Skip to content

Commit

Permalink
Add a BLAS wrapper for axpby (JuliaLang#23291)
Browse files Browse the repository at this point in the history
* Wrap axpby as well

* Update method description for noncommutative operators *

* Use Ref over Ptr + line length

* Add generic axpby! with tests
  • Loading branch information
haampie authored and andreasnoack committed Aug 22, 2017
1 parent bc66f02 commit ce402ec
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 9 deletions.
49 changes: 49 additions & 0 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,55 @@ function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},Range{Ti}},
y
end

"""
axpby!(a, X, b, Y)
Overwrite `Y` with `X*a + Y*b`, where `a` and `b` are scalars. Return `Y`.
# Examples
```jldoctest
julia> x = [1., 2, 3];
julia> y = [4., 5, 6];
julia> Base.BLAS.axpby!(2., x, 3., y)
3-element Array{Float64,1}:
14.0
19.0
24.0
```
"""
function axpby! end

for (fname, elty) in ((:daxpby_,:Float64), (:saxpby_,:Float32),
(:zaxpby_,:Complex128), (:caxpby_,:Complex64))
@eval begin
# SUBROUTINE DAXPBY(N,DA,DX,INCX,DB,DY,INCY)
# DY <- DA*DX + DB*DY
#* .. Scalar Arguments ..
# DOUBLE PRECISION DA,DB
# INTEGER INCX,INCY,N
#* .. Array Arguments ..
# DOUBLE PRECISION DX(*),DY(*)
function axpby!(n::Integer, alpha::($elty), dx::Union{Ptr{$elty},
DenseArray{$elty}}, incx::Integer, beta::($elty),
dy::Union{Ptr{$elty}, DenseArray{$elty}}, incy::Integer)
ccall((@blasfunc($fname), libblas), Void, (Ref{BlasInt}, Ref{$elty}, Ptr{$elty},
Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BlasInt}),
n, alpha, dx, incx, beta, dy, incy)
dy
end
end
end

function axpby!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, beta::Number, y::Union{DenseArray{T},StridedVector{T}}) where T<:BlasFloat
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
axpby!(length(x), convert(T,alpha), pointer(x), stride(x, 1), convert(T,beta), pointer(y), stride(y, 1))
y
end

## iamax
for (fname, elty) in ((:idamax_,:Float64),
(:isamax_,:Float32),
Expand Down
10 changes: 10 additions & 0 deletions base/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,16 @@ function axpy!(α, x::AbstractArray, rx::AbstractArray{<:Integer}, y::AbstractAr
y
end

function axpby!(α, x::AbstractArray, β, y::AbstractArray)
if _length(x) != _length(y)
throw(DimensionMismatch("x has length $(_length(x)), but y has length $(_length(y))"))
end
for (IX, IY) in zip(eachindex(x), eachindex(y))
@inbounds y[IY] = x[IX]*α + y[IY]*β
end
y
end


# Elementary reflection similar to LAPACK. The reflector is not Hermitian but
# ensures that tridiagonalization of Hermitian matrices become real. See lawn72
Expand Down
1 change: 1 addition & 0 deletions base/linalg/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export

# Functions
axpy!,
axpby!,
bkfact,
bkfact!,
chol,
Expand Down
7 changes: 5 additions & 2 deletions test/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ srand(100)
@test BLAS.iamax(z) == indmax(map(x -> abs(real(x)) + abs(imag(x)), z))
end
end
@testset "axpy" begin
@testset "axp(b)y" begin
if elty <: Real
x1 = convert(Vector{elty}, randn(n))
x2 = convert(Vector{elty}, randn(n))
α = rand(elty)
@test BLAS.axpy!(α,copy(x1),copy(x2)) x2 + α*x1
β = rand(elty)
@test BLAS.axpy!(α,copy(x1),copy(x2)) α*x1 + x2
@test BLAS.axpby!(α,copy(x1),β,copy(x2)) α*x1 + β*x2
@test_throws DimensionMismatch BLAS.axpy!(α, copy(x1), rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.axpby!(α, copy(x1), β, rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.axpy!(α, copy(x1), 1:div(n,2), copy(x2), 1:n)
@test_throws ArgumentError BLAS.axpy!(α, copy(x1), 0:div(n,2), copy(x2), 1:(div(n, 2) + 1))
@test_throws ArgumentError BLAS.axpy!(α, copy(x1), 1:div(n,2), copy(x2), 0:(div(n, 2) - 1))
Expand Down
20 changes: 13 additions & 7 deletions test/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ end
@testset "generic axpy" begin
x = ['a','b','c','d','e']
y = ['a','b','c','d','e']
α = 'f'
α, β = 'f', 'g'
@test_throws DimensionMismatch Base.LinAlg.axpy!(α,x,['g'])
@test_throws DimensionMismatch Base.LinAlg.axpby!(α,x,β,['g'])
@test_throws BoundsError Base.LinAlg.axpy!(α,x,collect(-1:5),y,collect(1:7))
@test_throws BoundsError Base.LinAlg.axpy!(α,x,collect(1:7),y,collect(-1:5))
@test_throws BoundsError Base.LinAlg.axpy!(α,x,collect(1:7),y,collect(1:7))
Expand Down Expand Up @@ -276,12 +277,17 @@ end
@test norm(x, 3) cbrt(sqrt(125)+125)
end

@testset "LinAlg.axpy! for element type without commutative multiplication" begin
α = ones(Int, 2, 2)
x = fill([1 0; 1 1], 3)
y = fill(zeros(Int, 2, 2), 3)
@test LinAlg.axpy!(α, x, deepcopy(y)) == x .* Matrix{Int}[α]
@test LinAlg.axpy!(α, x, deepcopy(y)) != Matrix{Int}[α] .* x
@testset "LinAlg.axp(b)y! for element type without commutative multiplication" begin
α = [1 2; 3 4]
β = [5 6; 7 8]
x = fill([ 9 10; 11 12], 3)
y = fill([13 14; 15 16], 3)
axpy = LinAlg.axpy!(α, x, deepcopy(y))
axpby = LinAlg.axpby!(α, x, β, deepcopy(y))
@test axpy == x .* [α] .+ y
@test axpy != [α] .* x .+ y
@test axpby == x .* [α] .+ y .* [β]
@test axpby != [α] .* x .+ [β] .* y
end

@testset "LinAlg.axpy! for x and y of different dimensions" begin
Expand Down

0 comments on commit ce402ec

Please sign in to comment.