Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transpose for arrays #9466

Merged
merged 1 commit into from
Dec 30, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add (c)transpose for Givens and Rotation types. Let AbstractMatrix th…
…row on (c)transpose to avoid silent errors when (c)transpose is not implemented for a matrix type. Add some tests. Remove r property from Givens type and return it from givens method.
  • Loading branch information
andreasnoack committed Dec 26, 2014
commit 2cdb63c97f0825e9c98ccbbb133255219ce01410
2 changes: 2 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ endof(a::AbstractArray) = length(a)
first(a::AbstractArray) = a[1]
first(a) = next(a,start(a))[1]
last(a) = a[end]
ctranspose(a::AbstractArray) = error("ctranspose not implemented for $(typeof(a)). Consider adding parentheses, e.g. A*(B*C') instead of A*B*C' to avoid explicit calculation of the transposed matrix.")
transpose(a::AbstractArray) = error("transpose not implemented for $(typeof(a)). Consider adding parentheses, e.g. A*(B*C.') instead of A*B*C' to avoid explicit calculation of the transposed matrix.")

function stride(a::AbstractArray, i::Integer)
if i > ndims(a)
Expand Down
2 changes: 2 additions & 0 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
abstract Factorization{T}

eltype{T}(F::Factorization{T}) = T
transpose(F::Factorization) = error("transpose not implemented for $(typeof(F))")
ctranspose(F::Factorization) = error("ctranspose not implemented for $(typeof(F))")

macro assertposdef(A, info)
:(($info)==0 ? $A : throw(PosDefException($info)))
Expand Down
38 changes: 28 additions & 10 deletions base/linalg/givens.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
immutable Givens{T}
abstract AbstractRotation{T}

transpose(R::AbstractRotation) = error("transpose not implemented for $(typeof(R)). Consider using conjugate transpose (') instead of transpose (.').")

function *{T,S}(R::AbstractRotation{T}, A::AbstractMatrix{S})
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
A_mul_B!(convert(AbstractRotation{TS}, R), TS == S ? copy(A) : convert(AbstractArray{TS}, A))
end
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, R::AbstractRotation{S})
TS = typeof(zero(T)*zero(S) + zero(T)*zero(S))
A_mul_Bc!(TS == T ? copy(A) : convert(AbstractArray{TS}, A), convert(AbstractRotation{TS}, R))
end

immutable Givens{T} <: AbstractRotation{T}
i1::Int
i2::Int
c::T
s::T
r::T
end
type Rotation{T}
rotations::Vector{T}
type Rotation{T} <: AbstractRotation{T}
rotations::Vector{Givens{T}}
end
typealias AbstractRotation Union(Givens, Rotation)

convert{T}(::Type{Givens{T}}, G::Givens{T}) = G
convert{T}(::Type{Givens{T}}, G::Givens) = Givens(G.i1, G.i2, convert(T, G.c), convert(T, G.s))
convert{T}(::Type{Rotation{T}}, R::Rotation{T}) = R
convert{T}(::Type{Rotation{T}}, R::Rotation) = Rotation{T}([convert(Givens{T}, g) for g in R.rotations])
convert{T}(::Type{AbstractRotation{T}}, G::Givens) = convert(Givens{T}, G)
convert{T}(::Type{AbstractRotation{T}}, R::Rotation) = convert(Rotation{T}, R)

ctranspose(G::Givens) = Givens(G.i1, G.i2, conj(G.c), -G.s)
ctranspose{T}(R::Rotation{T}) = Rotation{T}(reverse!([ctranspose(r) for r in R.rotations]))

realmin2(::Type{Float32}) = reinterpret(Float32, 0x26000000)
realmin2(::Type{Float64}) = reinterpret(Float64, 0x21a0000000000000)
Expand Down Expand Up @@ -187,13 +208,13 @@ end
function givens{T}(f::T, g::T, i1::Integer, i2::Integer)
i1 < i2 || error("second index must be larger than the first")
c, s, r = givensAlgorithm(f, g)
Givens(i1, i2, convert(T, c), convert(T, s), convert(T, r))
Givens(i1, i2, convert(T, c), convert(T, s)), r
end

function givens{T}(A::AbstractMatrix{T}, i1::Integer, i2::Integer, col::Integer)
i1 < i2 || error("second index must be larger than the first")
c, s, r = givensAlgorithm(A[i1,col], A[i2,col])
Givens(i1, i2, convert(T, c), convert(T, s), convert(T, r))
Givens(i1, i2, convert(T, c), convert(T, s)), r
end

getindex(G::Givens, i::Integer, j::Integer) = i == j ? (i == G.i1 || i == G.i2 ? G.c : one(G.c)) : (i == G.i1 && j == G.i2 ? G.s : (i == G.i2 && j == G.i1 ? -G.s : zero(G.s)))
Expand Down Expand Up @@ -236,6 +257,3 @@ function A_mul_Bc!(A::AbstractMatrix, R::Rotation)
return A
end
*{T}(G1::Givens{T}, G2::Givens{T}) = Rotation(push!(push!(Givens{T}[], G2), G1))
*(R::AbstractRotation, A::AbstractMatrix) = A_mul_B!(R, copy(A))

A_mul_Bc(A::AbstractMatrix, R::AbstractRotation) = A_mul_Bc!(copy(A), R)
34 changes: 34 additions & 0 deletions test/linalg/givens.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
let
debug = false

# Test givens rotations
for elty in (Float32, Float64, Complex64, Complex128)

debug && println("elty is $elty")

if elty <: Real
A = convert(Matrix{elty}, randn(10,10))
else
A = convert(Matrix{elty}, complex(randn(10,10),randn(10,10)))
end
Ac = copy(A)
R = Base.LinAlg.Rotation(Base.LinAlg.Givens{elty}[])
for j = 1:8
for i = j+2:10
G, _ = givens(A, j+1, i, j)
A_mul_B!(G, A)
A_mul_Bc!(A, G)
A_mul_B!(G, R)

# test transposes
@test_approx_eq ctranspose(G)*G*eye(10) eye(elty, 10)
@test_approx_eq ctranspose(R)*(R*eye(10)) eye(elty, 10)
@test_throws ErrorException transpose(G)
@test_throws ErrorException transpose(R)
end
end
@test_approx_eq abs(A) abs(hessfact(Ac)[:H])
@test_approx_eq norm(R*eye(elty, 10)) one(elty)

end
end #let
21 changes: 0 additions & 21 deletions test/linalg2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,6 @@ for elty in (Complex64, Complex128)
@test_approx_eq triu(LinAlg.BLAS.her2k('U','C',U,V)) triu(U'*V + V'*U)
end

# Test givens rotations
for elty in (Float32, Float64, Complex64, Complex128)
if elty <: Real
A = convert(Matrix{elty}, randn(10,10))
else
A = convert(Matrix{elty}, complex(randn(10,10),randn(10,10)))
end
Ac = copy(A)
R = Base.LinAlg.Rotation(Base.LinAlg.Givens{elty}[])
for j = 1:8
for i = j+2:10
G = givens(A, j+1, i, j)
A_mul_B!(G, A)
A_mul_Bc!(A, G)
A_mul_B!(G, R)
end
end
@test_approx_eq abs(A) abs(hessfact(Ac)[:H])
@test_approx_eq norm(R*eye(elty, 10)) one(elty)
end

# Test gradient
for elty in (Int32, Int64, Float32, Float64, Complex64, Complex128)
if elty <: Real
Expand Down
10 changes: 10 additions & 0 deletions test/linalg4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
end
end

# Because transpose(x) == x
@test_throws ErrorException transpose(qrfact(randn(3,3)))
@test_throws ErrorException ctranspose(qrfact(randn(3,3)))
@test_throws ErrorException transpose(qrfact(randn(3,3), pivot = false))
@test_throws ErrorException ctranspose(qrfact(randn(3,3), pivot = false))
@test_throws ErrorException transpose(qrfact(big(randn(3,3))))
@test_throws ErrorException ctranspose(qrfact(big(randn(3,3))))
@test_throws ErrorException transpose(sub(sprandn(10, 10, 0.3), 1:4, 1:4))
@test_throws ErrorException ctranspose(sub(sprandn(10, 10, 0.3), 1:4, 1:4))

# Issue #7933
A7933 = [1 2; 3 4]
B7933 = copy(A7933)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tests = (ARGS==["all"] || isempty(ARGS)) ? testnames : ARGS
if "linalg" in tests
# specifically selected case
filter!(x -> x != "linalg", tests)
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/lapack", "linalg/triangular", "linalg/tridiag", "linalg/pinv", "linalg/cholmod", "linalg/umfpack"])
prepend!(tests, ["linalg1", "linalg2", "linalg3", "linalg4", "linalg/lapack", "linalg/triangular", "linalg/tridiag", "linalg/pinv", "linalg/cholmod", "linalg/umfpack", "linalg/givens"])
end

net_required_for = ["socket", "parallel"]
Expand Down