Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix remove Cast fuse (#21086)
Browse files Browse the repository at this point in the history
* Fix remove Cast fuse

* Fix Reset()

* Fix review

* Fix sanity
  • Loading branch information
bartekkuncer committed Jul 15, 2022
1 parent f6d1ed1 commit cf15e0a
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 {
}

void Reset() override {
status_ = kFail;
status_ = kExpand;
castDtype = -1;
}
};
Expand All @@ -105,7 +105,7 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
SgDNNLRemoveCastsProperty() {}

static SubgraphPropertyPtr Create() {
static const std::string& name = "Remove casts optimization pass";
static const std::string& name = "Remove Casts optimization pass";
auto property = std::make_shared<SgDNNLRemoveCastsProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
Expand Down Expand Up @@ -137,6 +137,14 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
auto selector = std::make_shared<SgDNNLRemoveCastsSelector>();
return selector;
}

void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* output_entries) const override {
// Connect all extern output entries to output[0]
for (size_t i = 0; i < output_entries->size(); ++i) {
*output_entries->at(i) = nnvm::NodeEntry{subgraph_node, 0, 0};
}
}
};

} // namespace op
Expand Down

0 comments on commit cf15e0a

Please sign in to comment.