From f989767cf0b5293931ca41decc4460b5a5dd69ed Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 29 Nov 2022 23:13:15 +0100 Subject: [PATCH 01/17] Init --- .../raft/sparse/solver/detail/lobpcg.cuh | 111 ++++++++++ cpp/include/raft/sparse/solver/lobpcg.cuh | 21 ++ cpp/test/sparse/lobpcg.cu | 206 ++++++++++++++++++ 3 files changed, 338 insertions(+) create mode 100644 cpp/include/raft/sparse/solver/detail/lobpcg.cuh create mode 100644 cpp/include/raft/sparse/solver/lobpcg.cuh create mode 100644 cpp/test/sparse/lobpcg.cu diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh new file mode 100644 index 0000000000..231fbd9eb7 --- /dev/null +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +namespace raft::sparse::solver::detail { + +template +void spmm(raft::spectral::matrix::sparse_matrix_t A, + raft::device_matrix_view B, + raft::device_matrix_view C, + bool transpose_a, + bool transpose_b) +{} + +template +void b_orthonormalize( + const raft::handle_t& handle, + raft::spectral::matrix::sparse_matrix_t A, + raft::device_matrix_view V, + raft::device_matrix_view BV, + bool retInvR=false) +{ + auto V_max = raft::make_device_vector_view(handle, V.extent(1)); + normalization = raft::linalg::reduce(V, axis=0); +} + + + +) + +template +void lobpcg(const raft::handle_t& handle, + // IN + const raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) + raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors + raft::device_vector_view W, // shape=(k) OUT Eigvals + std::optional> B, // shape=(n,n) + std::optional> M, // shape=(n,n) + std::optional> Y, // Constraint matrix shape=(n,Y) + value_t tol=0, + std::uint32_t max_iter=20, + bool largest=true) +{ + cudaStream_t stream = handle.get_stream(); + auto size_y = 0; + if (Y.has_value()) size_y = Y.extent(1); + auto n = X.nrows_; + auto size_x = X.ncols_; + + if ((n - size_y) < (5 * size_x)) + { + // DENSE SOLUTION + return; + } + if (tol == 0) + { + tol = raft::mySqrt(1e-15) * n; + } + // Apply constraints to X + if (Y.has_value()) + { + cusparseDnMatDescr_t denseY; + RAFT_CUSPARSE_TRY(cusparsecreatednmat(&denseY, n, size_y, n, Y.value().data_handle(), CUSPARSE_ORDER_COL))); + auto* ptr_BY = Y.value().data_handle(); + if (B.has_value()) + { + cusparseSpMatDescr_t sparseB; + cusparseDnMatDescr_t dense_BY; + RAFT_CUSPARSE_TRY(cusparsecreatecsr(&sparseB, n, n, B.nnz_, B.row_offsets, B.col_indices_, B.values_)); + auto matrix_BY = raft::make_device_matrix(handle, n, size_y); + RAFT_CUSPARSE_TRY(cusparsecreatednmat(&dense_BY, n, size_y, n, matrix_BY.data_handle(), CUSPARSE_ORDER_COL))); + // B * Y + value_t alpha = 1; + value_t beta = 0; + size_t buff_size = 0; + cusparsespmm_bufferSize(handle.get_cusparse_handle(), + CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, sparseB, denseY, &beta, dense_BY, CUSPARSE_SPMM_ALG_DEFAULT, &buff_size, stream); + rmm::device_uvector dev_buffer(buff_size, stream); + cusparsespmm_bufferSize(handle.get_cusparse_handle(), + CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, sparseB, denseY, &beta, dense_BY, CUSPARSE_SPMM_ALG_DEFAULT, stream); + cusparseDestroyDnMat(dense_B); + cusparseDestroyDnMat(dense_BY); + cusparseDestroySpMat(sparseB); + // CONTINUE + } + + // GramYBY + // ApplyConstraints + } + +} +} // raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/lobpcg.cuh b/cpp/include/raft/sparse/solver/lobpcg.cuh new file mode 100644 index 0000000000..e130c1b369 --- /dev/null +++ b/cpp/include/raft/sparse/solver/lobpcg.cuh @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace raft::sparse::solver { +} \ No newline at end of file diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu new file mode 100644 index 0000000000..d018a6ff2e --- /dev/null +++ b/cpp/test/sparse/lobpcg.cu @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include + + #include + #include + #include + + #include "../test_utils.h" + #include + + #include + #include + + namespace raft { + namespace sparse { + + template + struct CSRMatrixVal { + std::vector row_ind; + std::vector row_ind_ptr; + std::vector values; + }; + + template + struct CSRAddInputs { + CSRMatrixVal matrix_a; + CSRMatrixVal matrix_b; + CSRMatrixVal matrix_verify; + }; + + template + class CSRAddTest : public ::testing::TestWithParam> { + public: + CSRAddTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + ind_a(params.matrix_a.row_ind.size(), stream), + ind_ptr_a(params.matrix_a.row_ind_ptr.size(), stream), + values_a(params.matrix_a.row_ind_ptr.size(), stream), + ind_b(params.matrix_a.row_ind.size(), stream), + ind_ptr_b(params.matrix_b.row_ind_ptr.size(), stream), + values_b(params.matrix_b.row_ind_ptr.size(), stream), + ind_verify(params.matrix_a.row_ind.size(), stream), + ind_ptr_verify(params.matrix_verify.row_ind_ptr.size(), stream), + values_verify(params.matrix_verify.row_ind_ptr.size(), stream), + ind_result(params.matrix_a.row_ind.size(), stream), + ind_ptr_result(params.matrix_verify.row_ind_ptr.size(), stream), + values_result(params.matrix_verify.row_ind_ptr.size(), stream) + { + } + + protected: + void SetUp() override + { + n_rows = params.matrix_a.row_ind.size(); + nnz_a = params.matrix_a.row_ind_ptr.size(); + nnz_b = params.matrix_b.row_ind_ptr.size(); + nnz_result = params.matrix_verify.row_ind_ptr.size(); + } + + void Run() + { + raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows, stream); + raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); + raft::update_device(values_a.data(), params.matrix_a.values.data(), nnz_a, stream); + + raft::update_device(ind_b.data(), params.matrix_b.row_ind.data(), n_rows, stream); + raft::update_device(ind_ptr_b.data(), params.matrix_b.row_ind_ptr.data(), nnz_b, stream); + raft::update_device(values_b.data(), params.matrix_b.values.data(), nnz_b, stream); + + raft::update_device(ind_verify.data(), params.matrix_verify.row_ind.data(), n_rows, stream); + raft::update_device( + ind_ptr_verify.data(), params.matrix_verify.row_ind_ptr.data(), nnz_result, stream); + raft::update_device( + values_verify.data(), params.matrix_verify.values.data(), nnz_result, stream); + + Index_ nnz = linalg::csr_add_calc_inds(ind_a.data(), + ind_ptr_a.data(), + values_a.data(), + nnz_a, + ind_b.data(), + ind_ptr_b.data(), + values_b.data(), + nnz_b, + n_rows, + ind_result.data(), + stream); + + ASSERT_TRUE(nnz == nnz_result); + ASSERT_TRUE(raft::devArrMatch( + ind_verify.data(), ind_result.data(), n_rows, raft::Compare(), stream)); + + linalg::csr_add_finalize(ind_a.data(), + ind_ptr_a.data(), + values_a.data(), + nnz_a, + ind_b.data(), + ind_ptr_b.data(), + values_b.data(), + nnz_b, + n_rows, + ind_result.data(), + ind_ptr_result.data(), + values_result.data(), + stream); + + ASSERT_TRUE(raft::devArrMatch( + ind_ptr_verify.data(), ind_ptr_result.data(), nnz, raft::Compare(), stream)); + ASSERT_TRUE(raft::devArrMatch( + values_verify.data(), values_result.data(), nnz, raft::Compare(), stream)); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + CSRAddInputs params; + Index_ n_rows, nnz_a, nnz_b, nnz_result; + rmm::device_uvector ind_a, ind_b, ind_verify, ind_result, ind_ptr_a, ind_ptr_b, + ind_ptr_verify, ind_ptr_result; + rmm::device_uvector values_a, values_b, values_verify, values_result; + }; + + using CSRAddTestF = CSRAddTest; + TEST_P(CSRAddTestF, Result) { Run(); } + + using CSRAddTestD = CSRAddTest; + TEST_P(CSRAddTestD, Result) { Run(); } + + const std::vector> csradd_inputs_f = { + {{{0, 4, 8, 9}, + {1, 2, 3, 4, 1, 2, 3, 5, 0, 1}, + {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, + {{0, 4, 8, 9}, + {1, 2, 5, 4, 0, 2, 3, 5, 1, 0}, + {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, + {{0, 5, 10, 12}, + {1, 2, 3, 4, 5, 1, 2, 3, 5, 0, 0, 1, 1, 0}, + {2.0, 2.0, 0.5, 1.0, 0.5, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}}}, + }; + const std::vector> csradd_inputs_d = { + {{{0, 4, 8, 9}, + {1, 2, 3, 4, 1, 2, 3, 5, 0, 1}, + {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, + {{0, 4, 8, 9}, + {1, 2, 5, 4, 0, 2, 3, 5, 1, 0}, + {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, + {{0, 5, 10, 12}, + {1, 2, 3, 4, 5, 1, 2, 3, 5, 0, 0, 1, 1, 0}, + {2.0, 2.0, 0.5, 1.0, 0.5, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}}}, + }; + + INSTANTIATE_TEST_CASE_P(SparseAddTest, CSRAddTestF, ::testing::ValuesIn(csradd_inputs_f)); + INSTANTIATE_TEST_CASE_P(SparseAddTest, CSRAddTestD, ::testing::ValuesIn(csradd_inputs_d)); + + } // namespace sparse + } // namespace raft + + + +/* + +a=cupyx.scipy.sparse.random(6,6, 0.8,'csr') +a.indptr = array([ 0, 4, 10, 14, 19, 24, 28], dtype=int32) + +a.indices = array([0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, + 4, 5, 0, 2, 3, 4], dtype=int32) + +a.data = array([0.37911922, 0.11567201, 0.5135106 , 0.08968836, 0.73450965, + 0.26432646, 0.21985123, 0.74888277, 0.34753734, 0.11204864, + 0.82902676, 0.53023521, 0.24047095, 0.37913592, 0.60975031, + 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, + 0.32938903, 0.82477561, 0.20858375, 0.24755519, 0.23677223, + 0.73957246, 0.09050876, 0.86530489]) + +x = np.random.rand(6,2) +x = array([[0.08319983, 0.35005079], + [0.17758466, 0.56035486], + [0.93301819, 0.64176631], + [0.67171826, 0.93904784], + [0.19967821, 0.38935935], + [0.30873092, 0.97182089]]) + +lobpcg(a, x) = (array([2.61153278, 0.85782948]), + array([[-0.38272064, -0.39778489], + [-0.25160901, 0.2539629 ], + [-0.48684676, -0.37506003], + [-0.50752949, 0.72637041], + [-0.43005954, 0.02727131], + [-0.33265696, -0.32900198]])) + */ \ No newline at end of file From 3f83bf44cc8727c0c75d8fbc4a43504403be5edf Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 2 Dec 2022 02:30:15 +0100 Subject: [PATCH 02/17] Add b_orthonormalize and helper fonction --- .../raft/linalg/detail/cusolver_wrappers.hpp | 60 ++++ .../raft/sparse/solver/detail/lobpcg.cuh | 287 +++++++++++---- cpp/include/raft/sparse/solver/lobpcg.cuh | 22 +- cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/lobpcg.cu | 333 +++++++++--------- 5 files changed, 465 insertions(+), 238 deletions(-) diff --git a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp index 3eff920dd8..5772c5e86d 100644 --- a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp @@ -774,6 +774,66 @@ inline cusolverStatus_t cusolverDnpotrf(cusolverDnHandle_t handle, // NOLINT } /** @} */ +/** + * @defgroup potri cusolver potri operations: inverse of a matrix A using Cholesky + * @{ + */ +template +cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, T* A, int lda, int* Lwork); +template <> +inline cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, int* Lwork) +{ + return cusolverDnSpotri_bufferSize(handle, uplo, n, A, lda, Lwork); +} +template <> +inline cusolverStatus_t cusolverDnpotri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, int* Lwork) +{ + return cusolverDnDpotri_bufferSize(handle, uplo, n, A, lda, Lwork); +} + +template +cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + T* A, + int lda, + T* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream); +template <> +inline cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + float* A, + int lda, + float* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream) +{ + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + return cusolverDnSpotri(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} +template <> +inline cusolverStatus_t cusolverDnpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, + double* A, + int lda, + double* Workspace, + int Lwork, + int* devInfo, + cudaStream_t stream) +{ + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + return cusolverDnDpotri(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} +/** @} */ + /** * @defgroup potrs cusolver potrs operations * @{ diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 231fbd9eb7..b9b604642b 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -16,96 +16,233 @@ #pragma once +#include #include +#include +#include +#include +#include +#include +#include +#include + namespace raft::sparse::solver::detail { +// C = A * B template -void spmm(raft::spectral::matrix::sparse_matrix_t A, +void spmm(const raft::handle_t& handle, + raft::spectral::matrix::sparse_matrix_t A, raft::device_matrix_view B, raft::device_matrix_view C, - bool transpose_a, - bool transpose_b) -{} + bool transpose_a = false, + bool transpose_b = false) +{ + auto stream = handle.get_stream(); + auto* A_values_ = const_cast(A.values_); + auto* A_row_offsets_ = const_cast(A.row_offsets_); + auto* A_col_indices_ = const_cast(A.col_indices_); + cusparseSpMatDescr_t sparse_A; + cusparseDnMatDescr_t dense_B; + cusparseDnMatDescr_t dense_C; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &sparse_A, A.nrows_, A.ncols_, A.nnz_, A_row_offsets_, A_col_indices_, A_values_)); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &dense_B, B.extent(0), B.extent(1), B.extent(0), B.data_handle(), CUSPARSE_ORDER_COL)); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &dense_C, C.extent(0), C.extent(1), C.extent(0), C.data_handle(), CUSPARSE_ORDER_COL)); + // a * b + value_t alpha = 1; + value_t beta = 0; + size_t buff_size = 0; + auto opA = transpose_a ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto opB = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opA, + opB, + &alpha, + sparse_A, + dense_B, + &beta, + dense_C, + CUSPARSE_SPMM_ALG_DEFAULT, + &buff_size, + stream); + rmm::device_uvector dev_buffer(buff_size / sizeof(value_t), stream); + raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opA, + opB, + &alpha, + sparse_A, + dense_B, + &beta, + dense_C, + CUSPARSE_SPMM_ALG_DEFAULT, + dev_buffer.data(), + stream); + + cusparseDestroySpMat(sparse_A); + cusparseDestroyDnMat(dense_B); + cusparseDestroyDnMat(dense_C); +} template -void b_orthonormalize( - const raft::handle_t& handle, - raft::spectral::matrix::sparse_matrix_t A, - raft::device_matrix_view V, - raft::device_matrix_view BV, - bool retInvR=false) +void cholesky(const raft::handle_t& handle, + raft::device_matrix_view P, + bool lower = true) { - auto V_max = raft::make_device_vector_view(handle, V.extent(1)); - normalization = raft::linalg::reduce(V, axis=0); + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( + handle.get_cusolver_dn_handle(), uplo, dim, P.data_handle(), lda, &Lwork)); + + rmm::device_uvector workspace_decomp(Lwork / sizeof(value_t), stream); + rmm::device_uvector info(1, stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(handle.get_cusolver_dn_handle(), + uplo, + dim, + P.data_handle(), + lda, + workspace_decomp.data(), + Lwork, + info.data(), + stream)); } - +template +void inverse(const raft::handle_t& handle, + raft::device_matrix_view P, + bool lower = true) +{ + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotri_bufferSize( + handle.get_cusolver_dn_handle(), uplo, dim, P.data_handle(), lda, &Lwork)); -) + rmm::device_uvector workspace_decomp(Lwork / sizeof(value_t), stream); + rmm::device_uvector info(1, stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotri(handle.get_cusolver_dn_handle(), + uplo, + dim, + P.data_handle(), + lda, + workspace_decomp.data(), + Lwork, + info.data(), + stream)); +} template -void lobpcg(const raft::handle_t& handle, - // IN - const raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) - raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors - raft::device_vector_view W, // shape=(k) OUT Eigvals - std::optional> B, // shape=(n,n) - std::optional> M, // shape=(n,n) - std::optional> Y, // Constraint matrix shape=(n,Y) - value_t tol=0, - std::uint32_t max_iter=20, - bool largest=true) +void b_orthonormalize( + const raft::handle_t& handle, + std::optional> B_opt, + raft::device_matrix_view V, + raft::device_matrix_view + BV, /// < Can't be optional because this is an OUT arg. + std::optional> VBV_opt = std::nullopt, + std::optional> V_max_opt = + std::nullopt, /// < normalization + bool bv_is_empty = true) { - cudaStream_t stream = handle.get_stream(); - auto size_y = 0; - if (Y.has_value()) size_y = Y.extent(1); - auto n = X.nrows_; - auto size_x = X.ncols_; + auto stream = handle.get_stream(); + auto V_max_buffer = rmm::device_uvector(0, stream); + value_t* V_max_ptr = nullptr; + if (!V_max_opt) { // allocate normalization buffer + V_max_buffer.resize(V.extent(1), stream); + V_max_ptr = V_max_buffer.data(); + } else { + V_max_ptr = V_max_opt.value().data_handle(); + } + auto V_max = raft::make_device_vector_view(V_max_ptr, V.extent(1)); + auto V_max_const = raft::make_device_vector_view(V_max_ptr, V.extent(1)); + raft::linalg::reduce(handle, + raft::make_device_matrix_view( + V.data_handle(), V.extent(0), V.extent(1)), + V_max, + value_t(0), + raft::linalg::Apply::ALONG_ROWS, + false, + raft::Nop(), + [] __device__(value_t a, value_t b) { return raft::maxPrim(a, b); }); + raft::linalg::binary_div_skip_zero(handle, V, V_max_const, raft::linalg::Apply::ALONG_ROWS); - if ((n - size_y) < (5 * size_x)) - { - // DENSE SOLUTION - return; - } - if (tol == 0) - { - tol = raft::mySqrt(1e-15) * n; - } - // Apply constraints to X - if (Y.has_value()) - { - cusparseDnMatDescr_t denseY; - RAFT_CUSPARSE_TRY(cusparsecreatednmat(&denseY, n, size_y, n, Y.value().data_handle(), CUSPARSE_ORDER_COL))); - auto* ptr_BY = Y.value().data_handle(); - if (B.has_value()) - { - cusparseSpMatDescr_t sparseB; - cusparseDnMatDescr_t dense_BY; - RAFT_CUSPARSE_TRY(cusparsecreatecsr(&sparseB, n, n, B.nnz_, B.row_offsets, B.col_indices_, B.values_)); - auto matrix_BY = raft::make_device_matrix(handle, n, size_y); - RAFT_CUSPARSE_TRY(cusparsecreatednmat(&dense_BY, n, size_y, n, matrix_BY.data_handle(), CUSPARSE_ORDER_COL))); - // B * Y - value_t alpha = 1; - value_t beta = 0; - size_t buff_size = 0; - cusparsespmm_bufferSize(handle.get_cusparse_handle(), - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, sparseB, denseY, &beta, dense_BY, CUSPARSE_SPMM_ALG_DEFAULT, &buff_size, stream); - rmm::device_uvector dev_buffer(buff_size, stream); - cusparsespmm_bufferSize(handle.get_cusparse_handle(), - CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, sparseB, denseY, &beta, dense_BY, CUSPARSE_SPMM_ALG_DEFAULT, stream); - cusparseDestroyDnMat(dense_B); - cusparseDestroyDnMat(dense_BY); - cusparseDestroySpMat(sparseB); - // CONTINUE - } + if (!bv_is_empty) { + raft::linalg::binary_div_skip_zero(handle, BV, V_max_const, raft::linalg::Apply::ALONG_ROWS); + } else { + if (B_opt) + spmm(handle, B_opt.value(), V, BV); + else + raft::copy(BV.data_handle(), V.data_handle(), V.size(), stream); + } + auto VBV_buffer = rmm::device_uvector(0, stream); + value_t* VBV_ptr = nullptr; + if (!VBV_opt) { // allocate normalization buffer + VBV_buffer.resize(V.extent(1) * V.extent(1), stream); + VBV_ptr = VBV_buffer.data(); + } else { + VBV_ptr = VBV_opt.value().data_handle(); + } + auto VBV = raft::make_device_matrix_view( + VBV_ptr, V.extent(1), V.extent(1)); + raft::linalg::gemm(handle, V, BV, VBV); + cholesky(handle, VBV); + raft::linalg::transpose(handle, VBV, VBV); + inverse(handle, VBV); + raft::linalg::gemm(handle, V, VBV, V); + if (B_opt) raft::linalg::gemm(handle, BV, VBV, BV); +} + +template +void lobpcg( + const raft::handle_t& handle, + // IN + raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) + raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors + raft::device_vector_view W, // shape=(k) OUT Eigvals + std::optional> B_opt, // shape=(n,n) + std::optional> M_opt, // shape=(n,n) + std::optional> Y_opt, // Constraint + // matrix shape=(n,Y) + value_t tol = 0, + std::uint32_t max_iter = 20, + bool largest = true) +{ + cudaStream_t stream = handle.get_stream(); + auto size_y = 0; + // if (Y_opt.has_value()) size_y = Y_opt.value().extent(1); + auto n = X.extent(0); + auto size_x = X.extent(1); - // GramYBY - // ApplyConstraints - } - + if ((n - size_y) < (5 * size_x)) { + // DENSE SOLUTION + return; + } + if (tol == 0) { tol = raft::mySqrt(1e-15) * n; } + // Apply constraints to X + /* + auto matrix_BY = raft::make_device_matrix(handle, n, size_y); + if (Y_opt.has_value()) + { + if (B_opt.has_value()) + { + auto B = B_opt.value(); + spmm(handle, Y_opt.value(), B, matrix_BY.view(), false, false); + // TODO + } else { + raft::copy(matrix_BY.data_handle(), Y_opt.value().data_handle(), n * size_y, + handle.get_stream()); + } + cusparseDestroyDnMat(denseY); + // GramYBY + // ApplyConstraints + }*/ + auto BX = raft::make_device_matrix(handle, n, size_x); + b_orthonormalize(handle, B_opt, X, BX.view()); + // TODO } -} // raft::sparse::solver::detail \ No newline at end of file +}; // namespace raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/lobpcg.cuh b/cpp/include/raft/sparse/solver/lobpcg.cuh index e130c1b369..10fa95c6bc 100644 --- a/cpp/include/raft/sparse/solver/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/lobpcg.cuh @@ -18,4 +18,24 @@ #include namespace raft::sparse::solver { -} \ No newline at end of file + +template +void lobpcg( + const raft::handle_t& handle, + // IN + raft::spectral::matrix::sparse_matrix_t A, // shape=(n,n) + raft::device_matrix_view X, // shape=(n,k) IN OUT Eigvectors + raft::device_vector_view W, // shape=(k) OUT Eigvals + std::optional> B = + std::nullopt, // shape=(n,n) + std::optional> M = + std::nullopt, // shape=(n,n) + std::optional> Y = + std::nullopt, // Constraint matrix shape=(n,Y) + value_t tol = 0, + std::uint32_t max_iter = 20, + bool largest = true) +{ + detail::lobpcg(handle, A, X, W, B, M, Y, tol, max_iter, largest); +} +}; // namespace raft::sparse::solver \ No newline at end of file diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0f5ebabcb9..8760a6f555 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -212,6 +212,7 @@ if(BUILD_TESTS) test/sparse/csr_transpose.cu test/sparse/degree.cu test/sparse/filter.cu + test/sparse/lobpcg.cu test/sparse/norm.cu test/sparse/reduce.cu test/sparse/row_op.cu diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index d018a6ff2e..75763e86c0 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -14,164 +14,173 @@ * limitations under the License. */ - #include - - #include - #include - #include - - #include "../test_utils.h" - #include - - #include - #include - - namespace raft { - namespace sparse { - - template - struct CSRMatrixVal { - std::vector row_ind; - std::vector row_ind_ptr; - std::vector values; - }; - - template - struct CSRAddInputs { - CSRMatrixVal matrix_a; - CSRMatrixVal matrix_b; - CSRMatrixVal matrix_verify; - }; - - template - class CSRAddTest : public ::testing::TestWithParam> { - public: - CSRAddTest() - : params(::testing::TestWithParam>::GetParam()), - stream(handle.get_stream()), - ind_a(params.matrix_a.row_ind.size(), stream), - ind_ptr_a(params.matrix_a.row_ind_ptr.size(), stream), - values_a(params.matrix_a.row_ind_ptr.size(), stream), - ind_b(params.matrix_a.row_ind.size(), stream), - ind_ptr_b(params.matrix_b.row_ind_ptr.size(), stream), - values_b(params.matrix_b.row_ind_ptr.size(), stream), - ind_verify(params.matrix_a.row_ind.size(), stream), - ind_ptr_verify(params.matrix_verify.row_ind_ptr.size(), stream), - values_verify(params.matrix_verify.row_ind_ptr.size(), stream), - ind_result(params.matrix_a.row_ind.size(), stream), - ind_ptr_result(params.matrix_verify.row_ind_ptr.size(), stream), - values_result(params.matrix_verify.row_ind_ptr.size(), stream) - { - } - - protected: - void SetUp() override - { - n_rows = params.matrix_a.row_ind.size(); - nnz_a = params.matrix_a.row_ind_ptr.size(); - nnz_b = params.matrix_b.row_ind_ptr.size(); - nnz_result = params.matrix_verify.row_ind_ptr.size(); - } - - void Run() - { - raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows, stream); - raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); - raft::update_device(values_a.data(), params.matrix_a.values.data(), nnz_a, stream); - - raft::update_device(ind_b.data(), params.matrix_b.row_ind.data(), n_rows, stream); - raft::update_device(ind_ptr_b.data(), params.matrix_b.row_ind_ptr.data(), nnz_b, stream); - raft::update_device(values_b.data(), params.matrix_b.values.data(), nnz_b, stream); - - raft::update_device(ind_verify.data(), params.matrix_verify.row_ind.data(), n_rows, stream); - raft::update_device( - ind_ptr_verify.data(), params.matrix_verify.row_ind_ptr.data(), nnz_result, stream); - raft::update_device( - values_verify.data(), params.matrix_verify.values.data(), nnz_result, stream); - - Index_ nnz = linalg::csr_add_calc_inds(ind_a.data(), - ind_ptr_a.data(), - values_a.data(), - nnz_a, - ind_b.data(), - ind_ptr_b.data(), - values_b.data(), - nnz_b, - n_rows, - ind_result.data(), - stream); - - ASSERT_TRUE(nnz == nnz_result); - ASSERT_TRUE(raft::devArrMatch( - ind_verify.data(), ind_result.data(), n_rows, raft::Compare(), stream)); - - linalg::csr_add_finalize(ind_a.data(), - ind_ptr_a.data(), - values_a.data(), - nnz_a, - ind_b.data(), - ind_ptr_b.data(), - values_b.data(), - nnz_b, - n_rows, - ind_result.data(), - ind_ptr_result.data(), - values_result.data(), - stream); - - ASSERT_TRUE(raft::devArrMatch( - ind_ptr_verify.data(), ind_ptr_result.data(), nnz, raft::Compare(), stream)); - ASSERT_TRUE(raft::devArrMatch( - values_verify.data(), values_result.data(), nnz, raft::Compare(), stream)); - } - - protected: - raft::handle_t handle; - cudaStream_t stream; - - CSRAddInputs params; - Index_ n_rows, nnz_a, nnz_b, nnz_result; - rmm::device_uvector ind_a, ind_b, ind_verify, ind_result, ind_ptr_a, ind_ptr_b, - ind_ptr_verify, ind_ptr_result; - rmm::device_uvector values_a, values_b, values_verify, values_result; - }; - - using CSRAddTestF = CSRAddTest; - TEST_P(CSRAddTestF, Result) { Run(); } - - using CSRAddTestD = CSRAddTest; - TEST_P(CSRAddTestD, Result) { Run(); } - - const std::vector> csradd_inputs_f = { - {{{0, 4, 8, 9}, - {1, 2, 3, 4, 1, 2, 3, 5, 0, 1}, - {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, - {{0, 4, 8, 9}, - {1, 2, 5, 4, 0, 2, 3, 5, 1, 0}, - {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, - {{0, 5, 10, 12}, - {1, 2, 3, 4, 5, 1, 2, 3, 5, 0, 0, 1, 1, 0}, - {2.0, 2.0, 0.5, 1.0, 0.5, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}}}, - }; - const std::vector> csradd_inputs_d = { - {{{0, 4, 8, 9}, - {1, 2, 3, 4, 1, 2, 3, 5, 0, 1}, - {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, - {{0, 4, 8, 9}, - {1, 2, 5, 4, 0, 2, 3, 5, 1, 0}, - {1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0}}, - {{0, 5, 10, 12}, - {1, 2, 3, 4, 5, 1, 2, 3, 5, 0, 0, 1, 1, 0}, - {2.0, 2.0, 0.5, 1.0, 0.5, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}}}, - }; - - INSTANTIATE_TEST_CASE_P(SparseAddTest, CSRAddTestF, ::testing::ValuesIn(csradd_inputs_f)); - INSTANTIATE_TEST_CASE_P(SparseAddTest, CSRAddTestD, ::testing::ValuesIn(csradd_inputs_d)); - - } // namespace sparse - } // namespace raft - +#include +#include +#include +#include +#include +#include + +#include "../test_utils.h" +#include + +#include +#include + +namespace raft { +namespace sparse { + +template +struct CSRMatrixVal { + std::vector row_ind; + std::vector row_ind_ptr; + std::vector values; +}; + +template +struct LOBPCGInputs { + CSRMatrixVal matrix_a; + std::vector init_eigvecs; + std::vector exp_eigvals; + std::vector exp_eigvecs; + idx_t n_components; +}; + +template +class LOBPCGTest : public ::testing::TestWithParam> { + public: + LOBPCGTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + ind_a(params.matrix_a.row_ind.size(), stream), + ind_ptr_a(params.matrix_a.row_ind_ptr.size(), stream), + values_a(params.matrix_a.row_ind_ptr.size(), stream), + exp_eigvals(params.exp_eigvals.size(), stream), + exp_eigvecs(params.exp_eigvecs.size(), stream), + act_eigvals(params.exp_eigvals.size(), stream), + act_eigvecs(params.exp_eigvecs.size(), stream) + { + } + + protected: + void SetUp() override + { + n_rows_a = params.matrix_a.row_ind.size(); + nnz_a = params.matrix_a.row_ind_ptr.size(); + } + + void Run() + { + raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows_a, stream); + raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); + raft::update_device(values_a.data(), params.matrix_a.values.data(), nnz_a, stream); + + raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); + + auto matA = raft::spectral::matrix::sparse_matrix_t( + handle, ind_ptr_a.data(), ind_a.data(), values_a.data(), n_rows_a, n_rows_a, nnz_a); + raft::sparse::solver::lobpcg( + handle, + matA, + raft::make_device_matrix_view( + act_eigvecs.data(), n_rows_a, params.n_components), + raft::make_device_vector_view(act_eigvals.data(), n_rows_a)); + + ASSERT_TRUE(raft::devArrMatch( + exp_eigvecs.data(), act_eigvecs.data(), exp_eigvecs.size(), raft::Compare(), stream)); + ASSERT_TRUE(raft::devArrMatch( + exp_eigvals.data(), act_eigvals.data(), exp_eigvals.size(), raft::Compare(), stream)); + } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + LOBPCGInputs params; + idx_t n_rows_a, nnz_a; + rmm::device_uvector ind_a, ind_ptr_a; + rmm::device_uvector values_a, exp_eigvals, exp_eigvecs, act_eigvals, act_eigvecs; +}; + +using LOBPCGTestF = LOBPCGTest; +TEST_P(LOBPCGTestF, Result) { Run(); } + +using LOBPCGTestD = LOBPCGTest; +TEST_P(LOBPCGTestD, Result) { Run(); } + +const std::vector> lobpcg_inputs_f = { + {{{0, 4, 10, 14, 19, 24, 28}, + {0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 2, 3, 4}, + {0.37911922, 0.11567201, 0.5135106, 0.08968836, 0.73450965, 0.26432646, 0.21985123, + 0.74888277, 0.34753734, 0.11204864, 0.82902676, 0.53023521, 0.24047095, 0.37913592, + 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, 0.32938903, + 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489}}, + {0.08319983, + 0.17758466, + 0.93301819, + 0.67171826, + 0.19967821, + 0.30873092, + 0.35005079, + 0.56035486, + 0.64176631, + 0.93904784, + 0.38935935, + 0.97182089}, + {2.61153278, 0.85782948}, + {-0.38272064, + -0.25160901, + -0.48684676, + -0.50752949, + -0.43005954, + -0.33265696, + -0.39778489, + 0.2539629, + -0.37506003, + 0.72637041, + 0.02727131, + -0.32900198}, + 2}}; +const std::vector> lobpcg_inputs_d = { + {{{0, 4, 10, 14, 19, 24, 28}, + {0, 2, 3, 5, 0, 1, 2, 3, 4, 5, 0, 2, 3, 5, 1, 2, 3, 4, 5, 0, 2, 3, 4, 5, 0, 2, 3, 4}, + {0.37911922, 0.11567201, 0.5135106, 0.08968836, 0.73450965, 0.26432646, 0.21985123, + 0.74888277, 0.34753734, 0.11204864, 0.82902676, 0.53023521, 0.24047095, 0.37913592, + 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955, 0.43530847, 0.32938903, + 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489}}, + {0.08319983, + 0.17758466, + 0.93301819, + 0.67171826, + 0.19967821, + 0.30873092, + 0.35005079, + 0.56035486, + 0.64176631, + 0.93904784, + 0.38935935, + 0.97182089}, + {2.61153278, 0.85782948}, + {-0.38272064, + -0.25160901, + -0.48684676, + -0.50752949, + -0.43005954, + -0.33265696, + -0.39778489, + 0.2539629, + -0.37506003, + 0.72637041, + 0.02727131, + -0.32900198}, + 2}}; + +INSTANTIATE_TEST_CASE_P(SparseAddTest, LOBPCGTestF, ::testing::ValuesIn(lobpcg_inputs_f)); +INSTANTIATE_TEST_CASE_P(SparseAddTest, LOBPCGTestD, ::testing::ValuesIn(lobpcg_inputs_d)); + +} // namespace sparse +} // namespace raft /* @@ -190,11 +199,11 @@ a.data = array([0.37911922, 0.11567201, 0.5135106 , 0.08968836, 0.73450965, x = np.random.rand(6,2) x = array([[0.08319983, 0.35005079], - [0.17758466, 0.56035486], - [0.93301819, 0.64176631], - [0.67171826, 0.93904784], - [0.19967821, 0.38935935], - [0.30873092, 0.97182089]]) + [0.17758466, 0.56035486], + [0.93301819, 0.64176631], + [0.67171826, 0.93904784], + [0.19967821, 0.38935935], + [0.30873092, 0.97182089]]) lobpcg(a, x) = (array([2.61153278, 0.85782948]), array([[-0.38272064, -0.39778489], From d6806c86101cf8d784dd86fbbeea6f21b8962c00 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Sun, 25 Dec 2022 18:53:58 +0100 Subject: [PATCH 03/17] Fix b_orthonormalize --- .../raft/sparse/solver/detail/lobpcg.cuh | 138 ++++++++++++++---- cpp/test/sparse/lobpcg.cu | 19 ++- 2 files changed, 127 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index b9b604642b..23f7e2a8c8 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -23,11 +23,22 @@ #include #include #include +#include +#include +#include #include #include namespace raft::sparse::solver::detail { +/** + * @brief stucture that defines the reduction Lambda to find minimum between elements + */ +template +struct MaxOp { + HDI DataT operator()(DataT a, DataT b) { return maxPrim(a, b); } +}; + // C = A * B template void spmm(const raft::handle_t& handle, @@ -109,45 +120,91 @@ void cholesky(const raft::handle_t& handle, Lwork, info.data(), stream)); + int info_h = 0; + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in potrf, info=%d | expected=0", info_h); } template void inverse(const raft::handle_t& handle, raft::device_matrix_view P, + raft::device_matrix_view Pinv, bool lower = true) { - auto stream = handle.get_stream(); - int Lwork = 0; - auto lda = P.extent(0); - auto dim = P.extent(0); - cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; - RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotri_bufferSize( - handle.get_cusolver_dn_handle(), uplo, dim, P.data_handle(), lda, &Lwork)); + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + int info_h = 0; + cublasOperation_t trans = CUBLAS_OP_N; + // make Pinv an identity matrix + auto diag = raft::make_device_vector(handle, dim); + raft::matrix::fill(handle, diag.view(), value_t(1)); + raft::matrix::fill(handle, Pinv, value_t(0)); + raft::matrix::set_diagonal( + handle, raft::make_device_vector_view(diag.data_handle(), dim), Pinv); - rmm::device_uvector workspace_decomp(Lwork / sizeof(value_t), stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf_bufferSize( + handle.get_cusolver_dn_handle(), dim, dim, P.data_handle(), lda, &Lwork)); + + rmm::device_uvector workspace_decomp(Lwork, stream); rmm::device_uvector info(1, stream); - RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotri(handle.get_cusolver_dn_handle(), - uplo, + auto ipiv = raft::make_device_vector(handle, dim); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf(handle.get_cusolver_dn_handle(), + dim, dim, P.data_handle(), lda, workspace_decomp.data(), - Lwork, + ipiv.data_handle(), + info.data(), + stream)); + + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in getrf, info=%d | expected=0", info_h); + + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrs(handle.get_cusolver_dn_handle(), + trans, + dim, + dim, + P.data_handle(), + lda, + ipiv.data_handle(), + Pinv.data_handle(), + lda, info.data(), stream)); + + raft::update_host(&info_h, info.data(), 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT(info_h == 0, "lobpcg: error in getrs, info=%d | expected=0", info_h); } +/** + * B-orthonormalize the given block vector using Cholesky + * + * @tparam value_t floating point type used for elements + * @tparam index_t integer type used for indexing + * @param[in] handle: raft handle + * @param[in] B_opt: optional sparse matrix for normalization + * @param[inout] V: dense matrix to normalize + * @param[inout] BV: dense matrix. Use with parameter `bv_is_empty`. + * @param[out] VBV_opt: optional dense matrix containing inverse matrix + * @param[out] V_max_opt: optional vector containing normalization of V + * @param[in] bv_is_empty: True if BV is used as input + */ template void b_orthonormalize( const raft::handle_t& handle, std::optional> B_opt, raft::device_matrix_view V, - raft::device_matrix_view - BV, /// < Can't be optional because this is an OUT arg. + raft::device_matrix_view BV, std::optional> VBV_opt = std::nullopt, - std::optional> V_max_opt = - std::nullopt, /// < normalization - bool bv_is_empty = true) + std::optional> V_max_opt = std::nullopt, + bool bv_is_empty = true) { auto stream = handle.get_stream(); auto V_max_buffer = rmm::device_uvector(0, stream); @@ -160,15 +217,29 @@ void b_orthonormalize( } auto V_max = raft::make_device_vector_view(V_max_ptr, V.extent(1)); auto V_max_const = raft::make_device_vector_view(V_max_ptr, V.extent(1)); - raft::linalg::reduce(handle, + + // + /*raft::linalg::reduce(handle, raft::make_device_matrix_view( - V.data_handle(), V.extent(0), V.extent(1)), + V.data_handle(), V.extent(1), V.extent(0)), V_max, value_t(0), raft::linalg::Apply::ALONG_ROWS, false, raft::Nop(), - [] __device__(value_t a, value_t b) { return raft::maxPrim(a, b); }); + MaxOp()); + */ + raft::linalg::reduce(V_max.data_handle(), + V.data_handle(), + V.extent(0), + V.extent(1), + value_t(0), + false, + true, + handle.get_stream(), + false, + raft::Nop(), + MaxOp()); raft::linalg::binary_div_skip_zero(handle, V, V_max_const, raft::linalg::Apply::ALONG_ROWS); if (!bv_is_empty) { @@ -189,10 +260,24 @@ void b_orthonormalize( } auto VBV = raft::make_device_matrix_view( VBV_ptr, V.extent(1), V.extent(1)); - raft::linalg::gemm(handle, V, BV, VBV); - cholesky(handle, VBV); - raft::linalg::transpose(handle, VBV, VBV); - inverse(handle, VBV); + auto VBVBuffer = raft::make_device_matrix( + handle, VBV.extent(0), VBV.extent(1)); + auto VT = + raft::make_device_matrix(handle, V.extent(1), V.extent(0)); + raft::linalg::transpose(handle, V, VT.view()); + + raft::linalg::gemm(handle, VT.view(), BV, VBVBuffer.view()); + cholesky(handle, VBVBuffer.view(), false); + // Reset VBV before copying upper triangular + raft::matrix::fill(handle, VBV, value_t(0)); + raft::matrix::upper_triangular( + handle, + raft::make_device_matrix_view( + VBVBuffer.data_handle(), VBVBuffer.extent(0), VBV.extent(1)), + VBV); + + inverse(handle, VBV, VBVBuffer.view()); + raft::copy(VBV.data_handle(), VBVBuffer.data_handle(), VBV.size(), stream); raft::linalg::gemm(handle, V, VBV, V); if (B_opt) raft::linalg::gemm(handle, BV, VBV, BV); } @@ -213,15 +298,15 @@ void lobpcg( bool largest = true) { cudaStream_t stream = handle.get_stream(); - auto size_y = 0; + // auto size_y = 0; // if (Y_opt.has_value()) size_y = Y_opt.value().extent(1); auto n = X.extent(0); auto size_x = X.extent(1); + /* DENSE SOLUTION if ((n - size_y) < (5 * size_x)) { - // DENSE SOLUTION return; - } + } */ if (tol == 0) { tol = raft::mySqrt(1e-15) * n; } // Apply constraints to X /* @@ -243,6 +328,7 @@ void lobpcg( }*/ auto BX = raft::make_device_matrix(handle, n, size_x); b_orthonormalize(handle, B_opt, X, BX.view()); + return; // TODO } }; // namespace raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index 75763e86c0..3b0be585a4 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -66,7 +66,7 @@ class LOBPCGTest : public ::testing::TestWithParam> protected: void SetUp() override { - n_rows_a = params.matrix_a.row_ind.size(); + n_rows_a = params.matrix_a.row_ind.size() - 1; nnz_a = params.matrix_a.row_ind_ptr.size(); } @@ -87,6 +87,10 @@ class LOBPCGTest : public ::testing::TestWithParam> act_eigvecs.data(), n_rows_a, params.n_components), raft::make_device_vector_view(act_eigvals.data(), n_rows_a)); + std::vector X_CPU(n_rows_a * params.n_components); + std::vector W_CPU(n_rows_a); + raft::copy(X_CPU.data(), act_eigvecs.data(), X_CPU.size(), stream); + raft::copy(W_CPU.data(), act_eigvals.data(), W_CPU.size(), stream); ASSERT_TRUE(raft::devArrMatch( exp_eigvecs.data(), act_eigvecs.data(), exp_eigvecs.size(), raft::Compare(), stream)); ASSERT_TRUE(raft::devArrMatch( @@ -176,8 +180,8 @@ const std::vector> lobpcg_inputs_d = { -0.32900198}, 2}}; -INSTANTIATE_TEST_CASE_P(SparseAddTest, LOBPCGTestF, ::testing::ValuesIn(lobpcg_inputs_f)); -INSTANTIATE_TEST_CASE_P(SparseAddTest, LOBPCGTestD, ::testing::ValuesIn(lobpcg_inputs_d)); +INSTANTIATE_TEST_CASE_P(LOBPCGTest, LOBPCGTestF, ::testing::ValuesIn(lobpcg_inputs_f)); +INSTANTIATE_TEST_CASE_P(LOBPCGTest, LOBPCGTestD, ::testing::ValuesIn(lobpcg_inputs_d)); } // namespace sparse } // namespace raft @@ -197,8 +201,15 @@ a.data = array([0.37911922, 0.11567201, 0.5135106 , 0.08968836, 0.73450965, 0.32938903, 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489]) +a.todense() = +np.matrix([[0.37911922, 0. , 0.11567201, 0.5135106 , 0. , 0.08968836], + [0.73450965, 0.26432646, 0.21985123, 0.74888277, 0.34753734, 0.11204864], + [0.82902676, 0. , 0.53023521, 0.24047095, 0. , 0.37913592], + [0. , 0.60975031, 0.60746519, 0.96833343, 0.30845102, 0.88653955], + [0.43530847, 0. , 0.32938903, 0.82477561, 0.20858375, 0.24755519], + [0.23677223, 0. , 0.73957246, 0.09050876, 0.86530489, 0. ]]) x = np.random.rand(6,2) -x = array([[0.08319983, 0.35005079], +x = np.array([[0.08319983, 0.35005079], [0.17758466, 0.56035486], [0.93301819, 0.64176631], [0.67171826, 0.93904784], From f982793c641362dff9b83d276572b7a9e071bbd8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 27 Dec 2022 02:23:12 +0100 Subject: [PATCH 04/17] Add eye and lower triangular functions --- cpp/include/raft/linalg/eig.cuh | 2 +- cpp/include/raft/matrix/detail/matrix.cuh | 54 +++++++++++++++++++++++ cpp/include/raft/matrix/init.cuh | 7 +++ cpp/include/raft/matrix/triangular.cuh | 18 ++++++++ cpp/test/matrix/matrix.cu | 14 ++++++ 5 files changed, 94 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 2ad222d42d..ba2e54aad1 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -133,7 +133,7 @@ void eig_dc(const raft::handle_t& handle, raft::device_vector_view eig_vals) { RAFT_EXPECTS(in.size() == eig_vectors.size(), "Size mismatch between Input and Eigen Vectors"); - RAFT_EXPECTS(eig_vals.size() == in.extent(1), "Size mismatch between Input and Eigen Values"); + RAFT_EXPECTS(eig_vals.extent(0) == in.extent(1), "Size mismatch between Input and Eigen Values"); eigDC(handle, in.data_handle(), diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 17a40be5d6..35ff6a3692 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -226,6 +226,60 @@ void copyUpperTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, c getUpperTriangular<<>>(src, dst, m, n, k); } +/** + * @brief Kernel for copying the lower triangular part of a matrix to another + * @param src: input matrix with a size of mxn + * @param dst: output matrix with a size of kxk + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param k: min(n_rows, n_cols) + */ +template +__global__ void getLowerTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, idx_t k) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + idx_t m = n_rows, n = n_cols; + if (idx < m * n) { + idx_t i = idx % m, j = idx / m; + if (i < k && j < k && j <= i) { dst[i + j * k] = src[idx]; } + } +} + +template +void copyLowerTriangular(const m_t* src, m_t* dst, idx_t n_rows, idx_t n_cols, cudaStream_t stream) +{ + idx_t m = n_rows, n = n_cols; + idx_t k = std::min(m, n); + dim3 block(64); + dim3 grid((m * n + block.x - 1) / block.x); + getLowerTriangular<<>>(src, dst, m, n, k); +} + +/** + * @brief Create a diagonal identity matrix + * @param matrix: matrix of size n_rows x n_cols + * @param n_rows: number of rows of the matrix + * @param n_cols: number of columns of the matrix + */ +template +__global__ void createEyeKernel(m_t* matrix, idx_t n_rows, idx_t n_cols) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx < n_rows * n_cols) { + idx_t i = idx % n_rows, j = idx / n_rows; + matrix[idx] = m_t(j == i); + } +} + +template +void createEye(m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream) +{ + idx_t m = n_rows, n = n_cols; + dim3 block(64); + dim3 grid((m * n + block.x - 1) / block.x); + createEyeKernel<<>>(matrix, n_rows, n_cols); + } + /** * @brief Copy a vector to the diagonal of a matrix * @param vec: vector of length k = min(n_rows, n_cols) diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index caee2555a9..e6447e90d7 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -63,4 +63,11 @@ void fill(const raft::handle_t& handle, detail::setValue( inout.data_handle(), inout.data_handle(), scalar, inout.size(), handle.get_stream()); } + +template +void eye(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::createEye(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index fad3dd77af..78aac5d6d4 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -38,4 +38,22 @@ void upper_triangular(const raft::handle_t& handle, detail::copyUpperTriangular( src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); } + +/** + * @brief Copy the Lower triangular part of a matrix to another + * @param[in] handle: raft handle + * @param[in] src: input matrix with a size of n_rows x n_cols + * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) + */ + template + void lower_triangular(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dst) + { + auto k = std::min(src.extent(0), src.extent(1)); + RAFT_EXPECTS(k == dst.extent(0) && k == dst.extent(1), + "dst should be of size kxk, k = min(n_rows, n_cols)"); + detail::copyLowerTriangular( + src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); + } } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 78391d5ff2..123e381612 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -57,8 +58,21 @@ class MatrixTest : public ::testing::TestWithParam> { } protected: + + void test_eye() + { + auto eyemat = raft::make_device_matrix(handle, 4, 5); + raft::matrix::eye(handle, eyemat.view()); + std::vector eye_exp{1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0}; + std::vector eye_act(20); + raft::copy(eye_act.data(), eyemat.data_handle(), eye_act.size(), stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT_TRUE(hostVecMatch(eye_exp, eye_act, raft::Compare())); + } + void SetUp() override { + test_eye(); raft::random::RngState r(params.seed); int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); From db281d75f5d7b59ec4aaeee1d3d42baa5bdf6966 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 30 Dec 2022 17:02:02 +0100 Subject: [PATCH 05/17] Add b_orthonormalize test and eigh function --- .../raft/sparse/solver/detail/lobpcg.cuh | 179 +++++++++++++++--- cpp/include/raft/sparse/solver/lobpcg.cuh | 2 +- cpp/test/sparse/lobpcg.cu | 30 ++- 3 files changed, 184 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 23f7e2a8c8..d15634d670 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -16,15 +16,24 @@ #pragma once +#include +#include + +#include + #include #include #include #include +#include #include #include +#include +#include #include #include #include +#include #include #include #include @@ -39,6 +48,28 @@ struct MaxOp { HDI DataT operator()(DataT a, DataT b) { return maxPrim(a, b); } }; +template +struct isnan_test { + HDA bool operator()(const DataT a) { return isnan(a); } +}; + +template +void truncEig(const raft::handle_t& handle, + raft::device_matrix_view eigVector, + raft::device_vector_view eigLambda, + index_t size_x, + bool largest) +{ + // The eigenvalues are already sorted in ascending order with syevd + if (largest) + { + auto nrows = eigVector.extent(0); + auto ncols = eigVector.extent(1); + raft::matrix::col_reverse(handle, eigVector); + raft::matrix::col_reverse(handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0))); + } +} + // C = A * B template void spmm(const raft::handle_t& handle, @@ -106,15 +137,19 @@ void cholesky(const raft::handle_t& handle, auto lda = P.extent(0); auto dim = P.extent(0); cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + + auto P_copy = raft::make_device_matrix(handle, P.extent(0), P.extent(1)); + raft::copy(P_copy.data_handle(), P.data_handle(), P.size(), stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( - handle.get_cusolver_dn_handle(), uplo, dim, P.data_handle(), lda, &Lwork)); + handle.get_cusolver_dn_handle(), uplo, dim, P_copy.data_handle(), lda, &Lwork)); rmm::device_uvector workspace_decomp(Lwork / sizeof(value_t), stream); rmm::device_uvector info(1, stream); RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf(handle.get_cusolver_dn_handle(), uplo, dim, - P.data_handle(), + P_copy.data_handle(), lda, workspace_decomp.data(), Lwork, @@ -124,6 +159,27 @@ void cholesky(const raft::handle_t& handle, raft::update_host(&info_h, info.data(), 1, stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); ASSERT(info_h == 0, "lobpcg: error in potrf, info=%d | expected=0", info_h); + + bool h_hasnan = thrust::reduce(P_copy.data_handle(), P_copy.data_handle() + P_copy.size(), isnan_test(), 0, thrust::plus()); + ASSERT(h_hasnan == 0, "lobpcg: error in cholesky, NaN in outputs", info_h); + + raft::matrix::fill(handle, P, value_t(0)); + if (lower) + { + raft::matrix::lower_triangular( + handle, + raft::make_device_matrix_view( + P_copy.data_handle(), P.extent(0), P.extent(1)), + P); + } + else + { + raft::matrix::upper_triangular( + handle, + raft::make_device_matrix_view( + P_copy.data_handle(), P.extent(0), P.extent(1)), + P); + } } template @@ -138,12 +194,7 @@ void inverse(const raft::handle_t& handle, auto dim = P.extent(0); int info_h = 0; cublasOperation_t trans = CUBLAS_OP_N; - // make Pinv an identity matrix - auto diag = raft::make_device_vector(handle, dim); - raft::matrix::fill(handle, diag.view(), value_t(1)); - raft::matrix::fill(handle, Pinv, value_t(0)); - raft::matrix::set_diagonal( - handle, raft::make_device_vector_view(diag.data_handle(), dim), Pinv); + raft::matrix::eye(handle, Pinv); RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf_bufferSize( handle.get_cusolver_dn_handle(), dim, dim, P.data_handle(), lda, &Lwork)); @@ -183,15 +234,58 @@ void inverse(const raft::handle_t& handle, ASSERT(info_h == 0, "lobpcg: error in getrs, info=%d | expected=0", info_h); } +/** + * Helper function for converting a generalized eigenvalue problem + * A(X) = lambda(B(X)) to standard eigen value problem using cholesky + * transformation + */ +template +void eigh(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_matrix_view eigVecs, + raft::device_vector_view eigVals, + std::optional> B_opt = std::nullopt) +{ + if (B_opt.has_value()) + { + raft::linalg::eig_dc(handle, + raft::make_device_matrix_view(A.data_handle(), A.extent(0), A.extent(1)), + eigVecs, eigVals); + return; + } + auto dim = A.extent(0); + auto RTi = raft::make_device_matrix(handle, dim, dim); + auto Ri = raft::make_device_matrix(handle, dim, dim); + auto RT = raft::make_device_matrix(handle, dim, dim); + auto F = raft::make_device_matrix(handle, dim, dim); + auto B = B_opt.value(); + cholesky(handle, B, false); + + raft::linalg::transpose(handle, B, RT.view()); + inverse(handle, RT.view(), Ri.view()); + inverse(handle, B, RTi.view()); + + // Reuse the memory of matrix + auto& ARi = B; + auto& Fvecs = RT; + raft::linalg::gemm(handle, A, Ri.view(), ARi); + raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); + + raft::linalg::eig_dc(handle, + raft::make_device_matrix_view(F.data_handle(), F.extent(0), F.extent(1)), + Fvecs.view(), eigVals); + raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); +} + /** * B-orthonormalize the given block vector using Cholesky * * @tparam value_t floating point type used for elements * @tparam index_t integer type used for indexing * @param[in] handle: raft handle - * @param[in] B_opt: optional sparse matrix for normalization * @param[inout] V: dense matrix to normalize * @param[inout] BV: dense matrix. Use with parameter `bv_is_empty`. + * @param[in] B_opt: optional sparse matrix for normalization * @param[out] VBV_opt: optional dense matrix containing inverse matrix * @param[out] V_max_opt: optional vector containing normalization of V * @param[in] bv_is_empty: True if BV is used as input @@ -199,9 +293,9 @@ void inverse(const raft::handle_t& handle, template void b_orthonormalize( const raft::handle_t& handle, - std::optional> B_opt, raft::device_matrix_view V, raft::device_matrix_view BV, + std::optional> B_opt = std::nullopt, std::optional> VBV_opt = std::nullopt, std::optional> V_max_opt = std::nullopt, bool bv_is_empty = true) @@ -266,15 +360,8 @@ void b_orthonormalize( raft::make_device_matrix(handle, V.extent(1), V.extent(0)); raft::linalg::transpose(handle, V, VT.view()); - raft::linalg::gemm(handle, VT.view(), BV, VBVBuffer.view()); - cholesky(handle, VBVBuffer.view(), false); - // Reset VBV before copying upper triangular - raft::matrix::fill(handle, VBV, value_t(0)); - raft::matrix::upper_triangular( - handle, - raft::make_device_matrix_view( - VBVBuffer.data_handle(), VBVBuffer.extent(0), VBV.extent(1)), - VBV); + raft::linalg::gemm(handle, VT.view(), BV, VBV); + cholesky(handle, VBV, false); inverse(handle, VBV, VBVBuffer.view()); raft::copy(VBV.data_handle(), VBVBuffer.data_handle(), VBV.size(), stream); @@ -294,7 +381,7 @@ void lobpcg( std::optional> Y_opt, // Constraint // matrix shape=(n,Y) value_t tol = 0, - std::uint32_t max_iter = 20, + std::int32_t max_iter = 20, bool largest = true) { cudaStream_t stream = handle.get_stream(); @@ -303,11 +390,11 @@ void lobpcg( auto n = X.extent(0); auto size_x = X.extent(1); - /* DENSE SOLUTION + /* TODO: DENSE SOLUTION if ((n - size_y) < (5 * size_x)) { return; } */ - if (tol == 0) { tol = raft::mySqrt(1e-15) * n; } + if (tol <= 0) { tol = raft::mySqrt(1e-15) * n; } // Apply constraints to X /* auto matrix_BY = raft::make_device_matrix(handle, n, size_y); @@ -322,13 +409,57 @@ void lobpcg( raft::copy(matrix_BY.data_handle(), Y_opt.value().data_handle(), n * size_y, handle.get_stream()); } - cusparseDestroyDnMat(denseY); // GramYBY // ApplyConstraints }*/ auto BX = raft::make_device_matrix(handle, n, size_x); - b_orthonormalize(handle, B_opt, X, BX.view()); + b_orthonormalize(handle, X, BX.view(), B_opt); + // Compute the initial Ritz vectors: solve the eigenproblem. + auto AX = raft::make_device_matrix(handle, n, size_x); + spmm(handle, A, X, AX.view()); + auto gramXAX = raft::make_device_matrix(handle, size_x, size_x); + auto XT = raft::make_device_matrix(handle, size_x, n); + raft::linalg::transpose(handle, X, XT.view()); + raft::linalg::gemm(handle, XT.view(), AX.view(), gramXAX.view()); + auto eigVector = raft::make_device_matrix(handle, size_x, size_x); + auto eigLambda = raft::make_device_vector(handle, size_x); + eigh(handle, gramXAX.view(), eigVector.view(), eigLambda.view()); + truncEig(handle, eigVector.view(), eigLambda.view(), size_x, largest); + // Slice not needed for first eigh + // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, eigVectorFull.extent(0), size_x)); + + raft::linalg::gemm(handle, X, eigVector.view(), X); + raft::linalg::gemm(handle, AX.view(), eigVector.view(), AX.view()); + if (B_opt) raft::linalg::gemm(handle, BX.view(), eigVector.view(), BX.view()); + + // Active index set + auto mask = raft::make_device_vector(handle, size_x); + auto previousBlockSize = size_x; + + auto ident = raft::make_device_matrix(handle, size_x, size_x); + auto ident0 = raft::make_device_matrix(handle, size_x, size_x); + raft::matrix::eye(handle, ident.view()); + raft::matrix::eye(handle, ident0.view()); + + std::int32_t iteration_number = -1; return; // TODO } + +// Helper for b_orthonormalize optional arguments +template +void b_orthonormalize( + const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view BV, + b_opt_t&& B_opt = std::nullopt, + vbv_opt_t&& VBV_opt = std::nullopt, + v_max_opt_t&& V_max_opt = std::nullopt, + bool bv_is_empty = true) +{ + std::optional> b = std::forward(B_opt); + std::optional> vbv = std::forward(VBV_opt); + std::optional> v_max = std::forward(V_max_opt); + b_orthonormalize(handle, V, BV, b, vbv, v_max, bv_is_empty); +} }; // namespace raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/lobpcg.cuh b/cpp/include/raft/sparse/solver/lobpcg.cuh index 10fa95c6bc..d798a42597 100644 --- a/cpp/include/raft/sparse/solver/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/lobpcg.cuh @@ -33,7 +33,7 @@ void lobpcg( std::optional> Y = std::nullopt, // Constraint matrix shape=(n,Y) value_t tol = 0, - std::uint32_t max_iter = 20, + std::int32_t max_iter = 20, bool largest = true) { detail::lobpcg(handle, A, X, W, B, M, Y, tol, max_iter, largest); diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index 3b0be585a4..b238c5e5b7 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -70,8 +70,34 @@ class LOBPCGTest : public ::testing::TestWithParam> nnz_a = params.matrix_a.row_ind_ptr.size(); } + void test_b_orthonormalize() + { + idx_t n_rows_v = n_rows_a; + idx_t n_features_v = params.n_components; + raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); + auto v = raft::make_device_matrix_view( + act_eigvecs.data(), n_rows_v, n_features_v); + auto bv = raft::make_device_matrix(handle, n_rows_v, n_features_v); + auto vbv = raft::make_device_matrix(handle, n_features_v, n_features_v); + raft::sparse::solver::detail::b_orthonormalize(handle, + v, + bv.view(), + std::nullopt, + std::make_optional(vbv.view()), + std::nullopt, + true + ); + std::vector vbv_inv_expected{0.76298383, 0.0, -1.20276028, 1.0791533}; + std::vector vbv_inv_actual(4); + raft::copy(vbv_inv_actual.data(), vbv.data_handle(), vbv_inv_actual.size(), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + ASSERT_TRUE(hostVecMatch(vbv_inv_expected, vbv_inv_actual, raft::CompareApprox(0.0001))); + } + void Run() { + test_b_orthonormalize(); raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows_a, stream); raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); raft::update_device(values_a.data(), params.matrix_a.values.data(), nnz_a, stream); @@ -92,9 +118,9 @@ class LOBPCGTest : public ::testing::TestWithParam> raft::copy(X_CPU.data(), act_eigvecs.data(), X_CPU.size(), stream); raft::copy(W_CPU.data(), act_eigvals.data(), W_CPU.size(), stream); ASSERT_TRUE(raft::devArrMatch( - exp_eigvecs.data(), act_eigvecs.data(), exp_eigvecs.size(), raft::Compare(), stream)); + exp_eigvecs.data(), act_eigvecs.data(), exp_eigvecs.size(), raft::CompareApprox(0.0001), stream)); ASSERT_TRUE(raft::devArrMatch( - exp_eigvals.data(), act_eigvals.data(), exp_eigvals.size(), raft::Compare(), stream)); + exp_eigvals.data(), act_eigvals.data(), exp_eigvals.size(), raft::CompareApprox(0.0001), stream)); } protected: From ccbd735778c00dd530f71ca2be3777108aaf0d17 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 2 Jan 2023 15:16:41 +0100 Subject: [PATCH 06/17] Init iterations --- .../raft/sparse/solver/detail/lobpcg.cuh | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index d15634d670..acc4324525 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -442,6 +442,43 @@ void lobpcg( raft::matrix::eye(handle, ident0.view()); std::int32_t iteration_number = -1; + while (iteration_number < max_iter + 1) + { + iteration_number += 1 + //auto lambda_matrix = raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)); + auto aux = raft::make_device_matrix(handle, BX.extent(0), eigLambda.extent(0)); + if (B_opt) + { + raft::matrix::copy(handle, + raft::make_device_matrix_view(BX.data_handle(), BX.extent(0), BX.extent(1)), + aux.view()); + } + else + { + raft::matrix::copy(handle, + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), + aux.view()); + } + raft::linalg::binary_mult_skip_zero( + handle, aux.view(), + raft::make_device_vector_view(eigLambda.data_handle(), eigLambda.extent(0)), + Apply::ALONG_ROWS); + + auto R = raft::make_device_matrix(handle, n, size_x); + raft::linalg::substract(handle, AX.view(), aux.view(), R.view()); + + auto aux_sum = raft::make_device_vector(handle, size_x); + raft::linalg::reduce( // Could be done in-place in aux buffer + aux_sum.data_handle(), + R.data_handle(), size_x, n, value_t(0), + false, true, stream, false, + raft::L2Op()); + + auto residual_norms = raft::make_device_vector(handle, size_x); + raft::linalg::sqrt(handle, aux_sum, residual_norms); + + // cupy where & activemask + } return; // TODO } From e44a765d7ba1d7218455724e3b018acbdac1a6fc Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 23 Jan 2023 18:50:11 +0100 Subject: [PATCH 07/17] select_cols_if --- cpp/include/raft/matrix/detail/matrix.cuh | 10 +- cpp/include/raft/matrix/triangular.cuh | 22 +- .../raft/sparse/solver/detail/lobpcg.cuh | 289 ++++++++++++------ cpp/include/raft/sparse/solver/lobpcg.cuh | 6 +- cpp/test/matrix/matrix.cu | 1 - cpp/test/sparse/lobpcg.cu | 79 +++-- 6 files changed, 276 insertions(+), 131 deletions(-) diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 35ff6a3692..ae97243466 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -274,11 +274,11 @@ __global__ void createEyeKernel(m_t* matrix, idx_t n_rows, idx_t n_cols) template void createEye(m_t* matrix, idx_t n_rows, idx_t n_cols, cudaStream_t stream) { - idx_t m = n_rows, n = n_cols; - dim3 block(64); - dim3 grid((m * n + block.x - 1) / block.x); - createEyeKernel<<>>(matrix, n_rows, n_cols); - } + idx_t m = n_rows, n = n_cols; + dim3 block(64); + dim3 grid((m * n + block.x - 1) / block.x); + createEyeKernel<<>>(matrix, n_rows, n_cols); +} /** * @brief Copy a vector to the diagonal of a matrix diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index fec1bcfa40..792bb4cff4 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -50,17 +50,17 @@ void upper_triangular(const raft::handle_t& handle, * @param[in] src: input matrix with a size of n_rows x n_cols * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) */ - template - void lower_triangular(const raft::handle_t& handle, - raft::device_matrix_view src, - raft::device_matrix_view dst) - { - auto k = std::min(src.extent(0), src.extent(1)); - RAFT_EXPECTS(k == dst.extent(0) && k == dst.extent(1), - "dst should be of size kxk, k = min(n_rows, n_cols)"); - detail::copyLowerTriangular( - src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); - } +template +void lower_triangular(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dst) +{ + auto k = std::min(src.extent(0), src.extent(1)); + RAFT_EXPECTS(k == dst.extent(0) && k == dst.extent(1), + "dst should be of size kxk, k = min(n_rows, n_cols)"); + detail::copyLowerTriangular( + src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); +} /** @} */ // end group matrix_triangular diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index acc4324525..d160758375 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -19,17 +19,23 @@ #include #include +#include +#include #include +#include +#include #include #include #include -#include #include +#include +#include +#include #include #include #include -#include +#include #include #include #include @@ -37,6 +43,7 @@ #include #include #include +#include namespace raft::sparse::solver::detail { @@ -50,23 +57,76 @@ struct MaxOp { template struct isnan_test { - HDA bool operator()(const DataT a) { return isnan(a); } + HDI int operator()(const DataT a) { return isnan(a); } }; +/* Modification of copyRows to reindex columns, col_major only + * On a 4x3 matrix, indices could be [0, 2] to select col 0 and 2 + */ +template +void selectCols(const m_t* in, + idx_t n_rows, + idx_t n_cols, + m_t* out, + const idx_array_t* indices, + idx_t n_cols_indices, + cudaStream_t stream) +{ + idx_t size = n_cols_indices * n_rows; + auto counting = thrust::make_counting_iterator(0); + + thrust::for_each(rmm::exec_policy(stream), counting, counting + size, [=] __device__(idx_t idx) { + idx_t row = idx % n_rows; + idx_t new_col = idx / n_rows; + idx_t old_col = indices[new_col]; + out[new_col * n_rows + row] = in[old_col * n_rows + row]; + }); +} + +template +void selectColsIf(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view mask, + raft::device_matrix_view out) +{ + auto stream = handle.get_stream(); + auto in_n_cols = in.extent(1); + auto out_n_cols = out.extent(1); + auto rangeVec = raft::make_device_vector(handle, in_n_cols); + raft::linalg::range(rangeVec.data_handle(), in_n_cols, stream); + raft::linalg::map( + handle, + raft::make_device_vector_view(mask.data_handle(), mask.extent(0)), + rangeVec.view(), + [] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; }, + rangeVec.view()); + thrust::sort(rmm::exec_policy(stream), + rangeVec.data_handle(), + rangeVec.data_handle() + rangeVec.size(), + thrust::less()); + selectCols(in.data_handle(), + in.extent(0), + in.extent(1), + out.data_handle(), + rangeVec.data_handle() + rangeVec.size() - out_n_cols, + out_n_cols, + stream); +} + template void truncEig(const raft::handle_t& handle, - raft::device_matrix_view eigVector, - raft::device_vector_view eigLambda, - index_t size_x, - bool largest) + raft::device_matrix_view eigVector, + raft::device_vector_view eigLambda, + index_t size_x, + bool largest) { // The eigenvalues are already sorted in ascending order with syevd - if (largest) - { + if (largest) { auto nrows = eigVector.extent(0); auto ncols = eigVector.extent(1); raft::matrix::col_reverse(handle, eigVector); - raft::matrix::col_reverse(handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0))); + raft::matrix::col_reverse( + handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0))); } } @@ -132,13 +192,15 @@ void cholesky(const raft::handle_t& handle, raft::device_matrix_view P, bool lower = true) { - auto stream = handle.get_stream(); - int Lwork = 0; - auto lda = P.extent(0); - auto dim = P.extent(0); - cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = handle.get_stream(); + int Lwork = 0; + auto lda = P.extent(0); + auto dim = P.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; - auto P_copy = raft::make_device_matrix(handle, P.extent(0), P.extent(1)); + auto P_copy = + raft::make_device_matrix(handle, P.extent(0), P.extent(1)); raft::copy(P_copy.data_handle(), P.data_handle(), P.size(), stream); RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf_bufferSize( @@ -160,20 +222,22 @@ void cholesky(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); ASSERT(info_h == 0, "lobpcg: error in potrf, info=%d | expected=0", info_h); - bool h_hasnan = thrust::reduce(P_copy.data_handle(), P_copy.data_handle() + P_copy.size(), isnan_test(), 0, thrust::plus()); - ASSERT(h_hasnan == 0, "lobpcg: error in cholesky, NaN in outputs", info_h); + int h_hasnan = thrust::transform_reduce(thrust_exec_policy, + P_copy.data_handle(), + P_copy.data_handle() + P_copy.size(), + isnan_test(), + 0, + thrust::plus()); + ASSERT(h_hasnan == 0, "lobpcg: error in cholesky, NaN in outputs"); raft::matrix::fill(handle, P, value_t(0)); - if (lower) - { + if (lower) { raft::matrix::lower_triangular( handle, raft::make_device_matrix_view( P_copy.data_handle(), P.extent(0), P.extent(1)), P); - } - else - { + } else { raft::matrix::upper_triangular( handle, raft::make_device_matrix_view( @@ -240,25 +304,27 @@ void inverse(const raft::handle_t& handle, * transformation */ template -void eigh(const raft::handle_t& handle, - raft::device_matrix_view A, - raft::device_matrix_view eigVecs, - raft::device_vector_view eigVals, - std::optional> B_opt = std::nullopt) +void eigh( + const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_matrix_view eigVecs, + raft::device_vector_view eigVals, + std::optional> B_opt = std::nullopt) { - if (B_opt.has_value()) - { - raft::linalg::eig_dc(handle, - raft::make_device_matrix_view(A.data_handle(), A.extent(0), A.extent(1)), - eigVecs, eigVals); + if (B_opt.has_value()) { + raft::linalg::eig_dc(handle, + raft::make_device_matrix_view( + A.data_handle(), A.extent(0), A.extent(1)), + eigVecs, + eigVals); return; } auto dim = A.extent(0); auto RTi = raft::make_device_matrix(handle, dim, dim); - auto Ri = raft::make_device_matrix(handle, dim, dim); - auto RT = raft::make_device_matrix(handle, dim, dim); - auto F = raft::make_device_matrix(handle, dim, dim); - auto B = B_opt.value(); + auto Ri = raft::make_device_matrix(handle, dim, dim); + auto RT = raft::make_device_matrix(handle, dim, dim); + auto F = raft::make_device_matrix(handle, dim, dim); + auto B = B_opt.value(); cholesky(handle, B, false); raft::linalg::transpose(handle, B, RT.view()); @@ -266,14 +332,16 @@ void eigh(const raft::handle_t& handle, inverse(handle, B, RTi.view()); // Reuse the memory of matrix - auto& ARi = B; + auto& ARi = B; auto& Fvecs = RT; raft::linalg::gemm(handle, A, Ri.view(), ARi); raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); - raft::linalg::eig_dc(handle, - raft::make_device_matrix_view(F.data_handle(), F.extent(0), F.extent(1)), - Fvecs.view(), eigVals); + raft::linalg::eig_dc(handle, + raft::make_device_matrix_view( + F.data_handle(), F.extent(0), F.extent(1)), + Fvecs.view(), + eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); } @@ -320,7 +388,7 @@ void b_orthonormalize( value_t(0), raft::linalg::Apply::ALONG_ROWS, false, - raft::Nop(), + raft::identity_op(), MaxOp()); */ raft::linalg::reduce(V_max.data_handle(), @@ -332,7 +400,7 @@ void b_orthonormalize( true, handle.get_stream(), false, - raft::Nop(), + raft::identity_op(), MaxOp()); raft::linalg::binary_div_skip_zero(handle, V, V_max_const, raft::linalg::Apply::ALONG_ROWS); @@ -380,9 +448,10 @@ void lobpcg( std::optional> M_opt, // shape=(n,n) std::optional> Y_opt, // Constraint // matrix shape=(n,Y) - value_t tol = 0, - std::int32_t max_iter = 20, - bool largest = true) + value_t tol = 0, + std::int32_t max_iter = 20, + bool largest = true, + int verbosityLevel = 0) { cudaStream_t stream = handle.get_stream(); // auto size_y = 0; @@ -417,86 +486,118 @@ void lobpcg( // Compute the initial Ritz vectors: solve the eigenproblem. auto AX = raft::make_device_matrix(handle, n, size_x); spmm(handle, A, X, AX.view()); - auto gramXAX = raft::make_device_matrix(handle, size_x, size_x); + auto gramXAX = + raft::make_device_matrix(handle, size_x, size_x); auto XT = raft::make_device_matrix(handle, size_x, n); raft::linalg::transpose(handle, X, XT.view()); raft::linalg::gemm(handle, XT.view(), AX.view(), gramXAX.view()); - auto eigVector = raft::make_device_matrix(handle, size_x, size_x); + auto eigVector = + raft::make_device_matrix(handle, size_x, size_x); auto eigLambda = raft::make_device_vector(handle, size_x); eigh(handle, gramXAX.view(), eigVector.view(), eigLambda.view()); truncEig(handle, eigVector.view(), eigLambda.view(), size_x, largest); // Slice not needed for first eigh - // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, eigVectorFull.extent(0), size_x)); + // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, + // eigVectorFull.extent(0), size_x)); raft::linalg::gemm(handle, X, eigVector.view(), X); raft::linalg::gemm(handle, AX.view(), eigVector.view(), AX.view()); if (B_opt) raft::linalg::gemm(handle, BX.view(), eigVector.view(), BX.view()); - + // Active index set - auto mask = raft::make_device_vector(handle, size_x); + // TODO: use uint8_t + auto active_mask = raft::make_device_vector(handle, size_x); auto previousBlockSize = size_x; - auto ident = raft::make_device_matrix(handle, size_x, size_x); + auto ident = raft::make_device_matrix(handle, size_x, size_x); auto ident0 = raft::make_device_matrix(handle, size_x, size_x); + // TODO: Maybe initialization here is not needed? raft::matrix::eye(handle, ident.view()); raft::matrix::eye(handle, ident0.view()); std::int32_t iteration_number = -1; - while (iteration_number < max_iter + 1) - { - iteration_number += 1 - //auto lambda_matrix = raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)); - auto aux = raft::make_device_matrix(handle, BX.extent(0), eigLambda.extent(0)); - if (B_opt) - { + while (iteration_number < max_iter + 1) { + iteration_number += 1; + // auto lambda_matrix = raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)); + auto aux = raft::make_device_matrix( + handle, BX.extent(0), eigLambda.extent(0)); + if (B_opt) { raft::matrix::copy(handle, - raft::make_device_matrix_view(BX.data_handle(), BX.extent(0), BX.extent(1)), - aux.view()); - } - else - { + raft::make_device_matrix_view( + BX.data_handle(), BX.extent(0), BX.extent(1)), + aux.view()); + } else { raft::matrix::copy(handle, - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), - aux.view()); + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), + aux.view()); } - raft::linalg::binary_mult_skip_zero( - handle, aux.view(), - raft::make_device_vector_view(eigLambda.data_handle(), eigLambda.extent(0)), - Apply::ALONG_ROWS); + raft::linalg::binary_mult_skip_zero(handle, + aux.view(), + raft::make_device_vector_view( + eigLambda.data_handle(), eigLambda.extent(0)), + raft::linalg::Apply::ALONG_ROWS); auto R = raft::make_device_matrix(handle, n, size_x); - raft::linalg::substract(handle, AX.view(), aux.view(), R.view()); + raft::linalg::subtract(handle, + raft::make_device_matrix_view( + AX.data_handle(), AX.extent(0), AX.extent(1)), + raft::make_device_matrix_view( + aux.data_handle(), aux.extent(0), aux.extent(1)), + R.view()); auto aux_sum = raft::make_device_vector(handle, size_x); - raft::linalg::reduce( // Could be done in-place in aux buffer + raft::linalg::reduce( // Could be done in-place in aux buffer aux_sum.data_handle(), - R.data_handle(), size_x, n, value_t(0), - false, true, stream, false, - raft::L2Op()); + R.data_handle(), + size_x, + n, + value_t(0), + false, + true, + stream, + false, + raft::sq_op()); auto residual_norms = raft::make_device_vector(handle, size_x); - raft::linalg::sqrt(handle, aux_sum, residual_norms); + raft::linalg::sqrt( + handle, + raft::make_device_vector_view(aux_sum.data_handle(), aux_sum.size()), + residual_norms.view()); + + // cupy where & active_mask + raft::linalg::unary_op(handle, + raft::make_device_vector_view( + residual_norms.data_handle(), residual_norms.size()), + active_mask.view(), + [tol] __device__(value_t rn) { return rn > tol; }); + if (verbosityLevel > 2) { + print_device_vector("active_mask", active_mask.data_handle(), active_mask.size(), std::cout); + } + index_t currentBlockSize = thrust::count(thrust::cuda::par.on(stream), + active_mask.data_handle(), + active_mask.data_handle() + active_mask.size(), + 0); + auto identView = raft::make_device_matrix_view( + ident.data_handle(), previousBlockSize, previousBlockSize); + if (currentBlockSize != previousBlockSize) { + previousBlockSize = currentBlockSize; + identView = raft::make_device_matrix_view( + ident.data_handle(), currentBlockSize, currentBlockSize); + raft::matrix::eye(handle, identView); + } - // cupy where & activemask + if (currentBlockSize == 0) break; + if (verbosityLevel > 0) { + // TODO add verb + } + auto activeblockR = + raft::make_device_matrix(handle, R.extent(0), currentBlockSize); + + selectColsIf(handle, R.view(), active_mask.view(), activeblockR.view()); } return; // TODO } - -// Helper for b_orthonormalize optional arguments -template -void b_orthonormalize( - const raft::handle_t& handle, - raft::device_matrix_view V, - raft::device_matrix_view BV, - b_opt_t&& B_opt = std::nullopt, - vbv_opt_t&& VBV_opt = std::nullopt, - v_max_opt_t&& V_max_opt = std::nullopt, - bool bv_is_empty = true) -{ - std::optional> b = std::forward(B_opt); - std::optional> vbv = std::forward(VBV_opt); - std::optional> v_max = std::forward(V_max_opt); - b_orthonormalize(handle, V, BV, b, vbv, v_max, bv_is_empty); -} }; // namespace raft::sparse::solver::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/lobpcg.cuh b/cpp/include/raft/sparse/solver/lobpcg.cuh index d798a42597..36aa909012 100644 --- a/cpp/include/raft/sparse/solver/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/lobpcg.cuh @@ -32,9 +32,9 @@ void lobpcg( std::nullopt, // shape=(n,n) std::optional> Y = std::nullopt, // Constraint matrix shape=(n,Y) - value_t tol = 0, - std::int32_t max_iter = 20, - bool largest = true) + value_t tol = 0, + std::int32_t max_iter = 20, + bool largest = true) { detail::lobpcg(handle, A, X, W, B, M, Y, tol, max_iter, largest); } diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index d87446c163..798f591ce6 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -58,7 +58,6 @@ class MatrixTest : public ::testing::TestWithParam> { } protected: - void test_eye() { auto eyemat = raft::make_device_matrix(handle, 4, 5); diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index b238c5e5b7..81d3d0d30b 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -22,6 +22,7 @@ #include #include +#include "../test_utils.cuh" #include "../test_utils.h" #include @@ -47,6 +48,29 @@ struct LOBPCGInputs { idx_t n_components; }; +// Helper for b_orthonormalize optional arguments +template +void b_orthonormalize(const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view BV, + b_opt_t&& B_opt = std::nullopt, + vbv_opt_t&& VBV_opt = std::nullopt, + v_max_opt_t&& V_max_opt = std::nullopt, + bool bv_is_empty = true) +{ + std::optional> b = + std::forward(B_opt); + std::optional> vbv = + std::forward(VBV_opt); + std::optional> v_max = + std::forward(V_max_opt); + raft::sparse::solver::detail::b_orthonormalize(handle, V, BV, b, vbv, v_max, bv_is_empty); +} + template class LOBPCGTest : public ::testing::TestWithParam> { public: @@ -70,33 +94,48 @@ class LOBPCGTest : public ::testing::TestWithParam> nnz_a = params.matrix_a.row_ind_ptr.size(); } + void test_selectcolsif() + { + auto a = raft::make_device_matrix(handle, 5, 8); + auto c = raft::make_device_matrix(handle, 5, 4); + auto b = raft::make_device_vector(handle, 8); + raft::linalg::range(a.data_handle(), a.size(), handle.get_stream()); + std::vector select_h{0, 1, 1, 1, 0, 0, 0, 1}; + raft::copy(b.data_handle(), select_h.data(), 8, handle.get_stream()); + raft::sparse::solver::detail::selectColsIf(handle, a.view(), b.view(), c.view()); + std::vector res(c.size()); + std::vector expected{5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39}; + raft::copy(res.data(), c.data_handle(), c.size(), handle.get_stream()); + + ASSERT_TRUE(hostVecMatch(expected, res, raft::CompareApprox(0.0001))); + } + void test_b_orthonormalize() { - idx_t n_rows_v = n_rows_a; + idx_t n_rows_v = n_rows_a; idx_t n_features_v = params.n_components; raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); auto v = raft::make_device_matrix_view( act_eigvecs.data(), n_rows_v, n_features_v); - auto bv = raft::make_device_matrix(handle, n_rows_v, n_features_v); - auto vbv = raft::make_device_matrix(handle, n_features_v, n_features_v); - raft::sparse::solver::detail::b_orthonormalize(handle, - v, - bv.view(), - std::nullopt, - std::make_optional(vbv.view()), - std::nullopt, - true - ); + auto bv = + raft::make_device_matrix(handle, n_rows_v, n_features_v); + auto vbv = + raft::make_device_matrix(handle, n_features_v, n_features_v); + b_orthonormalize( + handle, v, bv.view(), std::nullopt, std::make_optional(vbv.view()), std::nullopt, true); std::vector vbv_inv_expected{0.76298383, 0.0, -1.20276028, 1.0791533}; std::vector vbv_inv_actual(4); raft::copy(vbv_inv_actual.data(), vbv.data_handle(), vbv_inv_actual.size(), stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - ASSERT_TRUE(hostVecMatch(vbv_inv_expected, vbv_inv_actual, raft::CompareApprox(0.0001))); + ASSERT_TRUE( + hostVecMatch(vbv_inv_expected, vbv_inv_actual, raft::CompareApprox(0.0001))); } void Run() { + test_selectcolsif(); test_b_orthonormalize(); raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows_a, stream); raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); @@ -117,10 +156,16 @@ class LOBPCGTest : public ::testing::TestWithParam> std::vector W_CPU(n_rows_a); raft::copy(X_CPU.data(), act_eigvecs.data(), X_CPU.size(), stream); raft::copy(W_CPU.data(), act_eigvals.data(), W_CPU.size(), stream); - ASSERT_TRUE(raft::devArrMatch( - exp_eigvecs.data(), act_eigvecs.data(), exp_eigvecs.size(), raft::CompareApprox(0.0001), stream)); - ASSERT_TRUE(raft::devArrMatch( - exp_eigvals.data(), act_eigvals.data(), exp_eigvals.size(), raft::CompareApprox(0.0001), stream)); + ASSERT_TRUE(raft::devArrMatch(exp_eigvecs.data(), + act_eigvecs.data(), + exp_eigvecs.size(), + raft::CompareApprox(0.0001), + stream)); + ASSERT_TRUE(raft::devArrMatch(exp_eigvals.data(), + act_eigvals.data(), + exp_eigvals.size(), + raft::CompareApprox(0.0001), + stream)); } protected: @@ -227,7 +272,7 @@ a.data = array([0.37911922, 0.11567201, 0.5135106 , 0.08968836, 0.73450965, 0.32938903, 0.82477561, 0.20858375, 0.24755519, 0.23677223, 0.73957246, 0.09050876, 0.86530489]) -a.todense() = +a.todense() = np.matrix([[0.37911922, 0. , 0.11567201, 0.5135106 , 0. , 0.08968836], [0.73450965, 0.26432646, 0.21985123, 0.74888277, 0.34753734, 0.11204864], [0.82902676, 0. , 0.53023521, 0.24047095, 0. , 0.37913592], From c9224d51cdf380f6df1d85264352e06df32eadb4 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 26 Jan 2023 17:29:07 +0100 Subject: [PATCH 08/17] Add cho_solve --- cpp/include/raft/core/mdspan.hpp | 20 ++ cpp/include/raft/matrix/triangular.cuh | 2 +- .../raft/sparse/solver/detail/lobpcg.cuh | 199 +++++++++++++++--- 3 files changed, 188 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 786ce69f89..f3e0aa23e2 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -304,4 +304,24 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Create a copy of the given mdspan with const element type + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy + */ +template > +auto make_const_mdspan(mdspan_type mds) +{ + using const_element_t = std::add_const_t; + using const_accessor_t = + host_device_accessor, + mdspan_type::accessor_type::mem_type>; + return std::experimental::mdspan(mds); +} + } // namespace raft diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index 792bb4cff4..f82b10a3bb 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -45,7 +45,7 @@ void upper_triangular(const raft::handle_t& handle, } /** - * @brief Copy the Lower triangular part of a matrix to another + * @brief Copy the lower triangular part of a matrix to another * @param[in] handle: raft handle * @param[in] src: input matrix with a size of n_rows x n_cols * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index d160758375..8029cf5b5a 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -96,7 +96,7 @@ void selectColsIf(const raft::handle_t& handle, raft::linalg::range(rangeVec.data_handle(), in_n_cols, stream); raft::linalg::map( handle, - raft::make_device_vector_view(mask.data_handle(), mask.extent(0)), + raft::make_const_mdspan(mask), rangeVec.view(), [] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; }, rangeVec.view()); @@ -187,8 +187,37 @@ void spmm(const raft::handle_t& handle, cusparseDestroyDnMat(dense_C); } +/** + * Solve the linear equation A x = b, given the Cholesky factorization of A + * The operation is in-place, i.e. matrix X overwrites matrix B. + */ template -void cholesky(const raft::handle_t& handle, +void cho_solve(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_matrix_view B, + bool lower = true) +{ + auto thrust_exec_policy = handle.get_thrust_policy(); + auto stream = handle.get_stream(); + auto lda = A.extent(0); + auto dim = A.extent(0); + cublasFillMode_t uplo = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + + rmm::device_uvector info(1, stream); + RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrs(handle.get_cusolver_dn_handle(), + uplo, + dim, + B.extent(1), + A.data_handle(), + lda, + B.data_handle(), + dim, + info.data(), + stream)); +} + +template +int cholesky(const raft::handle_t& handle, raft::device_matrix_view P, bool lower = true) { @@ -228,22 +257,23 @@ void cholesky(const raft::handle_t& handle, isnan_test(), 0, thrust::plus()); - ASSERT(h_hasnan == 0, "lobpcg: error in cholesky, NaN in outputs"); + + if (h_hasnan != 0) // "lobpcg: error in cholesky, NaN in outputs" + return 1; raft::matrix::fill(handle, P, value_t(0)); if (lower) { raft::matrix::lower_triangular( handle, - raft::make_device_matrix_view( - P_copy.data_handle(), P.extent(0), P.extent(1)), + raft::make_const_mdspan(P_copy.view()), P); } else { raft::matrix::upper_triangular( handle, - raft::make_device_matrix_view( - P_copy.data_handle(), P.extent(0), P.extent(1)), + raft::make_const_mdspan(P_copy.view()), P); } + return 0; } template @@ -298,6 +328,31 @@ void inverse(const raft::handle_t& handle, ASSERT(info_h == 0, "lobpcg: error in getrs, info=%d | expected=0", info_h); } +template +void apply_constraints(const raft::handle_t& handle, + raft::device_matrix_view V, + raft::device_matrix_view YBY, + raft::device_matrix_view BY, + raft::device_matrix_view Y) +{ + auto stream = handle.get_stream(); + auto YBY_copy = raft::make_device_matrix(handle, YBY.extent(0), YBY.extent(1)); + raft::copy(YBY_copy.data_handle(), YBY.data_handle(), YBY.size(), stream); + // TODO: Using raw-pointer because no gemm function with mdspan accept transpose op + auto YBV = raft::make_device_matrix(handle, BY.extent(1), V.extent(1)); + value_t zero = 0; + value_t one = 1; + raft::linalg::gemm(handle, + true, false, YBV.extent(0), YBV.extent(1), BY.extent(0), &one, BY.data_handle(), + BY.extent(0), V.data_handle(), V.extent(0), &zero, YBV.data_handle(), YBV.extent(0), stream); + + cholesky(handle, YBY_copy.view()); + cho_solve(handle, raft::make_const_mdspan(YBY_copy.view()), YBV.view()); + auto BV = raft::make_device_matrix(handle, V.extent(0), YBV.extent(1)); + raft::linalg::gemm(handle, Y, YBV.view(), BV.view()); + raft::linalg::subtract(handle, raft::make_const_mdspan(Y), raft::make_const_mdspan(BV.view()), Y); +} + /** * Helper function for converting a generalized eigenvalue problem * A(X) = lambda(B(X)) to standard eigen value problem using cholesky @@ -313,8 +368,7 @@ void eigh( { if (B_opt.has_value()) { raft::linalg::eig_dc(handle, - raft::make_device_matrix_view( - A.data_handle(), A.extent(0), A.extent(1)), + raft::make_const_mdspan(A), eigVecs, eigVals); return; @@ -338,8 +392,7 @@ void eigh( raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); raft::linalg::eig_dc(handle, - raft::make_device_matrix_view( - F.data_handle(), F.extent(0), F.extent(1)), + raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); @@ -354,12 +407,12 @@ void eigh( * @param[inout] V: dense matrix to normalize * @param[inout] BV: dense matrix. Use with parameter `bv_is_empty`. * @param[in] B_opt: optional sparse matrix for normalization - * @param[out] VBV_opt: optional dense matrix containing inverse matrix - * @param[out] V_max_opt: optional vector containing normalization of V + * @param[out] VBV_opt: optional dense matrix containing inverse matrix (shape v[1] * v[1]) + * @param[out] V_max_opt: optional vector containing normalization of V (shape v[1]) * @param[in] bv_is_empty: True if BV is used as input */ template -void b_orthonormalize( +int b_orthonormalize( const raft::handle_t& handle, raft::device_matrix_view V, raft::device_matrix_view BV, @@ -378,7 +431,7 @@ void b_orthonormalize( V_max_ptr = V_max_opt.value().data_handle(); } auto V_max = raft::make_device_vector_view(V_max_ptr, V.extent(1)); - auto V_max_const = raft::make_device_vector_view(V_max_ptr, V.extent(1)); + auto V_max_const = raft::make_const_mdspan(V_max); // /*raft::linalg::reduce(handle, @@ -429,12 +482,14 @@ void b_orthonormalize( raft::linalg::transpose(handle, V, VT.view()); raft::linalg::gemm(handle, VT.view(), BV, VBV); - cholesky(handle, VBV, false); + int result_status = cholesky(handle, VBV, false); + if (result_status != 0) { return result_status; } inverse(handle, VBV, VBVBuffer.view()); raft::copy(VBV.data_handle(), VBVBuffer.data_handle(), VBV.size(), stream); raft::linalg::gemm(handle, V, VBV, V); if (B_opt) raft::linalg::gemm(handle, BV, VBV, BV); + return 0; } template @@ -511,11 +566,20 @@ void lobpcg( auto ident = raft::make_device_matrix(handle, size_x, size_x); auto ident0 = raft::make_device_matrix(handle, size_x, size_x); - // TODO: Maybe initialization here is not needed? + // TODO: Maybe initialization of ident here is not needed? raft::matrix::eye(handle, ident.view()); raft::matrix::eye(handle, ident0.view()); + auto Pbuffer = rmm::device_uvector(0, stream); + auto APbuffer = rmm::device_uvector(0, stream); + auto BPbuffer = rmm::device_uvector(0, stream); + auto activePView = raft::make_device_matrix_view(Pbuffer.data(), 0, 0); + auto activeAPView = raft::make_device_matrix_view(APbuffer.data(), 0, 0); + auto activeBPView = raft::make_device_matrix_view(BPbuffer.data(), 0, 0); + std::int32_t iteration_number = -1; + bool restart = true; + //bool explicitGramFlag = false; while (iteration_number < max_iter + 1) { iteration_number += 1; // auto lambda_matrix = raft::make_device_matrix_view( - BX.data_handle(), BX.extent(0), BX.extent(1)), + raft::make_const_mdspan(BX.view()), aux.view()); } else { raft::matrix::copy(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), + raft::make_const_mdspan(X), aux.view()); } raft::linalg::binary_mult_skip_zero(handle, aux.view(), - raft::make_device_vector_view( - eigLambda.data_handle(), eigLambda.extent(0)), + raft::make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS); auto R = raft::make_device_matrix(handle, n, size_x); raft::linalg::subtract(handle, - raft::make_device_matrix_view( - AX.data_handle(), AX.extent(0), AX.extent(1)), - raft::make_device_matrix_view( - aux.data_handle(), aux.extent(0), aux.extent(1)), + raft::make_const_mdspan(AX.view()), + raft::make_const_mdspan(aux.view()), R.view()); auto aux_sum = raft::make_device_vector(handle, size_x); @@ -563,13 +622,12 @@ void lobpcg( auto residual_norms = raft::make_device_vector(handle, size_x); raft::linalg::sqrt( handle, - raft::make_device_vector_view(aux_sum.data_handle(), aux_sum.size()), + raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); // cupy where & active_mask raft::linalg::unary_op(handle, - raft::make_device_vector_view( - residual_norms.data_handle(), residual_norms.size()), + raft::make_const_mdspan(residual_norms.view()), active_mask.view(), [tol] __device__(value_t rn) { return rn > tol; }); if (verbosityLevel > 2) { @@ -592,10 +650,87 @@ void lobpcg( if (verbosityLevel > 0) { // TODO add verb } - auto activeblockR = + auto activeBlockVectorR = raft::make_device_matrix(handle, R.extent(0), currentBlockSize); - selectColsIf(handle, R.view(), active_mask.view(), activeblockR.view()); + selectColsIf(handle, R.view(), active_mask.view(), activeBlockVectorR.view()); + + if (iteration_number > 0) + { + activePView = raft::make_device_matrix_view(Pbuffer.data(), n, currentBlockSize); + activeAPView = raft::make_device_matrix_view(APbuffer.data(), n, currentBlockSize); + if (B_opt.has_value()) { + activeBPView = raft::make_device_matrix_view(BPbuffer.data(), n, currentBlockSize); + } + } + if (M_opt.has_value()) + { + // Apply preconditioner T to the active residuals. + auto MRtemp = + raft::make_device_matrix(handle, R.extent(0), currentBlockSize); + spmm(handle, M_opt.value(), activeBlockVectorR.view(), MRtemp.view()); + raft::copy(activeBlockVectorR.data_handle(), MRtemp.data_handle(), MRtemp.size(), stream); + } + // Apply constraints to the preconditioned residuals. + if (Y_opt.has_value()) + { + // TODO Constraint + //apply_constraints(handle, X, gramYBY.view(), BY.view(), Y_opt.value()); + } + // B-orthogonalize the preconditioned residuals to X. + if (B_opt.has_value()) + { + auto BXT = raft::make_device_matrix(handle, BX.extent(1), BX.extent(0)); + auto BXTR = raft::make_device_matrix(handle, BXT.extent(0), activeBlockVectorR.extent(1)); + auto XBXTR = raft::make_device_matrix(handle, X.extent(0), BXTR.extent(1)); + + raft::linalg::transpose(handle, BX.view(), BXT.view()); + raft::linalg::gemm(handle, BXT.view(), activeBlockVectorR.view(), BXTR.view()); + raft::linalg::gemm(handle, X, BXTR.view(), XBXTR.view()); + raft::linalg::subtract(handle, + raft::make_const_mdspan(activeBlockVectorR.view()), + raft::make_const_mdspan(XBXTR.view()), + activeBlockVectorR.view()); + } else { + + auto XTR = raft::make_device_matrix(handle, XT.extent(0), activeBlockVectorR.extent(1)); + auto XXTR = raft::make_device_matrix(handle, X.extent(0), XTR.extent(1)); + + raft::linalg::gemm(handle, XT.view(), activeBlockVectorR.view(), XTR.view()); + raft::linalg::gemm(handle, X, XTR.view(), XXTR.view()); + raft::linalg::subtract(handle, + raft::make_const_mdspan(activeBlockVectorR.view()), + raft::make_const_mdspan(XXTR.view()), + activeBlockVectorR.view()); + } + // B-orthonormalize the preconditioned residuals. + auto BR = raft::make_device_matrix(handle, activeBlockVectorR.extent(0), activeBlockVectorR.extent(1)); + b_orthonormalize(handle, activeBlockVectorR.view(), BR.view(), B_opt); + + auto AR = raft::make_device_matrix(handle, n, activeBlockVectorR.extent(1)); + spmm(handle, A, activeBlockVectorR.view(), AR.view()); + + if (iteration_number > 0) + { + auto invR = raft::make_device_matrix(handle, activePView.extent(1), activePView.extent(1)); + auto normal = raft::make_device_vector(handle, activePView.extent(1)); + int b_orth_status = 0; + if (B_opt.has_value()) + { + b_orth_status = b_orthonormalize(handle, activePView, activeBPView, B_opt, std::make_optional(invR.view()), std::make_optional(normal.view()), false); + } else + { + auto BP = raft::make_device_matrix(handle, activePView.extent(0), activePView.extent(1)); + b_orth_status = b_orthonormalize(handle, activePView, BP.view(), B_opt, std::make_optional(invR.view()), std::make_optional(normal.view())); + } + if (b_orth_status != 0) { + restart = true; + } else { + raft::linalg::binary_div_skip_zero(handle, activeAPView, raft::make_const_mdspan(normal.view()), raft::linalg::Apply::ALONG_ROWS); + raft::linalg::gemm(handle, activeAPView, invR.view(), activeAPView); + restart = false; + } + } } return; // TODO From a30abc02ec50c1301d7aa4b7413c4ef5a84924f7 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 1 Mar 2023 11:18:25 +0100 Subject: [PATCH 09/17] Gram matrices computation --- .../raft/sparse/solver/detail/lobpcg.cuh | 140 ++++++++++++++---- 1 file changed, 114 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 8029cf5b5a..7f1a2ac3fc 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -508,7 +508,8 @@ void lobpcg( bool largest = true, int verbosityLevel = 0) { - cudaStream_t stream = handle.get_stream(); + cudaStream_t stream = handle.get_stream(); + auto thrust_exec_policy = handle.get_thrust_policy(); // auto size_y = 0; // if (Y_opt.has_value()) size_y = Y_opt.value().extent(1); auto n = X.extent(0); @@ -537,7 +538,8 @@ void lobpcg( // ApplyConstraints }*/ auto BX = raft::make_device_matrix(handle, n, size_x); - b_orthonormalize(handle, X, BX.view(), B_opt); + auto BXView = BX.view(); + b_orthonormalize(handle, X, BXView, B_opt); // Compute the initial Ritz vectors: solve the eigenproblem. auto AX = raft::make_device_matrix(handle, n, size_x); spmm(handle, A, X, AX.view()); @@ -557,7 +559,7 @@ void lobpcg( raft::linalg::gemm(handle, X, eigVector.view(), X); raft::linalg::gemm(handle, AX.view(), eigVector.view(), AX.view()); - if (B_opt) raft::linalg::gemm(handle, BX.view(), eigVector.view(), BX.view()); + if (B_opt) raft::linalg::gemm(handle, BXView, eigVector.view(), BXView); // Active index set // TODO: use uint8_t @@ -579,7 +581,7 @@ void lobpcg( std::int32_t iteration_number = -1; bool restart = true; - //bool explicitGramFlag = false; + bool explicitGramFlag = false; while (iteration_number < max_iter + 1) { iteration_number += 1; // auto lambda_matrix = raft::make_device_matrix_view 0) { // TODO add verb } - auto activeBlockVectorR = + auto activeR = raft::make_device_matrix(handle, R.extent(0), currentBlockSize); - selectColsIf(handle, R.view(), active_mask.view(), activeBlockVectorR.view()); + selectColsIf(handle, R.view(), active_mask.view(), activeR.view()); if (iteration_number > 0) { + /* TODO + Pbuffer.resize(n * currentBlockSize, stream); + APbuffer.resize(n * currentBlockSize, stream); + BPbuffer.resize(n * currentBlockSize, stream); activePView = raft::make_device_matrix_view(Pbuffer.data(), n, currentBlockSize); activeAPView = raft::make_device_matrix_view(APbuffer.data(), n, currentBlockSize); + selectColsIf(handle, R.view(), active_mask.view(), activePView); + selectColsIf(handle, AP.view(), active_mask.view(), activeAPView); if (B_opt.has_value()) { activeBPView = raft::make_device_matrix_view(BPbuffer.data(), n, currentBlockSize); - } + selectColsIf(handle, BP.view(), active_mask.view(), activeBPView); + }*/ } if (M_opt.has_value()) { // Apply preconditioner T to the active residuals. auto MRtemp = raft::make_device_matrix(handle, R.extent(0), currentBlockSize); - spmm(handle, M_opt.value(), activeBlockVectorR.view(), MRtemp.view()); - raft::copy(activeBlockVectorR.data_handle(), MRtemp.data_handle(), MRtemp.size(), stream); + spmm(handle, M_opt.value(), activeR.view(), MRtemp.view()); + raft::copy(activeR.data_handle(), MRtemp.data_handle(), MRtemp.size(), stream); } // Apply constraints to the preconditioned residuals. if (Y_opt.has_value()) @@ -681,34 +690,35 @@ void lobpcg( if (B_opt.has_value()) { auto BXT = raft::make_device_matrix(handle, BX.extent(1), BX.extent(0)); - auto BXTR = raft::make_device_matrix(handle, BXT.extent(0), activeBlockVectorR.extent(1)); + auto BXTR = raft::make_device_matrix(handle, BXT.extent(0), activeR.extent(1)); auto XBXTR = raft::make_device_matrix(handle, X.extent(0), BXTR.extent(1)); raft::linalg::transpose(handle, BX.view(), BXT.view()); - raft::linalg::gemm(handle, BXT.view(), activeBlockVectorR.view(), BXTR.view()); + raft::linalg::gemm(handle, BXT.view(), activeR.view(), BXTR.view()); raft::linalg::gemm(handle, X, BXTR.view(), XBXTR.view()); raft::linalg::subtract(handle, - raft::make_const_mdspan(activeBlockVectorR.view()), + raft::make_const_mdspan(activeR.view()), raft::make_const_mdspan(XBXTR.view()), - activeBlockVectorR.view()); + activeR.view()); } else { - auto XTR = raft::make_device_matrix(handle, XT.extent(0), activeBlockVectorR.extent(1)); + auto XTR = raft::make_device_matrix(handle, XT.extent(0), activeR.extent(1)); auto XXTR = raft::make_device_matrix(handle, X.extent(0), XTR.extent(1)); - raft::linalg::gemm(handle, XT.view(), activeBlockVectorR.view(), XTR.view()); + raft::linalg::gemm(handle, XT.view(), activeR.view(), XTR.view()); raft::linalg::gemm(handle, X, XTR.view(), XXTR.view()); raft::linalg::subtract(handle, - raft::make_const_mdspan(activeBlockVectorR.view()), + raft::make_const_mdspan(activeR.view()), raft::make_const_mdspan(XXTR.view()), - activeBlockVectorR.view()); + activeR.view()); } // B-orthonormalize the preconditioned residuals. - auto BR = raft::make_device_matrix(handle, activeBlockVectorR.extent(0), activeBlockVectorR.extent(1)); - b_orthonormalize(handle, activeBlockVectorR.view(), BR.view(), B_opt); + auto BR = raft::make_device_matrix(handle, activeR.extent(0), activeR.extent(1)); + auto BRView = BR.view(); + b_orthonormalize(handle, activeR.view(), BRView, B_opt); - auto AR = raft::make_device_matrix(handle, n, activeBlockVectorR.extent(1)); - spmm(handle, A, activeBlockVectorR.view(), AR.view()); + auto activeAR = raft::make_device_matrix(handle, n, activeR.extent(1)); + spmm(handle, A, activeR.view(), activeAR.view()); if (iteration_number > 0) { @@ -716,13 +726,13 @@ void lobpcg( auto normal = raft::make_device_vector(handle, activePView.extent(1)); int b_orth_status = 0; if (B_opt.has_value()) - { - b_orth_status = b_orthonormalize(handle, activePView, activeBPView, B_opt, std::make_optional(invR.view()), std::make_optional(normal.view()), false); - } else { auto BP = raft::make_device_matrix(handle, activePView.extent(0), activePView.extent(1)); b_orth_status = b_orthonormalize(handle, activePView, BP.view(), B_opt, std::make_optional(invR.view()), std::make_optional(normal.view())); - } + } else + { + b_orth_status = b_orthonormalize(handle, activePView, activeBPView, B_opt, std::make_optional(invR.view()), std::make_optional(normal.view()), false); + } if (b_orth_status != 0) { restart = true; } else { @@ -730,6 +740,84 @@ void lobpcg( raft::linalg::gemm(handle, activeAPView, invR.view(), activeAPView); restart = false; } + + // Perform the Rayleigh Ritz Procedure: + // Compute symmetric Gram matrices: + value_t myeps = 1; // TODO: std::is_same_t ? 1e-4 : 1e-8; + if (!explicitGramFlag) + { + value_t* residual_norms_max_elem = thrust::max_element(thrust_exec_policy, + residual_norms.data_handle(), + residual_norms.data_handle() + residual_norms.size()); + value_t residual_norms_max = 0; + raft::copy(&residual_norms_max, residual_norms_max_elem, 1, stream); + explicitGramFlag = residual_norms_max <= myeps; + } + + if (!B_opt.has_value()) + { + // Shared memory assignments to simplify the code + BXView = raft::make_device_matrix_view(X.data_handle(), n, currentBlockSize); + BRView = raft::make_device_matrix_view(activeR.data_handle(), n, currentBlockSize); + //if (!restart) TODO + //activeBPView = raft::make_device_matrix_view(P.data_handle(), n, currentBlockSize); + } + } + // Common submatrices + auto gramXAR = raft::make_device_matrix(handle, X.extent(1), activeAR.extent(1)); + auto gramRAR = raft::make_device_matrix(handle, R.extent(1), activeAR.extent(1)); + auto gramXBX = raft::make_device_matrix(handle, X.extent(1), BX.extent(1)); + auto gramRBR = raft::make_device_matrix(handle, R.extent(1), BRView.extent(1)); + auto gramXBR = raft::make_device_matrix(handle, X.extent(1), BRView.extent(1)); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + activeAR.view(), gramXAR.view()); + + raft::linalg::gemm(handle, + raft::make_device_matrix_view(activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + activeAR.view(), gramRAR.view()); + + if (explicitGramFlag) + { + auto device_half = raft::make_device_scalar(handle, 0.5); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(gramRAR.data_handle(), gramRAR.extent(0), gramRAR.extent(1)), // transpose for gemm + identView, + gramRAR.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + AX.view(), gramXAX.view()); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(gramXAX.data_handle(), gramXAX.extent(0), gramXAX.extent(1)), // transpose for gemm + identView, + gramXAX.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + + raft::linalg::gemm(handle, + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + BX.view(), gramXBX.view()); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + BRView, gramRBR.view()); + raft::linalg::gemm(handle, + raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + BRView, gramXBR.view()); + } + else + { + raft::matrix::fill(handle, gramXAX.view(), value_t(0)); + raft::matrix::set_diagonal(handle, make_const_mdspan(eigLambda.view()), gramXAX.view()); + + raft::matrix::eye(handle, gramXBX.view()); + raft::matrix::eye(handle, gramRBR.view()); + raft::matrix::fill(handle, gramXBR.view(), value_t(0)); + } + if (!restart) + { + //raft::linalg::gemm(handle) } } return; From 9094650f444783a69aee92a962adca132c7eff8c Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 1 Mar 2023 19:15:55 +0100 Subject: [PATCH 10/17] Add block construction function --- cpp/include/raft/matrix/detail/matrix.cuh | 41 ++ cpp/include/raft/matrix/slice.cuh | 34 ++ .../raft/sparse/solver/detail/lobpcg.cuh | 426 ++++++++++++------ cpp/test/sparse/lobpcg.cu | 41 ++ 4 files changed, 406 insertions(+), 136 deletions(-) diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 7be3f3b017..749acab573 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -197,6 +197,47 @@ void sliceMatrix(const m_t* in, slice<<>>(in, n_rows, n_cols, out, x1, y1, x2, y2); } +/** + * @brief Kernel for copying a small matrix inside of a bigger matrix with a + * size matches that slice + * @param src_d: input matrix + * @param m: number of rows of input matrix + * @param n: number of columns of input matrix + * @param dst_d: output matrix + * @param x1, y1: coordinate of the top-left point of the wanted area (0-based) + * @param x2, y2: coordinate of the bottom-right point of the wanted area + * (1-based) + */ +template +__global__ void slice_insert( + const m_t* src_d, idx_t n_rows, idx_t n_cols, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) +{ + idx_t idx = threadIdx.x + blockDim.x * blockIdx.x; + idx_t dm = x2 - x1, dn = y2 - y1; + if (idx < dm * dn) { + idx_t i = idx % dm, j = idx / dm; + idx_t is = i + x1, js = j + y1; + dst_d[is + js * n_rows] = src_d[idx]; + } +} + +template +void sliceMatrix_insert(const m_t* in, + idx_t n_rows, + idx_t n_cols, + m_t* out, + idx_t x1, + idx_t y1, + idx_t x2, + idx_t y2, + cudaStream_t stream) +{ + // Slicing + dim3 block(64); + dim3 grid(((x2 - x1) * (y2 - y1) + block.x - 1) / block.x); + slice_insert<<>>(in, n_rows, n_cols, out, x1, y1, x2, y2); +} + /** * @brief Kernel for copying the upper triangular part of a matrix to another * @param src: input matrix with a size of mxn diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index bb92b2b86f..aa0dacf6eb 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -74,6 +74,40 @@ void slice(raft::device_resources const& handle, handle.get_stream()); } +/** + * @brief Insert a small matrix into a bigger matrix using a slice (in-place) + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix (column-major) + * @param[out] out: output matrix (column-major) + * @param[in] coords: coordinates of the insertion slice + * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice(handle, in, out, {0, 1, 4, 3}); + */ +template +void slice_insert(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + slice_coordinates coords) +{ + RAFT_EXPECTS(coords.row2 > coords.row1, "row2 must be > row1"); + RAFT_EXPECTS(coords.col2 > coords.col1, "col2 must be > col1"); + RAFT_EXPECTS(coords.row1 >= 0, "row1 must be >= 0"); + RAFT_EXPECTS(coords.row2 <= out.extent(0), "row2 must be <= number of rows in the output matrix"); + RAFT_EXPECTS(coords.col1 >= 0, "col1 must be >= 0"); + RAFT_EXPECTS(coords.col2 <= out.extent(1), + "col2 must be <= number of columns in the output matrix"); + + detail::sliceMatrix_insert(in.data_handle(), + out.extent(0), + out.extent(1), + out.data_handle(), + coords.row1, + coords.col1, + coords.row2, + coords.col2, + handle.get_stream()); +} /** @} */ // end group matrix_slice } // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 7f1a2ac3fc..c1c2b64d7a 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -60,6 +61,36 @@ struct isnan_test { HDI int operator()(const DataT a) { return isnan(a); } }; +/** + * @tparam value_t floating point type used for elements + * @tparam index_t integer type used for indexing + * Assemble a matrix from a list of blocks + */ +template +void bmat(const raft::handle_t& handle, + raft::device_matrix_view out, + std::vector> ins, + index_t n_blocks) +{ + RAFT_EXPECTS(n_blocks * n_blocks == ins.size(), "inconsistent number of blocks"); + std::vector cumulative_row(n_blocks); + std::vector cumulative_col(n_blocks); + for (index_t i = 0; i < n_blocks; i++) { + for (index_t j = 0; j < n_blocks; j++) { + raft::matrix::slice_insert( + handle, + ins[j + i * n_blocks], + out, + raft::matrix::slice_coordinates(cumulative_row[j], + cumulative_col[i], + cumulative_row[j] + ins[j + i * n_blocks].extent(0), + cumulative_col[i] + ins[j + i * n_blocks].extent(1))); + cumulative_col[i] += ins[j + i * n_blocks].extent(0); + cumulative_row[j] += ins[j + i * n_blocks].extent(1); + } + } +} + /* Modification of copyRows to reindex columns, col_major only * On a 4x3 matrix, indices could be [0, 2] to select col 0 and 2 */ @@ -218,8 +249,8 @@ void cho_solve(const raft::handle_t& handle, template int cholesky(const raft::handle_t& handle, - raft::device_matrix_view P, - bool lower = true) + raft::device_matrix_view P, + bool lower = true) { auto thrust_exec_policy = handle.get_thrust_policy(); auto stream = handle.get_stream(); @@ -258,20 +289,14 @@ int cholesky(const raft::handle_t& handle, 0, thrust::plus()); - if (h_hasnan != 0) // "lobpcg: error in cholesky, NaN in outputs" + if (h_hasnan != 0) // "lobpcg: error in cholesky, NaN in outputs" return 1; raft::matrix::fill(handle, P, value_t(0)); if (lower) { - raft::matrix::lower_triangular( - handle, - raft::make_const_mdspan(P_copy.view()), - P); + raft::matrix::lower_triangular(handle, raft::make_const_mdspan(P_copy.view()), P); } else { - raft::matrix::upper_triangular( - handle, - raft::make_const_mdspan(P_copy.view()), - P); + raft::matrix::upper_triangular(handle, raft::make_const_mdspan(P_copy.view()), P); } return 0; } @@ -335,20 +360,35 @@ void apply_constraints(const raft::handle_t& handle, raft::device_matrix_view BY, raft::device_matrix_view Y) { - auto stream = handle.get_stream(); - auto YBY_copy = raft::make_device_matrix(handle, YBY.extent(0), YBY.extent(1)); + auto stream = handle.get_stream(); + auto YBY_copy = raft::make_device_matrix( + handle, YBY.extent(0), YBY.extent(1)); raft::copy(YBY_copy.data_handle(), YBY.data_handle(), YBY.size(), stream); // TODO: Using raw-pointer because no gemm function with mdspan accept transpose op - auto YBV = raft::make_device_matrix(handle, BY.extent(1), V.extent(1)); + auto YBV = + raft::make_device_matrix(handle, BY.extent(1), V.extent(1)); value_t zero = 0; - value_t one = 1; + value_t one = 1; raft::linalg::gemm(handle, - true, false, YBV.extent(0), YBV.extent(1), BY.extent(0), &one, BY.data_handle(), - BY.extent(0), V.data_handle(), V.extent(0), &zero, YBV.data_handle(), YBV.extent(0), stream); - + true, + false, + YBV.extent(0), + YBV.extent(1), + BY.extent(0), + &one, + BY.data_handle(), + BY.extent(0), + V.data_handle(), + V.extent(0), + &zero, + YBV.data_handle(), + YBV.extent(0), + stream); + cholesky(handle, YBY_copy.view()); cho_solve(handle, raft::make_const_mdspan(YBY_copy.view()), YBV.view()); - auto BV = raft::make_device_matrix(handle, V.extent(0), YBV.extent(1)); + auto BV = + raft::make_device_matrix(handle, V.extent(0), YBV.extent(1)); raft::linalg::gemm(handle, Y, YBV.view(), BV.view()); raft::linalg::subtract(handle, raft::make_const_mdspan(Y), raft::make_const_mdspan(BV.view()), Y); } @@ -367,10 +407,7 @@ void eigh( std::optional> B_opt = std::nullopt) { if (B_opt.has_value()) { - raft::linalg::eig_dc(handle, - raft::make_const_mdspan(A), - eigVecs, - eigVals); + raft::linalg::eig_dc(handle, raft::make_const_mdspan(A), eigVecs, eigVals); return; } auto dim = A.extent(0); @@ -391,10 +428,7 @@ void eigh( raft::linalg::gemm(handle, A, Ri.view(), ARi); raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); - raft::linalg::eig_dc(handle, - raft::make_const_mdspan(F.view()), - Fvecs.view(), - eigVals); + raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); } @@ -537,7 +571,7 @@ void lobpcg( // GramYBY // ApplyConstraints }*/ - auto BX = raft::make_device_matrix(handle, n, size_x); + auto BX = raft::make_device_matrix(handle, n, size_x); auto BXView = BX.view(); b_orthonormalize(handle, X, BXView, B_opt); // Compute the initial Ritz vectors: solve the eigenproblem. @@ -572,16 +606,19 @@ void lobpcg( raft::matrix::eye(handle, ident.view()); raft::matrix::eye(handle, ident0.view()); - auto Pbuffer = rmm::device_uvector(0, stream); + auto Pbuffer = rmm::device_uvector(0, stream); auto APbuffer = rmm::device_uvector(0, stream); auto BPbuffer = rmm::device_uvector(0, stream); - auto activePView = raft::make_device_matrix_view(Pbuffer.data(), 0, 0); - auto activeAPView = raft::make_device_matrix_view(APbuffer.data(), 0, 0); - auto activeBPView = raft::make_device_matrix_view(BPbuffer.data(), 0, 0); + auto activePView = + raft::make_device_matrix_view(Pbuffer.data(), 0, 0); + auto activeAPView = + raft::make_device_matrix_view(APbuffer.data(), 0, 0); + auto activeBPView = + raft::make_device_matrix_view(BPbuffer.data(), 0, 0); std::int32_t iteration_number = -1; - bool restart = true; - bool explicitGramFlag = false; + bool restart = true; + bool explicitGramFlag = false; while (iteration_number < max_iter + 1) { iteration_number += 1; // auto lambda_matrix = raft::make_device_matrix_view( handle, BX.extent(0), eigLambda.extent(0)); if (B_opt) { - raft::matrix::copy(handle, - raft::make_const_mdspan(BXView), - aux.view()); + raft::matrix::copy(handle, raft::make_const_mdspan(BXView), aux.view()); } else { - raft::matrix::copy(handle, - raft::make_const_mdspan(X), - aux.view()); + raft::matrix::copy(handle, raft::make_const_mdspan(X), aux.view()); } raft::linalg::binary_mult_skip_zero(handle, aux.view(), @@ -603,10 +636,8 @@ void lobpcg( raft::linalg::Apply::ALONG_ROWS); auto R = raft::make_device_matrix(handle, n, size_x); - raft::linalg::subtract(handle, - raft::make_const_mdspan(AX.view()), - raft::make_const_mdspan(aux.view()), - R.view()); + raft::linalg::subtract( + handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); auto aux_sum = raft::make_device_vector(handle, size_x); raft::linalg::reduce( // Could be done in-place in aux buffer @@ -622,10 +653,7 @@ void lobpcg( raft::sq_op()); auto residual_norms = raft::make_device_vector(handle, size_x); - raft::linalg::sqrt( - handle, - raft::make_const_mdspan(aux_sum.view()), - residual_norms.view()); + raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); // cupy where & active_mask raft::linalg::unary_op(handle, @@ -656,42 +684,41 @@ void lobpcg( raft::make_device_matrix(handle, R.extent(0), currentBlockSize); selectColsIf(handle, R.view(), active_mask.view(), activeR.view()); - - if (iteration_number > 0) - { + + if (iteration_number > 0) { /* TODO Pbuffer.resize(n * currentBlockSize, stream); APbuffer.resize(n * currentBlockSize, stream); BPbuffer.resize(n * currentBlockSize, stream); - activePView = raft::make_device_matrix_view(Pbuffer.data(), n, currentBlockSize); - activeAPView = raft::make_device_matrix_view(APbuffer.data(), n, currentBlockSize); - selectColsIf(handle, R.view(), active_mask.view(), activePView); - selectColsIf(handle, AP.view(), active_mask.view(), activeAPView); - if (B_opt.has_value()) { - activeBPView = raft::make_device_matrix_view(BPbuffer.data(), n, currentBlockSize); - selectColsIf(handle, BP.view(), active_mask.view(), activeBPView); + activePView = raft::make_device_matrix_view(Pbuffer.data(), n, + currentBlockSize); activeAPView = raft::make_device_matrix_view(APbuffer.data(), n, currentBlockSize); selectColsIf(handle, R.view(), + active_mask.view(), activePView); selectColsIf(handle, AP.view(), active_mask.view(), + activeAPView); if (B_opt.has_value()) { activeBPView = raft::make_device_matrix_view(BPbuffer.data(), n, currentBlockSize); selectColsIf(handle, BP.view(), + active_mask.view(), activeBPView); }*/ } - if (M_opt.has_value()) - { + if (M_opt.has_value()) { // Apply preconditioner T to the active residuals. - auto MRtemp = - raft::make_device_matrix(handle, R.extent(0), currentBlockSize); + auto MRtemp = raft::make_device_matrix( + handle, R.extent(0), currentBlockSize); spmm(handle, M_opt.value(), activeR.view(), MRtemp.view()); raft::copy(activeR.data_handle(), MRtemp.data_handle(), MRtemp.size(), stream); } // Apply constraints to the preconditioned residuals. - if (Y_opt.has_value()) - { + if (Y_opt.has_value()) { // TODO Constraint - //apply_constraints(handle, X, gramYBY.view(), BY.view(), Y_opt.value()); + // apply_constraints(handle, X, gramYBY.view(), BY.view(), Y_opt.value()); } // B-orthogonalize the preconditioned residuals to X. - if (B_opt.has_value()) - { - auto BXT = raft::make_device_matrix(handle, BX.extent(1), BX.extent(0)); - auto BXTR = raft::make_device_matrix(handle, BXT.extent(0), activeR.extent(1)); - auto XBXTR = raft::make_device_matrix(handle, X.extent(0), BXTR.extent(1)); + if (B_opt.has_value()) { + auto BXT = raft::make_device_matrix( + handle, BX.extent(1), BX.extent(0)); + auto BXTR = raft::make_device_matrix( + handle, BXT.extent(0), activeR.extent(1)); + auto XBXTR = raft::make_device_matrix( + handle, X.extent(0), BXTR.extent(1)); raft::linalg::transpose(handle, BX.view(), BXT.view()); raft::linalg::gemm(handle, BXT.view(), activeR.view(), BXTR.view()); @@ -701,9 +728,10 @@ void lobpcg( raft::make_const_mdspan(XBXTR.view()), activeR.view()); } else { - - auto XTR = raft::make_device_matrix(handle, XT.extent(0), activeR.extent(1)); - auto XXTR = raft::make_device_matrix(handle, X.extent(0), XTR.extent(1)); + auto XTR = raft::make_device_matrix( + handle, XT.extent(0), activeR.extent(1)); + auto XXTR = raft::make_device_matrix( + handle, X.extent(0), XTR.extent(1)); raft::linalg::gemm(handle, XT.view(), activeR.view(), XTR.view()); raft::linalg::gemm(handle, X, XTR.view(), XXTR.view()); @@ -713,101 +741,138 @@ void lobpcg( activeR.view()); } // B-orthonormalize the preconditioned residuals. - auto BR = raft::make_device_matrix(handle, activeR.extent(0), activeR.extent(1)); + auto BR = raft::make_device_matrix( + handle, activeR.extent(0), activeR.extent(1)); auto BRView = BR.view(); b_orthonormalize(handle, activeR.view(), BRView, B_opt); - auto activeAR = raft::make_device_matrix(handle, n, activeR.extent(1)); + auto activeAR = + raft::make_device_matrix(handle, n, activeR.extent(1)); spmm(handle, A, activeR.view(), activeAR.view()); - if (iteration_number > 0) - { - auto invR = raft::make_device_matrix(handle, activePView.extent(1), activePView.extent(1)); - auto normal = raft::make_device_vector(handle, activePView.extent(1)); + if (iteration_number > 0) { + auto invR = raft::make_device_matrix( + handle, activePView.extent(1), activePView.extent(1)); + auto normal = raft::make_device_vector(handle, activePView.extent(1)); int b_orth_status = 0; - if (B_opt.has_value()) - { - auto BP = raft::make_device_matrix(handle, activePView.extent(0), activePView.extent(1)); - b_orth_status = b_orthonormalize(handle, activePView, BP.view(), B_opt, std::make_optional(invR.view()), std::make_optional(normal.view())); - } else - { - b_orth_status = b_orthonormalize(handle, activePView, activeBPView, B_opt, std::make_optional(invR.view()), std::make_optional(normal.view()), false); - } + if (B_opt.has_value()) { + auto BP = raft::make_device_matrix( + handle, activePView.extent(0), activePView.extent(1)); + b_orth_status = b_orthonormalize(handle, + activePView, + BP.view(), + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view())); + } else { + b_orth_status = b_orthonormalize(handle, + activePView, + activeBPView, + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view()), + false); + } if (b_orth_status != 0) { restart = true; } else { - raft::linalg::binary_div_skip_zero(handle, activeAPView, raft::make_const_mdspan(normal.view()), raft::linalg::Apply::ALONG_ROWS); + raft::linalg::binary_div_skip_zero(handle, + activeAPView, + raft::make_const_mdspan(normal.view()), + raft::linalg::Apply::ALONG_ROWS); raft::linalg::gemm(handle, activeAPView, invR.view(), activeAPView); restart = false; } - + // Perform the Rayleigh Ritz Procedure: // Compute symmetric Gram matrices: - value_t myeps = 1; // TODO: std::is_same_t ? 1e-4 : 1e-8; - if (!explicitGramFlag) - { - value_t* residual_norms_max_elem = thrust::max_element(thrust_exec_policy, - residual_norms.data_handle(), - residual_norms.data_handle() + residual_norms.size()); + value_t myeps = 1; // TODO: std::is_same_t ? 1e-4 : 1e-8; + if (!explicitGramFlag) { + value_t* residual_norms_max_elem = + thrust::max_element(thrust_exec_policy, + residual_norms.data_handle(), + residual_norms.data_handle() + residual_norms.size()); value_t residual_norms_max = 0; raft::copy(&residual_norms_max, residual_norms_max_elem, 1, stream); explicitGramFlag = residual_norms_max <= myeps; } - if (!B_opt.has_value()) - { + if (!B_opt.has_value()) { // Shared memory assignments to simplify the code - BXView = raft::make_device_matrix_view(X.data_handle(), n, currentBlockSize); - BRView = raft::make_device_matrix_view(activeR.data_handle(), n, currentBlockSize); - //if (!restart) TODO - //activeBPView = raft::make_device_matrix_view(P.data_handle(), n, currentBlockSize); + BXView = raft::make_device_matrix_view( + X.data_handle(), n, currentBlockSize); + BRView = raft::make_device_matrix_view( + activeR.data_handle(), n, currentBlockSize); + // if (!restart) TODO + // activeBPView = raft::make_device_matrix_view(P.data_handle(), n, currentBlockSize); } } // Common submatrices - auto gramXAR = raft::make_device_matrix(handle, X.extent(1), activeAR.extent(1)); - auto gramRAR = raft::make_device_matrix(handle, R.extent(1), activeAR.extent(1)); - auto gramXBX = raft::make_device_matrix(handle, X.extent(1), BX.extent(1)); - auto gramRBR = raft::make_device_matrix(handle, R.extent(1), BRView.extent(1)); - auto gramXBR = raft::make_device_matrix(handle, X.extent(1), BRView.extent(1)); - raft::linalg::gemm(handle, - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm - activeAR.view(), gramXAR.view()); - - raft::linalg::gemm(handle, - raft::make_device_matrix_view(activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm - activeAR.view(), gramRAR.view()); - - if (explicitGramFlag) - { - auto device_half = raft::make_device_scalar(handle, 0.5); - raft::linalg::gemm(handle, - raft::make_device_matrix_view(gramRAR.data_handle(), gramRAR.extent(0), gramRAR.extent(1)), // transpose for gemm + auto gramXAR = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRAR = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBX = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRBR = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBR = + raft::make_device_matrix(handle, size_x, currentBlockSize); + raft::linalg::gemm(handle, + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + activeAR.view(), + gramXAR.view()); + + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + activeAR.view(), + gramRAR.view()); + + auto device_half = raft::make_device_scalar(handle, 0.5); + if (explicitGramFlag) { + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + gramRAR.data_handle(), gramRAR.extent(0), gramRAR.extent(1)), // transpose for gemm identView, gramRAR.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); - raft::linalg::gemm(handle, - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm - AX.view(), gramXAX.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view(gramXAX.data_handle(), gramXAX.extent(0), gramXAX.extent(1)), // transpose for gemm + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + AX.view(), + gramXAX.view()); + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + gramXAX.data_handle(), gramXAX.extent(0), gramXAX.extent(1)), // transpose for gemm identView, gramXAX.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); - raft::linalg::gemm(handle, - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm - BX.view(), gramXBX.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view(activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm - BRView, gramRBR.view()); + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + BX.view(), + gramXBX.view()); + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + BRView, + gramRBR.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view(X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm - BRView, gramXBR.view()); - } - else - { + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + BRView, + gramXBR.view()); + } else { raft::matrix::fill(handle, gramXAX.view(), value_t(0)); raft::matrix::set_diagonal(handle, make_const_mdspan(eigLambda.view()), gramXAX.view()); @@ -815,9 +880,98 @@ void lobpcg( raft::matrix::eye(handle, gramRBR.view()); raft::matrix::fill(handle, gramXBR.view(), value_t(0)); } - if (!restart) - { - //raft::linalg::gemm(handle) + if (!restart) { + auto gramA = raft::make_device_matrix(handle, n, n); + auto gramB = raft::make_device_matrix(handle, n, n); + auto gramXAP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + + raft::linalg::gemm(handle, + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + activeAPView, + gramXAP.view()); + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + activeAPView, + gramRAP.view()); + raft::linalg::gemm(handle, + raft::make_device_matrix_view( + activePView.data_handle(), + activePView.extent(0), + activePView.extent(1)), // transpose for gemm + activeAPView, + gramPAP.view()); + raft::linalg::gemm(handle, + raft::make_device_matrix_view( + X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + activeBPView, + gramXBP.view()); + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + activeBPView, + gramRBP.view()); + + if (explicitGramFlag) { + raft::linalg::gemm( + handle, + raft::make_device_matrix_view( + gramPAP.data_handle(), gramPAP.extent(0), gramPAP.extent(1)), // transpose for gemm + ident.view(), + gramPAP.view(), + std::make_optional(device_half.view()), + std::make_optional(device_half.view())); + raft::linalg::gemm(handle, + raft::make_device_matrix_view( + activePView.data_handle(), + activePView.extent(0), + activePView.extent(1)), // transpose for gemm + activeBPView, + gramPBP.view()); + } else { + raft::matrix::eye(handle, gramPBP.view()); + } + + // create transpose mat + auto gramXAPT = raft::make_device_matrix( + handle, gramXAPT.extent(1), gramXAPT.extent(0)); + auto gramXART = raft::make_device_matrix( + handle, gramXART.extent(1), gramXART.extent(0)); + auto gramRAPT = raft::make_device_matrix( + handle, gramRAPT.extent(1), gramRAPT.extent(0)); + auto gramXBPT = raft::make_device_matrix( + handle, gramXBPT.extent(1), gramXBPT.extent(0)); + auto gramXBRT = raft::make_device_matrix( + handle, gramXBRT.extent(1), gramXBRT.extent(0)); + auto gramRBPT = raft::make_device_matrix( + handle, gramRBPT.extent(1), gramRBPT.extent(0)); + raft::linalg::transpose(handle, gramXAP.view(), gramXAPT.view()); + raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); + raft::linalg::transpose(handle, gramRAP.view(), gramRAPT.view()); + raft::linalg::transpose(handle, gramXBP.view(), gramXBPT.view()); + raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view()); + raft::linalg::transpose(handle, gramRBP.view(), gramRBPT.view()); + + std::vector> A_blocks = { + gramXAX, gramXAR, gramXAP, gramXART, gramRAR, gramRAP, gramXAPT, gramRAPT, gramPAP}; + std::vector> B_blocks = { + gramXBX, gramXBR, gramXBP, gramXBRT, gramRBR, gramRBP, gramXBPT, gramRBPT, gramPBP}; + bmat(handle, gramA, A_blocks); + bmat(handle, gramB, B_blocks); } } return; diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index 81d3d0d30b..2d5fe4109b 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -111,6 +111,46 @@ class LOBPCGTest : public ::testing::TestWithParam> ASSERT_TRUE(hostVecMatch(expected, res, raft::CompareApprox(0.0001))); } + void test_bmat() + { + auto total = raft::make_device_matrix(handle, 6, 6); + auto x1 = raft::make_device_matrix(handle, 2, 2); + auto x2 = raft::make_device_matrix(handle, 2, 2); + auto x3 = raft::make_device_matrix(handle, 2, 2); + auto x4 = raft::make_device_matrix(handle, 2, 2); + auto x5 = raft::make_device_matrix(handle, 2, 2); + auto x6 = raft::make_device_matrix(handle, 2, 2); + auto x7 = raft::make_device_matrix(handle, 2, 2); + auto x8 = raft::make_device_matrix(handle, 2, 2); + auto x9 = raft::make_device_matrix(handle, 2, 2); + raft::linalg::range(x1.data_handle(), 0, 4, handle.get_stream()); + raft::linalg::range(x2.data_handle(), 4, 8, handle.get_stream()); + raft::linalg::range(x3.data_handle(), 8, 12, handle.get_stream()); + raft::linalg::range(x4.data_handle(), 12, 16, handle.get_stream()); + raft::linalg::range(x5.data_handle(), 16, 20, handle.get_stream()); + raft::linalg::range(x6.data_handle(), 20, 24, handle.get_stream()); + raft::linalg::range(x7.data_handle(), 24, 28, handle.get_stream()); + raft::linalg::range(x8.data_handle(), 28, 32, handle.get_stream()); + raft::linalg::range(x9.data_handle(), 32, 36, handle.get_stream()); + std::vector> xs = {x1.view(), + x2.view(), + x3.view(), + x4.view(), + x5.view(), + x6.view(), + x7.view(), + x8.view(), + x9.view()}; + raft::sparse::solver::detail::bmat(handle, total.view(), xs, 3); + std::vector res(total.size()); + std::vector expected{0, 1, 12, 13, 24, 25, 2, 3, 14, 15, 26, 27, + 4, 5, 16, 17, 28, 29, 6, 7, 18, 19, 30, 31, + 8, 9, 20, 21, 32, 33, 10, 11, 22, 23, 34, 35}; + raft::copy(res.data(), total.data_handle(), total.size(), handle.get_stream()); + handle.sync_stream(); + ASSERT_TRUE(hostVecMatch(expected, res, raft::CompareApprox(0.0001))); + } + void test_b_orthonormalize() { idx_t n_rows_v = n_rows_a; @@ -135,6 +175,7 @@ class LOBPCGTest : public ::testing::TestWithParam> void Run() { + test_bmat(); test_selectcolsif(); test_b_orthonormalize(); raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows_a, stream); From a7d205cb03a6a47096729dd5cd10365934a86ab3 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 14 Mar 2023 23:06:58 +0100 Subject: [PATCH 11/17] Eig iteration computation --- .../raft/sparse/solver/detail/lobpcg.cuh | 233 ++++++++++++------ 1 file changed, 151 insertions(+), 82 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index c1c2b64d7a..3f0f0d096e 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -69,10 +69,17 @@ struct isnan_test { template void bmat(const raft::handle_t& handle, raft::device_matrix_view out, - std::vector> ins, + const std::vector>& ins, index_t n_blocks) { RAFT_EXPECTS(n_blocks * n_blocks == ins.size(), "inconsistent number of blocks"); + index_t n_rows = 0; + index_t n_cols = 0; + for (const auto& inView : ins) { + n_rows += inView.extent(0); + n_cols += inView.extent(1); + } + RAFT_EXPECTS(n_rows == out.extent(0) n_cols == out.extent(1), "input/output dimension mismatch"); std::vector cumulative_row(n_blocks); std::vector cumulative_col(n_blocks); for (index_t i = 0; i < n_blocks; i++) { @@ -144,21 +151,32 @@ void selectColsIf(const raft::handle_t& handle, stream); } +/** + * Reverse if needed the eigenvalues/vectors and truncate the columns to fit eigVectorTrunc + */ template -void truncEig(const raft::handle_t& handle, - raft::device_matrix_view eigVector, - raft::device_vector_view eigLambda, - index_t size_x, - bool largest) +void truncEig( + const raft::handle_t& handle, + raft::device_matrix_view eigVectorin, + std::optional> eigVectorTrunc, + raft::device_vector_view eigLambda, + bool largest) { // The eigenvalues are already sorted in ascending order with syevd + auto nrows = eigVectorin.extent(0); + auto ncols = eigVectorin.extent(1); if (largest) { - auto nrows = eigVector.extent(0); - auto ncols = eigVector.extent(1); - raft::matrix::col_reverse(handle, eigVector); + raft::matrix::col_reverse(handle, eigVectorin); raft::matrix::col_reverse( handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0))); } + if (eigVectorTrunc.has_value() && ncols > eigVectorTrunc->extent(1)) + raft::matrix::truncZeroOrigin(eigVectorin.data_handle(), + n_rows, + eigVectorTrunc->data_handle(), + nrows, + eigVectorTrunc->extent(1), + stream); } // C = A * B @@ -248,9 +266,9 @@ void cho_solve(const raft::handle_t& handle, } template -int cholesky(const raft::handle_t& handle, - raft::device_matrix_view P, - bool lower = true) +bool cholesky(const raft::handle_t& handle, + raft::device_matrix_view P, + bool lower = true) { auto thrust_exec_policy = handle.get_thrust_policy(); auto stream = handle.get_stream(); @@ -290,7 +308,7 @@ int cholesky(const raft::handle_t& handle, thrust::plus()); if (h_hasnan != 0) // "lobpcg: error in cholesky, NaN in outputs" - return 1; + return false; raft::matrix::fill(handle, P, value_t(0)); if (lower) { @@ -298,7 +316,7 @@ int cholesky(const raft::handle_t& handle, } else { raft::matrix::upper_triangular(handle, raft::make_const_mdspan(P_copy.view()), P); } - return 0; + return true; } template @@ -399,24 +417,23 @@ void apply_constraints(const raft::handle_t& handle, * transformation */ template -void eigh( - const raft::handle_t& handle, - raft::device_matrix_view A, - raft::device_matrix_view eigVecs, - raft::device_vector_view eigVals, - std::optional> B_opt = std::nullopt) +bool eigh(const raft::handle_t& handle, + raft::device_matrix_view A, + std::optional> B_opt, + raft::device_matrix_view eigVecs, + raft::device_vector_view eigVals) { - if (B_opt.has_value()) { + if (!B_opt.has_value()) { raft::linalg::eig_dc(handle, raft::make_const_mdspan(A), eigVecs, eigVals); - return; + return true; } - auto dim = A.extent(0); - auto RTi = raft::make_device_matrix(handle, dim, dim); - auto Ri = raft::make_device_matrix(handle, dim, dim); - auto RT = raft::make_device_matrix(handle, dim, dim); - auto F = raft::make_device_matrix(handle, dim, dim); - auto B = B_opt.value(); - cholesky(handle, B, false); + auto dim = A.extent(0); + auto RTi = raft::make_device_matrix(handle, dim, dim); + auto Ri = raft::make_device_matrix(handle, dim, dim); + auto RT = raft::make_device_matrix(handle, dim, dim); + auto F = raft::make_device_matrix(handle, dim, dim); + auto B = B_opt.value(); + bool cho_success = cholesky(handle, B, false); raft::linalg::transpose(handle, B, RT.view()); inverse(handle, RT.view(), Ri.view()); @@ -430,6 +447,7 @@ void eigh( raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); + return cho_success } /** @@ -444,9 +462,10 @@ void eigh( * @param[out] VBV_opt: optional dense matrix containing inverse matrix (shape v[1] * v[1]) * @param[out] V_max_opt: optional vector containing normalization of V (shape v[1]) * @param[in] bv_is_empty: True if BV is used as input + * @return success status */ template -int b_orthonormalize( +bool b_orthonormalize( const raft::handle_t& handle, raft::device_matrix_view V, raft::device_matrix_view BV, @@ -516,14 +535,14 @@ int b_orthonormalize( raft::linalg::transpose(handle, V, VT.view()); raft::linalg::gemm(handle, VT.view(), BV, VBV); - int result_status = cholesky(handle, VBV, false); - if (result_status != 0) { return result_status; } + bool cholesky_success = cholesky(handle, VBV, false); + if (!cholesky_success) { return cholesky_success; } inverse(handle, VBV, VBVBuffer.view()); raft::copy(VBV.data_handle(), VBVBuffer.data_handle(), VBV.size(), stream); raft::linalg::gemm(handle, V, VBV, V); if (B_opt) raft::linalg::gemm(handle, BV, VBV, BV); - return 0; + return true; } template @@ -753,27 +772,27 @@ void lobpcg( if (iteration_number > 0) { auto invR = raft::make_device_matrix( handle, activePView.extent(1), activePView.extent(1)); - auto normal = raft::make_device_vector(handle, activePView.extent(1)); - int b_orth_status = 0; + auto normal = raft::make_device_vector(handle, activePView.extent(1)); + bool b_orth_success = true; if (B_opt.has_value()) { auto BP = raft::make_device_matrix( handle, activePView.extent(0), activePView.extent(1)); - b_orth_status = b_orthonormalize(handle, - activePView, - BP.view(), - B_opt, - std::make_optional(invR.view()), - std::make_optional(normal.view())); + b_orth_success = b_orthonormalize(handle, + activePView, + BP.view(), + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view())); } else { - b_orth_status = b_orthonormalize(handle, - activePView, - activeBPView, - B_opt, - std::make_optional(invR.view()), - std::make_optional(normal.view()), - false); + b_orth_success = b_orthonormalize(handle, + activePView, + activeBPView, + B_opt, + std::make_optional(invR.view()), + std::make_optional(normal.view()), + false); } - if (b_orth_status != 0) { + if (!b_orth_success) { restart = true; } else { raft::linalg::binary_div_skip_zero(handle, @@ -880,22 +899,45 @@ void lobpcg( raft::matrix::eye(handle, gramRBR.view()); raft::matrix::fill(handle, gramXBR.view(), value_t(0)); } - if (!restart) { - auto gramA = raft::make_device_matrix(handle, n, n); - auto gramB = raft::make_device_matrix(handle, n, n); - auto gramXAP = - raft::make_device_matrix(handle, size_x, currentBlockSize); - auto gramRAP = raft::make_device_matrix( - handle, currentBlockSize, currentBlockSize); - auto gramPAP = raft::make_device_matrix( - handle, currentBlockSize, currentBlockSize); - auto gramXBP = - raft::make_device_matrix(handle, size_x, currentBlockSize); - auto gramRBP = raft::make_device_matrix( - handle, currentBlockSize, currentBlockSize); - auto gramPBP = raft::make_device_matrix( - handle, currentBlockSize, currentBlockSize); + auto gramDim = gramXAX.extent(1) + gramXAR.extent(1) + gramXAP.extent(1); + auto gramA = raft::make_device_matrix(handle, gramDim, gramDim); + auto gramB = raft::make_device_matrix(handle, gramDim, gramDim); + auto gramAView = gramA.view(); + auto gramBView = gramB.view(); + auto eigLambdaTemp = raft::make_device_vector_view(handle, gramDim); + auto eigVectorTemp = + raft::make_device_matrix_view(handle, gramDim, gramDim); + auto eigLambdaTempView = eigLambdaTemp.view(); + auto eigVectorTempView = eigVectorTemp.view(); + auto gramXAP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPAP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramXBP = + raft::make_device_matrix(handle, size_x, currentBlockSize); + auto gramRBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + auto gramPBP = raft::make_device_matrix( + handle, currentBlockSize, currentBlockSize); + // create transpose mat + auto gramXAPT = raft::make_device_matrix( + handle, gramXAPT.extent(1), gramXAPT.extent(0)); + auto gramXART = raft::make_device_matrix( + handle, gramXART.extent(1), gramXART.extent(0)); + auto gramRAPT = raft::make_device_matrix( + handle, gramRAPT.extent(1), gramRAPT.extent(0)); + auto gramXBPT = raft::make_device_matrix( + handle, gramXBPT.extent(1), gramXBPT.extent(0)); + auto gramXBRT = raft::make_device_matrix( + handle, gramXBRT.extent(1), gramXBRT.extent(0)); + auto gramRBPT = raft::make_device_matrix( + handle, gramRBPT.extent(1), gramRBPT.extent(0)); + raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); + raft::linalg::transpose(handle, gramXVR.view(), gramXBRT.view()); + if (!restart) { raft::linalg::gemm(handle, raft::make_device_matrix_view( X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm @@ -945,33 +987,60 @@ void lobpcg( } else { raft::matrix::eye(handle, gramPBP.view()); } - - // create transpose mat - auto gramXAPT = raft::make_device_matrix( - handle, gramXAPT.extent(1), gramXAPT.extent(0)); - auto gramXART = raft::make_device_matrix( - handle, gramXART.extent(1), gramXART.extent(0)); - auto gramRAPT = raft::make_device_matrix( - handle, gramRAPT.extent(1), gramRAPT.extent(0)); - auto gramXBPT = raft::make_device_matrix( - handle, gramXBPT.extent(1), gramXBPT.extent(0)); - auto gramXBRT = raft::make_device_matrix( - handle, gramXBRT.extent(1), gramXBRT.extent(0)); - auto gramRBPT = raft::make_device_matrix( - handle, gramRBPT.extent(1), gramRBPT.extent(0)); raft::linalg::transpose(handle, gramXAP.view(), gramXAPT.view()); - raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); raft::linalg::transpose(handle, gramRAP.view(), gramRAPT.view()); raft::linalg::transpose(handle, gramXBP.view(), gramXBPT.view()); - raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view()); raft::linalg::transpose(handle, gramRBP.view(), gramRBPT.view()); std::vector> A_blocks = { gramXAX, gramXAR, gramXAP, gramXART, gramRAR, gramRAP, gramXAPT, gramRAPT, gramPAP}; std::vector> B_blocks = { gramXBX, gramXBR, gramXBP, gramXBRT, gramRBR, gramRBP, gramXBPT, gramRBPT, gramPBP}; - bmat(handle, gramA, A_blocks); - bmat(handle, gramB, B_blocks); + gramAView = + raft::make_device_matrix_view(gramA.data_handle(), n, n); + gramBView = + raft::make_device_matrix_view(gramB.data_handle(), n, n); + + bmat(handle, gramAView, A_blocks); + bmat(handle, gramBView, B_blocks); + + bool eig_sucess = + eigh(handle, gramA, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); + if (!eig_sucess) restart = true; + } + if (restart) { + gramDim = gramXAX.extent(1) + gramXAR.extent(1); + std::vector> A_blocks = { + gramXAX, gramXAR, gramXART, gramRAR}; + std::vector> B_blocks = { + gramXBX, gramXBR, gramXBRT, gramRBR}; + gramAView = raft::make_device_matrix_view( + gramA.data_handle(), gramDim, gramDim); + gramBView = raft::make_device_matrix_view( + gramB.data_handle(), gramDim, gramDim); + eigLambdaTempView = + raft::make_device_vector_view(eigLambdaTempView.data_handle(), gramDim); + eigVectorTempView = raft::make_device_matrix_view( + eigVectorTempView.data_handle(), gramDim, gramDim); + bmat(handle, gramAView, A_blocks); + bmat(handle, gramBView, B_blocks); + bool eig_sucess = eigh( + handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); + ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations"); + } + truncEig( + handle, eigVectorTempView, std::make_optional(eigVector.view()), eigLambdaTempView, largest); + raft::copy(eigLambda.data_handle(), eigLambdaTempView.data_handle(), size_x, stream); + + // Verbosity print + + // Compute Ritz vectors. + if (B_opt.has_value()) { + auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); + auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); + if (!restart) { + raft::matrix::truncZeroOrigin(eigVector.data_handle(), ) + } } } return; From 03c1760ba6c755346ed68effe7ab7a27f100b770 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 30 Mar 2023 02:03:27 +0200 Subject: [PATCH 12/17] Add update on the main iteration loop --- cpp/include/raft/linalg/gemv.cuh | 4 +- .../raft/sparse/solver/detail/lobpcg.cuh | 225 ++++++++++++------ 2 files changed, 156 insertions(+), 73 deletions(-) diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 96846003f6..df916d6d12 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -233,8 +233,8 @@ void gemv(raft::device_resources const& handle, * @tparam LayoutPolicyZ layout of Z * @param[in] handle raft handle * @param[in] A input raft::device_matrix_view of size (M, N) - * @param[in] x input raft::device_matrix_view of size (N, 1) if A is raft::col_major, else (M, 1) - * @param[out] y output raft::device_matrix_view of size (M, 1) if A is raft::col_major, else (N, 1) + * @param[in] x input raft::device_vector_view of size (N, 1) if A is raft::col_major, else (M, 1) + * @param[out] y output raft::device_vector_view of size (M, 1) if A is raft::col_major, else (N, 1) * @param[in] alpha optional raft::host_scalar_view or raft::device_scalar_view, default 1.0 * @param[in] beta optional raft::host_scalar_view or raft::device_scalar_view, default 0.0 */ diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 3f0f0d096e..6c016a54ec 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -79,7 +79,7 @@ void bmat(const raft::handle_t& handle, n_rows += inView.extent(0); n_cols += inView.extent(1); } - RAFT_EXPECTS(n_rows == out.extent(0) n_cols == out.extent(1), "input/output dimension mismatch"); + RAFT_EXPECTS(n_rows == out.extent(0) && n_cols == out.extent(1), "input/output dimension mismatch"); std::vector cumulative_row(n_blocks); std::vector cumulative_col(n_blocks); for (index_t i = 0; i < n_blocks; i++) { @@ -382,7 +382,7 @@ void apply_constraints(const raft::handle_t& handle, auto YBY_copy = raft::make_device_matrix( handle, YBY.extent(0), YBY.extent(1)); raft::copy(YBY_copy.data_handle(), YBY.data_handle(), YBY.size(), stream); - // TODO: Using raw-pointer because no gemm function with mdspan accept transpose op + // TODO: Use mdspan gemm with row-major to transpose auto YBV = raft::make_device_matrix(handle, BY.extent(1), V.extent(1)); value_t zero = 0; @@ -601,49 +601,56 @@ void lobpcg( auto XT = raft::make_device_matrix(handle, size_x, n); raft::linalg::transpose(handle, X, XT.view()); raft::linalg::gemm(handle, XT.view(), AX.view(), gramXAX.view()); - auto eigVector = - raft::make_device_matrix(handle, size_x, size_x); + auto eigVectorBuffer = rmm::device_uvector(size_x * size_x, stream); // rmm because of resize + auto eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), size_x, size_x); auto eigLambda = raft::make_device_vector(handle, size_x); - eigh(handle, gramXAX.view(), eigVector.view(), eigLambda.view()); - truncEig(handle, eigVector.view(), eigLambda.view(), size_x, largest); + eigh(handle, gramXAX.view(), eigVectorView, eigLambda.view()); + truncEig(handle, eigVectorView, eigLambda.view(), size_x, largest); // Slice not needed for first eigh // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, // eigVectorFull.extent(0), size_x)); - raft::linalg::gemm(handle, X, eigVector.view(), X); - raft::linalg::gemm(handle, AX.view(), eigVector.view(), AX.view()); - if (B_opt) raft::linalg::gemm(handle, BXView, eigVector.view(), BXView); + raft::linalg::gemm(handle, X, eigVectorView, X); + raft::linalg::gemm(handle, AX.view(), eigVectorView, AX.view()); + if (B_opt) raft::linalg::gemm(handle, BXView, eigVectorView, BXView); // Active index set // TODO: use uint8_t auto active_mask = raft::make_device_vector(handle, size_x); auto previousBlockSize = size_x; - auto ident = raft::make_device_matrix(handle, size_x, size_x); - auto ident0 = raft::make_device_matrix(handle, size_x, size_x); - // TODO: Maybe initialization of ident here is not needed? - raft::matrix::eye(handle, ident.view()); - raft::matrix::eye(handle, ident0.view()); + auto ident = rmm::device_uvector(size_x * size_x, stream); + auto identView = raft::make_device_matrix_view( + ident.data(), size_x, size_x); + raft::matrix::eye(handle, identView); auto Pbuffer = rmm::device_uvector(0, stream); auto APbuffer = rmm::device_uvector(0, stream); auto BPbuffer = rmm::device_uvector(0, stream); - auto activePView = + auto PView = raft::make_device_matrix_view(Pbuffer.data(), 0, 0); - auto activeAPView = + auto APView = raft::make_device_matrix_view(APbuffer.data(), 0, 0); - auto activeBPView = + auto BPView = raft::make_device_matrix_view(BPbuffer.data(), 0, 0); + auto activePbuffer = rmm::device_uvector(0, stream); + auto activeAPbuffer = rmm::device_uvector(0, stream); + auto activeBPbuffer = rmm::device_uvector(0, stream); + auto activePView = + raft::make_device_matrix_view(activePbuffer.data(), 0, 0); + auto activeAPView = + raft::make_device_matrix_view(activeAPbuffer.data(), 0, 0); + auto activeBPView = + raft::make_device_matrix_view(activeBPbuffer.data(), 0, 0); + auto R = raft::make_device_matrix(handle, n, size_x); + auto aux = raft::make_device_matrix( + handle, n, size_x); std::int32_t iteration_number = -1; bool restart = true; bool explicitGramFlag = false; while (iteration_number < max_iter + 1) { iteration_number += 1; - // auto lambda_matrix = raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)); - auto aux = raft::make_device_matrix( - handle, BX.extent(0), eigLambda.extent(0)); if (B_opt) { raft::matrix::copy(handle, raft::make_const_mdspan(BXView), aux.view()); } else { @@ -654,12 +661,11 @@ void lobpcg( raft::make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS); - auto R = raft::make_device_matrix(handle, n, size_x); raft::linalg::subtract( handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); auto aux_sum = raft::make_device_vector(handle, size_x); - raft::linalg::reduce( // Could be done in-place in aux buffer + raft::linalg::reduce( aux_sum.data_handle(), R.data_handle(), size_x, @@ -686,12 +692,12 @@ void lobpcg( active_mask.data_handle(), active_mask.data_handle() + active_mask.size(), 0); - auto identView = raft::make_device_matrix_view( - ident.data_handle(), previousBlockSize, previousBlockSize); + handle.sync_stream(); if (currentBlockSize != previousBlockSize) { previousBlockSize = currentBlockSize; + ident.resize(currentBlockSize * currentBlockSize, stream); identView = raft::make_device_matrix_view( - ident.data_handle(), currentBlockSize, currentBlockSize); + ident.data(), currentBlockSize, currentBlockSize); raft::matrix::eye(handle, identView); } @@ -700,23 +706,22 @@ void lobpcg( // TODO add verb } auto activeR = - raft::make_device_matrix(handle, R.extent(0), currentBlockSize); + raft::make_device_matrix(handle, n, currentBlockSize); selectColsIf(handle, R.view(), active_mask.view(), activeR.view()); if (iteration_number > 0) { - /* TODO - Pbuffer.resize(n * currentBlockSize, stream); - APbuffer.resize(n * currentBlockSize, stream); - BPbuffer.resize(n * currentBlockSize, stream); - activePView = raft::make_device_matrix_view(Pbuffer.data(), n, - currentBlockSize); activeAPView = raft::make_device_matrix_view(APbuffer.data(), n, currentBlockSize); selectColsIf(handle, R.view(), - active_mask.view(), activePView); selectColsIf(handle, AP.view(), active_mask.view(), - activeAPView); if (B_opt.has_value()) { activeBPView = raft::make_device_matrix_view(BPbuffer.data(), n, currentBlockSize); selectColsIf(handle, BP.view(), - active_mask.view(), activeBPView); - }*/ + activePbuffer.resize(n * currentBlockSize, stream); + activeAPbuffer.resize(n * currentBlockSize, stream); + activeBPbuffer.resize(n * currentBlockSize, stream); + activePView = raft::make_device_matrix_view(activePbuffer.data(), n, currentBlockSize); + activeAPView = raft::make_device_matrix_view(activeAPbuffer.data(), n, currentBlockSize); + selectColsIf(handle, PView, active_mask.view(), activePView); + selectColsIf(handle, APView, active_mask.view(), activeAPView); + if (B_opt.has_value()) { + activeBPView = raft::make_device_matrix_view(activeBPbuffer.data(), n, currentBlockSize); + selectColsIf(handle, BPbuffer.view(), active_mask.view(), activeBPView); + } } if (M_opt.has_value()) { // Apply preconditioner T to the active residuals. @@ -760,10 +765,10 @@ void lobpcg( activeR.view()); } // B-orthonormalize the preconditioned residuals. - auto BR = raft::make_device_matrix( + auto activeBR = raft::make_device_matrix( handle, activeR.extent(0), activeR.extent(1)); - auto BRView = BR.view(); - b_orthonormalize(handle, activeR.view(), BRView, B_opt); + auto activeBRView = activeBR.view(); + b_orthonormalize(handle, activeR.view(), activeBRView, B_opt); auto activeAR = raft::make_device_matrix(handle, n, activeR.extent(1)); @@ -818,13 +823,10 @@ void lobpcg( if (!B_opt.has_value()) { // Shared memory assignments to simplify the code - BXView = raft::make_device_matrix_view( - X.data_handle(), n, currentBlockSize); - BRView = raft::make_device_matrix_view( - activeR.data_handle(), n, currentBlockSize); - // if (!restart) TODO - // activeBPView = raft::make_device_matrix_view(P.data_handle(), n, currentBlockSize); + BXView = X.view(); + activeBRView = activeR.view(); + if (!restart) + activeBPView = activePView; } } // Common submatrices @@ -884,12 +886,12 @@ void lobpcg( handle, raft::make_device_matrix_view( activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm - BRView, + activeBRView, gramRBR.view()); raft::linalg::gemm(handle, raft::make_device_matrix_view( X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm - BRView, + activeBRView, gramXBR.view()); } else { raft::matrix::fill(handle, gramXAX.view(), value_t(0)); @@ -899,7 +901,7 @@ void lobpcg( raft::matrix::eye(handle, gramRBR.view()); raft::matrix::fill(handle, gramXBR.view(), value_t(0)); } - auto gramDim = gramXAX.extent(1) + gramXAR.extent(1) + gramXAP.extent(1); + auto gramDim = gramXAX.extent(1) + gramXAR.extent(1) + currentBlockSize; auto gramA = raft::make_device_matrix(handle, gramDim, gramDim); auto gramB = raft::make_device_matrix(handle, gramDim, gramDim); auto gramAView = gramA.view(); @@ -909,6 +911,8 @@ void lobpcg( raft::make_device_matrix_view(handle, gramDim, gramDim); auto eigLambdaTempView = eigLambdaTemp.view(); auto eigVectorTempView = eigVectorTemp.view(); + eigVectorBuffer.resize(gramDim * size_x, stream); + eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), gramDim, size_x); auto gramXAP = raft::make_device_matrix(handle, size_x, currentBlockSize); auto gramRAP = raft::make_device_matrix( @@ -922,17 +926,17 @@ void lobpcg( auto gramPBP = raft::make_device_matrix( handle, currentBlockSize, currentBlockSize); // create transpose mat - auto gramXAPT = raft::make_device_matrix( + auto gramXAPT = raft::make_device_matrix( handle, gramXAPT.extent(1), gramXAPT.extent(0)); - auto gramXART = raft::make_device_matrix( + auto gramXART = raft::make_device_matrix( handle, gramXART.extent(1), gramXART.extent(0)); - auto gramRAPT = raft::make_device_matrix( + auto gramRAPT = raft::make_device_matrix( handle, gramRAPT.extent(1), gramRAPT.extent(0)); - auto gramXBPT = raft::make_device_matrix( + auto gramXBPT = raft::make_device_matrix( handle, gramXBPT.extent(1), gramXBPT.extent(0)); - auto gramXBRT = raft::make_device_matrix( + auto gramXBRT = raft::make_device_matrix( handle, gramXBRT.extent(1), gramXBRT.extent(0)); - auto gramRBPT = raft::make_device_matrix( + auto gramRBPT = raft::make_device_matrix( handle, gramRBPT.extent(1), gramRBPT.extent(0)); raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); raft::linalg::transpose(handle, gramXVR.view(), gramXBRT.view()); @@ -973,7 +977,7 @@ void lobpcg( handle, raft::make_device_matrix_view( gramPAP.data_handle(), gramPAP.extent(0), gramPAP.extent(1)), // transpose for gemm - ident.view(), + identView, gramPAP.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); @@ -992,10 +996,10 @@ void lobpcg( raft::linalg::transpose(handle, gramXBP.view(), gramXBPT.view()); raft::linalg::transpose(handle, gramRBP.view(), gramRBPT.view()); - std::vector> A_blocks = { - gramXAX, gramXAR, gramXAP, gramXART, gramRAR, gramRAP, gramXAPT, gramRAPT, gramPAP}; - std::vector> B_blocks = { - gramXBX, gramXBR, gramXBP, gramXBRT, gramRBR, gramRBP, gramXBPT, gramRBPT, gramPBP}; + std::vector> A_blocks = { + gramXAX.view(), gramXAR.view(), gramXAP.view(), gramXART.view(), gramRAR.view(), gramRAP.view(), gramXAPT.view(), gramRAPT.view(), gramPAP.view()}; + std::vector> B_blocks = { + gramXBX.view(), gramXBR.view(), gramXBP.view(), gramXBRT.view(), gramRBR.view(), gramRBP.view(), gramXBPT.view(), gramRBPT.view(), gramPBP.view()}; gramAView = raft::make_device_matrix_view(gramA.data_handle(), n, n); gramBView = @@ -1010,9 +1014,9 @@ void lobpcg( } if (restart) { gramDim = gramXAX.extent(1) + gramXAR.extent(1); - std::vector> A_blocks = { + std::vector> A_blocks = { gramXAX, gramXAR, gramXART, gramRAR}; - std::vector> B_blocks = { + std::vector> B_blocks = { gramXBX, gramXBR, gramXBRT, gramRBR}; gramAView = raft::make_device_matrix_view( gramA.data_handle(), gramDim, gramDim); @@ -1029,21 +1033,100 @@ void lobpcg( ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations"); } truncEig( - handle, eigVectorTempView, std::make_optional(eigVector.view()), eigLambdaTempView, largest); + handle, eigVectorTempView, std::make_optional(eigVectorView), eigLambdaTempView, largest); raft::copy(eigLambda.data_handle(), eigLambdaTempView.data_handle(), size_x, stream); // Verbosity print // Compute Ritz vectors. + auto d_one = raft::make_device_scalar(handle, 1); + auto one = std::make_optional(d_one.view()); + auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); + auto eigBlockVectorR = raft::make_device_matrix(handle, currentBlockSize, size_x); + auto eigBlockVectorP = raft::make_device_matrix(handle, gramDim - (size_x + currentBlockSize), size_x); + auto pp = raft::make_device_matrix(handle, n, size_x); + auto app = raft::make_device_matrix(handle, n, size_x); if (B_opt.has_value()) { - auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); - auto eigBlockVectorX = raft::make_device_matrix(handle, size_x, size_x); + auto bpp = raft::make_device_matrix(handle, n, size_x); + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice_coordinates(0, 0, size_x, size_x)); + if (!restart) { + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); + } else { + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); + } + + raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view()); + raft::linalg::gemm(handle, activeBRView, eigBlockVectorR.view(), bpp.view()); + if (!restart) { + raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); + raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one); + raft::linalg::gemm(handle, activeBPView, eigBlockVectorP.view(), bpp.view(), one, one); + } + Pbuffer.resize(n * size_x, stream); + APbuffer.resize(n * size_x, stream); + BPbuffer.resize(n * size_x, stream); + PView = raft::make_device_matrix_view(Pbuffer.data(), n, size_x); + APView = raft::make_device_matrix_view(APbuffer.data(), n, size_x); + BPView = raft::make_device_matrix_view(BPbuffer.data(), n, size_x); + + raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); + raft::copy(BPView.data_handle(), bpp.data_handle(), bpp.size(), stream); + + raft::linalg::gemm(handle, X, eigBlockVectorX.view(), pp.view(), one, one); + raft::linalg::gemm(handle, AX.view(), eigBlockVectorX.view(), app.view(), one, one); + raft::linalg::gemm(handle, BXView, eigBlockVectorX.view(), bpp.view(), one, one); + + raft::copy(X.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream); + raft::copy(BXView.data_handle(), bpp.data_handle(), bpp.size(), stream); + } else { + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice_coordinates(0, 0, size_x, size_x)); if (!restart) { - raft::matrix::truncZeroOrigin(eigVector.data_handle(), ) + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); + } else { + raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); } + + raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view()); + if (!restart) { + raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); + raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one); + } + Pbuffer.resize(n * size_x, stream); + APbuffer.resize(n * size_x, stream); + PView = raft::make_device_matrix_view(Pbuffer.data(), n, size_x); + APView = raft::make_device_matrix_view(APbuffer.data(), n, size_x); + + raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); + + raft::linalg::gemm(handle, X, eigBlockVectorX.view(), pp.view(), one, one); + raft::linalg::gemm(handle, AX.view(), eigBlockVectorX.view(), app.view(), one, one); + + raft::copy(X.data_handle(), pp.data_handle(), pp.size(), stream); + raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream); } } - return; - // TODO + + if (B_opt.has_value()) { // Using blockVectorR instead of aux + raft::copy(R.data_handle(), BXView.data_handle(), BXView.size(), stream); + } else { + raft::copy(R.data_handle(), X.data_handle(), X.size(), stream); + } + raft::linalg::binary_mult_skip_zero(handle, R.view(), make_const_mdspan(eigLambda.view()), linalg::Apply::ALONG_ROWS); + raft::linalg::gemm(handle, AX.view(),) } }; // namespace raft::sparse::solver::detail \ No newline at end of file From 6694ae4ffa6dc4df5ca149b6e66cfb645b337062 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 19 Apr 2023 17:01:55 +0200 Subject: [PATCH 13/17] Fix compilation issues --- .../raft/sparse/solver/detail/lobpcg.cuh | 116 +++++++++++------- 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 6c016a54ec..0fbf615967 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -135,9 +135,9 @@ void selectColsIf(const raft::handle_t& handle, raft::linalg::map( handle, raft::make_const_mdspan(mask), + raft::make_const_mdspan(rangeVec.view()), rangeVec.view(), - [] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; }, - rangeVec.view()); + [] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; }); thrust::sort(rmm::exec_policy(stream), rangeVec.data_handle(), rangeVec.data_handle() + rangeVec.size(), @@ -172,11 +172,11 @@ void truncEig( } if (eigVectorTrunc.has_value() && ncols > eigVectorTrunc->extent(1)) raft::matrix::truncZeroOrigin(eigVectorin.data_handle(), - n_rows, + nrows, eigVectorTrunc->data_handle(), nrows, eigVectorTrunc->extent(1), - stream); + handle.get_stream()); } // C = A * B @@ -447,7 +447,7 @@ bool eigh(const raft::handle_t& handle, raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); - return cho_success + return cho_success; } /** @@ -604,8 +604,10 @@ void lobpcg( auto eigVectorBuffer = rmm::device_uvector(size_x * size_x, stream); // rmm because of resize auto eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), size_x, size_x); auto eigLambda = raft::make_device_vector(handle, size_x); - eigh(handle, gramXAX.view(), eigVectorView, eigLambda.view()); - truncEig(handle, eigVectorView, eigLambda.view(), size_x, largest); + std::optional> empty_matrix_opt = std::nullopt; + eigh(handle, gramXAX.view(), empty_matrix_opt, eigVectorView, eigLambda.view()); + + truncEig(handle, eigVectorView, empty_matrix_opt, eigLambda.view(), largest); // Slice not needed for first eigh // raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0, // eigVectorFull.extent(0), size_x)); @@ -623,6 +625,9 @@ void lobpcg( auto identView = raft::make_device_matrix_view( ident.data(), size_x, size_x); raft::matrix::eye(handle, identView); + auto identSizeX = raft::make_device_matrix( + handle, size_x, size_x); + raft::matrix::eye(handle, identSizeX.view()); auto Pbuffer = rmm::device_uvector(0, stream); auto APbuffer = rmm::device_uvector(0, stream); @@ -646,6 +651,8 @@ void lobpcg( auto aux = raft::make_device_matrix( handle, n, size_x); + //auto aux_sum = raft::make_device_vector(handle, size_x); + auto residual_norms = raft::make_device_vector(handle, size_x); std::int32_t iteration_number = -1; bool restart = true; bool explicitGramFlag = false; @@ -664,9 +671,8 @@ void lobpcg( raft::linalg::subtract( handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); - auto aux_sum = raft::make_device_vector(handle, size_x); raft::linalg::reduce( - aux_sum.data_handle(), + residual_norms.data_handle(), R.data_handle(), size_x, n, @@ -677,8 +683,7 @@ void lobpcg( false, raft::sq_op()); - auto residual_norms = raft::make_device_vector(handle, size_x); - raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); + // TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); // cupy where & active_mask raft::linalg::unary_op(handle, @@ -720,7 +725,7 @@ void lobpcg( selectColsIf(handle, APView, active_mask.view(), activeAPView); if (B_opt.has_value()) { activeBPView = raft::make_device_matrix_view(activeBPbuffer.data(), n, currentBlockSize); - selectColsIf(handle, BPbuffer.view(), active_mask.view(), activeBPView); + selectColsIf(handle, BPView, active_mask.view(), activeBPView); } } if (M_opt.has_value()) { @@ -823,7 +828,7 @@ void lobpcg( if (!B_opt.has_value()) { // Shared memory assignments to simplify the code - BXView = X.view(); + BXView = X; activeBRView = activeR.view(); if (!restart) activeBPView = activePView; @@ -906,9 +911,9 @@ void lobpcg( auto gramB = raft::make_device_matrix(handle, gramDim, gramDim); auto gramAView = gramA.view(); auto gramBView = gramB.view(); - auto eigLambdaTemp = raft::make_device_vector_view(handle, gramDim); + auto eigLambdaTemp = raft::make_device_vector(handle, gramDim); auto eigVectorTemp = - raft::make_device_matrix_view(handle, gramDim, gramDim); + raft::make_device_matrix(handle, gramDim, gramDim); auto eigLambdaTempView = eigLambdaTemp.view(); auto eigVectorTempView = eigVectorTemp.view(); eigVectorBuffer.resize(gramDim * size_x, stream); @@ -927,19 +932,19 @@ void lobpcg( handle, currentBlockSize, currentBlockSize); // create transpose mat auto gramXAPT = raft::make_device_matrix( - handle, gramXAPT.extent(1), gramXAPT.extent(0)); + handle, gramXAP.extent(1), gramXAP.extent(0)); auto gramXART = raft::make_device_matrix( - handle, gramXART.extent(1), gramXART.extent(0)); + handle, gramXAR.extent(1), gramXAR.extent(0)); auto gramRAPT = raft::make_device_matrix( - handle, gramRAPT.extent(1), gramRAPT.extent(0)); + handle, gramRAP.extent(1), gramRAP.extent(0)); auto gramXBPT = raft::make_device_matrix( - handle, gramXBPT.extent(1), gramXBPT.extent(0)); + handle, gramXBP.extent(1), gramXBP.extent(0)); auto gramXBRT = raft::make_device_matrix( - handle, gramXBRT.extent(1), gramXBRT.extent(0)); + handle, gramXBR.extent(1), gramXBR.extent(0)); auto gramRBPT = raft::make_device_matrix( - handle, gramRBPT.extent(1), gramRBPT.extent(0)); + handle, gramRBP.extent(1), gramRBP.extent(0)); raft::linalg::transpose(handle, gramXAR.view(), gramXART.view()); - raft::linalg::transpose(handle, gramXVR.view(), gramXBRT.view()); + raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view()); if (!restart) { raft::linalg::gemm(handle, @@ -1005,19 +1010,19 @@ void lobpcg( gramBView = raft::make_device_matrix_view(gramB.data_handle(), n, n); - bmat(handle, gramAView, A_blocks); - bmat(handle, gramBView, B_blocks); + bmat(handle, gramAView, A_blocks, 3); + bmat(handle, gramBView, B_blocks, 3); bool eig_sucess = - eigh(handle, gramA, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); + eigh(handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); if (!eig_sucess) restart = true; } if (restart) { gramDim = gramXAX.extent(1) + gramXAR.extent(1); std::vector> A_blocks = { - gramXAX, gramXAR, gramXART, gramRAR}; + gramXAX.view(), gramXAR.view(), gramXART.view(), gramRAR.view()}; std::vector> B_blocks = { - gramXBX, gramXBR, gramXBRT, gramRBR}; + gramXBX.view(), gramXBR.view(), gramXBRT.view(), gramRBR.view()}; gramAView = raft::make_device_matrix_view( gramA.data_handle(), gramDim, gramDim); gramBView = raft::make_device_matrix_view( @@ -1026,8 +1031,8 @@ void lobpcg( raft::make_device_vector_view(eigLambdaTempView.data_handle(), gramDim); eigVectorTempView = raft::make_device_matrix_view( eigVectorTempView.data_handle(), gramDim, gramDim); - bmat(handle, gramAView, A_blocks); - bmat(handle, gramBView, B_blocks); + bmat(handle, gramAView, A_blocks, 2); + bmat(handle, gramBView, B_blocks, 2); bool eig_sucess = eigh( handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations"); @@ -1048,20 +1053,20 @@ void lobpcg( auto app = raft::make_device_matrix(handle, n, size_x); if (B_opt.has_value()) { auto bpp = raft::make_device_matrix(handle, n, size_x); - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(), raft::matrix::slice_coordinates(0, 0, size_x, size_x)); if (!restart) { - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(), raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); } else { - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); } - raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view()); - raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view()); + raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view()); raft::linalg::gemm(handle, activeBRView, eigBlockVectorR.view(), bpp.view()); if (!restart) { raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); @@ -1087,20 +1092,20 @@ void lobpcg( raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream); raft::copy(BXView.data_handle(), bpp.data_handle(), bpp.size(), stream); } else { - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(), raft::matrix::slice_coordinates(0, 0, size_x, size_x)); if (!restart) { - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), raft::matrix::slice_coordinates(size_x, 0, size_x + currentBlockSize, size_x)); - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(), raft::matrix::slice_coordinates(size_x + currentBlockSize, 0, gramDim, size_x)); } else { - raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(), + raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(), raft::matrix::slice_coordinates(size_x, 0, gramDim, size_x)); } - raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view()); - raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view()); + raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view()); + raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view()); if (!restart) { raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one); raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one); @@ -1121,12 +1126,31 @@ void lobpcg( } } - if (B_opt.has_value()) { // Using blockVectorR instead of aux - raft::copy(R.data_handle(), BXView.data_handle(), BXView.size(), stream); + if (B_opt.has_value()) { + raft::copy(aux.data_handle(), BXView.data_handle(), BXView.size(), stream); } else { - raft::copy(R.data_handle(), X.data_handle(), X.size(), stream); + raft::copy(aux.data_handle(), X.data_handle(), X.size(), stream); + } + raft::linalg::binary_mult_skip_zero(handle, aux.view(), make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS); + + raft::linalg::subtract( + handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); + + raft::linalg::reduce( + residual_norms.data_handle(), + R.data_handle(), + size_x, + n, + value_t(0), + false, + true, + stream, + false, + raft::sq_op()); + // TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); + + if (verbosityLevel > 0) { + /// TODO add verb } - raft::linalg::binary_mult_skip_zero(handle, R.view(), make_const_mdspan(eigLambda.view()), linalg::Apply::ALONG_ROWS); - raft::linalg::gemm(handle, AX.view(),) } }; // namespace raft::sparse::solver::detail \ No newline at end of file From d55b32bd64b731df747c9df13d3bf6773221e1b8 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 24 Apr 2023 19:02:14 +0200 Subject: [PATCH 14/17] Fix bmat --- .../raft/sparse/solver/detail/lobpcg.cuh | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 0fbf615967..35c215d0ec 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -73,13 +73,6 @@ void bmat(const raft::handle_t& handle, index_t n_blocks) { RAFT_EXPECTS(n_blocks * n_blocks == ins.size(), "inconsistent number of blocks"); - index_t n_rows = 0; - index_t n_cols = 0; - for (const auto& inView : ins) { - n_rows += inView.extent(0); - n_cols += inView.extent(1); - } - RAFT_EXPECTS(n_rows == out.extent(0) && n_cols == out.extent(1), "input/output dimension mismatch"); std::vector cumulative_row(n_blocks); std::vector cumulative_col(n_blocks); for (index_t i = 0; i < n_blocks; i++) { @@ -693,10 +686,10 @@ void lobpcg( if (verbosityLevel > 2) { print_device_vector("active_mask", active_mask.data_handle(), active_mask.size(), std::cout); } - index_t currentBlockSize = thrust::count(thrust::cuda::par.on(stream), - active_mask.data_handle(), - active_mask.data_handle() + active_mask.size(), - 0); + index_t currentBlockSize = thrust::count_if(thrust::cuda::par.on(stream), + active_mask.data_handle(), + active_mask.data_handle() + active_mask.size(), + [] __device__(value_t v) {return v > 0; }); handle.sync_stream(); if (currentBlockSize != previousBlockSize) { previousBlockSize = currentBlockSize; @@ -847,7 +840,7 @@ void lobpcg( raft::make_device_matrix(handle, size_x, currentBlockSize); raft::linalg::gemm(handle, raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + X.data_handle(), X.extent(1), X.extent(0)), // transpose for gemm? activeAR.view(), gramXAR.view()); From eba5f1c3193b260dd2912517c18e05825312ca1d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 11 May 2023 02:28:28 +0200 Subject: [PATCH 15/17] Use common transpose function --- .../raft/sparse/solver/detail/lobpcg.cuh | 89 +++++++++---------- cpp/test/sparse/lobpcg.cu | 32 +++++-- 2 files changed, 66 insertions(+), 55 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 35c215d0ec..bf70696e98 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -46,6 +46,17 @@ #include #include +template +auto make_transpose_layout_view(raft::device_matrix_view mds) +{ + return raft::make_device_matrix_view(mds.data_handle(), mds.extent(1), mds.extent(0)); +} +template +auto make_transpose_layout_view(raft::device_matrix_view mds) +{ + return raft::make_device_matrix_view(mds.data_handle(), mds.extent(1), mds.extent(0)); +} + namespace raft::sparse::solver::detail { /** @@ -416,11 +427,13 @@ bool eigh(const raft::handle_t& handle, raft::device_matrix_view eigVecs, raft::device_vector_view eigVals) { + auto dim = A.extent(0); + auto AT = raft::make_device_matrix(handle, dim, dim); + raft::linalg::transpose(handle, A, AT.view()); if (!B_opt.has_value()) { - raft::linalg::eig_dc(handle, raft::make_const_mdspan(A), eigVecs, eigVals); + raft::linalg::eig_dc(handle, raft::make_const_mdspan(AT.view()), eigVecs, eigVals); return true; } - auto dim = A.extent(0); auto RTi = raft::make_device_matrix(handle, dim, dim); auto Ri = raft::make_device_matrix(handle, dim, dim); auto RT = raft::make_device_matrix(handle, dim, dim); @@ -435,7 +448,7 @@ bool eigh(const raft::handle_t& handle, // Reuse the memory of matrix auto& ARi = B; auto& Fvecs = RT; - raft::linalg::gemm(handle, A, Ri.view(), ARi); + raft::linalg::gemm(handle, AT.view(), Ri.view(), ARi); raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); @@ -591,9 +604,10 @@ void lobpcg( spmm(handle, A, X, AX.view()); auto gramXAX = raft::make_device_matrix(handle, size_x, size_x); - auto XT = raft::make_device_matrix(handle, size_x, n); - raft::linalg::transpose(handle, X, XT.view()); - raft::linalg::gemm(handle, XT.view(), AX.view(), gramXAX.view()); + auto XTRowView = make_transpose_layout_view(X); + raft::linalg::gemm(handle, + XTRowView, + AX.view(), gramXAX.view()); auto eigVectorBuffer = rmm::device_uvector(size_x * size_x, stream); // rmm because of resize auto eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), size_x, size_x); auto eigLambda = raft::make_device_vector(handle, size_x); @@ -735,15 +749,14 @@ void lobpcg( } // B-orthogonalize the preconditioned residuals to X. if (B_opt.has_value()) { - auto BXT = raft::make_device_matrix( - handle, BX.extent(1), BX.extent(0)); auto BXTR = raft::make_device_matrix( - handle, BXT.extent(0), activeR.extent(1)); + handle, BX.extent(1), activeR.extent(1)); auto XBXTR = raft::make_device_matrix( handle, X.extent(0), BXTR.extent(1)); - raft::linalg::transpose(handle, BX.view(), BXT.view()); - raft::linalg::gemm(handle, BXT.view(), activeR.view(), BXTR.view()); + raft::linalg::gemm(handle, + make_transpose_layout_view(BX.view()), + activeR.view(), BXTR.view()); raft::linalg::gemm(handle, X, BXTR.view(), XBXTR.view()); raft::linalg::subtract(handle, raft::make_const_mdspan(activeR.view()), @@ -751,11 +764,10 @@ void lobpcg( activeR.view()); } else { auto XTR = raft::make_device_matrix( - handle, XT.extent(0), activeR.extent(1)); + handle, X.extent(1), activeR.extent(1)); auto XXTR = raft::make_device_matrix( handle, X.extent(0), XTR.extent(1)); - - raft::linalg::gemm(handle, XT.view(), activeR.view(), XTR.view()); + raft::linalg::gemm(handle, XTRowView, activeR.view(), XTR.view()); raft::linalg::gemm(handle, X, XTR.view(), XXTR.view()); raft::linalg::subtract(handle, raft::make_const_mdspan(activeR.view()), @@ -839,15 +851,13 @@ void lobpcg( auto gramXBR = raft::make_device_matrix(handle, size_x, currentBlockSize); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(1), X.extent(0)), // transpose for gemm? + XTRowView, activeAR.view(), gramXAR.view()); raft::linalg::gemm( handle, - raft::make_device_matrix_view( - activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + make_transpose_layout_view(activeR.view()), activeAR.view(), gramRAR.view()); @@ -855,40 +865,34 @@ void lobpcg( if (explicitGramFlag) { raft::linalg::gemm( handle, - raft::make_device_matrix_view( - gramRAR.data_handle(), gramRAR.extent(0), gramRAR.extent(1)), // transpose for gemm + make_transpose_layout_view(gramRAR.view()), identView, gramRAR.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + XTRowView, AX.view(), gramXAX.view()); raft::linalg::gemm( handle, - raft::make_device_matrix_view( - gramXAX.data_handle(), gramXAX.extent(0), gramXAX.extent(1)), // transpose for gemm + make_transpose_layout_view(gramXAX.view()), identView, gramXAX.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + XTRowView, BX.view(), gramXBX.view()); raft::linalg::gemm( handle, - raft::make_device_matrix_view( - activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + make_transpose_layout_view(activeR.view()), activeBRView, gramRBR.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + XTRowView, activeBRView, gramXBR.view()); } else { @@ -941,49 +945,38 @@ void lobpcg( if (!restart) { raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + XTRowView, activeAPView, gramXAP.view()); raft::linalg::gemm( handle, - raft::make_device_matrix_view( - activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + make_transpose_layout_view(activeR.view()), activeAPView, gramRAP.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - activePView.data_handle(), - activePView.extent(0), - activePView.extent(1)), // transpose for gemm + make_transpose_layout_view(activePView), activeAPView, gramPAP.view()); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - X.data_handle(), X.extent(0), X.extent(1)), // transpose for gemm + XTRowView, activeBPView, gramXBP.view()); raft::linalg::gemm( handle, - raft::make_device_matrix_view( - activeR.data_handle(), activeR.extent(0), activeR.extent(1)), // transpose for gemm + make_transpose_layout_view(activeR.view()), activeBPView, gramRBP.view()); if (explicitGramFlag) { raft::linalg::gemm( handle, - raft::make_device_matrix_view( - gramPAP.data_handle(), gramPAP.extent(0), gramPAP.extent(1)), // transpose for gemm + make_transpose_layout_view(gramPAP.view()), identView, gramPAP.view(), std::make_optional(device_half.view()), std::make_optional(device_half.view())); raft::linalg::gemm(handle, - raft::make_device_matrix_view( - activePView.data_handle(), - activePView.extent(0), - activePView.extent(1)), // transpose for gemm + make_transpose_layout_view(activePView), activeBPView, gramPBP.view()); } else { diff --git a/cpp/test/sparse/lobpcg.cu b/cpp/test/sparse/lobpcg.cu index 2d5fe4109b..e226772b9e 100644 --- a/cpp/test/sparse/lobpcg.cu +++ b/cpp/test/sparse/lobpcg.cu @@ -34,8 +34,8 @@ namespace sparse { template struct CSRMatrixVal { - std::vector row_ind; std::vector row_ind_ptr; + std::vector row_ind; std::vector values; }; @@ -79,7 +79,7 @@ class LOBPCGTest : public ::testing::TestWithParam> stream(handle.get_stream()), ind_a(params.matrix_a.row_ind.size(), stream), ind_ptr_a(params.matrix_a.row_ind_ptr.size(), stream), - values_a(params.matrix_a.row_ind_ptr.size(), stream), + values_a(params.matrix_a.values.size(), stream), exp_eigvals(params.exp_eigvals.size(), stream), exp_eigvecs(params.exp_eigvecs.size(), stream), act_eigvals(params.exp_eigvals.size(), stream), @@ -90,8 +90,8 @@ class LOBPCGTest : public ::testing::TestWithParam> protected: void SetUp() override { - n_rows_a = params.matrix_a.row_ind.size() - 1; - nnz_a = params.matrix_a.row_ind_ptr.size(); + n_rows_a = params.matrix_a.row_ind_ptr.size() - 1; + nnz_a = params.matrix_a.values.size(); } void test_selectcolsif() @@ -173,14 +173,32 @@ class LOBPCGTest : public ::testing::TestWithParam> hostVecMatch(vbv_inv_expected, vbv_inv_actual, raft::CompareApprox(0.0001))); } + void test_eigh() + { + std::vector in_cpu{1.73969722, 0.98719877, 0.73374337, 0.211756781}; + std::vector lambda_cpu{-0.27255666, 2.22401026}; + std::vector vector_cpu{-0.44044489, 0.89777965, 0.89777965, 0.44044489}; + auto in_gpu = raft::make_device_matrix(handle, 2, 2); + auto lambda_gpu = raft::make_device_vector(handle, 2); + auto vector_gpu = raft::make_device_matrix(handle, 2, 2); + std::optional> empty_matrix_opt = std::nullopt; + + raft::copy(in_gpu.data_handle(), in_cpu.data(), 4, handle.get_stream()); + raft::sparse::solver::detail::eigh(handle, in_gpu.view(), empty_matrix_opt, vector_gpu.view(), lambda_gpu.view()); + + ASSERT_TRUE(devArrMatchHost(lambda_cpu.data(), lambda_gpu.data_handle(), lambda_cpu.size(), raft::CompareApprox(0.0001), handle.get_stream())); + ASSERT_TRUE(devArrMatchHost(vector_cpu.data(), vector_gpu.data_handle(), vector_cpu.size(), raft::CompareApprox(0.0001), handle.get_stream())); + } + void Run() { + test_eigh(); test_bmat(); test_selectcolsif(); test_b_orthonormalize(); - raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), n_rows_a, stream); - raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), nnz_a, stream); - raft::update_device(values_a.data(), params.matrix_a.values.data(), nnz_a, stream); + raft::update_device(ind_a.data(), params.matrix_a.row_ind.data(), params.matrix_a.row_ind.size(), stream); + raft::update_device(ind_ptr_a.data(), params.matrix_a.row_ind_ptr.data(), params.matrix_a.row_ind_ptr.size(), stream); + raft::update_device(values_a.data(), params.matrix_a.values.data(), params.matrix_a.values.size(), stream); raft::update_device(act_eigvecs.data(), params.init_eigvecs.data(), act_eigvecs.size(), stream); From f6a7f02c767a8f9795969ef333529cefb437eea6 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 15 May 2023 15:57:14 +0200 Subject: [PATCH 16/17] Add verbosity print --- .../raft/sparse/solver/detail/lobpcg.cuh | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index bf70696e98..1ced904336 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -715,9 +716,19 @@ void lobpcg( if (currentBlockSize == 0) break; if (verbosityLevel > 0) { - // TODO add verb + printf("Iteration: %i\n", iteration_number); + printf("current block size: %d\n", currentBlockSize); + raft::matrix::print_separators ps{}; + printf("lambda:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(eigLambda.data_handle(), 1, eigLambda.extent(0)), ps); + printf("residual norms:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(residual_norms.data_handle(), 1, residual_norms.extent(0)), ps); + if (verbosityLevel > 10) { + printf("eigBlockVector:\n"); + raft::matrix::print(handle, make_const_mdspan(eigVectorView), ps); + + } } - auto activeR = raft::make_device_matrix(handle, n, currentBlockSize); selectColsIf(handle, R.view(), active_mask.view(), activeR.view()); @@ -789,7 +800,7 @@ void lobpcg( handle, activePView.extent(1), activePView.extent(1)); auto normal = raft::make_device_vector(handle, activePView.extent(1)); bool b_orth_success = true; - if (B_opt.has_value()) { + if (!B_opt.has_value()) { auto BP = raft::make_device_matrix( handle, activePView.extent(0), activePView.extent(1)); b_orth_success = b_orthonormalize(handle, @@ -1028,6 +1039,16 @@ void lobpcg( raft::copy(eigLambda.data_handle(), eigLambdaTempView.data_handle(), size_x, stream); // Verbosity print + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("gramA:\n"); + raft::matrix::print(handle, make_const_mdspan(gramAView), ps); + printf("gramB:\n"); + raft::matrix::print(handle, make_const_mdspan(gramBView), ps); + printf("lambdaPostGram:\n"); + raft::matrix::print(handle, raft::make_device_matrix_view(eigLambdaTempView.data_handle(), 1, eigLambdaTempView.extent(0)), ps); + + } // Compute Ritz vectors. auto d_one = raft::make_device_scalar(handle, 1); From 59374f0789fa8ac24abc4e7aad9d03e7880c8c7e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 29 May 2023 17:55:07 +0200 Subject: [PATCH 17/17] Fix inverse and eigh function, verbose print --- .../raft/sparse/solver/detail/lobpcg.cuh | 109 ++++++++++-------- 1 file changed, 64 insertions(+), 45 deletions(-) diff --git a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh index 1ced904336..323e431e29 100644 --- a/cpp/include/raft/sparse/solver/detail/lobpcg.cuh +++ b/cpp/include/raft/sparse/solver/detail/lobpcg.cuh @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -328,19 +329,21 @@ template void inverse(const raft::handle_t& handle, raft::device_matrix_view P, raft::device_matrix_view Pinv, - bool lower = true) + bool transposeP = false) { auto stream = handle.get_stream(); int Lwork = 0; auto lda = P.extent(0); auto dim = P.extent(0); int info_h = 0; - cublasOperation_t trans = CUBLAS_OP_N; + cublasOperation_t trans = transposeP ? CUBLAS_OP_T : CUBLAS_OP_N; raft::matrix::eye(handle, Pinv); RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf_bufferSize( handle.get_cusolver_dn_handle(), dim, dim, P.data_handle(), lda, &Lwork)); + auto P_copy = raft::make_device_matrix(handle, P.extent(0), P.extent(1)); + raft::copy(P_copy.data_handle(), P.data_handle(), P.size(), stream); rmm::device_uvector workspace_decomp(Lwork, stream); rmm::device_uvector info(1, stream); auto ipiv = raft::make_device_vector(handle, dim); @@ -348,7 +351,7 @@ void inverse(const raft::handle_t& handle, RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDngetrf(handle.get_cusolver_dn_handle(), dim, dim, - P.data_handle(), + P_copy.data_handle(), lda, workspace_decomp.data(), ipiv.data_handle(), @@ -363,7 +366,7 @@ void inverse(const raft::handle_t& handle, trans, dim, dim, - P.data_handle(), + P_copy.data_handle(), lda, ipiv.data_handle(), Pinv.data_handle(), @@ -437,22 +440,25 @@ bool eigh(const raft::handle_t& handle, } auto RTi = raft::make_device_matrix(handle, dim, dim); auto Ri = raft::make_device_matrix(handle, dim, dim); - auto RT = raft::make_device_matrix(handle, dim, dim); + auto R = raft::make_device_matrix(handle, dim, dim); auto F = raft::make_device_matrix(handle, dim, dim); auto B = B_opt.value(); bool cho_success = cholesky(handle, B, false); + raft::linalg::transpose(handle, B, R.view()); - raft::linalg::transpose(handle, B, RT.view()); - inverse(handle, RT.view(), Ri.view()); - inverse(handle, B, RTi.view()); + inverse(handle, R.view(), Ri.view(), true); + inverse(handle, R.view(), RTi.view(), false); // Reuse the memory of matrix auto& ARi = B; - auto& Fvecs = RT; - raft::linalg::gemm(handle, AT.view(), Ri.view(), ARi); + auto& Fvecs = R; + raft::linalg::gemm(handle, A, Ri.view(), ARi); raft::linalg::gemm(handle, RTi.view(), ARi, F.view()); - raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals); + auto FT = raft::make_device_matrix(handle, dim, dim); + raft::linalg::transpose(handle, F.view(), FT.view()); + + raft::linalg::eig_dc(handle, raft::make_const_mdspan(FT.view()), Fvecs.view(), eigVals); raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs); return cho_success; } @@ -491,34 +497,32 @@ bool b_orthonormalize( V_max_ptr = V_max_opt.value().data_handle(); } auto V_max = raft::make_device_vector_view(V_max_ptr, V.extent(1)); - auto V_max_const = raft::make_const_mdspan(V_max); - // /*raft::linalg::reduce(handle, raft::make_device_matrix_view( - V.data_handle(), V.extent(1), V.extent(0)), + V.data_handle(), V.extent(0), V.extent(1)), V_max, value_t(0), raft::linalg::Apply::ALONG_ROWS, false, raft::identity_op(), - MaxOp()); - */ + MaxOp());*/ + // Coalesced reduction raft::linalg::reduce(V_max.data_handle(), V.data_handle(), - V.extent(0), V.extent(1), + V.extent(0), value_t(0), false, - true, + false, handle.get_stream(), false, raft::identity_op(), MaxOp()); - raft::linalg::binary_div_skip_zero(handle, V, V_max_const, raft::linalg::Apply::ALONG_ROWS); + raft::linalg::binary_div_skip_zero(handle, V, raft::make_const_mdspan(V_max), raft::linalg::Apply::ALONG_ROWS); if (!bv_is_empty) { - raft::linalg::binary_div_skip_zero(handle, BV, V_max_const, raft::linalg::Apply::ALONG_ROWS); + raft::linalg::binary_div_skip_zero(handle, BV, raft::make_const_mdspan(V_max), raft::linalg::Apply::ALONG_ROWS); } else { if (B_opt) spmm(handle, B_opt.value(), V, BV); @@ -537,11 +541,9 @@ bool b_orthonormalize( VBV_ptr, V.extent(1), V.extent(1)); auto VBVBuffer = raft::make_device_matrix( handle, VBV.extent(0), VBV.extent(1)); - auto VT = - raft::make_device_matrix(handle, V.extent(1), V.extent(0)); - raft::linalg::transpose(handle, V, VT.view()); + auto VT = make_transpose_layout_view(V); - raft::linalg::gemm(handle, VT.view(), BV, VBV); + raft::linalg::gemm(handle, VT, BV, VBV); bool cholesky_success = cholesky(handle, VBV, false); if (!cholesky_success) { return cholesky_success; } @@ -679,19 +681,7 @@ void lobpcg( raft::linalg::subtract( handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view()); - raft::linalg::reduce( - residual_norms.data_handle(), - R.data_handle(), - size_x, - n, - value_t(0), - false, - true, - stream, - false, - raft::sq_op()); - - // TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view()); + raft::linalg::norm(handle, make_const_mdspan(R.view()), residual_norms.view(), raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_COLUMNS, raft::sqrt_op()); // cupy where & active_mask raft::linalg::unary_op(handle, @@ -839,7 +829,8 @@ void lobpcg( residual_norms.data_handle() + residual_norms.size()); value_t residual_norms_max = 0; raft::copy(&residual_norms_max, residual_norms_max_elem, 1, stream); - explicitGramFlag = residual_norms_max <= myeps; + handle.sync_stream(); + explicitGramFlag = residual_norms_max > myeps; } if (!B_opt.has_value()) { @@ -924,8 +915,6 @@ void lobpcg( raft::make_device_matrix(handle, gramDim, gramDim); auto eigLambdaTempView = eigLambdaTemp.view(); auto eigVectorTempView = eigVectorTemp.view(); - eigVectorBuffer.resize(gramDim * size_x, stream); - eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), gramDim, size_x); auto gramXAP = raft::make_device_matrix(handle, size_x, currentBlockSize); auto gramRAP = raft::make_device_matrix( @@ -1010,6 +999,14 @@ void lobpcg( bmat(handle, gramAView, A_blocks, 3); bmat(handle, gramBView, B_blocks, 3); + // Verbosity print + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("gramA:\n"); + raft::matrix::print(handle, make_const_mdspan(gramAView), ps); + printf("gramB:\n"); + raft::matrix::print(handle, make_const_mdspan(gramBView), ps); + } bool eig_sucess = eigh(handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); if (!eig_sucess) restart = true; @@ -1030,10 +1027,19 @@ void lobpcg( eigVectorTempView.data_handle(), gramDim, gramDim); bmat(handle, gramAView, A_blocks, 2); bmat(handle, gramBView, B_blocks, 2); + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("gramA:\n"); + raft::matrix::print(handle, make_const_mdspan(gramAView), ps); + printf("gramB:\n"); + raft::matrix::print(handle, make_const_mdspan(gramBView), ps); + } bool eig_sucess = eigh( handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView); ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations"); } + eigVectorBuffer.resize(gramDim * size_x, stream); + eigVectorView = raft::make_device_matrix_view(eigVectorBuffer.data(), gramDim, size_x); truncEig( handle, eigVectorTempView, std::make_optional(eigVectorView), eigLambdaTempView, largest); raft::copy(eigLambda.data_handle(), eigLambdaTempView.data_handle(), size_x, stream); @@ -1041,10 +1047,6 @@ void lobpcg( // Verbosity print if (verbosityLevel > 10) { raft::matrix::print_separators ps{}; - printf("gramA:\n"); - raft::matrix::print(handle, make_const_mdspan(gramAView), ps); - printf("gramB:\n"); - raft::matrix::print(handle, make_const_mdspan(gramBView), ps); printf("lambdaPostGram:\n"); raft::matrix::print(handle, raft::make_device_matrix_view(eigLambdaTempView.data_handle(), 1, eigLambdaTempView.extent(0)), ps); @@ -1086,7 +1088,16 @@ void lobpcg( PView = raft::make_device_matrix_view(Pbuffer.data(), n, size_x); APView = raft::make_device_matrix_view(APbuffer.data(), n, size_x); BPView = raft::make_device_matrix_view(BPbuffer.data(), n, size_x); - + + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("pp:\n"); + raft::matrix::print(handle, make_const_mdspan(pp.view()), ps); + printf("app:\n"); + raft::matrix::print(handle, make_const_mdspan(app.view()), ps); + printf("bpp:\n"); + raft::matrix::print(handle, make_const_mdspan(bpp.view()), ps); + } raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); raft::copy(BPView.data_handle(), bpp.data_handle(), bpp.size(), stream); @@ -1124,6 +1135,14 @@ void lobpcg( raft::copy(PView.data_handle(), pp.data_handle(), pp.size(), stream); raft::copy(APView.data_handle(), app.data_handle(), app.size(), stream); + + if (verbosityLevel > 10) { + raft::matrix::print_separators ps{}; + printf("pp:\n"); + raft::matrix::print(handle, make_const_mdspan(pp.view()), ps); + printf("app:\n"); + raft::matrix::print(handle, make_const_mdspan(app.view()), ps); + } raft::linalg::gemm(handle, X, eigBlockVectorX.view(), pp.view(), one, one); raft::linalg::gemm(handle, AX.view(), eigBlockVectorX.view(), app.view(), one, one);