Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nick/tenmat docs #294

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
SPTENMAT: Add copy argument to give option to try and optimize views.
* Less critical for sparse tensors, so copy not added to to_sptenmat.
  • Loading branch information
ntjohnson1 committed Dec 16, 2023
commit ae7ef7803c7ba69b6581913755e947b9b2539439
50 changes: 32 additions & 18 deletions pyttb/sptenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__( # noqa: PLR0913
rdims: Optional[np.ndarray] = None,
cdims: Optional[np.ndarray] = None,
tshape: Tuple[int, ...] = (),
copy: bool = True,
):
"""
Construct a :class:`pyttb.sptenmat` from a set of 2D subscripts (subs)
Expand All @@ -50,6 +51,9 @@ def __init__( # noqa: PLR0913
Mapping of column indices.
tshape:
Shape of the original tensor.
copy:
Whether to make a copy of provided data or just reference it.
Skips error checking when just setting reference.

Examples
--------
Expand Down Expand Up @@ -121,11 +125,13 @@ def __init__( # noqa: PLR0913
), "Invalid column index."

# Sum any duplicates
newsubs = subs
newvals = vals
if vals.size == 0:
assert vals.size == 0, "Empty subs requires empty vals"
newsubs = np.array([])
newvals = np.array([])
else:
elif copy:
# Identify only the unique indices
newsubs, loc = np.unique(subs, axis=0, return_inverse=True)
# Sum the corresponding values
Expand All @@ -134,19 +140,26 @@ def __init__( # noqa: PLR0913
loc, np.squeeze(vals, axis=1), size=newsubs.shape[0], func=sum
)

# Find the nonzero indices of the new values
nzidx = np.nonzero(newvals)
newsubs = newsubs[nzidx]
# None index to convert from row back to column vector
newvals = newvals[nzidx]
if newvals.size > 0:
newvals = newvals[:, None]

self.tshape = tshape
self.rdims = rdims.copy().astype(int)
self.cdims = cdims.copy().astype(int)
self.subs = newsubs
self.vals = newvals
if copy:
# Find the nonzero indices of the new values
nzidx = np.nonzero(newvals)
newsubs = newsubs[nzidx]
# None index to convert from row back to column vector
newvals = newvals[nzidx]
if newvals.size > 0:
newvals = newvals[:, None]

self.tshape = tshape
self.rdims = rdims.copy().astype(int)
self.cdims = cdims.copy().astype(int)
self.subs = newsubs
self.vals = newvals
else:
self.tshape = tshape
self.rdims = rdims
self.cdims = cdims
self.subs = newsubs
self.vals = newvals

@classmethod
def from_array(
Expand Down Expand Up @@ -227,11 +240,12 @@ def copy(self) -> sptenmat:
False
"""
return sptenmat(
self.subs.copy(),
self.vals.copy(),
self.rdims.copy(),
self.cdims.copy(),
self.subs,
self.vals,
self.rdims,
self.cdims,
self.tshape,
copy=True,
)

def __deepcopy__(self, memo):
Expand Down
2 changes: 1 addition & 1 deletion pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
subs: Optional[np.ndarray] = None,
vals: Optional[np.ndarray] = None,
shape: Optional[Tuple[int, ...]] = None,
copy=True,
copy: bool = True,
):
"""
Construct a :class:`pyttb.sptensor` from a set of `subs` (subscripts),
Expand Down
5 changes: 5 additions & 0 deletions tests/test_sptenmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def test_sptenmat_initialization_from_data(sample_sptenmat):
np.testing.assert_array_equal(S.tshape, tshape)
np.testing.assert_array_equal(S.shape, shape)

# Constructor from data as reference
S = ttb.sptenmat(subs, vals, rdims, cdims, tshape, copy=False)
assert np.may_share_memory(S.subs, subs)
assert np.may_share_memory(S.vals, vals)

# Constructor from data: rdims, cdims, and tshape
S = ttb.sptenmat(rdims=rdims, cdims=cdims, tshape=tshape)
np.testing.assert_array_equal(S.subs, np.array([]))
Expand Down