Skip to content

Commit

Permalink
include review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored and pull[bot] committed Jul 19, 2023
1 parent 8d84c4a commit b6189c7
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,19 +419,18 @@ promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix

_us2number(A) = A
_us2number(J::UniformScaling) = J.λ
szfun(::UniformScaling, _) = -1
szfun(A, dim) = (require_one_based_indexing(A); return size(A, dim))

for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"))
@eval begin
@inline $f(A::Union{AbstractVecOrMat,UniformScaling}...) = $_f(A...)
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $f(map(_us2number, A)...)
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $_f(A...)
function $_f(A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
n = -1
for a in A
if !isa(a, UniformScaling)
require_one_based_indexing(a)
na = size(a,$dim)
sizes = map(a -> szfun(a, $dim), A)
for na in sizes
if na != -1
n >= 0 && n != na &&
throw(DimensionMismatch(string("number of ", $name,
" of each array must match (got ", n, " and ", na, ")")))
Expand All @@ -455,9 +454,9 @@ function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScali
j = 0
for i = 1:nr # infer UniformScaling sizes from row counts, if possible:
ni = -1 # number of rows in this block-row, -1 indicates unknown
for k = 1:rows[i]
if !isa(A[j+k], UniformScaling)
na = size(A[j+k], 1)
sizes = map(a -> szfun(a, 1), A[j+1:j+rows[i]])
for na in sizes
if na != -1
ni >= 0 && ni != na &&
throw(DimensionMismatch("mismatch in number of rows"))
ni = na
Expand Down

0 comments on commit b6189c7

Please sign in to comment.