Skip to content

Commit

Permalink
implemented error checking for {h,v,hv, }cat (closes #183)
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeXing committed Oct 12, 2011
1 parent dab5a0c commit c82e159
Showing 1 changed file with 79 additions and 2 deletions.
81 changes: 79 additions & 2 deletions j/abstractarray.j
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ end

## Concatenation ##

#TODO: ERROR CHECK
cat(catdim::Int) = Array(None, 0)

vcat() = Array(None, 0)
Expand All @@ -571,7 +572,13 @@ hcat() = Array(None, 0)
hcat{T}(X::T...) = [ X[j] | i=1, j=1:length(X) ]
vcat{T}(X::T...) = [ X[i] | i=1:length(X) ]

hcat{T}(V::AbstractVector{T}...) = [ V[j][i] | i=1:length(V[1]), j=1:length(V) ]
function hcat{T}(V::AbstractVector{T}...)
height = length(V[1])
for j = 2:length(V)
if length(V[j]) != height; error("hcat: mismatched dimensions"); end
end
[ V[j][i] | i=1:length(V[1]), j=1:length(V) ]
end

function vcat{T}(V::AbstractVector{T}...)
n = 0
Expand All @@ -594,6 +601,10 @@ function hcat{T}(A::AbstractMatrix{T}...)
nargs = length(A)
ncols = sum(a->size(a, 2), A)::Size
nrows = size(A[1], 1)
for j = 2:nargs
if size(A[j], 1) != nrows; error("hcat: mismatched dimensions"); end
end

B = similar(A[1], nrows, ncols)
pos = 1
for k=1:nargs
Expand All @@ -610,6 +621,10 @@ function vcat{T}(A::AbstractMatrix{T}...)
nargs = length(A)
nrows = sum(a->size(a, 1), A)::Size
ncols = size(A[1], 2)
for j = 2:nargs
if size(A[j], 2) != ncols; error("vcat: mismatched dimensions"); end
end

B = similar(A[1], nrows, ncols)
pos = 1
for k=1:nargs
Expand All @@ -633,7 +648,18 @@ function cat(catdim::Int, X...)
if catdim > d_max + 1
for i=1:nargs
if dimsX[1] != dimsX[i]
error("all inputs must have same dimensions when concatenating along a higher dimension");
error("cat: all inputs must have same dimensions when concatenating along a higher dimension");
end
end
elseif nargs >= 2
for d=1:d_max
if d == catdim; continue; end
len = d <= ndimsX[1] ? dimsX[1][d] : 1
for i = 2:nargs
if len != (d <= ndimsX[i] ? dimsX[i][d] : 1)
error("cat: dimension mismatch on dimension", d)
#error("lala $d")
end
end
end
end
Expand Down Expand Up @@ -682,6 +708,24 @@ function cat(catdim::Int, A::AbstractArray...)
ndimsA = map(ndims, A)
d_max = max(ndimsA)

if catdim > d_max + 1
for i=1:nargs
if dimsA[1] != dimsA[i]
error("cat: all inputs must have same dimensions when concatenating along a higher dimension");
end
end
elseif nargs >= 2
for d=1:d_max
if d == catdim; continue; end
len = d <= ndimsA[1] ? dimsA[1][d] : 1
for i = 2:nargs
if len != (d <= ndimsA[i] ? dimsA[i][d] : 1)
error("cat: dimension mismatch on dimension ", d)
end
end
end
end

cat_ranges = ntuple(nargs, i->(catdim <= ndimsA[i] ? dimsA[i][catdim] : 1))

function compute_dims(d)
Expand Down Expand Up @@ -726,10 +770,30 @@ function hvcat{T}(rows::(Size...), as::AbstractMatrix{T}...)
nc = mapreduce(+, a->size(a,2), as[1:rows[1]])::Size
nr = 0

a_index = cumsum(rows)
a = 1
for i = 1:nbr
nr += size(as[a],1)
a += rows[i]
#error checking
#num rows in each block row
if rows[i] > 1
first_row_index = (i == 1 ? 1 : a_index[i-1] + 1)
first_height = size(as[first_row_index],1)
for j = (first_row_index + 1):a_index[i]
if size(as[j],1) != first_height
error("hvcat: mismatched height in block row ", i)
end
end
end

if i != 1
#check num columns
nci = mapreduce(+, b->size(b,2), as[(a_index[i-1]+1):a_index[i]])::Size
if nc != nci
error("hvcat: block row ", i, " has mismatched number of columns")
end
end
end

out = similar(as[1], T, nr, nc)
Expand All @@ -755,6 +819,13 @@ hvcat(rows::(Size...)) = []
function hvcat{T<:Number}(rows::(Size...), xs::T...)
nr = length(rows)
nc = rows[1]
#error check
for i = 2:nr
if nc != rows[i]
error("hvcat: row ", i, " has mismatched number of columns")
end
end

a = Array(T, nr, nc)
k = 1
for i=1:nr
Expand All @@ -781,6 +852,12 @@ end
function hvcat(rows::(Size...), xs::Number...)
nr = length(rows)
nc = rows[1]
#error check
for i = 2:nr
if nc != rows[i]
error("hvcat: row ", i, " has mismatched number of columns")
end
end
T = typeof(xs[1])
for i=2:length(xs)
T = promote_type(T,typeof(xs[i]))
Expand Down

0 comments on commit c82e159

Please sign in to comment.