Skip to content

Commit

Permalink
Merge reusing tmp storage for reductions in solvers
Browse files Browse the repository at this point in the history
This uses the new reduction interface with temporary storage introduced in #990 in our iterative solvers.

Related PR: #1013
  • Loading branch information
upsj committed Apr 12, 2022
2 parents 029e0f8 + 849dfbc commit fda151c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 17 deletions.
6 changes: 4 additions & 2 deletions core/solver/bicg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ void Bicg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,

auto exec = this->get_executor();

Array<char> reduction_tmp{exec};

auto one_op = initialize<Vector>({one<ValueType>()}, exec);
auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);

Expand Down Expand Up @@ -208,7 +210,7 @@ void Bicg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
while (true) {
get_preconditioner()->apply(r.get(), z.get());
conj_trans_preconditioner->apply(r2.get(), z2.get());
z->compute_conj_dot(r2.get(), rho.get());
z->compute_conj_dot(r2.get(), rho.get(), reduction_tmp);

++iter;
this->template log<log::Logger::iteration_complete>(
Expand All @@ -229,7 +231,7 @@ void Bicg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
rho.get(), prev_rho.get(), &stop_status));
system_matrix_->apply(p.get(), q.get());
conj_trans_A->apply(p2.get(), q2.get());
p2->compute_conj_dot(q.get(), beta.get());
p2->compute_conj_dot(q.get(), beta.get(), reduction_tmp);
// tmp = rho / beta
// x = x + tmp * p
// r = r - tmp * q
Expand Down
10 changes: 6 additions & 4 deletions core/solver/bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ void Bicgstab<ValueType>::apply_dense_impl(

auto exec = this->get_executor();

Array<char> reduction_tmp{exec};

auto one_op = initialize<Vector>({one<ValueType>()}, exec);
auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);

Expand Down Expand Up @@ -168,7 +170,7 @@ void Bicgstab<ValueType>::apply_dense_impl(
++iter;
this->template log<log::Logger::iteration_complete>(
this, iter, r.get(), dense_x, nullptr, rho.get());
rr->compute_conj_dot(r.get(), rho.get());
rr->compute_conj_dot(r.get(), rho.get(), reduction_tmp);

if (stop_criterion->update()
.num_iterations(iter)
Expand All @@ -187,7 +189,7 @@ void Bicgstab<ValueType>::apply_dense_impl(

get_preconditioner()->apply(p.get(), y.get());
system_matrix_->apply(y.get(), v.get());
rr->compute_conj_dot(v.get(), beta.get());
rr->compute_conj_dot(v.get(), beta.get(), reduction_tmp);
// alpha = rho / beta
// s = r - alpha * v
exec->run(bicgstab::make_step_2(r.get(), s.get(), v.get(), rho.get(),
Expand All @@ -210,8 +212,8 @@ void Bicgstab<ValueType>::apply_dense_impl(

get_preconditioner()->apply(s.get(), z.get());
system_matrix_->apply(z.get(), t.get());
s->compute_conj_dot(t.get(), gamma.get());
t->compute_conj_dot(t.get(), beta.get());
s->compute_conj_dot(t.get(), gamma.get(), reduction_tmp);
t->compute_conj_dot(t.get(), beta.get(), reduction_tmp);
// omega = gamma / beta
// x = x + alpha * y + omega * z
// r = s - omega * t
Expand Down
6 changes: 4 additions & 2 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ void Cg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,

auto exec = this->get_executor();

Array<char> reduction_tmp{exec};

auto one_op = initialize<Vector>({one<ValueType>()}, exec);
auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);

Expand Down Expand Up @@ -151,7 +153,7 @@ void Cg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
*/
while (true) {
get_preconditioner()->apply(r.get(), z.get());
r->compute_conj_dot(z.get(), rho.get());
r->compute_conj_dot(z.get(), rho.get(), reduction_tmp);

++iter;
this->template log<log::Logger::iteration_complete>(
Expand All @@ -170,7 +172,7 @@ void Cg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
exec->run(cg::make_step_1(p.get(), z.get(), rho.get(), prev_rho.get(),
&stop_status));
system_matrix_->apply(p.get(), q.get());
p->compute_conj_dot(q.get(), beta.get());
p->compute_conj_dot(q.get(), beta.get(), reduction_tmp);
// tmp = rho / beta
// x = x + tmp * p
// r = r - tmp * q
Expand Down
6 changes: 4 additions & 2 deletions core/solver/cgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ void Cgs<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
auto exec = this->get_executor();
size_type num_vectors = dense_b->get_size()[1];

Array<char> reduction_tmp{exec};

auto one_op = initialize<Vector>({one<ValueType>()}, exec);
auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);

Expand Down Expand Up @@ -161,7 +163,7 @@ void Cgs<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
* 1x norm2 residual n
*/
while (true) {
r->compute_conj_dot(r_tld.get(), rho.get());
r->compute_conj_dot(r_tld.get(), rho.get(), reduction_tmp);

++iter;
this->template log<log::Logger::iteration_complete>(
Expand All @@ -183,7 +185,7 @@ void Cgs<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
&stop_status));
get_preconditioner()->apply(p.get(), t.get());
system_matrix_->apply(t.get(), v_hat.get());
r_tld->compute_conj_dot(v_hat.get(), gamma.get());
r_tld->compute_conj_dot(v_hat.get(), gamma.get(), reduction_tmp);
// alpha = rho / gamma
// q = u - alpha * v_hat
// t = u + q
Expand Down
8 changes: 5 additions & 3 deletions core/solver/fcg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void Fcg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,

auto exec = this->get_executor();

Array<char> reduction_tmp{exec};

auto one_op = initialize<Vector>({one<ValueType>()}, exec);
auto neg_one_op = initialize<Vector>({-one<ValueType>()}, exec);

Expand Down Expand Up @@ -155,8 +157,8 @@ void Fcg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
*/
while (true) {
get_preconditioner()->apply(r.get(), z.get());
r->compute_conj_dot(z.get(), rho.get());
t->compute_conj_dot(z.get(), rho_t.get());
r->compute_conj_dot(z.get(), rho.get(), reduction_tmp);
t->compute_conj_dot(z.get(), rho_t.get(), reduction_tmp);

++iter;
this->template log<log::Logger::iteration_complete>(
Expand All @@ -175,7 +177,7 @@ void Fcg<ValueType>::apply_dense_impl(const matrix::Dense<ValueType>* dense_b,
exec->run(fcg::make_step_1(p.get(), z.get(), rho_t.get(),
prev_rho.get(), &stop_status));
system_matrix_->apply(p.get(), q.get());
p->compute_conj_dot(q.get(), beta.get());
p->compute_conj_dot(q.get(), beta.get(), reduction_tmp);
// tmp = rho / beta
// [prev_r = r] in registers
// x = x + tmp * p
Expand Down
10 changes: 6 additions & 4 deletions core/solver/idr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ void Idr<ValueType>::iterate(const matrix::Dense<SubspaceType>* dense_b,

auto exec = this->get_executor();

Array<char> reduction_tmp{exec};

auto one_op =
initialize<matrix::Dense<ValueType>>({one<ValueType>()}, exec);
auto neg_one_op =
Expand Down Expand Up @@ -156,7 +158,7 @@ void Idr<ValueType>::iterate(const matrix::Dense<SubspaceType>* dense_b,
residual->copy_from(dense_b);
system_matrix_->apply(neg_one_op.get(), dense_x, one_op.get(),
residual.get());
residual->compute_norm2(residual_norm.get());
residual->compute_norm2(residual_norm.get(), reduction_tmp);

// g = u = 0
exec->run(idr::make_fill_array(
Expand Down Expand Up @@ -248,9 +250,9 @@ void Idr<ValueType>::iterate(const matrix::Dense<SubspaceType>* dense_b,
get_preconditioner()->apply(residual.get(), helper.get());
system_matrix_->apply(helper.get(), t.get());

t->compute_conj_dot(residual.get(), omega.get());
t->compute_conj_dot(t.get(), tht.get());
residual->compute_norm2(residual_norm.get());
t->compute_conj_dot(residual.get(), omega.get(), reduction_tmp);
t->compute_conj_dot(t.get(), tht.get(), reduction_tmp);
residual->compute_norm2(residual_norm.get(), reduction_tmp);

// omega = (t^H * residual) / (t^H * t)
// rho = (t^H * residual) / (norm(t) * norm(residual))
Expand Down

0 comments on commit fda151c

Please sign in to comment.