Skip to content

Commit

Permalink
Add the basic TRS algorithm and start with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 10, 2019
1 parent f8c0695 commit 44c0c0c
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 716 deletions.
17 changes: 3 additions & 14 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,13 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CG_STEP_2_KERNEL);
} // namespace cg


// TODO (script): adapt this block as needed
namespace trs {


template <typename ValueType>
GKO_DECLARE_TRS_INITIALIZE_KERNEL(ValueType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_TRS_INITIALIZE_KERNEL);

template <typename ValueType>
GKO_DECLARE_TRS_STEP_1_KERNEL(ValueType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_TRS_STEP_1_KERNEL);

template <typename ValueType>
GKO_DECLARE_TRS_STEP_2_KERNEL(ValueType)
template <typename ValueType, typename IndexType>
GKO_DECLARE_TRS_SOLVE_KERNEL(ValueType, IndexType)
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_TRS_STEP_2_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_TRS_SOLVE_KERNEL);


} // namespace trs
Expand Down
162 changes: 61 additions & 101 deletions core/solver/trs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,123 +39,83 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/name_demangling.hpp>
#include <ginkgo/core/base/utils.hpp>
#include <ginkgo/core/matrix/csr.hpp>


#include <iostream>
#include "core/matrix/csr_kernels.hpp"
#include "core/solver/trs_kernels.hpp"


namespace gko {
namespace solver {


namespace trs {


GKO_REGISTER_OPERATION(initialize, trs::initialize);
GKO_REGISTER_OPERATION(step_1, trs::step_1);
GKO_REGISTER_OPERATION(step_2, trs::step_2);
GKO_REGISTER_OPERATION(solve, trs::solve);


} // namespace trs


template <typename ValueType>
void Trs<ValueType>::apply_impl(const LinOp *b,
LinOp *x) const GKO_NOT_IMPLEMENTED;
//{
// TODO (script): change the code imported from solver/cg if needed
// using std::swap;
// using Vector = matrix::Dense<ValueType>;
//
// constexpr uint8 RelativeStoppingId{1};
//
// auto exec = this->get_executor();
//
// auto one_op = initialize<Vector>({one<ValueType>()}, exec);
// auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);
//
// auto dense_b = as<const Vector>(b);
// auto dense_x = as<Vector>(x);
// auto r = Vector::create_with_config_of(dense_b);
// auto z = Vector::create_with_config_of(dense_b);
// auto p = Vector::create_with_config_of(dense_b);
// auto q = Vector::create_with_config_of(dense_b);
//
// auto alpha = Vector::create(exec, dim<2>{1, dense_b->get_size()[1]});
// auto beta = Vector::create_with_config_of(alpha.get());
// auto prev_rho = Vector::create_with_config_of(alpha.get());
// auto rho = Vector::create_with_config_of(alpha.get());
//
// bool one_changed{};
// Array<stopping_status> stop_status(alpha->get_executor(),
// dense_b->get_size()[1]);
//
// // TODO: replace this with automatic merged kernel generator
// exec->run(trs::make_initialize(dense_b, r.get(), z.get(), p.get(),
// q.get(),
// prev_rho.get(), rho.get(), &stop_status));
// // r = dense_b
// // rho = 0.0
// // prev_rho = 1.0
// // z = p = q = 0
//
// system_matrix_->apply(neg_one_op.get(), dense_x, one_op.get(), r.get());
// auto stop_criterion = stop_criterion_factory_->generate(
// system_matrix_, std::shared_ptr<const LinOp>(b, [](const LinOp *) {}),
// x, r.get());
//
// int iter = -1;
// while (true) {
// preconditioner_->apply(r.get(), z.get());
// r->compute_dot(z.get(), rho.get());
//
// ++iter;
// this->template log<log::Logger::iteration_complete>(this, iter,
// r.get(),
// dense_x);
// if (stop_criterion->update()
// .num_iterations(iter)
// .residual(r.get())
// .solution(dense_x)
// .check(RelativeStoppingId, true, &stop_status, &one_changed))
// {
// break;
// }
//
// exec->run(trs::make_step_1(p.get(), z.get(), rho.get(),
// prev_rho.get(),
// &stop_status));
// // tmp = rho / prev_rho
// // p = z + tmp * p
// system_matrix_->apply(p.get(), q.get());
// p->compute_dot(q.get(), beta.get());
// exec->run(trs::make_step_2(dense_x, r.get(), p.get(), q.get(),
// beta.get(), rho.get(), &stop_status));
// // tmp = rho / beta
// // x = x + tmp * p
// // r = r - tmp * q
// swap(prev_rho, rho);
// }
//}


template <typename ValueType>
void Trs<ValueType>::apply_impl(const LinOp *alpha, const LinOp *b,
const LinOp *beta,
LinOp *x) const GKO_NOT_IMPLEMENTED;
//{
// TODO (script): change the code imported from solver/cg if needed
// auto dense_x = as<matrix::Dense<ValueType>>(x);
//
// auto x_clone = dense_x->clone();
// this->apply(b, x_clone.get());
// dense_x->scale(beta);
// dense_x->add_scaled(alpha, x_clone.get());
//}


#define GKO_DECLARE_TRS(_type) class Trs<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_TRS);
template <typename ValueType, typename IndexType>
void Trs<ValueType, IndexType>::apply_impl(const LinOp *b, LinOp *x) const
{
using Vector = matrix::Dense<ValueType>;
using CsrMatrix = matrix::Csr<ValueType, IndexType>;

GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix_);
const auto exec = this->get_executor();
const auto host_exec = exec->get_master();

auto dense_x = as<Vector>(x);
auto dense_b = as<const Vector>(b);

// If required, it is also possible to make this a Factory parameter
auto csr_strategy = std::make_shared<typename CsrMatrix::cusparse>();

// Only copies the matrix if it is not on the same executor or was not in
// the right format. Throws an exception if it is not convertable.
std::unique_ptr<CsrMatrix> csr_system_matrix_unique_ptr{};
auto csr_system_matrix =
dynamic_cast<const CsrMatrix *>(system_matrix_.get());
if (csr_system_matrix == nullptr ||
csr_system_matrix->get_executor() != exec) {
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
as<ConvertibleTo<CsrMatrix>>(system_matrix_.get())
->convert_to(csr_system_matrix_unique_ptr.get());
csr_system_matrix = csr_system_matrix_unique_ptr.get();
}
If it needs to be sorted,
copy it if necessary and sort it if (csr_system_matrix_unique_ptr ==
nullptr)
{
csr_system_matrix_unique_ptr = CsrMatrix::create(exec);
csr_system_matrix_unique_ptr->copy_from(csr_system_matrix);
}
csr_system_matrix_unique_ptr->sort_by_column_index();
csr_system_matrix = csr_system_matrix_unique_ptr.get();

exec->run(trs::make_solve(gko::lend(csr_system_matrix), dense_b, dense_x));
}


template <typename ValueType, typename IndexType>
void Trs<ValueType, IndexType>::apply_impl(const LinOp *alpha, const LinOp *b,
const LinOp *beta, LinOp *x) const
{
auto dense_x = as<matrix::Dense<ValueType>>(x);

auto x_clone = dense_x->clone();
this->apply(b, x_clone.get());
dense_x->scale(beta);
dense_x->add_scaled(alpha, x_clone.get());
}


#define GKO_DECLARE_TRS(_vtype, _itype) class Trs<_vtype, _itype>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_TRS);


} // namespace solver
Expand Down
39 changes: 8 additions & 31 deletions core/solver/trs_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/stop/stopping_status.hpp>

Expand All @@ -45,39 +46,15 @@ namespace kernels {
namespace trs {


#define GKO_DECLARE_TRS_INITIALIZE_KERNEL(_type) \
void initialize(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Dense<_type> *b, matrix::Dense<_type> *r, \
matrix::Dense<_type> *z, matrix::Dense<_type> *p, \
matrix::Dense<_type> *q, matrix::Dense<_type> *prev_rho, \
matrix::Dense<_type> *rho, \
Array<stopping_status> *stop_status)
#define GKO_DECLARE_TRS_SOLVE_KERNEL(_vtype, _itype) \
void solve(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<_vtype, _itype> *matrix, \
const matrix::Dense<_vtype> *b, matrix::Dense<_vtype> *x)


#define GKO_DECLARE_TRS_STEP_1_KERNEL(_type) \
void step_1(std::shared_ptr<const DefaultExecutor> exec, \
matrix::Dense<_type> *p, const matrix::Dense<_type> *z, \
const matrix::Dense<_type> *rho, \
const matrix::Dense<_type> *prev_rho, \
const Array<stopping_status> *stop_status)


#define GKO_DECLARE_TRS_STEP_2_KERNEL(_type) \
void step_2(std::shared_ptr<const DefaultExecutor> exec, \
matrix::Dense<_type> *x, matrix::Dense<_type> *r, \
const matrix::Dense<_type> *p, const matrix::Dense<_type> *q, \
const matrix::Dense<_type> *beta, \
const matrix::Dense<_type> *rho, \
const Array<stopping_status> *stop_status)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType> \
GKO_DECLARE_TRS_INITIALIZE_KERNEL(ValueType); \
template <typename ValueType> \
GKO_DECLARE_TRS_STEP_1_KERNEL(ValueType); \
template <typename ValueType> \
GKO_DECLARE_TRS_STEP_2_KERNEL(ValueType)
#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_TRS_SOLVE_KERNEL(ValueType, IndexType)


} // namespace trs
Expand Down
Loading

0 comments on commit 44c0c0c

Please sign in to comment.