Skip to content

Commit

Permalink
Pass arrays instead of pointers when possible in blas.jl (JuliaLang#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Feb 19, 2020
1 parent b0d1c1a commit 663ab4a
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,23 +333,23 @@ function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T}
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
GC.@preserve DX DY dot(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
return dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
GC.@preserve DX DY dotc(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
GC.@preserve DX DY dotu(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
end

## nrm2
Expand Down Expand Up @@ -383,7 +383,7 @@ for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
end
end
end
nrm2(x::Union{AbstractVector,DenseArray}) = GC.@preserve x nrm2(length(x), pointer(x), stride1(x))
nrm2(x::Union{AbstractVector,DenseArray}) = nrm2(length(x), x, stride1(x))

## asum

Expand Down Expand Up @@ -416,7 +416,7 @@ for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
end
end
end
asum(x::Union{AbstractVector,DenseArray}) = GC.@preserve x asum(length(x), pointer(x), stride1(x))
asum(x::Union{AbstractVector,DenseArray}) = asum(length(x), x, stride1(x))

## axpy

Expand Down Expand Up @@ -464,8 +464,7 @@ function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
GC.@preserve x y axpy!(length(x), convert(T,alpha), pointer(x), stride(x, 1), pointer(y), stride(y, 1))
y
return axpy!(length(x), convert(T,alpha), x, stride(x, 1), y, stride(y, 1))
end

function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange{Ti}},
Expand All @@ -479,8 +478,15 @@ function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange
if minimum(ry) < 1 || maximum(ry) > length(y)
throw(ArgumentError("range out of bounds for y, of length $(length(y))"))
end
GC.@preserve x y axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
y
GC.@preserve x y axpy!(
length(rx),
convert(T, alpha),
pointer(x) + (first(rx) - 1)*sizeof(T),
step(rx),
pointer(y) + (first(ry) - 1)*sizeof(T),
step(ry))

return y
end

"""
Expand Down Expand Up @@ -529,8 +535,7 @@ function axpby!(alpha::Number, x::Union{DenseArray{T},AbstractVector{T}}, beta::
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
GC.@preserve x y axpby!(length(x), convert(T,alpha), pointer(x), stride(x, 1), convert(T,beta), pointer(y), stride(y, 1))
y
return axpby!(length(x), convert(T, alpha), x, stride(x, 1), convert(T, beta), y, stride(y, 1))
end

## iamax
Expand All @@ -546,7 +551,7 @@ for (fname, elty) in ((:idamax_,:Float64),
end
end
end
iamax(dx::Union{AbstractVector,DenseArray}) = GC.@preserve dx iamax(length(dx), pointer(dx), stride1(dx))
iamax(dx::Union{AbstractVector,DenseArray}) = iamax(length(dx), dx, stride1(dx))

"""
iamax(n, dx, incx)
Expand Down Expand Up @@ -837,7 +842,7 @@ for (fname, elty) in ((:zhpmv_, :ComplexF64),
# * .. Array Arguments ..
# DOUBLE PRECISION A(N,N),X(N),Y(N)
function hpmv!(uplo::AbstractChar,
n::BlasInt,
n::Integer,
α::$elty,
AP::Union{Ptr{$elty}, AbstractArray{$elty}},
x::Union{Ptr{$elty}, AbstractArray{$elty}},
Expand Down Expand Up @@ -880,8 +885,7 @@ function hpmv!(uplo::AbstractChar,
if length(AP) < Int64(N*(N+1)/2)
throw(DimensionMismatch("Packed Hermitian matrix A has size smaller than length(x) = $(N)."))
end
GC.@preserve x y AP hpmv!(uplo, BlasInt(N), convert(T, α), AP, pointer(x), BlasInt(stride(x, 1)), convert(T, β), pointer(y), BlasInt(stride(y, 1)))
y
return hpmv!(uplo, N, convert(T, α), AP, x, stride(x, 1), convert(T, β), y, stride(y, 1))
end

"""
Expand Down Expand Up @@ -1800,10 +1804,12 @@ function copyto!(dest::Array{T}, rdest::Union{UnitRange{Ti},AbstractRange{Ti}},
if length(rdest) != length(rsrc)
throw(DimensionMismatch("ranges must be of the same length"))
end
GC.@preserve src dest BLAS.blascopy!(length(rsrc),
pointer(src) + (first(rsrc) - 1) * sizeof(T),
step(rsrc),
pointer(dest) + (first(rdest) - 1) * sizeof(T),
step(rdest))
dest
GC.@preserve src dest BLAS.blascopy!(
length(rsrc),
pointer(src) + (first(rsrc) - 1) * sizeof(T),
step(rsrc),
pointer(dest) + (first(rdest) - 1) * sizeof(T),
step(rdest))

return dest
end

0 comments on commit 663ab4a

Please sign in to comment.