diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index ee19b5cd594..c75ebbe3f10 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -185,6 +185,9 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_2_KERNEL); namespace lower_trs { +GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL() +GKO_NOT_COMPILED(GKO_HOOK_MODULE); + template GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index 4afeb0eace9..71ea4e0016f 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -53,6 +53,7 @@ namespace solver { namespace lower_trs { +GKO_REGISTER_OPERATION(clear, lower_trs::clear); GKO_REGISTER_OPERATION(generate, lower_trs::generate); GKO_REGISTER_OPERATION(solve, lower_trs::solve); @@ -60,6 +61,13 @@ GKO_REGISTER_OPERATION(solve, lower_trs::solve); } // namespace lower_trs +template +void LowerTrs::clear_data() const +{ + this->get_executor()->run(lower_trs::make_clear()); +} + + template void LowerTrs::generate() { diff --git a/core/solver/lower_trs_kernels.hpp b/core/solver/lower_trs_kernels.hpp index f98f8bbb59b..77875d7ff26 100644 --- a/core/solver/lower_trs_kernels.hpp +++ b/core/solver/lower_trs_kernels.hpp @@ -47,6 +47,10 @@ namespace kernels { namespace lower_trs { +#define GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL() \ + void clear(std::shared_ptr exec) + + #define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ void generate(std::shared_ptr exec, \ const matrix::Csr<_vtype, _itype> *matrix, \ @@ -60,6 +64,7 @@ namespace lower_trs { #define GKO_DECLARE_ALL_AS_TEMPLATES \ + GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL(); \ template \ GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \ template \ diff --git a/cuda/solver/lower_trs_kernels.cu b/cuda/solver/lower_trs_kernels.cu index f2f6816ea17..a403de64390 100644 --- a/cuda/solver/lower_trs_kernels.cu +++ b/cuda/solver/lower_trs_kernels.cu @@ -78,6 +78,25 @@ static cusp_csrsm_data cusp_csrsm_data{}; #endif +void clear(std::shared_ptr exec) +{ +#if (defined(CUDA_VERSION) && (CUDA_VERSION > 9100)) + cusparse::destroy(cusp_csrsm2_data.factor_descr); + if (cusp_csrsm2_data.solve_info) { + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseDestroyCsrsm2Info(cusp_csrsm2_data.solve_info)); + } + if (cusp_csrsm2_data.factor_work_vec != nullptr) { + exec->free(cusp_csrsm2_data.factor_work_vec); + } +#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9200)) + cusparse::destroy(cusp_csrsm_data.factor_descr); + GKO_ASSERT_NO_CUSPARSE_ERRORS( + cusparseDestroySolveAnalysisInfo(cusp_csrsm_data.solve_info)); +#endif +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, diff --git a/cuda/test/solver/lower_trs_kernels.cpp b/cuda/test/solver/lower_trs_kernels.cpp index f8fc1c988ee..c47074536dc 100644 --- a/cuda/test/solver/lower_trs_kernels.cpp +++ b/cuda/test/solver/lower_trs_kernels.cpp @@ -75,16 +75,16 @@ class LowerTrs : public ::testing::Test { std::unique_ptr gen_mtx(int num_rows, int num_cols) { - return gko::test::generate_random_lower_triangular_matrix( - num_rows, num_cols, false, + return gko::test::generate_random_matrix( + num_rows, num_cols, std::uniform_int_distribution<>(num_cols, num_cols), std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } std::unique_ptr gen_l_mtx(int num_rows, int num_cols) { - return gko::test::generate_random_matrix( - num_rows, num_cols, + return gko::test::generate_random_lower_triangular_matrix( + num_rows, num_cols, false, std::uniform_int_distribution<>(num_cols, num_cols), std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } diff --git a/include/ginkgo/core/solver/lower_trs.hpp b/include/ginkgo/core/solver/lower_trs.hpp index 9707e6cb1d0..8c52b5ea17a 100644 --- a/include/ginkgo/core/solver/lower_trs.hpp +++ b/include/ginkgo/core/solver/lower_trs.hpp @@ -117,18 +117,26 @@ class LowerTrs : public EnableLinOp>, GKO_ENABLE_LIN_OP_FACTORY(LowerTrs, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + ~LowerTrs() { this->clear_data(); } + protected: void apply_impl(const LinOp *b, LinOp *x) const override; void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta, LinOp *x) const override; + /** + * Clears the held data. + */ + void clear_data() const; + /** * Generates the analysis structure from the system matrix and the right * hand side needed for the level solver. */ void generate(); + explicit LowerTrs(std::shared_ptr exec) : EnableLinOp(std::move(exec)) {} diff --git a/omp/solver/lower_trs_kernels.cpp b/omp/solver/lower_trs_kernels.cpp index 258695bd955..f7d4938964d 100644 --- a/omp/solver/lower_trs_kernels.cpp +++ b/omp/solver/lower_trs_kernels.cpp @@ -58,6 +58,13 @@ namespace omp { namespace lower_trs { +void clear(std::shared_ptr exec) +{ + // This clear kernel is here to allow for a more sophisticated + // implementation as for other executors. +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, diff --git a/reference/solver/lower_trs_kernels.cpp b/reference/solver/lower_trs_kernels.cpp index 3b39f66f892..f9fd8b60e04 100644 --- a/reference/solver/lower_trs_kernels.cpp +++ b/reference/solver/lower_trs_kernels.cpp @@ -54,6 +54,13 @@ namespace reference { namespace lower_trs { +void clear(std::shared_ptr exec) +{ + // This clear kernel is here to allow for a more sophisticated + // implementation as for other executors. +} + + template void generate(std::shared_ptr exec, const matrix::Csr *matrix, diff --git a/reference/test/solver/lower_trs.cpp b/reference/test/solver/lower_trs.cpp index f7004c9f570..16f2eb96b6c 100644 --- a/reference/test/solver/lower_trs.cpp +++ b/reference/test/solver/lower_trs.cpp @@ -135,8 +135,8 @@ TEST_F(LowerTrs, CanBeCleared) auto solver_mtx = lower_trs_solver->get_system_matrix(); - ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0)); ASSERT_EQ(solver_mtx, nullptr); + ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0)); }