Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nick/tenmat docs #294

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add copy constructor for tenmat.
  • Loading branch information
ntjohnson1 committed Nov 24, 2023
commit e5d45ec081e09caf6a14a366cb695c9bf5ed0546
56 changes: 39 additions & 17 deletions pyttb/sptenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def from_data( # noqa: PLR0913
newsubs, loc = np.unique(subs, axis=0, return_inverse=True)
# Sum the corresponding values
# Squeeze to convert from column vector to row vector
newvals = accumarray(loc, np.squeeze(vals), size=newsubs.shape[0], func=sum)
newvals = accumarray(
loc, np.squeeze(vals, axis=1), size=newsubs.shape[0], func=sum
)

# Find the nonzero indices of the new values
nzidx = np.nonzero(newvals)
Expand All @@ -115,27 +117,16 @@ def from_data( # noqa: PLR0913
@classmethod
def from_tensor_type( # noqa: PLR0912
cls,
source: Union[ttb.sptensor, ttb.sptenmat],
source: Union[ttb.sptensor],
rdims: Optional[np.ndarray] = None,
cdims: Optional[np.ndarray] = None,
cdims_cyclic: Optional[
Union[Literal["fc"], Literal["bc"], Literal["t"]]
] = None,
):
valid_sources = (sptenmat, ttb.sptensor)
assert isinstance(source, valid_sources), (
"Can only generate sptenmat from "
f"{[src.__name__ for src in valid_sources]} but received {type(source)}."
assert isinstance(source, ttb.sptensor), (
"Can only generate sptenmat from " f"sptensor but received {type(source)}."
)
# Copy Constructor
if isinstance(source, sptenmat):
return cls().from_data(
source.subs.copy(),
source.vals.copy(),
source.rdims.copy(),
source.cdims.copy(),
source.tshape,
)

if isinstance(source, ttb.sptensor):
n = source.ndims
Expand Down Expand Up @@ -242,6 +233,37 @@ def from_array(
subs = np.vstack(array.nonzero()).transpose()
return ttb.sptenmat.from_data(subs, vals, rdims, cdims, tshape)

def copy(self) -> sptenmat:
"""
Return a deep copy of the :class:`pyttb.sptenmat`.

Examples
--------
Create a :class:`pyttb.sptenmat` (ST1) and make a deep copy. Verify
the deep copy (ST3) is not just a reference (like ST2) to the original.

>>> S1 = ttb.sptensor(shape=(2,2))
>>> S1[0,0] = 1
>>> ST1 = ttb.sptenmat.from_tensor_type(S1, np.array([0]))
>>> ST2 = ST1
>>> ST3 = ST1.copy()
>>> ST1[0,0] = 3
>>> ST1.to_sptensor().isequal(ST2.to_sptensor())
True
>>> ST1.to_sptensor().isequal(ST3.to_sptensor())
False
"""
return sptenmat().from_data(
self.subs.copy(),
self.vals.copy(),
self.rdims.copy(),
self.cdims.copy(),
self.tshape,
)

def __deepcopy__(self, memo):
return self.copy()

def to_sptensor(self) -> ttb.sptensor:
"""
Contruct a :class:`pyttb.sptensor` from `:class:pyttb.sptenmat`
Expand Down Expand Up @@ -312,13 +334,13 @@ def __pos__(self):
"""
Unary plus operator (+).
"""
return self.from_tensor_type(self)
return self.copy()

def __neg__(self):
"""
Unary minus operator (-).
"""
result = self.from_tensor_type(self)
result = self.copy()
result.vals *= -1
return result

Expand Down
60 changes: 40 additions & 20 deletions pyttb/tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,6 @@ def from_tensor_type( # noqa: PLR0912
-------
Constructed tenmat
"""
# Case 0b: Copy Constructor
if isinstance(source, tenmat):
# Create tenmat
tenmatInstance = cls()
tenmatInstance.tshape = source.tshape
tenmatInstance.rindices = source.rindices.copy()
tenmatInstance.cindices = source.cindices.copy()
tenmatInstance.data = source.data.copy()
return tenmatInstance

# Case III: Convert a tensor to a tenmat
if isinstance(source, ttb.tensor):
n = source.ndims
Expand Down Expand Up @@ -210,9 +200,39 @@ def from_tensor_type( # noqa: PLR0912
tenmatInstance.data = data.copy()
return tenmatInstance
raise ValueError(
f"Can only create tenmat from tensor or tenmat but recieved {type(source)}"
f"Can only create tenmat from tensor but recieved {type(source)}"
)

def copy(self) -> tenmat:
"""
Return a deep copy of the :class:`pyttb.tenmat`.

Examples
--------
Create a :class:`pyttb.tenmat` (TM1) and make a deep copy. Verify
the deep copy (TM3) is not just a reference (like TM2) to the original.

>>> T1 = ttb.tensor(np.ones((3,2)))
>>> TM1 = ttb.tenmat.from_tensor_type(T1, np.array([0]))
>>> TM2 = TM1
>>> TM3 = TM1.copy()
>>> TM1[0,0] = 3
>>> TM1[0,0] == TM2[0,0]
True
>>> TM1[0,0] == TM3[0,0]
False
"""
# Create tenmat
tenmatInstance = tenmat()
tenmatInstance.tshape = self.tshape
tenmatInstance.rindices = self.rindices.copy()
tenmatInstance.cindices = self.cindices.copy()
tenmatInstance.data = self.data.copy()
return tenmatInstance

def __deepcopy__(self, memo):
return self.copy()

def to_tensor(self) -> ttb.tensor:
"""Return copy of tenmat data as a tensor"""
# RESHAPE TENSOR-AS-MATRIX
Expand Down Expand Up @@ -309,7 +329,7 @@ def __mul__(self, other):
"""
# One argument is a scalar
if np.isscalar(other):
Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = Z.data * other
return Z
if isinstance(other, tenmat):
Expand Down Expand Up @@ -370,15 +390,15 @@ def __add__(self, other):

# One argument is a scalar
if np.isscalar(other):
Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = Z.data + other
return Z
if isinstance(other, tenmat):
# Check that data shapes agree
if not self.shape == other.shape:
assert False, "tenmat shape mismatch."

Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = Z.data + other.data
return Z
assert False, "tenmat addition only valid with scalar or tenmat objects."
Expand Down Expand Up @@ -412,15 +432,15 @@ def __sub__(self, other):

# One argument is a scalar
if np.isscalar(other):
Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = Z.data - other
return Z
if isinstance(other, tenmat):
# Check that data shapes agree
if not self.shape == other.shape:
assert False, "tenmat shape mismatch."

Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = Z.data - other.data
return Z
assert False, "tenmat subtraction only valid with scalar or tenmat objects."
Expand All @@ -440,15 +460,15 @@ def __rsub__(self, other):

# One argument is a scalar
if np.isscalar(other):
Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = other - Z.data
return Z
if isinstance(other, tenmat):
# Check that data shapes agree
if not self.shape == other.shape:
assert False, "tenmat shape mismatch."

Z = ttb.tenmat.from_tensor_type(self)
Z = self.copy()
Z.data = other.data - Z.data
return Z
assert False, "tenmat subtraction only valid with scalar or tenmat objects."
Expand All @@ -463,7 +483,7 @@ def __pos__(self):
copy of tenmat
"""

T = ttb.tenmat.from_tensor_type(self)
T = self.copy()

return T

Expand All @@ -477,7 +497,7 @@ def __neg__(self):
copy of tenmat
"""

T = ttb.tenmat.from_tensor_type(self)
T = self.copy()
T.data = -1 * T.data

return T
Expand Down
12 changes: 11 additions & 1 deletion tests/test_sptenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from copy import deepcopy

import numpy as np
import pytest
from scipy import sparse
Expand Down Expand Up @@ -134,7 +136,15 @@ def test_sptenmat_initialization_from_tensor_type(
params, sptenmatInstance = sample_sptenmat
params3, sptensorInstance = sample_sptensor_3way
# Copy constructor
S = ttb.sptenmat.from_tensor_type(sptenmatInstance)
S = sptenmatInstance.copy()
assert S is not sptenmatInstance
np.testing.assert_array_equal(S.subs, sptenmatInstance.subs)
np.testing.assert_array_equal(S.vals, sptenmatInstance.vals)
np.testing.assert_array_equal(S.rdims, sptenmatInstance.rdims)
np.testing.assert_array_equal(S.cdims, sptenmatInstance.cdims)
np.testing.assert_array_equal(S.tshape, sptenmatInstance.tshape)

S = deepcopy(sptenmatInstance)
assert S is not sptenmatInstance
np.testing.assert_array_equal(S.subs, sptenmatInstance.subs)
np.testing.assert_array_equal(S.vals, sptenmatInstance.vals)
Expand Down
13 changes: 11 additions & 2 deletions tests/test_tenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

from copy import deepcopy

import numpy as np
import pytest

Expand Down Expand Up @@ -231,7 +233,14 @@ def test_tenmat_initialization_from_tensor_type(
data = params["data"]

# Copy Constructor
tenmatCopy = ttb.tenmat.from_tensor_type(tenmatInstance)
tenmatCopy = tenmatInstance.copy()
assert (tenmatCopy.data == data).all()
assert (tenmatCopy.rindices == rdims).all()
assert (tenmatCopy.cindices == cdims).all()
assert tenmatCopy.shape == data.shape
assert tenmatCopy.tshape == tshape

tenmatCopy = deepcopy(tenmatInstance)
assert (tenmatCopy.data == data).all()
assert (tenmatCopy.rindices == rdims).all()
assert (tenmatCopy.cindices == cdims).all()
Expand Down Expand Up @@ -427,7 +436,7 @@ def test_tenmat__setitem__():
tenmatInstance = ttb.tenmat.from_tensor_type(tensorInstance, rdims=np.array([0, 1]))

# single element -> scalar
tenmatInstance2 = ttb.tenmat.from_tensor_type(tenmatInstance)
tenmatInstance2 = tenmatInstance.copy()
for i in range(4):
for j in range(4):
tenmatInstance2[i, j] = i * 4 + j + 10
Expand Down