Skip to content

Commit

Permalink
make one(::AbstractMatrix) use similar instead of zeros (JuliaL…
Browse files Browse the repository at this point in the history
  • Loading branch information
nsajko committed Apr 20, 2024
1 parent 2aa55e2 commit aad7245
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
17 changes: 17 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,23 @@ zero(x::AbstractArray{T}) where {T<:Number} = fill!(similar(x, typeof(zero(T))),
zero(x::AbstractArray{S}) where {S<:Union{Missing, Number}} = fill!(similar(x, typeof(zero(S))), zero(S))
zero(x::AbstractArray) = map(zero, x)

function _one(unit::T, mat::AbstractMatrix) where {T}
(rows, cols) = axes(mat)
(length(rows) == length(cols)) ||
throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
zer = zero(unit)::T
require_one_based_indexing(mat)
I = similar(mat, T)
fill!(I, zer)
for i rows
I[i, i] = unit
end
I
end

one(x::AbstractMatrix{T}) where {T} = _one(one(T), x)
oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)

## iteration support for arrays by iterating over `eachindex` in the array ##
# Allows fast iteration by default for both IndexLinear and IndexCartesian arrays

Expand Down
15 changes: 0 additions & 15 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -598,21 +598,6 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
end
end

function _one(unit::T, x::AbstractMatrix) where T
require_one_based_indexing(x)
m,n = size(x)
m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
# Matrix{T}(I, m, m)
I = zeros(T, m, m)
for i in 1:m
I[i,i] = unit
end
I
end

one(x::AbstractMatrix{T}) where {T} = _one(one(T), x)
oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)

## Conversions ##

convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a)::T
Expand Down
17 changes: 17 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2119,3 +2119,20 @@ end
end
end
end

@testset "one" begin
@test one([1 2; 3 4]) == [1 0; 0 1]
@test one([1 2; 3 4]) isa Matrix{Int}

struct Mat <: AbstractMatrix{Int}
p::Matrix{Int}
end
Base.size(m::Mat) = size(m.p)
Base.IndexStyle(::Type{<:Mat}) = IndexLinear()
Base.getindex(m::Mat, i::Int) = m.p[i]
Base.setindex!(m::Mat, v, i::Int) = m.p[i] = v
Base.similar(::Mat, ::Type{Int}, size::NTuple{2,Int}) = Mat(Matrix{Int}(undef, size))

@test one(Mat([1 2; 3 4])) == Mat([1 0; 0 1])
@test one(Mat([1 2; 3 4])) isa Mat
end

0 comments on commit aad7245

Please sign in to comment.