Skip to content

Commit

Permalink
modification according to the suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Aug 22, 2019
1 parent cf06242 commit f677ddf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 28 deletions.
42 changes: 29 additions & 13 deletions core/base/perturbation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ void Perturbation<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
// x = 1 * x + scalar * basis * temp : basis->apply(scalar, temp, 1, x)
using vec = gko::matrix::Dense<ValueType>;
auto exec = this->get_executor();
auto temp = vec::create(
exec, gko::dim<2>(this->projector_->get_size()[0], b->get_size()[1]));
this->projector_->apply(b, lend(temp));
if (cache_.one == nullptr) {
cache_.one = initialize<vec>({gko::one<ValueType>()}, exec);
}
auto intermediate_size =
gko::dim<2>(projector_->get_size()[0], b->get_size()[1]);
if (cache_.intermediate == nullptr ||
cache_.intermediate->get_size() != intermediate_size) {
cache_.intermediate = vec::create(exec, intermediate_size);
}
projector_->apply(b, lend(cache_.intermediate));
x->copy_from(b);
auto one = gko::initialize<vec>({1.0}, exec);
this->basis_->apply(lend(this->scalar_), lend(temp), lend(one), x);
basis_->apply(lend(scalar_), lend(cache_.intermediate), lend(cache_.one),
x);
}


Expand All @@ -70,17 +77,26 @@ void Perturbation<ValueType>::apply_impl(const LinOp *alpha, const LinOp *b,
// : basis->apply(alpha * scalar, temp, 1, x)
using vec = gko::matrix::Dense<ValueType>;
auto exec = this->get_executor();
auto temp = vec::create(
exec, gko::dim<2>(this->projector_->get_size()[0], b->get_size()[1]));
this->projector_->apply(b, lend(temp));
if (cache_.one == nullptr) {
cache_.one = initialize<vec>({gko::one<ValueType>()}, exec);
}
auto intermediate_size =
gko::dim<2>(projector_->get_size()[0], b->get_size()[1]);
if (cache_.intermediate == nullptr ||
cache_.intermediate->get_size() != intermediate_size) {
cache_.intermediate = vec::create(exec, intermediate_size);
}
projector_->apply(b, lend(cache_.intermediate));
auto vec_x = as<vec>(x);
vec_x->scale(beta);
vec_x->add_scaled(alpha, b);
auto one = gko::initialize<vec>({1.0}, exec);
auto alpha_scalar = vec::create(exec, gko::dim<2>(1));
alpha_scalar->copy_from(alpha);
alpha_scalar->scale(lend(this->scalar_));
this->basis_->apply(lend(alpha_scalar), lend(temp), lend(one), vec_x);
if (cache_.alpha_scalar == nullptr) {
cache_.alpha_scalar = vec::create(exec, gko::dim<2>(1));
}
cache_.alpha_scalar->copy_from(alpha);
cache_.alpha_scalar->scale(lend(scalar_));
basis_->apply(lend(cache_.alpha_scalar), lend(cache_.intermediate),
lend(cache_.one), vec_x);
}


Expand Down
27 changes: 21 additions & 6 deletions include/ginkgo/core/base/perturbation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/matrix/dense.hpp>


namespace gko {
Expand All @@ -49,11 +50,14 @@ namespace gko {
* a direction constructed by `basis` and `projector` on the LinOp. `projector`
* gives the coefficient of `basis` to decide the direction.
* For example, Householder matrix can be represented in Perturbation.
* Householder matrix = (I - 2 u u*), u is the housholder factor
* scalar = -2, basis = u, and projector = u*
* u is the housholder factor and then we can generate the Householder matrix =
* (I - 2 u u*). In this case, the parameters of Perturbation class are scalar =
* -2, basis = u, and projector = u*.
*
* @tparam ValueType precision of input and result vectors
*
* @note the apply operations pf Perturbation class are not thread safe
*
* @ingroup LinOp
*/
template <typename ValueType = default_precision>
Expand All @@ -68,7 +72,7 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
/**
* Returns the basis of the perturbation.
*
* @return the basis
* @return the basis of the perturbation
*/
const std::shared_ptr<const LinOp> get_basis() const noexcept
{
Expand All @@ -78,7 +82,7 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
/**
* Returns the projector of the perturbation.
*
* @return the projector
* @return the projector of the perturbation
*/
const std::shared_ptr<const LinOp> get_projector() const noexcept
{
Expand All @@ -88,7 +92,7 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
/**
* Returns the scalar of the perturbation.
*
* @return the scalar
* @return the scalar of the perturbation
*/
const std::shared_ptr<const LinOp> get_scalar() const noexcept
{
Expand All @@ -108,7 +112,7 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
/**
* Creates a perturbation with scalar and basis by setting projector to the
* conjugate transpose of basis. Basis must be transposable. Perturbation
* will throw GKO_NOT_SUPPORT if basis is not transposable.
* will throw gko::NotSupported if basis is not transposable.
*
* @param scalar scaling of the movement
* @param basis the direction basis
Expand Down Expand Up @@ -164,6 +168,17 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
std::shared_ptr<const LinOp> basis_;
std::shared_ptr<const LinOp> projector_;
std::shared_ptr<const LinOp> scalar_;

// TODO: solve race conditions when multithreading
mutable struct cache_struct {
cache_struct() = default;
cache_struct(const cache_struct &other) {}
cache_struct &operator=(const cache_struct &other) { return *this; }

std::unique_ptr<LinOp> intermediate;
std::unique_ptr<LinOp> one;
std::unique_ptr<matrix::Dense<ValueType>> alpha_scalar;
} cache_;
};


Expand Down
16 changes: 8 additions & 8 deletions reference/test/base/perturbation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ TEST_F(Perturbation, AppliesToVector)
*/
auto cmp = gko::Perturbation<>::create(scalar, basis, projector);
auto x = gko::initialize<mtx>({1.0, 2.0}, exec);
auto res = clone(x);
auto res = mtx::create_with_config_of(gko::lend(x));

cmp->apply(lend(x), lend(res));
cmp->apply(gko::lend(x), gko::lend(res));

GKO_ASSERT_MTX_NEAR(res, l({29.0, 16.0}), 1e-15);
}
Expand All @@ -92,9 +92,9 @@ TEST_F(Perturbation, AppliesLinearCombinationToVector)
auto alpha = gko::initialize<mtx>({3.0}, exec);
auto beta = gko::initialize<mtx>({-1.0}, exec);
auto x = gko::initialize<mtx>({1.0, 2.0}, exec);
auto res = clone(x);
auto res = gko::clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));
cmp->apply(gko::lend(alpha), gko::lend(x), gko::lend(beta), gko::lend(res));

GKO_ASSERT_MTX_NEAR(res, l({86.0, 46.0}), 1e-15);
}
Expand All @@ -108,9 +108,9 @@ TEST_F(Perturbation, ConstructionByBasisAppliesToVector)
*/
auto cmp = gko::Perturbation<>::create(scalar, basis);
auto x = gko::initialize<mtx>({1.0, 2.0}, exec);
auto res = clone(x);
auto res = mtx::create_with_config_of(gko::lend(x));

cmp->apply(lend(x), lend(res));
cmp->apply(gko::lend(x), gko::lend(res));

GKO_ASSERT_MTX_NEAR(res, l({17.0, 10.0}), 1e-15);
}
Expand All @@ -126,9 +126,9 @@ TEST_F(Perturbation, ConstructionByBasisAppliesLinearCombinationToVector)
auto alpha = gko::initialize<mtx>({3.0}, exec);
auto beta = gko::initialize<mtx>({-1.0}, exec);
auto x = gko::initialize<mtx>({1.0, 2.0}, exec);
auto res = clone(x);
auto res = gko::clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));
cmp->apply(gko::lend(alpha), gko::lend(x), gko::lend(beta), gko::lend(res));

GKO_ASSERT_MTX_NEAR(res, l({50.0, 28.0}), 1e-15);
}
Expand Down
2 changes: 1 addition & 1 deletion test_install/test_install.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ int main(int, char **)
using type1 = int;
static_assert(
std::is_same<gko::Perturbation<type1>::value_type, type1>::value,
"Perturbation.hpp not included properly");
"perturbation.hpp not included properly");
}

// core/base/std_extensions.hpp
Expand Down

0 comments on commit f677ddf

Please sign in to comment.