Skip to content

Commit

Permalink
Init struct and transpose check flag for the Upper Trs solver.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Sep 10, 2019
1 parent 802811e commit 994fd4b
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 33 deletions.
5 changes: 4 additions & 1 deletion core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,10 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
namespace upper_trs {


GKO_DECLARE_UPPER_TRS_CLEAR_KERNEL()
GKO_DECLARE_UPPER_TRS_CHECK_TRANSPOSABILITY_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL()
GKO_NOT_COMPILED(GKO_HOOK_MODULE);

template <typename ValueType, typename IndexType>
Expand Down
28 changes: 21 additions & 7 deletions core/solver/upper_trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,28 @@ namespace solver {
namespace upper_trs {


GKO_REGISTER_OPERATION(clear, upper_trs::clear);
GKO_REGISTER_OPERATION(generate, upper_trs::generate);
GKO_REGISTER_OPERATION(init_struct, upper_trs::init_struct);
GKO_REGISTER_OPERATION(perform_transpose, upper_trs::perform_transpose);
GKO_REGISTER_OPERATION(solve, upper_trs::solve);


} // namespace upper_trs


template <typename ValueType, typename IndexType>
void UpperTrs<ValueType, IndexType>::clear_data() const
void UpperTrs<ValueType, IndexType>::init_trs_solve_struct()
{
this->get_executor()->run(upper_trs::make_clear());
this->get_executor()->run(upper_trs::make_init_struct(this->solve_struct_));
}


template <typename ValueType, typename IndexType>
void UpperTrs<ValueType, IndexType>::generate()
{
this->get_executor()->run(upper_trs::make_generate(
gko::lend(system_matrix_), parameters_.num_rhs));
gko::lend(system_matrix_), this->solve_struct_.get(),
parameters_.num_rhs));
}


Expand All @@ -84,9 +86,21 @@ void UpperTrs<ValueType, IndexType>::apply_impl(const LinOp *b, LinOp *x) const

auto dense_b = as<const Vector>(b);
auto dense_x = as<Vector>(x);

exec->run(
upper_trs::make_solve(gko::lend(system_matrix_), dense_b, dense_x));
bool transposability = false;
std::shared_ptr<Vector> trans_b;
std::shared_ptr<Vector> trans_x;
this->get_executor()->run(
upper_trs::make_perform_transpose(transposability));
if (transposability) {
trans_b = Vector::create(exec, gko::transpose(dense_b->get_size()));
trans_x = Vector::create(exec, gko::transpose(dense_x->get_size()));
} else {
trans_b = Vector::create(exec);
trans_x = Vector::create(exec);
}
exec->run(upper_trs::make_solve(
gko::lend(system_matrix_), this->solve_struct_.get(),
gko::lend(trans_b), gko::lend(trans_x), dense_b, dense_x));
}


Expand Down
23 changes: 17 additions & 6 deletions core/solver/upper_trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,42 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/solver/upper_trs.hpp>


namespace gko {
namespace kernels {
namespace upper_trs {


#define GKO_DECLARE_UPPER_TRS_CLEAR_KERNEL() \
void clear(std::shared_ptr<const DefaultExecutor> exec)
#define GKO_DECLARE_UPPER_TRS_CHECK_TRANSPOSABILITY_KERNEL() \
void perform_transpose(std::shared_ptr<const DefaultExecutor> exec, \
bool &transposability)


#define GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL() \
void init_struct(std::shared_ptr<const DefaultExecutor> exec, \
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)


#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \
void generate(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
gko::solver::SolveStruct *solve_struct, \
const gko::size_type num_rhs)


#define GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(_vtype, _itype) \
void solve(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
#define GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(_vtype, _itype) \
void solve(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
gko::solver::SolveStruct *solve_struct, \
matrix::Dense<_vtype> *trans_b, matrix::Dense<_vtype> *trans_x, \
const matrix::Dense<_vtype> *b, matrix::Dense<_vtype> *x)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
GKO_DECLARE_UPPER_TRS_CLEAR_KERNEL(); \
GKO_DECLARE_UPPER_TRS_CHECK_TRANSPOSABILITY_KERNEL(); \
GKO_DECLARE_UPPER_TRS_INIT_STRUCT_KERNEL(); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_UPPER_TRS_SOLVE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
Expand Down
14 changes: 13 additions & 1 deletion cuda/solver/upper_trs_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/solver/upper_trs.hpp>


#include "core/matrix/dense_kernels.hpp"
#include "core/solver/lower_trs_kernels.hpp"
#include "cuda/base/cusparse_bindings.hpp"
#include "cuda/base/math.hpp"
#include "cuda/base/types.hpp"
Expand All @@ -53,12 +56,19 @@ namespace cuda {
namespace upper_trs {


void clear(std::shared_ptr<const CudaExecutor> exec) {}
void perform_transpose(std::shared_ptr<const CudaExecutor> exec,
bool &transposability) GKO_NOT_IMPLEMENTED;


void init_struct(std::shared_ptr<const CudaExecutor> exec,
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)
GKO_NOT_IMPLEMENTED;


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const CudaExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
const gko::size_type num_rhs) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
Expand All @@ -68,6 +78,8 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void solve(std::shared_ptr<const CudaExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
matrix::Dense<ValueType> *trans_b, matrix::Dense<ValueType> *trans_x,
const matrix::Dense<ValueType> *b,
matrix::Dense<ValueType> *x) GKO_NOT_IMPLEMENTED;

Expand Down
24 changes: 17 additions & 7 deletions include/ginkgo/core/solver/upper_trs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ namespace gko {
namespace solver {


struct SolveStruct;


/**
* UpperTrs is the triangular solver which solves the system U x = b, when U is
* an upper triangular matrix. It works best when passing in a matrix in CSR
Expand Down Expand Up @@ -97,6 +100,16 @@ class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
return preconditioner_;
}

/**
* Get the triangular solve struct
*
* @return the trs solve struct
*/
gko::solver::SolveStruct *get_solve_struct() const
{
return solve_struct_.get();
}

GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
Expand All @@ -116,19 +129,14 @@ class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
GKO_ENABLE_LIN_OP_FACTORY(UpperTrs, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);

~UpperTrs() { this->clear_data(); }

protected:
void init_trs_solve_struct();

void apply_impl(const LinOp *b, LinOp *x) const override;

void apply_impl(const LinOp *alpha, const LinOp *b, const LinOp *beta,
LinOp *x) const override;

/**
* Clears the held data.
*/
void clear_data() const;

/**
* Generates the analysis structure from the system matrix and the right
* hand side(only dimensional info needed) needed for the level solver.
Expand Down Expand Up @@ -165,12 +173,14 @@ class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
preconditioner_ = matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()[0]);
}
this->init_trs_solve_struct();
this->generate();
}

private:
std::shared_ptr<const matrix::Csr<ValueType, IndexType>> system_matrix_{};
std::shared_ptr<const LinOp> preconditioner_{};
std::shared_ptr<gko::solver::SolveStruct> solve_struct_;
};


Expand Down
18 changes: 15 additions & 3 deletions omp/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/solver/upper_trs.hpp>


namespace gko {
Expand All @@ -58,16 +59,25 @@ namespace omp {
namespace upper_trs {


void clear(std::shared_ptr<const OmpExecutor> exec)
void perform_transpose(std::shared_ptr<const OmpExecutor> exec,
bool &transposability)
{
// This clear kernel is here to allow for a more sophisticated
// implementation as for other executors.
transposability = false;
}


void init_struct(std::shared_ptr<const OmpExecutor> exec,
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)
{
// This init kernel is here to allow initialization of the solve struct for
// a more sophisticated implementation as for other executors.
}


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const OmpExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
const gko::size_type num_rhs)
{
// This generate kernel is here to allow for a more sophisticated
Expand All @@ -82,6 +92,8 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void solve(std::shared_ptr<const OmpExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
matrix::Dense<ValueType> *trans_b, matrix::Dense<ValueType> *trans_x,
const matrix::Dense<ValueType> *b, matrix::Dense<ValueType> *x)
{
auto row_ptrs = matrix->get_const_row_ptrs();
Expand Down
33 changes: 29 additions & 4 deletions omp/test/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,18 @@ class UpperTrs : public ::testing::Test {
{
b = gen_mtx(m, n);
x = gen_mtx(m, n);
t_b = Mtx::create(ref);
t_x = Mtx::create(ref);
t_b->copy_from(b.get());
t_x->copy_from(x.get());
d_b = Mtx::create(omp);
d_b->copy_from(b.get());
d_x = Mtx::create(omp);
d_x->copy_from(x.get());
dt_b = Mtx::create(omp);
dt_b->copy_from(b.get());
dt_x = Mtx::create(omp);
dt_x->copy_from(x.get());
mat = gen_u_mtx(m, m);
csr_mat = CsrMtx::create(ref);
mat->convert_to(csr_mat.get());
Expand All @@ -113,23 +121,40 @@ class UpperTrs : public ::testing::Test {

std::shared_ptr<Mtx> b;
std::shared_ptr<Mtx> x;
std::shared_ptr<Mtx> t_b;
std::shared_ptr<Mtx> t_x;
std::shared_ptr<Mtx> mat;
std::shared_ptr<CsrMtx> csr_mat;
std::shared_ptr<Mtx> d_b;
std::shared_ptr<Mtx> d_x;
std::shared_ptr<Mtx> dt_b;
std::shared_ptr<Mtx> dt_x;
std::shared_ptr<Mtx> d_mat;
std::shared_ptr<CsrMtx> d_csr_mat;
std::shared_ptr<gko::solver::SolveStruct> solve_struct;
};


TEST_F(UpperTrs, OmpUpperTrsFlagCheckIsCorrect)
{
bool trans_flag = true;
bool expected_flag = false;
gko::kernels::omp::upper_trs::perform_transpose(omp, trans_flag);

ASSERT_EQ(expected_flag, trans_flag);
}


TEST_F(UpperTrs, OmpUpperTrsSolveIsEquivalentToRef)
{
initialize_data(59, 43);

gko::kernels::reference::upper_trs::solve(ref, csr_mat.get(), b.get(),
x.get());
gko::kernels::omp::upper_trs::solve(omp, d_csr_mat.get(), d_b.get(),
d_x.get());
gko::kernels::reference::upper_trs::solve(ref, csr_mat.get(),
solve_struct.get(), t_b.get(),
t_x.get(), b.get(), x.get());
gko::kernels::omp::upper_trs::solve(omp, d_csr_mat.get(),
solve_struct.get(), dt_b.get(),
dt_x.get(), d_b.get(), d_x.get());

GKO_ASSERT_MTX_NEAR(d_x, x, 1e-14);
}
Expand Down
20 changes: 16 additions & 4 deletions reference/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <iostream>
#include <ginkgo/core/solver/upper_trs.hpp>


namespace gko {
namespace kernels {
Expand All @@ -54,16 +55,25 @@ namespace reference {
namespace upper_trs {


void clear(std::shared_ptr<const ReferenceExecutor> exec)
void perform_transpose(std::shared_ptr<const ReferenceExecutor> exec,
bool &transposability)
{
transposability = false;
}


void init_struct(std::shared_ptr<const ReferenceExecutor> exec,
std::shared_ptr<gko::solver::SolveStruct> &solve_struct)
{
// This clear kernel is here to allow for a more sophisticated
// implementation as for other executors.
// This init kernel is here to allow initialization of the solve struct for
// a more sophisticated implementation as for other executors.
}


template <typename ValueType, typename IndexType>
void generate(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
const gko::size_type num_rhs)
{
// This generate kernel is here to allow for a more sophisticated
Expand All @@ -78,6 +88,8 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
template <typename ValueType, typename IndexType>
void solve(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Csr<ValueType, IndexType> *matrix,
gko::solver::SolveStruct *solve_struct,
matrix::Dense<ValueType> *trans_b, matrix::Dense<ValueType> *trans_x,
const matrix::Dense<ValueType> *b, matrix::Dense<ValueType> *x)
{
auto row_ptrs = matrix->get_const_row_ptrs();
Expand Down
Loading

0 comments on commit 994fd4b

Please sign in to comment.