Skip to content

Commit

Permalink
[Sparse] Add sparse matrix slicing operator implementation (dmlc#6208)
Browse files Browse the repository at this point in the history
Co-authored-by: Hongzhi (Steve), Chen <[email protected]>
  • Loading branch information
2 people authored and DominikaJedynak committed Mar 12, 2024
1 parent e74ab92 commit 6d58986
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 4 deletions.
45 changes: 45 additions & 0 deletions dgl_sparse/include/sparse/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,51 @@ class SparseMatrix : public torch::CustomClassHolder {
static c10::intrusive_ptr<SparseMatrix> FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape);

/**
* @brief Create a SparseMatrix by selecting rows or columns based on provided
* indices.
*
* This function allows you to create a new SparseMatrix by selecting specific
* rows or columns from the original SparseMatrix based on the provided
* indices. The selection can be performed either row-wise or column-wise,
* determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param ids A tensor containing the indices of the selected rows or columns.
*
* @return A new SparseMatrix containing the selected rows or columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'ids' tensor should contain valid indices within the range of the
* original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);

/**
* @brief Create a SparseMatrix by selecting a range of rows or columns based
* on provided indices.
*
* This function allows you to create a new SparseMatrix by selecting a range
* of specific rows or columns from the original SparseMatrix based on the
* provided indices. The selection can be performed either row-wise or
* column-wise, determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param start The starting index (inclusive) of the range.
* @param end The ending index (exclusive) of the range.
*
* @return A new SparseMatrix containing the selected range of rows or
* columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'start' and 'end' indices should be valid indices within
* the valid range of the original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> RangeSelect(
int64_t dim, int64_t start, int64_t end);

/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
Expand Down
4 changes: 3 additions & 1 deletion dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("transpose", &SparseMatrix::Transpose)
.def("coalesce", &SparseMatrix::Coalesce)
.def("has_duplicate", &SparseMatrix::HasDuplicate)
.def("is_diag", &SparseMatrix::HasDiag);
.def("is_diag", &SparseMatrix::HasDiag)
.def("index_select", &SparseMatrix::IndexSelect)
.def("range_select", &SparseMatrix::RangeSelect);
m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC)
Expand Down
45 changes: 45 additions & 0 deletions dgl_sparse/src/sparse_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include "./utils.h"

namespace dgl {
namespace sparse {

Expand Down Expand Up @@ -122,6 +124,49 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
return SparseMatrix::FromDiagPointer(diag, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
int64_t dim, torch::Tensor ids) {
auto id_array = TorchTensorToDGLArray(ids);
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
int64_t dim, int64_t start, int64_t end) {
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK(
Expand Down
14 changes: 11 additions & 3 deletions python/dgl/sparse/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def index_select(self, dim: int, index: torch.Tensor):
dim : int
The dim to select from matrix, should be 0 or 1. `dim = 0` for
rowwise selection and `dim = 1` for columnwise selection.
index : tensor.Tensor
index : torch.Tensor
The selection index indicates which IDs from the `dim` should
be chosen from the matrix.
Note that duplicated ids are allowed.
Expand Down Expand Up @@ -527,7 +527,7 @@ def index_select(self, dim: int, index: torch.Tensor):
if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, torch.Tensor):
raise NotImplementedError
return SparseMatrix(self.c_sparse_matrix.index_select(dim, index))
raise TypeError(f"{type(index).__name__} is unsupported input type.")

def range_select(self, dim: int, index: slice):
Expand Down Expand Up @@ -575,7 +575,15 @@ def range_select(self, dim: int, index: slice):
if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, slice):
raise NotImplementedError
if index.step not in (None, 1):
raise NotImplementedError(
"Slice with step other than 1 are not supported yet."
)
start = 0 if index.start is None else index.start
end = index.stop
return SparseMatrix(
self.c_sparse_matrix.range_select(dim, start, end)
)
raise TypeError(f"{type(index).__name__} is unsupported input type.")


Expand Down
54 changes: 54 additions & 0 deletions tests/python/pytorch/sparse/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
val_like,
)

from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)


def _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape):
with warnings.catch_warnings():
Expand Down Expand Up @@ -450,6 +458,52 @@ def test_has_duplicate():
assert csc_A.has_duplicate()


@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("index", [(0, 1, 3), (1, 2)])
def test_index_select(create_func, shape, dense_dim, select_dim, index):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
index = torch.tensor(index).to(ctx)
A_select = A.index_select(select_dim, index)

dense = sparse_matrix_to_dense(A)
dense_select = torch.index_select(dense, select_dim, index)

A_select_to_dense = sparse_matrix_to_dense(A_select)

assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)


@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("rang", [slice(0, 2), slice(1, 3)])
def test_range_select(create_func, shape, dense_dim, select_dim, rang):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
A_select = A.range_select(select_dim, rang)

dense = sparse_matrix_to_dense(A)
if select_dim == 0:
dense_select = dense[rang, :]
else:
dense_select = dense[:, rang]

A_select_to_dense = sparse_matrix_to_dense(A_select)

assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)


def test_print():
ctx = F.ctx()

Expand Down

0 comments on commit 6d58986

Please sign in to comment.