Skip to content

Commit

Permalink
Merge pull request tensor-compiler#339 from RawnH/workspace_reuse
Browse files Browse the repository at this point in the history
Workspace reuse
  • Loading branch information
stephenchouca committed Jan 13, 2021
2 parents f051a8f + 8471869 commit cb4731d
Show file tree
Hide file tree
Showing 18 changed files with 698 additions and 51 deletions.
4 changes: 4 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,10 @@ std::vector<TensorVar> getArguments(IndexStmt stmt);
/// Returns the temporaries in the index statement, in the order they appear.
std::vector<TensorVar> getTemporaries(IndexStmt stmt);

// [Olivia]
/// Returns the temporaries in the index statement, in the order they appear.
std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt);

/// Returns the tensors in the index statement.
std::vector<TensorVar> getTensorVars(IndexStmt stmt);

Expand Down
10 changes: 9 additions & 1 deletion include/taco/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ enum class IRNodeType {
BlankLine,
Print,
GetProperty,
Break
Break,
Sort
};

enum class TensorProperty {
Expand Down Expand Up @@ -725,6 +726,13 @@ struct Break : public StmtNode<Break> {
static const IRNodeType _type_info = IRNodeType::Break;
};

struct Sort : public StmtNode<Sort> {
std::vector<Expr> args;
static Stmt make(std::vector<Expr> args);

static const IRNodeType _type_info = IRNodeType::Sort;
};

/** A print statement.
* Takes in a printf-style format string and Exprs to pass
* for the values.
Expand Down
1 change: 1 addition & 0 deletions include/taco/ir/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class IRPrinter : public IRVisitorStrict {
virtual void visit(const Break*);
virtual void visit(const Print*);
virtual void visit(const GetProperty*);
virtual void visit(const Sort*);

std::ostream &stream;
int indent;
Expand Down
1 change: 1 addition & 0 deletions include/taco/ir/ir_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class IRRewriter : public IRVisitorStrict {
virtual void visit(const Break* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort *op);
};

}}
Expand Down
3 changes: 3 additions & 0 deletions include/taco/ir/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct BlankLine;
struct Break;
struct Print;
struct GetProperty;
struct Sort;

/// Extend this class to visit every node in the IR.
class IRVisitorStrict {
Expand Down Expand Up @@ -98,6 +99,7 @@ class IRVisitorStrict {
virtual void visit(const Break*) = 0;
virtual void visit(const Print*) = 0;
virtual void visit(const GetProperty*) = 0;
virtual void visit(const Sort*) = 0;
};


Expand Down Expand Up @@ -151,6 +153,7 @@ class IRVisitor : public IRVisitorStrict {
virtual void visit(const Break* op);
virtual void visit(const Print* op);
virtual void visit(const GetProperty* op);
virtual void visit(const Sort* op);
};

}}
Expand Down
38 changes: 36 additions & 2 deletions include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class LowererImpl : public util::Uncopyable {
std::set<Access> reducedAccesses,
ir::Stmt recoveryStmt);

/// Lower a forall that iterates over all the coordinates in the forall index
/// var's dimension, and locates tensor positions from the locate iterators.
virtual ir::Stmt lowerForallDenseAcceleration(Forall forall,
std::vector<Iterator> locaters,
std::vector<Iterator> inserters,
std::vector<Iterator> appenders,
std::set<Access> reducedAccesses,
ir::Stmt recoveryStmt);


/// Lower a forall that iterates over the coordinates in the iterator, and
/// locates tensor positions from the locate iterators.
virtual ir::Stmt lowerForallCoordinate(Forall forall, Iterator iterator,
Expand Down Expand Up @@ -333,17 +343,29 @@ class LowererImpl : public util::Uncopyable {
ir::Stmt codeToInitializeIteratorVars(std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coord, IndexVar coordinateVar);
ir::Stmt codeToInitializeIteratorVar(Iterator iterator, std::vector<Iterator> iterators, std::vector<Iterator> rangers, std::vector<Iterator> mergers, ir::Expr coordinate, IndexVar coordinateVar);

/// Returns true iff the temporary used in the where statement is dense and sparse iteration over that
/// temporary can be automaticallty supported by the compiler.
bool canAccelerateDenseTemp(Where where);

/// Initializes a temporary workspace
std::vector<ir::Stmt> codeToInitializeTemporary(Where where);

/// Gets the size of a temporary tensorVar in the where statement
ir::Expr getTemporarySize(Where where);

/// Initializes helper arrays to give dense workspaces sparse acceleration
std::vector<ir::Stmt> codeToInitializeDenseAcceleratorArrays(Where where);

/// Recovers a derived indexvar from an underived variable.
ir::Stmt codeToRecoverDerivedIndexVar(IndexVar underived, IndexVar indexVar, bool emitVarDecl);

/// Conditionally increment iterator position variables.
/// Conditionally increment iterator position variables.
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
std::vector<Iterator> iterators, std::vector<Iterator> mergers);

ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);

/// Create statements to append coordinate to result modes.
/// Create statements to append coordinate to result modes.
ir::Stmt appendCoordinate(std::vector<Iterator> appenders, ir::Expr coord);

/// Create statements to append positions to result modes.
Expand All @@ -363,6 +385,9 @@ class LowererImpl : public util::Uncopyable {
int markAssignsAtomicDepth = 0;
ParallelUnit atomicParallelUnit;

/// Map used to hoist temporary workspace initialization
std::map<Forall, Where> temporaryInitialization;

/// Map from tensor variables in index notation to variables in the IR
std::map<TensorVar, ir::Expr> tensorVars;

Expand All @@ -371,6 +396,15 @@ class LowererImpl : public util::Uncopyable {
};
std::map<TensorVar, TemporaryArrays> temporaryArrays;

/// Map form temporary to indexList var if accelerating dense workspace
std::map<TensorVar, ir::Expr> tempToIndexList;

/// Map form temporary to indexListSize if accelerating dense workspace
std::map<TensorVar, ir::Expr> tempToIndexListSize;

/// Map form temporary to bitGuard var if accelerating dense workspace
std::map<TensorVar, ir::Expr> tempToBitGuard;

/// Map from result tensors to variables tracking values array capacity.
std::map<ir::Expr, ir::Expr> capacityVars;

Expand Down
3 changes: 2 additions & 1 deletion src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,15 @@ class CodeGen_C::FindVars : public IRVisitor {

virtual void visit(const Var *op) {
if (varMap.count(op) == 0) {
varMap[op] = codeGen->genUniqueName(op->name);
varMap[op] = op->is_ptr? op->name : codeGen->genUniqueName(op->name);
}
}

virtual void visit(const VarDecl *op) {
if (!util::contains(localVars, op->var)) {
localVars.push_back(op->var);
}
op->var.accept(this);
op->rhs.accept(this);
}

Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class CodeGen_CUDA::FindVars : public IRVisitor {

virtual void visit(const Var *op) {
if (varMap.count(op) == 0 && !inBlock) {
varMap[op] = codeGen->genUniqueName(op->name);
varMap[op] = op->is_ptr? op->name : codeGen->genUniqueName(op->name);
}
}

Expand Down
33 changes: 32 additions & 1 deletion src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2118,8 +2118,23 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) {
return;
}

// Handles derived vars on RHS with underived vars on LHS.
Assignment assignPtrWrapper = Assignment(op);
std::vector<IndexVar> possibleReductionVars = assignPtrWrapper.getReductionVars();
std::vector<IndexVar> freeVars = assignPtrWrapper.getFreeVars();
std::set<IndexVar> freeVarsSet(freeVars.begin(), freeVars.end());

int numReductionVars = 0;
for(const auto& reductionVar : possibleReductionVars) {
std::vector<IndexVar> underivedParents = provGraph.getUnderivedAncestors(reductionVar);
for(const auto& parent : underivedParents) {
if(!util::contains(freeVarsSet, parent)) {
++numReductionVars;
}
}
}
// allow introducing precompute loops where we set a temporary to values instead of +=
if (Assignment(op).getReductionVars().size() > 0 &&
if (numReductionVars > 0 &&
op->op == IndexExpr() && !inWhereProducer) {
*reason = "reduction variables in concrete notation must be dominated "
"by compound assignments (such as +=)";
Expand Down Expand Up @@ -2342,6 +2357,22 @@ vector<TensorVar> getArguments(IndexStmt stmt) {
return result;
}

std::map<Forall, Where> getTemporaryLocations(IndexStmt stmt) {
map<Forall, Where> temporaryLocs;
Forall f = Forall();
match(stmt,
function<void(const ForallNode*, Matcher*)>([&](const ForallNode* op, Matcher* ctx) {
f = op;
ctx->match(op->stmt);
}),
function<void(const WhereNode*, Matcher*)>([&](const WhereNode* w, Matcher* ctx) {
if (!(f == IndexStmt()))
temporaryLocs.insert({f, Where(w)});
})
);
return temporaryLocs;
}

std::vector<TensorVar> getTemporaries(IndexStmt stmt) {
vector<TensorVar> temporaries;
bool firstAssignment = true;
Expand Down
7 changes: 6 additions & 1 deletion src/index_notation/index_notation_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ void IndexNotationPrinter::visit(const NegNode* op) {
Precedence precedence = Precedence::NEG;
bool parenthesize = precedence > parentPrecedence;
parentPrecedence = precedence;
os << "-";
if(op->getDataType().isBool()) {
os << "!";
} else {
os << "-";
}

if (parenthesize) {
os << "(";
}
Expand Down
10 changes: 6 additions & 4 deletions src/index_notation/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,17 +1114,19 @@ static IndexStmt optimizeSpMM(IndexStmt stmt) {
return stmt;
}

// I think we can to linear combination of rows as long as there are no permutations in the format and the
// level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue.
TensorVar B = Baccess.getTensorVar();
if (B.getFormat().getModeFormats()[0].getName() != "dense" ||
B.getFormat().getModeFormats()[1].getName() != "compressed" ||
if (!B.getFormat().getModeFormats()[0].isOrdered() ||
!B.getFormat().getModeFormats()[1].isOrdered() ||
B.getFormat().getModeOrdering()[0] != 0 ||
B.getFormat().getModeOrdering()[1] != 1) {
return stmt;
}

TensorVar C = Caccess.getTensorVar();
if (C.getFormat().getModeFormats()[0].getName() != "dense" ||
C.getFormat().getModeFormats()[1].getName() != "compressed" ||
if (!C.getFormat().getModeFormats()[0].isOrdered() ||
!C.getFormat().getModeFormats()[1].isOrdered() ||
C.getFormat().getModeOrdering()[0] != 0 ||
C.getFormat().getModeOrdering()[1] != 1) {
return stmt;
Expand Down
9 changes: 9 additions & 0 deletions src/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,13 @@ Expr GetProperty::make(Expr tensor, TensorProperty property, int mode,
return gp;
}

// Sort
Stmt Sort::make(std::vector<Expr> args) {
Sort* sort = new Sort;
sort->args = args;
return sort;
}


// GetProperty
Expr GetProperty::make(Expr tensor, TensorProperty property, int mode) {
Expand Down Expand Up @@ -953,6 +960,8 @@ template<> void StmtNode<Print>::accept(IRVisitorStrict *v)
const { v->visit((const Print*)this); }
template<> void ExprNode<GetProperty>::accept(IRVisitorStrict *v)
const { v->visit((const GetProperty*)this); }
template<> void StmtNode<Sort>::accept(IRVisitorStrict *v)
const { v->visit((const Sort*)this); }

// printing methods
std::ostream& operator<<(std::ostream& os, const Stmt& stmt) {
Expand Down
16 changes: 15 additions & 1 deletion src/ir/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ void IRPrinter::visit(const Var* op) {
}

void IRPrinter::visit(const Neg* op) {
stream << "-";
if(op->type.isBool()) {
stream << "!";
} else {
stream << "-";
}
parentPrecedence = Precedence::NEG;
op->a.accept(this);
}
Expand Down Expand Up @@ -575,6 +579,16 @@ void IRPrinter::visit(const GetProperty* op) {
stream << op->name;
}

void IRPrinter::visit(const Sort* op) {
doIndent();
stream << "qsort(";
parentPrecedence = Precedence::CALL;
acceptJoin(this, stream, op->args, ", ");
stream << ");";
stream << endl;
}


void IRPrinter::resetNameCounters() {
// seed the unique names with all C99 keywords
// from: http:https://en.cppreference.com/w/c/keyword
Expand Down
18 changes: 18 additions & 0 deletions src/ir/ir_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,5 +479,23 @@ void IRRewriter::visit(const GetProperty* op) {
}
}

void IRRewriter::visit(const Sort* op) {
std::vector<Expr> args;
bool rewritten = false;
for (auto& arg : op->args) {
Expr rewrittenArg = rewrite(arg);
args.push_back(rewrittenArg);
if (rewrittenArg != arg) {
rewritten = true;
}
}
if (rewritten) {
stmt = Sort::make(args);
}
else {
stmt = op;
}
}


}}
5 changes: 5 additions & 0 deletions src/ir/ir_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,10 @@ void IRVisitor::visit(const Print* op) {
e.accept(this);
}

void IRVisitor::visit(const Sort* op) {
for (auto e: op->args)
e.accept(this);
}

} // namespace ir
} // namespace taco
1 change: 1 addition & 0 deletions src/lower/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "taco/ir/ir.h"
#include "taco/ir/simplify.h"
#include "ir/ir_generators.h"
#include "taco/ir/ir_printer.h"

#include "taco/lower/lowerer_impl.h"
#include "taco/lower/iterator.h"
Expand Down
Loading

0 comments on commit cb4731d

Please sign in to comment.