Skip to content

Commit

Permalink
flatten recursive call to cat_shape (JuliaLang#36838)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnychen94 committed Aug 14, 2020
1 parent 1a47fce commit 51be63c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
13 changes: 9 additions & 4 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1535,9 +1535,14 @@ cat_indices(A::AbstractArray, d) = axes(A, d)
cat_similar(A, T, shape) = Array{T}(undef, shape)
cat_similar(A::AbstractArray, T, shape) = similar(A, T, shape)

cat_shape(dims, shape::Tuple) = shape
@inline cat_shape(dims, shape::Tuple, nshape::Tuple, shapes::Tuple...) =
cat_shape(dims, _cshp(1, dims, shape, nshape), shapes...)
cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape
function cat_shape(dims, shapes::Tuple)
out_shape = ()
for s in shapes
out_shape = _cshp(1, dims, out_shape, s)
end
return out_shape
end

_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
Expand Down Expand Up @@ -1581,7 +1586,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
@inline function _cat_t(dims, T::Type, X...)
catdims = dims2cat(dims)
shape = cat_shape(catdims, (), map(cat_size, X)...)
shape = cat_shape(catdims, map(cat_size, X))
A = cat_similar(X[1], T, shape)
if count(!iszero, catdims) > 1
fill!(A, zero(T))
Expand Down
2 changes: 1 addition & 1 deletion base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1810,7 +1810,7 @@ end
# general case, specialized for BitArrays and Integers
function _cat(dims::Integer, X::Union{BitArray, Bool}...)
catdims = dims2cat(dims)
shape = cat_shape(catdims, (), map(cat_size, X)...)
shape = cat_shape(catdims, map(cat_size, X))
A = falses(shape)
return __cat(A, shape, catdims, X...)
end
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,8 @@ function tuple_type_cons(::Type{S}, ::Type{T}) where T<:Tuple where S
Tuple{S, T.parameters...}
end

# these were internal functions, but some packages seem to be relying on them
@deprecate cat_shape(dims, shape::Tuple{}, shapes::Tuple...) cat_shape(dims, shapes)
cat_shape(dims, shape::Tuple{}) = () # make sure `cat_shape(dims, ())` do not recursively calls itself

# END 1.6 deprecations

0 comments on commit 51be63c

Please sign in to comment.