diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 17a4dab3b..b7570bdeb 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -417,6 +417,24 @@ Stmt LowererImpl::lowerForall(Forall forall) Expr recoveredValue = provGraph.recoverVariable(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators); taco_iassert(indexVarToExprMap.count(varToRecover)); recoverySteps.push_back(VarDecl::make(indexVarToExprMap[varToRecover], recoveredValue)); + + // After we've recovered this index variable, some iterators are now + // accessible for use when declaring locator access variables. So, generate + // the accessors for those locator variables as part of the recovery process. + // This is necessary after a fuse transformation, for example: If we fuse + // two index variables (i, j) into f, then after we've generated the loop for + // f, all locate accessors for i and j are now available for use. + std::vector itersForVar; + for (auto& iters : iterators.levelIterators()) { + // Collect all level iterators that have locate and iterate over + // the recovered index variable. + if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) { + itersForVar.push_back(iters.second); + } + } + // Finally, declare all of the collected iterators' position access variables. + recoverySteps.push_back(this->declLocatePosVars(itersForVar)); + // place underived guard std::vector iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators); if (forallNeedsUnderivedGuards && underivedBounds.count(varToRecover) && @@ -2275,7 +2293,6 @@ Stmt LowererImpl::declLocatePosVars(vector locators) { if (locateIterator.isLeaf()) { break; } - locateIterator = locateIterator.getChild(); } while (accessibleIterators.contains(locateIterator)); } diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index 0fa117be3..a6f32d06d 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -72,6 +72,56 @@ TEST(scheduling, splitIndexStmt) { ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt())); } +TEST(scheduling, fuseDenseLoops) { + auto dim = 4; + Tensor A("A", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor B("B", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor expected("expected", {dim, dim, dim}, {Dense, Dense, Dense}); + IndexVar f("f"), g("g"); + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + A.insert({i, j, k}, i + j + k); + B.insert({i, j, k}, i + j + k); + expected.insert({i, j, k}, 2 * (i + j + k)); + } + } + } + A.pack(); + B.pack(); + expected.pack(); + + // Helper function to evaluate the target statement and verify the results. + // It takes in a function that applies some scheduling transforms to the + // input IndexStmt, and applies to the point-wise tensor addition below. + // The test is structured this way as TACO does its best to avoid re-compilation + // whenever possible. I.e. changing the stmt that a tensor is compiled with + // doesn't cause compilation to occur again. + auto testFn = [&](IndexStmt modifier(IndexStmt)) { + Tensor C("C", {dim, dim, dim}, {Dense, Dense, Dense}); + C(i, j, k) = A(i, j, k) + B(i, j, k); + auto stmt = C.getAssignment().concretize(); + C.compile(modifier(stmt)); + C.evaluate(); + ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl; + }; + + // First, a sanity check with no transformations. + testFn([](IndexStmt stmt) { return stmt; }); + // Next, fuse the outer two loops. This tests the original bug in #355. + testFn([](IndexStmt stmt) { + IndexVar f("f"); + return stmt.fuse(i, j, f); + }); + // Lastly, fuse all of the loops into a single loop. This ensures that + // locators with a chain of ancestors have all of their dependencies + // generated in a valid ordering. + testFn([](IndexStmt stmt) { + IndexVar f("f"), g("g"); + return stmt.fuse(i, j, f).fuse(f, k, g); + }); +} + TEST(scheduling, lowerDenseMatrixMul) { Tensor A("A", {4, 4}, {Dense, Dense}); Tensor B("B", {4, 4}, {Dense, Dense}); diff --git a/test/tests-transpose.cpp b/test/tests-transpose.cpp index b97e1936c..0a852e0bb 100644 --- a/test/tests-transpose.cpp +++ b/test/tests-transpose.cpp @@ -76,3 +76,26 @@ TEST(DISABLED_lower, transpose3) { &reason)); ASSERT_EQ(error::expr_transposition, reason); } + +// denseIterationTranspose tests a dense iteration that contain a transposition +// of one of the tensors. +TEST(lower, denseIterationTranspose) { + auto dim = 4; + Tensor A("A", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor B("B", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor C("C", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor expected("expected", {dim, dim, dim}, {Dense, Dense, Dense}); + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + A.insert({i, j, k}, i + j + k); + B.insert({i, j, k}, i + j + k); + expected.insert({i, j, k}, 2 * (i + j + k)); + } + } + } + A.pack(); B.pack(); expected.pack(); + C(i, j, k) = A(i, j, k) + B(k, j, i); + C.evaluate(); + ASSERT_TRUE(equals(C, expected)); +}