Skip to content

Commit

Permalink
Add a clear kernel to clear the analysis data.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Sep 2, 2019
1 parent c371d9b commit d89c1a0
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 5 deletions.
3 changes: 3 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_2_KERNEL);
namespace lower_trs {


GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

template <typename ValueType, typename IndexType>
GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
Expand Down
8 changes: 8 additions & 0 deletions core/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,21 @@ namespace solver {
namespace lower_trs {


GKO_REGISTER_OPERATION(clear, lower_trs::clear);
GKO_REGISTER_OPERATION(generate, lower_trs::generate);
GKO_REGISTER_OPERATION(solve, lower_trs::solve);


} // namespace lower_trs


template <typename ValueType, typename IndexType>
void LowerTrs<ValueType, IndexType>::clear_data() const
{
this->get_executor()->run(lower_trs::make_clear());
}


template <typename ValueType, typename IndexType>
void LowerTrs<ValueType, IndexType>::generate()
{
Expand Down
5 changes: 5 additions & 0 deletions core/solver/lower_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ namespace kernels {
namespace lower_trs {


#define GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL() \
void clear(std::shared_ptr<const DefaultExecutor> exec)


#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
Expand All @@ -60,6 +64,7 @@ namespace lower_trs {


#define GKO_DECLARE_ALL_AS_TEMPLATES \
GKO_DECLARE_LOWER_TRS_CLEAR_KERNEL(); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_LOWER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
Expand Down
19 changes: 19 additions & 0 deletions cuda/solver/lower_trs_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ static cusp_csrsm_data cusp_csrsm_data{};
#endif


void clear(std::shared_ptr<const CudaExecutor> exec)
{
#if (defined(CUDA_VERSION) && (CUDA_VERSION > 9100))
cusparse::destroy(cusp_csrsm2_data.factor_descr);
if (cusp_csrsm2_data.solve_info) {
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseDestroyCsrsm2Info(cusp_csrsm2_data.solve_info));
}
if (cusp_csrsm2_data.factor_work_vec != nullptr) {
exec->free(cusp_csrsm2_data.factor_work_vec);
}
#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9200))
cusparse::destroy(cusp_csrsm_data.factor_descr);
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseDestroySolveAnalysisInfo(cusp_csrsm_data.solve_info));
#endif
}


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const CudaExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
Expand Down
8 changes: 4 additions & 4 deletions cuda/test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,16 @@ class LowerTrs : public ::testing::Test {

std::unique_ptr<Mtx> gen_mtx(int num_rows, int num_cols)
{
return gko::test::generate_random_lower_triangular_matrix<Mtx>(
num_rows, num_cols, false,
return gko::test::generate_random_matrix<Mtx>(
num_rows, num_cols,
std::uniform_int_distribution<>(num_cols, num_cols),
std::normal_distribution<>(-1.0, 1.0), rand_engine, ref);
}

std::unique_ptr<Mtx> gen_l_mtx(int num_rows, int num_cols)
{
return gko::test::generate_random_matrix<Mtx>(
num_rows, num_cols,
return gko::test::generate_random_lower_triangular_matrix<Mtx>(
num_rows, num_cols, false,
std::uniform_int_distribution<>(num_cols, num_cols),
std::normal_distribution<>(-1.0, 1.0), rand_engine, ref);
}
Expand Down
8 changes: 8 additions & 0 deletions include/ginkgo/core/solver/lower_trs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,26 @@ class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
GKO_ENABLE_LIN_OP_FACTORY(LowerTrs, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

~LowerTrs() { this->clear_data(); }

protected:
void apply_impl(const LinOp *b, LinOp *x) const override;

void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta,
LinOp *x) const override;

/**
* Clears the held data.
*/
void clear_data() const;

/**
* Generates the analysis structure from the system matrix and the right
* hand side needed for the level solver.
*/
void generate();


explicit LowerTrs(std::shared_ptr<const Executor> exec)
: EnableLinOp<LowerTrs>(std::move(exec))
{}
Expand Down
7 changes: 7 additions & 0 deletions omp/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ namespace omp {
namespace lower_trs {


void clear(std::shared_ptr<const OmpExecutor> exec)
{
// This clear kernel is here to allow for a more sophisticated
// implementation as for other executors.
}


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const OmpExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
Expand Down
7 changes: 7 additions & 0 deletions reference/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ namespace reference {
namespace lower_trs {


void clear(std::shared_ptr<const ReferenceExecutor> exec)
{
// This clear kernel is here to allow for a more sophisticated
// implementation as for other executors.
}


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
Expand Down
2 changes: 1 addition & 1 deletion reference/test/solver/lower_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ TEST_F(LowerTrs, CanBeCleared)

auto solver_mtx = lower_trs_solver->get_system_matrix();

ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0));
ASSERT_EQ(solver_mtx, nullptr);
ASSERT_EQ(lower_trs_solver->get_size(), gko::dim<2>(0, 0));
}


Expand Down

0 comments on commit d89c1a0

Please sign in to comment.