Skip to content

Commit

Permalink
Merge pull request #9466 from JuliaLang/anj/notr2
Browse files Browse the repository at this point in the history
Fix transpose for arrays
  • Loading branch information
andreasnoack committed Dec 30, 2014
2 parents 31f46d2 + 2cdb63c commit 89b0aa6
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 32 deletions.
2 changes: 2 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ endof(a::AbstractArray) = length(a)
first(a::AbstractArray) = a[1]
first(a) = next(a,start(a))[1]
last(a) = a[end]
ctranspose(a::AbstractArray) = error("ctranspose not implemented for $(typeof(a)). Consider adding parentheses, e.g. A*(B*C') instead of A*B*C' to avoid explicit calculation of the transposed matrix.")
transpose(a::AbstractArray) = error("transpose not implemented for $(typeof(a)). Consider adding parentheses, e.g. A*(B*C.') instead of A*B*C' to avoid explicit calculation of the transposed matrix.")

function stride(a::AbstractArray, i::Integer)
if i > ndims(a)
Expand Down
2 changes: 2 additions & 0 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
abstract Factorization{T}

eltype{T}(F::Factorization{T}) = T
transpose(F::Factorization) = error("transpose not implemented for $(typeof(F))")
ctranspose(F::Factorization) = error("ctranspose not implemented for $(typeof(F))")

macro assertposdef(A, info)
:(($info)==0 ? $A : throw(PosDefException($info)))
Expand Down
38 changes: 28 additions & 10 deletions base/linalg/givens.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
immutable Givens{T}
abstract AbstractRotation{T}

transpose(R::AbstractRotation) = error("transpose not implemented for $(typeof(R)). Consider using conjugate transpose (') instead of transpose (.').")

function *{T,S}(R::AbstractRotation{T}, A::AbstractMatrix{S})
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
A_mul_B!(convert(AbstractRotation{TS}, R), TS == S ? copy(A) : convert(AbstractArray{TS}, A))
end
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, R::AbstractRotation{S})
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
A_mul_Bc!(TS == T ? copy(A) : convert(AbstractArray{TS}, A), convert(AbstractRotation{TS}, R))
end

immutable Givens{T} <: AbstractRotation{T}
i1::Int
i2::Int
c::T
s::T
r::T
end
type Rotation{T}
rotations::Vector{T}
type Rotation{T} <: AbstractRotation{T}
rotations::Vector{Givens{T}}
end
typealias AbstractRotation Union(Givens, Rotation)

convert{T}(::Type{Givens{T}}, G::Givens{T}) = G
convert{T}(::Type{Givens{T}}, G::Givens) = Givens(G.i1, G.i2, convert(T, G.c), convert(T, G.s))
convert{T}(::Type{Rotation{T}}, R::Rotation{T}) = R
convert{T}(::Type{Rotation{T}}, R::Rotation) = Rotation{T}([convert(Givens{T}, g) for g in R.rotations])
convert{T}(::Type{AbstractRotation{T}}, G::Givens) = convert(Givens{T}, G)
convert{T}(::Type{AbstractRotation{T}}, R::Rotation) = convert(Rotation{T}, R)

ctranspose(G::Givens) = Givens(G.i1, G.i2, conj(G.c), -G.s)
ctranspose{T}(R::Rotation{T}) = Rotation{T}(reverse!([ctranspose(r) for r in R.rotations]))

realmin2(::Type{Float32}) = reinterpret(Float32, 0x26000000)
realmin2(::Type{Float64}) = reinterpret(Float64, 0x21a0000000000000)
Expand Down Expand Up @@ -187,13 +208,13 @@ end
function givens{T}(f::T, g::T, i1::Integer, i2::Integer)
i1 < i2 || error("second index must be larger than the first")
c, s, r = givensAlgorithm(f, g)
Givens(i1, i2, convert(T, c), convert(T, s), convert(T, r))
Givens(i1, i2, convert(T, c), convert(T, s)), r
end

function givens{T}(A::AbstractMatrix{T}, i1::Integer, i2::Integer, col::Integer)
i1 < i2 || error("second index must be larger than the first")
c, s, r = givensAlgorithm(A[i1,col], A[i2,col])
Givens(i1, i2, convert(T, c), convert(T, s), convert(T, r))
Givens(i1, i2, convert(T, c), convert(T, s)), r
end

getindex(G::Givens, i::Integer, j::Integer) = i == j ? (i == G.i1 || i == G.i2 ? G.c : one(G.c)) : (i == G.i1 && j == G.i2 ? G.s : (i == G.i2 && j == G.i1 ? -G.s : zero(G.s)))
Expand Down Expand Up @@ -236,6 +257,3 @@ function A_mul_Bc!(A::AbstractMatrix, R::Rotation)
return A
end
*{T}(G1::Givens{T}, G2::Givens{T}) = Rotation(push!(push!(Givens{T}[], G2), G1))
*(R::AbstractRotation, A::AbstractMatrix) = A_mul_B!(R, copy(A))

A_mul_Bc(A::AbstractMatrix, R::AbstractRotation) = A_mul_Bc!(copy(A), R)
34 changes: 34 additions & 0 deletions test/linalg/givens.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
let
debug = false

# Test givens rotations
for elty in (Float32, Float64, Complex64, Complex128)

debug && println("elty is $elty")

if elty <: Real
A = convert(Matrix{elty}, randn(10,10))
else
A = convert(Matrix{elty}, complex(randn(10,10),randn(10,10)))
end
Ac = copy(A)
R = Base.LinAlg.Rotation(Base.LinAlg.Givens{elty}[])
for j = 1:8
for i = j+2:10
G, _ = givens(A, j+1, i, j)
A_mul_B!(G, A)
A_mul_Bc!(A, G)
A_mul_B!(G, R)

# test transposes
@test_approx_eq ctranspose(G)*G*eye(10) eye(elty, 10)
@test_approx_eq ctranspose(R)*(R*eye(10)) eye(elty, 10)
@test_throws ErrorException transpose(G)
@test_throws ErrorException transpose(R)
end
end
@test_approx_eq abs(A) abs(hessfact(Ac)[:H])
@test_approx_eq norm(R*eye(elty, 10)) one(elty)

end
end #let
21 changes: 0 additions & 21 deletions test/linalg2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,6 @@ for elty in (Complex64, Complex128)
@test_approx_eq triu(LinAlg.BLAS.her2k('U','C',U,V)) triu(U'*V + V'*U)
end

# Test givens rotations
for elty in (Float32, Float64, Complex64, Complex128)
if elty <: Real
A = convert(Matrix{elty}, randn(10,10))
else
A = convert(Matrix{elty}, complex(randn(10,10),randn(10,10)))
end
Ac = copy(A)
R = Base.LinAlg.Rotation(Base.LinAlg.Givens{elty}[])
for j = 1:8
for i = j+2:10
G = givens(A, j+1, i, j)
A_mul_B!(G, A)
A_mul_Bc!(A, G)
A_mul_B!(G, R)
end
end
@test_approx_eq abs(A) abs(hessfact(Ac)[:H])
@test_approx_eq norm(R*eye(elty, 10)) one(elty)
end

# Test gradient
for elty in (Int32, Int64, Float32, Float64, Complex64, Complex128)
if elty <: Real
Expand Down
10 changes: 10 additions & 0 deletions test/linalg4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
end
end

# Because transpose(x) == x
@test_throws ErrorException transpose(qrfact(randn(3,3)))
@test_throws ErrorException ctranspose(qrfact(randn(3,3)))
@test_throws ErrorException transpose(qrfact(randn(3,3), pivot = false))
@test_throws ErrorException ctranspose(qrfact(randn(3,3), pivot = false))
@test_throws ErrorException transpose(qrfact(big(randn(3,3))))
@test_throws ErrorException ctranspose(qrfact(big(randn(3,3))))
@test_throws ErrorException transpose(sub(sprandn(10, 10, 0.3), 1:4, 1:4))
@test_throws ErrorException ctranspose(sub(sprandn(10, 10, 0.3), 1:4, 1:4))

# Issue #7933
A7933 = [1 2; 3 4]
B7933 = copy(A7933)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tests = (ARGS==["all"] || isempty(ARGS)) ? testnames : ARGS
if "linalg" in tests
# specifically selected case
filter!(x -> x != "linalg", tests)
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/lapack", "linalg/triangular", "linalg/tridiag", "linalg/pinv", "linalg/cholmod", "linalg/umfpack"])
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/lapack", "linalg/triangular", "linalg/tridiag", "linalg/pinv", "linalg/cholmod", "linalg/umfpack", "linalg/givens"])
end

net_required_for = ["socket", "parallel"]
Expand Down

0 comments on commit 89b0aa6

Please sign in to comment.