Skip to content

Commit

Permalink
add allocate in cache and improve the doc
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Aug 22, 2019
1 parent f9f0549 commit c923f48
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
22 changes: 3 additions & 19 deletions core/base/perturbation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,9 @@ void Perturbation<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
// x = 1 * x + scalar * basis * temp : basis->apply(scalar, temp, 1, x)
using vec = gko::matrix::Dense<ValueType>;
auto exec = this->get_executor();
if (cache_.one == nullptr) {
cache_.one = initialize<vec>({gko::one<ValueType>()}, exec);
}
auto intermediate_size =
gko::dim<2>(projector_->get_size()[0], b->get_size()[1]);
if (cache_.intermediate == nullptr ||
cache_.intermediate->get_size() != intermediate_size) {
cache_.intermediate = vec::create(exec, intermediate_size);
}
cache_.allocate(exec, intermediate_size);
projector_->apply(b, lend(cache_.intermediate));
x->copy_from(b);
basis_->apply(lend(scalar_), lend(cache_.intermediate), lend(cache_.one),
Expand All @@ -77,24 +71,14 @@ void Perturbation<ValueType>::apply_impl(const LinOp *alpha, const LinOp *b,
// : basis->apply(alpha * scalar, temp, 1, x)
using vec = gko::matrix::Dense<ValueType>;
auto exec = this->get_executor();
if (cache_.one == nullptr) {
cache_.one = initialize<vec>({gko::one<ValueType>()}, exec);
}
auto intermediate_size =
gko::dim<2>(projector_->get_size()[0], b->get_size()[1]);
if (cache_.intermediate == nullptr ||
cache_.intermediate->get_size() != intermediate_size) {
cache_.intermediate = vec::create(exec, intermediate_size);
}
cache_.allocate(exec, intermediate_size);
projector_->apply(b, lend(cache_.intermediate));
auto vec_x = as<vec>(x);
vec_x->scale(beta);
vec_x->add_scaled(alpha, b);
if (cache_.alpha_scalar == nullptr) {
cache_.alpha_scalar = vec::create(exec, gko::dim<2>(1));
}
cache_.alpha_scalar->copy_from(alpha);
cache_.alpha_scalar->scale(lend(scalar_));
alpha->apply(lend(scalar_), lend(cache_.alpha_scalar));
basis_->apply(lend(cache_.alpha_scalar), lend(cache_.intermediate),
lend(cache_.one), vec_x);
}
Expand Down
45 changes: 32 additions & 13 deletions include/ginkgo/core/base/perturbation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,20 @@ namespace gko {

/**
* The Perturbation class can be used to construct a LinOp to represent the
* `(identity + scalar * basis * projector)` This operator adds a movement along
* a direction constructed by `basis` and `projector` on the LinOp. `projector`
* gives the coefficient of `basis` to decide the direction.
* For example, Householder matrix can be represented in Perturbation.
* u is the householder factor and then we can generate the Householder matrix =
* (I - 2 u u*). In this case, the parameters of Perturbation class are
* operation `(identity + scalar * basis * projector)`. This operator adds a
* movement along a direction constructed by `basis` and `projector` on the
* LinOp. `projector` gives the coefficient of `basis` to decide the direction.
*
* For example, the Householder matrix can be represented with the Perturbation
* operator as follows.
* If u is the Householder factor then we can generate the [Householder
* transformation](https://en.wikipedia.org/wiki/Householder_transformation),
* H = (I - 2 u u*). In this case, the parameters of Perturbation class are
* scalar = -2, basis = u, and projector = u*.
*
* @tparam ValueType precision of input and result vectors
*
* @note the apply operations pf Perturbation class are not thread safe
* @note the apply operations of Perturbation class are not thread safe
*
* @ingroup LinOp
*/
Expand Down Expand Up @@ -101,7 +104,7 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,

protected:
/**
* Creates an empty operator perturbation (0x0 operator).
* Creates an empty perturbation operator (0x0 operator).
*
* @param exec Executor associated to the perturbation
*/
Expand Down Expand Up @@ -152,10 +155,9 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
LinOp *x) const override;

/**
* validate_perturbation check the dimension of scalar, basis, projector.
* scalar must be 1 by 1.
* The dimension of basis should be same as the dimension of conjugate
* transpose of projector.
* Validates the dimensions of the `scalar`, `basis` and `projector`
* parameters for the `apply`. scalar must be 1 by 1. The dimension of basis
* should be same as the dimension of conjugate transpose of projector.
*/
void validate_perturbation()
{
Expand All @@ -176,9 +178,26 @@ class Perturbation : public EnableLinOp<Perturbation<ValueType>>,
cache_struct(const cache_struct &other) {}
cache_struct &operator=(const cache_struct &other) { return *this; }

// allocate linops of cache. The dimenstion of `intermediate` is
// (the number of rows of projector, the number of columns of b). Others
// are 1x1 scalar.
void allocate(std::shared_ptr<const Executor> exec, dim<2> size)
{
using vec = gko::matrix::Dense<ValueType>;
if (one == nullptr) {
one = initialize<vec>({gko::one<ValueType>()}, exec);
}
if (alpha_scalar == nullptr) {
alpha_scalar = vec::create(exec, gko::dim<2>(1));
}
if (intermediate == nullptr || intermediate->get_size() != size) {
intermediate = vec::create(exec, size);
}
}

std::unique_ptr<LinOp> intermediate;
std::unique_ptr<LinOp> one;
std::unique_ptr<matrix::Dense<ValueType>> alpha_scalar;
std::unique_ptr<LinOp> alpha_scalar;
} cache_;
};

Expand Down

0 comments on commit c923f48

Please sign in to comment.