Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] Add sparse matrix slicing operator implementation #6208

Merged
merged 31 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
87448c8
Fix description and not emplement error.
xiangyuzhi Aug 23, 2023
83d060d
Fix description and not emplement error.
xiangyuzhi Aug 23, 2023
ead3874
merge API to one 'select'
xiangyuzhi Aug 28, 2023
9a6cd71
fix
xiangyuzhi Aug 29, 2023
35f6d6e
fix typecheck
xiangyuzhi Aug 29, 2023
6f3f1ee
update comments
xiangyuzhi Aug 29, 2023
7975131
fix type error
xiangyuzhi Aug 29, 2023
9ba1046
Update python/dgl/sparse/sparse_matrix.py
xiangyuzhi Aug 29, 2023
2564ddb
try to fix CI error
xiangyuzhi Aug 29, 2023
22f0daa
Update python/dgl/sparse/sparse_matrix.py
xiangyuzhi Aug 29, 2023
758ecc4
fix comment and CI
xiangyuzhi Aug 29, 2023
f4ce204
fix comments and CI
xiangyuzhi Aug 29, 2023
2f8b625
split API
xiangyuzhi Aug 29, 2023
710fb9f
fix CI
xiangyuzhi Aug 29, 2023
24e04cc
fix input type
xiangyuzhi Aug 30, 2023
94b296b
fix description
xiangyuzhi Aug 30, 2023
2289369
Add one row slice implement
xiangyuzhi Aug 25, 2023
14bc4c0
add col and unit test
xiangyuzhi Aug 28, 2023
eb37ecc
little fix
xiangyuzhi Aug 28, 2023
b792fd9
New API implementation
xiangyuzhi Aug 29, 2023
6613c37
fix description
xiangyuzhi Aug 29, 2023
0891d60
update new API
xiangyuzhi Aug 30, 2023
84dc943
fix update
xiangyuzhi Aug 30, 2023
2084dc0
fix bug and extend test
xiangyuzhi Sep 2, 2023
9d1f1bd
fix urange select
xiangyuzhi Sep 2, 2023
fd3a735
concise code and extend test
xiangyuzhi Sep 4, 2023
957f771
lint fix
xiangyuzhi Sep 4, 2023
1ddf91a
optimize code and change test
xiangyuzhi Sep 4, 2023
7caa5e5
add coo test and remove coo slice
xiangyuzhi Sep 5, 2023
822b21f
fix comments
xiangyuzhi Sep 5, 2023
539122c
add comment
xiangyuzhi Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions dgl_sparse/include/sparse/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,51 @@ class SparseMatrix : public torch::CustomClassHolder {
static c10::intrusive_ptr<SparseMatrix> FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape);

/**
* @brief Create a SparseMatrix by selecting rows or columns based on provided
* indices.
*
* This function allows you to create a new SparseMatrix by selecting specific
* rows or columns from the original SparseMatrix based on the provided
* indices. The selection can be performed either row-wise or column-wise,
* determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param ids A tensor containing the indices of the selected rows or columns.
*
* @return A new SparseMatrix containing the selected rows or columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'ids' tensor should contain valid indices within the range of the
* original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> IndexSelect(int64_t dim, torch::Tensor ids);

/**
* @brief Create a SparseMatrix by selecting a range of rows or columns based
* on provided indices.
*
* This function allows you to create a new SparseMatrix by selecting a range
* of specific rows or columns from the original SparseMatrix based on the
* provided indices. The selection can be performed either row-wise or
* column-wise, determined by the 'dim' parameter.
*
* @param dim Select rows (dim=0) or columns (dim=1).
* @param start The starting index (inclusive) of the range.
* @param end The ending index (exclusive) of the range.
*
* @return A new SparseMatrix containing the selected range of rows or
* columns.
*
* @note The 'dim' parameter should be either 0 (for row-wise selection) or 1
* (for column-wise selection).
* @note The 'start' and 'end' indices should be valid indices within
* the valid range of the original SparseMatrix's dimensions.
*/
c10::intrusive_ptr<SparseMatrix> RangeSelect(
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
int64_t dim, int64_t start, int64_t end);

/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
Expand Down
4 changes: 3 additions & 1 deletion dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("transpose", &SparseMatrix::Transpose)
.def("coalesce", &SparseMatrix::Coalesce)
.def("has_duplicate", &SparseMatrix::HasDuplicate)
.def("is_diag", &SparseMatrix::HasDiag);
.def("is_diag", &SparseMatrix::HasDiag)
.def("index_select", &SparseMatrix::IndexSelect)
.def("range_select", &SparseMatrix::RangeSelect);
m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC)
Expand Down
45 changes: 45 additions & 0 deletions dgl_sparse/src/sparse_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include "./utils.h"

namespace dgl {
namespace sparse {

Expand Down Expand Up @@ -122,6 +124,49 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
return SparseMatrix::FromDiagPointer(diag, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::IndexSelect(
int64_t dim, torch::Tensor ids) {
auto id_array = TorchTensorToDGLArray(ids);
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), id_array);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::RangeSelect(
int64_t dim, int64_t start, int64_t end) {
bool rowwise = dim == 0;
auto csr = rowwise ? this->CSRPtr() : this->CSCPtr();
auto slice_csr = dgl::aten::CSRSliceRows(CSRToOldDGLCSR(csr), start, end);
auto slice_value =
this->value().index_select(0, DGLArrayToTorchTensor(slice_csr.data));
// To prevent potential errors in future conversions to the COO format,
// where this array might be used as an initialization array for
// constructing COO representations, it is necessary to clear this array.
slice_csr.data = dgl::aten::NullArray();
auto ret = CSRFromOldDGLCSR(slice_csr);
if (rowwise) {
return SparseMatrix::FromCSRPointer(
ret, slice_value, {ret->num_rows, ret->num_cols});
} else {
return SparseMatrix::FromCSCPointer(
ret, slice_value, {ret->num_cols, ret->num_rows});
}
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK(
Expand Down
14 changes: 11 additions & 3 deletions python/dgl/sparse/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def index_select(self, dim: int, index: torch.Tensor):
dim : int
The dim to select from matrix, should be 0 or 1. `dim = 0` for
rowwise selection and `dim = 1` for columnwise selection.
index : tensor.Tensor
index : torch.Tensor
The selection index indicates which IDs from the `dim` should
be chosen from the matrix.
Note that duplicated ids are allowed.
Expand Down Expand Up @@ -527,7 +527,7 @@ def index_select(self, dim: int, index: torch.Tensor):
if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, torch.Tensor):
raise NotImplementedError
return SparseMatrix(self.c_sparse_matrix.index_select(dim, index))
raise TypeError(f"{type(index).__name__} is unsupported input type.")

def range_select(self, dim: int, index: slice):
Expand Down Expand Up @@ -575,7 +575,15 @@ def range_select(self, dim: int, index: slice):
if dim not in (0, 1):
raise ValueError("The selection dimension should be 0 or 1.")
if isinstance(index, slice):
raise NotImplementedError
if index.step not in (None, 1):
raise NotImplementedError(
"Slice with step other than 1 are not supported yet."
)
start = 0 if index.start is None else index.start
end = index.stop
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
return SparseMatrix(
self.c_sparse_matrix.range_select(dim, start, end)
)
raise TypeError(f"{type(index).__name__} is unsupported input type.")


Expand Down
54 changes: 54 additions & 0 deletions tests/python/pytorch/sparse/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@
val_like,
)

from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)


def _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape):
with warnings.catch_warnings():
Expand Down Expand Up @@ -450,6 +458,52 @@ def test_has_duplicate():
assert csc_A.has_duplicate()


@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("index", [(0, 1, 3), (1, 2)])
def test_index_select(create_func, shape, dense_dim, select_dim, index):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
index = torch.tensor(index).to(ctx)
A_select = A.index_select(select_dim, index)

dense = sparse_matrix_to_dense(A)
dense_select = torch.index_select(dense, select_dim, index)

A_select_to_dense = sparse_matrix_to_dense(A_select)

assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)


@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("shape", [(5, 5), (6, 4)])
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("select_dim", [0, 1])
@pytest.mark.parametrize("rang", [slice(0, 2), slice(1, 3)])
def test_range_select(create_func, shape, dense_dim, select_dim, rang):
ctx = F.ctx()
A = create_func(shape, 20, ctx, dense_dim)
A_select = A.range_select(select_dim, rang)

dense = sparse_matrix_to_dense(A)
if select_dim == 0:
dense_select = dense[rang, :]
else:
dense_select = dense[:, rang]

A_select_to_dense = sparse_matrix_to_dense(A_select)

assert A_select_to_dense.shape == dense_select.shape
assert torch.allclose(A_select_to_dense, dense_select)


def test_print():
ctx = F.ctx()

Expand Down