Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] fix style in ir_mutator and ir_visitor #4561

Merged
merged 1 commit into from
Dec 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class IRTransformer final : public IRMutator {
}

private:
template<typename T>
template <typename T>
T MutateInternal(T node) {
if (only_enable_.size() &&
!only_enable_.count(node->type_index())) {
Expand Down Expand Up @@ -89,11 +89,11 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*)
static FMutateStmt inst; return inst;
}

inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); });
inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator* m) {
return UpdateArray(arr, [&m](const Expr& e) { return m->Mutate(e); });
}

inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator* m) {
std::vector<IterVar> new_dom(rdom.size());
bool changed = false;
for (size_t i = 0; i < rdom.size(); i++) {
Expand Down Expand Up @@ -133,7 +133,7 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const LetStmt* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Stmt body = this->Mutate(op->body);
if (value.same_as(op->value) &&
Expand All @@ -144,7 +144,7 @@ Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const For *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const For* op, const Stmt& s) {
Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body);
Expand Down Expand Up @@ -185,7 +185,7 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const IfThenElse* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
Expand All @@ -201,7 +201,7 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Store* op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
Expand Down Expand Up @@ -233,7 +233,7 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
Expand Down Expand Up @@ -263,7 +263,7 @@ Stmt IRMutator::Mutate_(const Prefetch* op, const Stmt& s) {
Expr old_extent = op->bounds[i]->extent;
Expr new_min = m->Mutate(old_min);
Expr new_extent = m->Mutate(old_extent);
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_min.same_as(old_min)) bounds_changed = true;
if (!new_extent.same_as(old_extent)) bounds_changed = true;
new_bounds.push_back(
Range::make_by_min_extent(new_min, new_extent));
Expand All @@ -288,7 +288,7 @@ Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const AssertStmt* op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Expr message = this->Mutate(op->message);
Stmt body = this->Mutate(op->body);
Expand All @@ -302,7 +302,7 @@ Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const ProducerConsumer* op, const Stmt& s) {
Stmt body = this->Mutate(op->body);
if (body.same_as(op->body)) {
return s;
Expand All @@ -311,7 +311,7 @@ Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Evaluate* op, const Stmt& s) {
Expr v = this->Mutate(op->value);
if (v.same_as(op->value)) {
return s;
Expand All @@ -320,7 +320,7 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
Stmt IRMutator::Mutate_(const Free* op, const Stmt& s) {
return s;
}

Expand Down Expand Up @@ -348,11 +348,11 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return m->Mutate_(static_cast<const OP*>(node.get()), e); \
})

Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
Expr IRMutator::Mutate_(const Variable* op, const Expr& e) {
return e;
}

Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr IRMutator::Mutate_(const Load* op, const Expr& e) {
Expr index = this->Mutate(op->index);
Expr pred = this->Mutate(op->predicate);
if (index.same_as(op->index) && pred.same_as(op->predicate)) {
Expand All @@ -362,7 +362,7 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
Expr IRMutator::Mutate_(const Let* op, const Expr& e) {
Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
Expand Down Expand Up @@ -413,8 +413,8 @@ DEFINE_BIOP_EXPR_MUTATE_(GE)
DEFINE_BIOP_EXPR_MUTATE_(And)
DEFINE_BIOP_EXPR_MUTATE_(Or)

Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Expr IRMutator::Mutate_(const Reduce* op, const Expr& e) {
Array<IterVar> new_axis = MutateIterVarArr(op->axis, this);
Array<Expr> new_source = MutateArray(op->source, this);
Expr new_cond = this->Mutate(op->condition);
if (op->axis.same_as(new_axis) &&
Expand All @@ -427,7 +427,7 @@ Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
Expr IRMutator::Mutate_(const Cast* op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
Expand All @@ -436,7 +436,7 @@ Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
Expr IRMutator::Mutate_(const Not* op, const Expr& e) {
Expr a = this->Mutate(op->a);
if (a.same_as(op->a)) {
return e;
Expand All @@ -445,7 +445,7 @@ Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
Expr IRMutator::Mutate_(const Select* op, const Expr& e) {
Expr cond = this->Mutate(op->condition);
Expr t = this->Mutate(op->true_value);
Expr f = this->Mutate(op->false_value);
Expand All @@ -458,7 +458,7 @@ Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
Expr IRMutator::Mutate_(const Ramp* op, const Expr& e) {
Expr base = this->Mutate(op->base);
Expr stride = this->Mutate(op->stride);
if (base.same_as(op->base) &&
Expand All @@ -469,7 +469,7 @@ Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
Expr IRMutator::Mutate_(const Broadcast* op, const Expr& e) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
Expand All @@ -478,7 +478,7 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Shuffle *op, const Expr& e) {
Expr IRMutator::Mutate_(const Shuffle* op, const Expr& e) {
auto new_vec = MutateArray(op->vectors, this);
if (new_vec.same_as(op->vectors)) {
return e;
Expand Down
41 changes: 20 additions & 21 deletions src/pass/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class IRApplyVisit : public IRVisitor {
std::unordered_set<const Node*> visited_;
};


void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
Expand All @@ -68,7 +67,7 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {

void IRVisitor::Visit_(const Variable* op) {}

void IRVisitor::Visit_(const LetStmt *op) {
void IRVisitor::Visit_(const LetStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}
Expand All @@ -78,14 +77,14 @@ void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->body);
}

void IRVisitor::Visit_(const For *op) {
void IRVisitor::Visit_(const For* op) {
IRVisitor* v = this;
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
}

void IRVisitor::Visit_(const Allocate *op) {
void IRVisitor::Visit_(const Allocate* op) {
IRVisitor* v = this;
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
Expand All @@ -97,33 +96,33 @@ void IRVisitor::Visit_(const Allocate *op) {
}
}

void IRVisitor::Visit_(const Load *op) {
void IRVisitor::Visit_(const Load* op) {
this->Visit(op->index);
this->Visit(op->predicate);
}

void IRVisitor::Visit_(const Store *op) {
void IRVisitor::Visit_(const Store* op) {
this->Visit(op->value);
this->Visit(op->index);
this->Visit(op->predicate);
}

void IRVisitor::Visit_(const IfThenElse *op) {
void IRVisitor::Visit_(const IfThenElse* op) {
this->Visit(op->condition);
this->Visit(op->then_case);
if (op->else_case.defined()) {
this->Visit(op->else_case);
}
}

void IRVisitor::Visit_(const Let *op) {
void IRVisitor::Visit_(const Let* op) {
this->Visit(op->value);
this->Visit(op->body);
}

void IRVisitor::Visit_(const Free* op) {}

void IRVisitor::Visit_(const Call *op) {
void IRVisitor::Visit_(const Call* op) {
VisitArray(op->args, this);
}

Expand Down Expand Up @@ -171,38 +170,38 @@ void IRVisitor::Visit_(const Select* op) {
this->Visit(op->false_value);
}

void IRVisitor::Visit_(const Ramp *op) {
void IRVisitor::Visit_(const Ramp* op) {
this->Visit(op->base);
this->Visit(op->stride);
}

void IRVisitor::Visit_(const Shuffle *op) {
for (const auto &elem : op->indices)
void IRVisitor::Visit_(const Shuffle* op) {
for (const auto& elem : op->indices)
this->Visit(elem);
for (const auto &elem : op->vectors)
for (const auto& elem : op->vectors)
this->Visit(elem);
}

void IRVisitor::Visit_(const Broadcast *op) {
void IRVisitor::Visit_(const Broadcast* op) {
this->Visit(op->value);
}

void IRVisitor::Visit_(const AssertStmt *op) {
void IRVisitor::Visit_(const AssertStmt* op) {
this->Visit(op->condition);
this->Visit(op->message);
this->Visit(op->body);
}

void IRVisitor::Visit_(const ProducerConsumer *op) {
void IRVisitor::Visit_(const ProducerConsumer* op) {
this->Visit(op->body);
}

void IRVisitor::Visit_(const Provide *op) {
void IRVisitor::Visit_(const Provide* op) {
VisitArray(op->args, this);
this->Visit(op->value);
}

void IRVisitor::Visit_(const Realize *op) {
void IRVisitor::Visit_(const Realize* op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
Expand All @@ -212,19 +211,19 @@ void IRVisitor::Visit_(const Realize *op) {
this->Visit(op->condition);
}

void IRVisitor::Visit_(const Prefetch *op) {
void IRVisitor::Visit_(const Prefetch* op) {
for (size_t i = 0; i < op->bounds.size(); i++) {
this->Visit(op->bounds[i]->min);
this->Visit(op->bounds[i]->extent);
}
}

void IRVisitor::Visit_(const Block *op) {
void IRVisitor::Visit_(const Block* op) {
this->Visit(op->first);
this->Visit(op->rest);
}

void IRVisitor::Visit_(const Evaluate *op) {
void IRVisitor::Visit_(const Evaluate* op) {
this->Visit(op->value);
}

Expand Down