Skip to content

Commit

Permalink
Add indirection to inplace matrix scaling (#52840)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Jan 10, 2024
1 parent bf13a56 commit 124ce94
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
10 changes: 8 additions & 2 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,21 @@ match the length of the second, $(length(X))."))
C
end

@inline function mul!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
@inline mul!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number) =
_lscale_add!(C, s, X, alpha, beta)

@inline function _lscale_add!(C::AbstractArray, s::Number, X::AbstractArray, alpha::Number, beta::Number)
if axes(C) == axes(X)
C .= (s .* X) .*ₛ alpha .+ C .*ₛ beta
else
generic_mul!(C, s, X, MulAddMul(alpha, beta))
end
return C
end
@inline function mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number)
@inline mul!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number) =
_rscale_add!(C, X, s, alpha, beta)

@inline function _rscale_add!(C::AbstractArray, X::AbstractArray, s::Number, alpha::Number, beta::Number)
if axes(C) == axes(X)
C .= (X .* s) .*ₛ alpha .+ C .*ₛ beta
else
Expand Down
6 changes: 2 additions & 4 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,9 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular
return A
end

# Define `mul!` for (Unit){Upper,Lower}Triangular matrices times a number.
# be permissive here and require compatibility later in _triscale!
@inline mul!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
_triscale!(A, B, C, MulAddMul(alpha, beta))
@inline mul!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =
@inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =
_triscale!(A, B, C, MulAddMul(alpha, beta))

function checksize1(A, B)
Expand Down

0 comments on commit 124ce94

Please sign in to comment.