Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Composition intermediate storage #540

Merged
merged 5 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 61 additions & 35 deletions core/base/composition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,63 +33,89 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/composition.hpp>


#include <ginkgo/core/matrix/dense.hpp>
#include <algorithm>


namespace gko {
namespace {
#include <ginkgo/core/matrix/dense.hpp>


template <typename ValueType, typename OpIterator, typename VecIterator>
inline void allocate_vectors(OpIterator begin, OpIterator end, VecIterator res)
{
for (auto it = begin; it != end; ++it) {
if (*res == nullptr || (*res)->get_size()[0] != (*it)->get_size()[0]) {
*res = matrix::Dense<ValueType>::create(
(*it)->get_executor(), gko::dim<2>{(*it)->get_size()[0], 1});
}
++res;
}
}
namespace gko {


inline const LinOp *apply_inner_operators(
template <typename ValueType>
std::unique_ptr<LinOp> apply_inner_operators(
const std::vector<std::shared_ptr<const LinOp>> &operators,
const std::vector<std::unique_ptr<LinOp>> &intermediate, const LinOp *rhs)
Array<ValueType> &storage, const LinOp *rhs)
{
for (auto i = operators.size() - 1; i > 0u; --i) {
auto solution = lend(intermediate[i - 1]);
operators[i]->apply(rhs, solution);
rhs = solution;
using Dense = matrix::Dense<ValueType>;
// determine amount of necessary storage:
// maximum sum of two subsequent intermediate vectors
// (and the out dimension of the last op if we only have one operator)
auto num_rhs = rhs->get_size()[1];
auto max_intermediate_size = std::accumulate(
begin(operators) + 1, end(operators) - 1,
operators.back()->get_size()[0],
[](size_type acc, std::shared_ptr<const LinOp> op) {
return std::max(acc, op->get_size()[0] + op->get_size()[1]);
thoasm marked this conversation as resolved.
Show resolved Hide resolved
});
auto storage_size = max_intermediate_size * num_rhs;
yhmtsai marked this conversation as resolved.
Show resolved Hide resolved
storage.resize_and_reset(storage_size);

// apply inner vectors
auto exec = rhs->get_executor();
auto data = storage.get_data();
// apply last operator
auto out_dim = gko::dim<2>{operators.back()->get_size()[0], num_rhs};
auto out = Dense::create(
exec, out_dim, Array<ValueType>::view(exec, out_dim[0] * num_rhs, data),
num_rhs);
operators.back()->apply(rhs, lend(out));
// apply following operators
// alternate intermediate vectors between beginning/end of storage
auto reversed_storage = true;
for (auto i = operators.size() - 2; i > 0; --i) {
// swap in and out
auto in = std::move(out);
// build new intermediate vector
out_dim[0] = operators[i]->get_size()[0];
auto out_size = out_dim[0] * num_rhs;
auto out_data =
data + (reversed_storage ? storage_size - out_size : size_type{});
reversed_storage = !reversed_storage;
out = Dense::create(exec, out_dim,
Array<ValueType>::view(exec, out_size, out_data),
num_rhs);
// apply operator
operators[i]->apply(lend(in), lend(out));
}
return rhs;
}


} // namespace
return std::move(out);
thoasm marked this conversation as resolved.
Show resolved Hide resolved
}


template <typename ValueType>
void Composition<ValueType>::apply_impl(const LinOp *b, LinOp *x) const
{
cache_.intermediate.resize(operators_.size() - 1);
allocate_vectors<ValueType>(begin(operators_) + 1, end(operators_),
begin(cache_.intermediate));
operators_[0]->apply(
apply_inner_operators(operators_, cache_.intermediate, b), x);
if (operators_.size() > 1) {
operators_[0]->apply(
lend(apply_inner_operators(operators_, storage_, b)), x);
} else {
operators_[0]->apply(b, x);
}
}


template <typename ValueType>
void Composition<ValueType>::apply_impl(const LinOp *alpha, const LinOp *b,
const LinOp *beta, LinOp *x) const
{
cache_.intermediate.resize(operators_.size() - 1);
allocate_vectors<ValueType>(begin(operators_) + 1, end(operators_),
begin(cache_.intermediate));
operators_[0]->apply(
alpha, apply_inner_operators(operators_, cache_.intermediate, b), beta,
x);
if (operators_.size() > 1) {
operators_[0]->apply(
alpha, lend(apply_inner_operators(operators_, storage_, b)), beta,
x);
} else {
operators_[0]->apply(alpha, b, beta, x);
}
}


Expand Down
23 changes: 7 additions & 16 deletions include/ginkgo/core/base/composition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class Composition : public EnableLinOp<Composition<ValueType>>,
*
* @return a list of operators
*/
const std::vector<std::shared_ptr<const LinOp>> &get_operators() const
noexcept
const std::vector<std::shared_ptr<const LinOp>> &get_operators()
const noexcept
{
return operators_;
}
Expand All @@ -79,7 +79,7 @@ class Composition : public EnableLinOp<Composition<ValueType>>,
* @param exec Executor associated to the composition
*/
explicit Composition(std::shared_ptr<const Executor> exec)
: EnableLinOp<Composition>(exec)
: EnableLinOp<Composition>(exec), storage_{exec}
{}

/**
Expand All @@ -101,6 +101,7 @@ class Composition : public EnableLinOp<Composition<ValueType>>,
}
return (*begin)->get_executor();
}()),
storage_{(*begin)->get_executor()},
operators_(begin, end)
{
this->set_size(gko::dim<2>{operators_.front()->get_size()[0],
Expand Down Expand Up @@ -138,7 +139,8 @@ class Composition : public EnableLinOp<Composition<ValueType>>,
*/
explicit Composition(std::shared_ptr<const LinOp> oper)
: EnableLinOp<Composition>(oper->get_executor(), oper->get_size()),
operators_{oper}
operators_{oper},
storage_{oper->get_executor()}
{}

void apply_impl(const LinOp *b, LinOp *x) const override;
Expand All @@ -148,18 +150,7 @@ class Composition : public EnableLinOp<Composition<ValueType>>,

private:
std::vector<std::shared_ptr<const LinOp>> operators_;

// TODO: solve race conditions when multithreading
mutable struct cache_struct {
cache_struct() = default;
~cache_struct() = default;
cache_struct(const cache_struct &other) {}
cache_struct &operator=(const cache_struct &other) { return *this; }

// TODO: reduce the amount of intermediate vectors we need (careful --
// not all of them are of the same size)
std::vector<std::unique_ptr<LinOp>> intermediate;
} cache_;
mutable Array<ValueType> storage_;
};


Expand Down
174 changes: 172 additions & 2 deletions reference/test/base/composition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,68 @@ class Composition : public ::testing::Test {

Composition()
: exec{gko::ReferenceExecutor::create()},
operators{gko::initialize<Mtx>(I<T>({2.0, 1.0}), exec),
gko::initialize<Mtx>({I<T>({3.0, 2.0})}, exec)}
operators{
gko::initialize<Mtx>(I<T>({2.0, 1.0}), exec),
gko::initialize<Mtx>({I<T>({3.0, 2.0})}, exec),
gko::initialize<Mtx>(
{I<T>({-1.0, 1.0, 2.0}), I<T>({5.0, -3.0, 0.0})}, exec),
gko::initialize<Mtx>(
{I<T>({9.0, 4.0}), I<T>({6.0, -2.0}), I<T>({-3.0, 2.0})},
exec),
gko::initialize<Mtx>({I<T>({1.0, 0.0}), I<T>({0.0, 1.0})}, exec),
gko::initialize<Mtx>({I<T>({1.0, 0.0}), I<T>({0.0, 1.0})}, exec)},
identity{
gko::initialize<Mtx>({I<T>({1.0, 0.0}), I<T>({0.0, 1.0})}, exec)},
product{gko::initialize<Mtx>({I<T>({-9.0, -2.0}), I<T>({27.0, 26.0})},
exec)}
{}

std::shared_ptr<const gko::Executor> exec;
std::vector<std::shared_ptr<gko::LinOp>> coefficients;
std::vector<std::shared_ptr<gko::LinOp>> operators;
std::shared_ptr<Mtx> identity;
std::shared_ptr<Mtx> product;
};

TYPED_TEST_CASE(Composition, gko::test::ValueTypes);


TYPED_TEST(Composition, AppliesSingleToVector)
{
/*
cmp = [ -9 -2 ]
[ 27 26 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->product);
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(x), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({-13.0, 79.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesSingleLinearCombinationToVector)
{
/*
cmp = [ -9 -2 ]
[ 27 26 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->product);
auto alpha = gko::initialize<Mtx>({3.0}, this->exec);
auto beta = gko::initialize<Mtx>({-1.0}, this->exec);
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({-40.0, 235.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesToVector)
{
/*
Expand Down Expand Up @@ -105,4 +155,124 @@ TYPED_TEST(Composition, AppliesLinearCombinationToVector)
}


TYPED_TEST(Composition, AppliesLongerToVector)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -9 -2 ]
[ 1 ] [ 27 26 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(
this->operators[0], this->operators[1], this->product);
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(x), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({238.0, 119.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesLongerLinearCombinationToVector)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -9 -2 ]
[ 1 ] [ 27 26 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(
this->operators[0], this->operators[1], this->product);
auto alpha = gko::initialize<Mtx>({3.0}, this->exec);
auto beta = gko::initialize<Mtx>({-1.0}, this->exec);
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({713.0, 355.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesLongestToVector)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -1 1 2 ] * [ 9 4 ] * [ 1 0 ]^2
[ 1 ] [ 5 -3 0 ] [ 6 -2 ] [ 0 1 ]
[ -3 2 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->operators.begin(),
this->operators.end());
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(x), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({238.0, 119.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesLongestLinearCombinationToVector)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -1 1 2 ] * [ 9 4 ] * [ 1 0 ]^2
[ 1 ] [ 5 -3 0 ] [ 6 -2 ] [ 0 1 ]
[ -3 2 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->operators.begin(),
this->operators.end());
auto alpha = gko::initialize<Mtx>({3.0}, this->exec);
auto beta = gko::initialize<Mtx>({-1.0}, this->exec);
auto x = gko::initialize<Mtx>({1.0, 2.0}, this->exec);
auto res = clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({713.0, 355.0}), r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesLongestToVectorMultipleRhs)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -1 1 2 ] * [ 9 4 ] * [ 1 0 ]^2
[ 1 ] [ 5 -3 0 ] [ 6 -2 ] [ 0 1 ]
[ -3 2 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->operators.begin(),
this->operators.end());
auto x = clone(this->identity);
auto res = clone(x);

cmp->apply(lend(x), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({{54.0, 92.0}, {27.0, 46.0}}),
r<TypeParam>::value);
}


TYPED_TEST(Composition, AppliesLongestLinearCombinationToVectorMultipleRhs)
{
/*
cmp = [ 2 ] * [ 3 2 ] * [ -1 1 2 ] * [ 9 4 ] * [ 1 0 ]^2
[ 1 ] [ 5 -3 0 ] [ 6 -2 ] [ 0 1 ]
[ -3 2 ]
*/
using Mtx = typename TestFixture::Mtx;
auto cmp = gko::Composition<TypeParam>::create(this->operators.begin(),
this->operators.end());
auto alpha = gko::initialize<Mtx>({3.0}, this->exec);
auto beta = gko::initialize<Mtx>({-1.0}, this->exec);
auto x = clone(this->identity);
auto res = clone(x);

cmp->apply(lend(alpha), lend(x), lend(beta), lend(res));

GKO_ASSERT_MTX_NEAR(res, l({{161.0, 276.0}, {81.0, 137.0}}),
r<TypeParam>::value);
}
upsj marked this conversation as resolved.
Show resolved Hide resolved

yhmtsai marked this conversation as resolved.
Show resolved Hide resolved

} // namespace