Skip to content

Commit

Permalink
Switch contrib/rnn ops to use C++ shape functions.
Browse files Browse the repository at this point in the history
Change: 133193463
  • Loading branch information
tensorflower-gardener committed Sep 15, 2016
1 parent 9287c2f commit 6b104af
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 85 deletions.
44 changes: 43 additions & 1 deletion tensorflow/contrib/rnn/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])

load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_cc_test",
)
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)

py_library(
name = "rnn_py",
Expand Down Expand Up @@ -90,6 +98,40 @@ cuda_py_tests(
],
)

tf_cc_test(
name = "ops/gru_ops_test",
size = "small",
srcs = ["ops/gru_ops_test.cc"],
data = [":python/ops/_gru_ops.so"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_cc_test(
name = "ops/lstm_ops_test",
size = "small",
srcs = ["ops/lstm_ops_test.cc"],
data = [":python/ops/_lstm_ops.so"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/contrib/rnn/ops/gru_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using tensorflow::shape_inference::DimensionHandle;
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;

REGISTER_OP("GRUBlockCell")
.Attr("T: {float}")
Expand All @@ -27,6 +32,19 @@ REGISTER_OP("GRUBlockCell")
.Output("u: T")
.Output("c: T")
.Output("h: T")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, h_prev;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size = c->Dim(h_prev, 1);
ShapeHandle output = c->Matrix(batch_size, cell_size);
for (int i = 0; i < 4; ++i) {
c->set_output(i, output);
}
return tensorflow::Status::OK();
})
.Doc(R"doc(
Computes the GRU cell forward propagation for 1 time step.
Expand Down Expand Up @@ -92,6 +110,23 @@ REGISTER_OP("GRUBlockCellGrad")
.Output("d_h_prev: T")
.Output("d_c_bar: T")
.Output("d_r_bar_u_bar: T")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, h_prev, w_ru;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size = c->Dim(h_prev, 1);
DimensionHandle twice_cell_size = c->Dim(w_ru, 1);
ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size);

c->set_output(0, x);
c->set_output(1, batch_cell_shape);
c->set_output(2, batch_cell_shape);
c->set_output(3, c->Matrix(batch_size, twice_cell_size));
return tensorflow::Status::OK();
})
.Doc(R"doc(
Computes the GRU cell back-propagation for 1 time step.
Expand Down
63 changes: 63 additions & 0 deletions tensorflow/contrib/rnn/ops/gru_ops_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at
http:https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow {

class GruOpsTest : public ::testing::Test {
public:
static void SetUpTestCase() {
TF_Status* status = TF_NewStatus();
auto* lib = TF_LoadLibrary(
"tensorflow/contrib/rnn/python/ops/_gru_ops.so", status);
CHECK_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
TF_DeleteLibraryHandle(lib);
}
};

TEST_F(GruOpsTest, GRUBlockCell_ShapeFn) {
ShapeInferenceTestOp op("GRUBlockCell");

// Rank checks.
INFER_ERROR("must be rank 2", op, "[?];?;?;?;?;?");
INFER_ERROR("must be rank 2", op, "?;[?];?;?;?;?");

// Output
INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?,?];[?,?];[?,?]");
INFER_OK(op, "[?,?];[?,?];?;?;?;?",
"[d0_0,d1_1];[d0_0,d1_1];[d0_0,d1_1];[d0_0,d1_1]");
}

TEST_F(GruOpsTest, GRUBlockCellGrad_ShapeFn) {
ShapeInferenceTestOp op("GRUBlockCellGrad");

// Rank checks.
INFER_ERROR("must be rank 2", op, "[?];?;?;?;?;?;?;?;?;?");
INFER_ERROR("must be rank 2", op, "?;[?];?;?;?;?;?;?;?;?");
INFER_ERROR("must be rank 2", op, "?;?;[?];?;?;?;?;?;?;?");

// Output
INFER_OK(op, "?;?;?;?;?;?;?;?;?;?", "[?,?];[?,?];[?,?];[?,?]");
INFER_OK(op, "[?,?];[?,?];[?,?];?;?;?;?;?;?;?",
"in0;[d0_0,d1_1];[d0_0,d1_1];[d0_0,d2_1]");
}

} // namespace tensorflow
84 changes: 84 additions & 0 deletions tensorflow/contrib/rnn/ops/lstm_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {

using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

REGISTER_OP("LSTMBlockCell")
.Input("x: T")
.Input("cs_prev: T")
Expand All @@ -37,6 +42,19 @@ REGISTER_OP("LSTMBlockCell")
.Attr("cell_clip: float = 3.0")
.Attr("use_peephole: bool = false")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, cs_prev;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size = c->Dim(cs_prev, 1);
ShapeHandle output = c->Matrix(batch_size, cell_size);
for (int i = 0; i < 7; ++i) {
c->set_output(i, output);
}
return tensorflow::Status::OK();
})
.Doc(R"doc(
Computes the LSTM cell forward propagation for 1 time step.
Expand Down Expand Up @@ -98,6 +116,24 @@ REGISTER_OP("LSTMBlockCellGrad")
.Output("wco_grad: T")
.Attr("use_peephole: bool")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, cs_prev;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size = c->Dim(cs_prev, 1);
DimensionHandle cell_size_times_4;
TF_RETURN_IF_ERROR(c->Multiply(cell_size, 4, &cell_size_times_4));
ShapeHandle cell_size_vec = c->Vector(cell_size);

c->set_output(0, c->Matrix(batch_size, cell_size));
c->set_output(1, c->Matrix(batch_size, cell_size_times_4));
c->set_output(2, cell_size_vec);
c->set_output(3, cell_size_vec);
c->set_output(4, cell_size_vec);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Computes the LSTM cell backward propagation for 1 timestep.
Expand Down Expand Up @@ -141,6 +177,28 @@ REGISTER_OP("BlockLSTM")
.Attr("cell_clip: float = 3.0")
.Attr("use_peephole: bool = false")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, b;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));

DimensionHandle batch_size = c->Dim(x, 0);
DimensionHandle cell_size;
TF_RETURN_IF_ERROR(
c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));

int64 max_len;
TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len));

DCHECK_EQ(max_len * 7, c->num_outputs());
ShapeHandle output = c->Matrix(batch_size, cell_size);
for (int i = 0; i < max_len; ++i) {
for (int j = 0; j < 7; ++j) {
c->set_output(i * 7 + j, output);
}
}
return Status::OK();
})
.Doc(R"doc(
)doc");

Expand Down Expand Up @@ -174,6 +232,32 @@ REGISTER_OP("BlockLSTMGrad")
.Attr("max_len: int")
.Attr("use_peephole: bool")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
int64 max_len;
TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len));

ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + max_len), 2, &cs_prev));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2 + max_len), 2, &h_prev));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3 + max_len), 2, &w));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4 + max_len), 1, &wci));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5 + max_len), 1, &wco));
TF_RETURN_IF_ERROR(c->WithRank(c->input(6 + max_len), 1, &wcf));
TF_RETURN_IF_ERROR(c->WithRank(c->input(7 + max_len), 1, &b));

int out_idx = 0;
for (int i = 0; i < max_len; ++i) c->set_output(out_idx++, x);
c->set_output(out_idx++, cs_prev);
c->set_output(out_idx++, h_prev);
c->set_output(out_idx++, w);
c->set_output(out_idx++, wci);
c->set_output(out_idx++, wco);
c->set_output(out_idx++, wcf);
c->set_output(out_idx++, b);

return Status::OK();
})
.Doc(R"doc(
)doc");

Expand Down
Loading

0 comments on commit 6b104af

Please sign in to comment.