Skip to content

Commit

Permalink
Diagonal inverse scaling (JuliaLang#22230)
Browse files Browse the repository at this point in the history
* Diagonal inverse scaling

As described in  JuliaStats/MixedModels.jl#85, I
use in-place `_rdiv_` and `_ldiv_` methods.  This would move the methods
for Base types into Base.

* Don't use commutativity

Use division instead multipliciation with the reciprocal
  • Loading branch information
dmbates authored and stevengj committed Jun 14, 2017
1 parent 17d5c55 commit 516617d
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 12 deletions.
43 changes: 33 additions & 10 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,36 +236,59 @@ At_mul_B!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix) = out .= transpo


(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)
function A_ldiv_B!(D::Diagonal{T}, v::AbstractVector{T}) where T
function A_ldiv_B!(D::Diagonal{T}, v::AbstractVector{T}) where {T}
if length(v) != length(D.diag)
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $(length(v)) rows"))
end
for i=1:length(D.diag)
for i = 1:length(D.diag)
d = D.diag[i]
if d == zero(T)
if iszero(d)
throw(SingularException(i))
end
v[i] *= inv(d)
v[i] = d\v[i]
end
v
end
function A_ldiv_B!(D::Diagonal{T}, V::AbstractMatrix{T}) where T
function A_ldiv_B!(D::Diagonal{T}, V::AbstractMatrix{T}) where {T}
if size(V,1) != length(D.diag)
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $(size(V,1)) rows"))
end
for i=1:length(D.diag)
for i = 1:length(D.diag)
d = D.diag[i]
if d == zero(T)
if iszero(d)
throw(SingularException(i))
end
d⁻¹ = inv(d)
for j=1:size(V,2)
@inbounds V[i,j] *= d⁻¹
for j = 1:size(V,2)
@inbounds V[i,j] = d\V[i,j]
end
end
V
end

Ac_ldiv_B!(D::Diagonal{T}, B::AbstractVecOrMat{T}) where {T} = A_ldiv_B!(conj(D), B)
At_ldiv_B!(D::Diagonal{T}, B::AbstractVecOrMat{T}) where {T} = A_ldiv_B!(D, B)

function A_rdiv_B!(A::AbstractMatrix{T}, D::Diagonal{T}) where {T}
dd = D.diag
m, n = size(A)
if (k = length(dd)) n
throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
end
@inbounds for j in 1:n
ddj = dd[j]
if iszero(ddj)
throw(SingularException(j))
end
for i in 1:m
A[i, j] /= ddj
end
end
A
end

A_rdiv_Bc!(A::AbstractMatrix{T}, D::Diagonal{T}) where {T} = A_rdiv_B!(A, conj(D))
A_rdiv_Bt!(A::AbstractMatrix{T}, D::Diagonal{T}) where {T} = A_rdiv_B!(A, D)

# Methods to resolve ambiguities with `Diagonal`
@inline *(rowvec::RowVector, D::Diagonal) = transpose(D * transpose(rowvec))
@inline A_mul_Bt(D::Diagonal, rowvec::RowVector) = D*transpose(rowvec)
Expand Down
21 changes: 21 additions & 0 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,27 @@ A_ldiv_B!(U::UpperTriangular{T,<:SparseMatrixCSC{T}}, B::StridedVecOrMat) where
(\)(L::LowerTriangular{T,<:SparseMatrixCSC{T}}, B::SparseMatrixCSC) where {T} = A_ldiv_B!(L, Array(B))
(\)(U::UpperTriangular{T,<:SparseMatrixCSC{T}}, B::SparseMatrixCSC) where {T} = A_ldiv_B!(U, Array(B))

function A_rdiv_B!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T
dd = D.diag
if (k = length(dd)) A.n
throw(DimensionMismatch("size(A, 2)=$(A.n) should be size(D, 1)=$k"))
end
nonz = nonzeros(A)
@inbounds for j in 1:k
ddj = dd[j]
if iszero(ddj)
throw(SingularException(j))
end
for k in nzrange(A, j)
nonz[k] /= ddj
end
end
A
end

A_rdiv_Bc!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T = A_rdiv_B!(A, conj(D))
A_rdiv_Bt!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T = A_rdiv_B!(A, D)

## triu, tril

function triu(S::SparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Base: +, -, *, \, /, &, |, xor, ==
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!
import Base: A_mul_Bc, A_mul_Bt, Ac_mul_Bc, At_mul_Bt
import Base: At_ldiv_B, Ac_ldiv_B, A_ldiv_B!
import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!
import Base.LinAlg: At_ldiv_B!, Ac_ldiv_B!, A_rdiv_B!, A_rdiv_Bc!

import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
atan, atand, atanh, broadcast!, chol, conj!, cos, cosc, cosd, cosh, cospi, cot,
Expand Down
12 changes: 11 additions & 1 deletion test/linalg/diagonal.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using Base.Test
import Base.LinAlg: BlasFloat, BlasComplex, SingularException
import Base.LinAlg: BlasFloat, BlasComplex, SingularException, A_rdiv_B!, A_rdiv_Bt!,
A_rdiv_Bc!

n=12 #Size of matrix problem to test
srand(1)
Expand Down Expand Up @@ -82,7 +83,16 @@ srand(1)
@test D\v DM\v atol=2n^2*eps(relty)*(1+(elty<:Complex))
@test D\U DM\U atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test A_ldiv_B!(D,copy(v)) DM\v atol=2n^2*eps(relty)*(1+(elty<:Complex))
@test At_ldiv_B!(D,copy(v)) DM\v atol=2n^2*eps(relty)*(1+(elty<:Complex))
@test Ac_ldiv_B!(conj(D),copy(v)) DM\v atol=2n^2*eps(relty)*(1+(elty<:Complex))
@test A_ldiv_B!(D,copy(U)) DM\U atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test At_ldiv_B!(D,copy(U)) DM\U atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test Ac_ldiv_B!(conj(D),copy(U)) DM\U atol=2n^3*eps(relty)*(1+(elty<:Complex))
Uc = ctranspose(U)
target = scale!(Uc,inv.(D.diag))
@test A_rdiv_B!(Uc,D) target atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test A_rdiv_Bt!(Uc,D) target atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test A_rdiv_Bc!(Uc,conj(D)) target atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test A_ldiv_B!(D,eye(D)) D\eye(D) atol=2n^3*eps(relty)*(1+(elty<:Complex))
@test_throws DimensionMismatch A_ldiv_B!(D, ones(elty, n + 1))
@test_throws SingularException A_ldiv_B!(Diagonal(zeros(relty,n)),copy(v))
Expand Down
9 changes: 9 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ b = randn(7)
@test scale!(sC, 0.5, sA) == scale!(sC, sA, 0.5)
end

@testset "inverse scale!" begin
bi = inv.(b)
dAt = transpose(dA)
sAt = transpose(sA)
@test scale!(copy(dAt), bi) Base.LinAlg.A_rdiv_B!(copy(sAt), Diagonal(b))
@test scale!(copy(dAt), bi) Base.LinAlg.A_rdiv_Bt!(copy(sAt), Diagonal(b))
@test scale!(copy(dAt), conj(bi)) Base.LinAlg.A_rdiv_Bc!(copy(sAt), Diagonal(b))
end

@testset "copy!" begin
A = sprand(5, 5, 0.2)
B = sprand(5, 5, 0.2)
Expand Down

0 comments on commit 516617d

Please sign in to comment.