Skip to content

Commit

Permalink
Enable tf.strided_slice GPU kernel for tf.complex64.
Browse files Browse the repository at this point in the history
Change: 147881489
  • Loading branch information
rryan authored and tensorflower-gardener committed Feb 17, 2017
1 parent 24b1cdd commit c0df7c6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 24 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/strided_slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ REGISTER_STRIDED_SLICE(bfloat16);
StridedSliceAssignOp<GPUDevice, type>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);

// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::StridedSliceAssign<GPUDevice, T, 7>; \
template struct functor::StridedSliceAssignScalar<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
DEFINE_GPU_KERNELS(int32);

#undef DEFINE_GPU_KERNELS
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/strided_slice_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ class HandleStridedSliceAssignCase<Device, T, 0> {

#if GOOGLE_CUDA
TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
TF_CALL_complex64(PREVENT_FOR_N_GPU);

TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
TF_CALL_complex64(DECLARE_FOR_N_GPU);
DECLARE_FOR_N_GPU(int32);
#endif // END GOOGLE_CUDA

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/strided_slice_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ static void BM_SliceFloat(int iters, int dim2) {

BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000);

static void BM_SliceComplex64(int iters, int dim2) {
SliceHelper<std::complex<float>>(iters, dim2);
}

BENCHMARK(BM_SliceComplex64)->Arg(100)->Arg(1000)->Arg(10000);

static void BM_SliceBFloat16(int iters, int dim2) {
SliceHelper<bfloat16>(iters, dim2);
}
Expand Down
51 changes: 27 additions & 24 deletions tensorflow/python/kernel_tests/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,15 @@ def eval_if_tensor(x):
return tensor


STRIDED_SLICE_TYPES = [dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8,
dtypes.float32, dtypes.float64, dtypes.complex64]


class StridedSliceTest(test_util.TensorFlowTestCase):
"""Test the strided slice operation with variants of slices."""

def test_basic_slice(self):
for tensor_type in [
dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, dtypes.float32,
dtypes.float64
]:
for tensor_type in STRIDED_SLICE_TYPES:
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu):
checker = StridedSliceChecker(
Expand All @@ -463,7 +464,7 @@ def test_basic_slice(self):
_ = checker[-2::-1, :, ::2]

# Check rank-0 examples
checker2 = StridedSliceChecker(self, 5, tensor_type=dtypes.int32)
checker2 = StridedSliceChecker(self, 5, tensor_type=tensor_type)
_ = checker2[None]
_ = checker2[...]
_ = checker2[tuple()]
Expand Down Expand Up @@ -847,26 +848,28 @@ def testInvalidSlice(self):
sess.run(bar)

def testSliceAssign(self):
checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]])
# Check if equal
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Check trivial (1,1) shape tensor
checker[1:2, 1:2] = [[666]]
# shrinks shape changes
checker[1:2, 1] = [666]
checker[1, 1:2] = [666]
checker[1, 1] = 666
# newaxis shape changes
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
# shrink and newaxis
checker[None, None, 0, 0:1] = [[[999]]]
# Non unit strides
checker[::1, ::-2] = [[33, 333], [44, 444]]
# degenerate interval
checker[8:10, 0] = []
checker[8:10, 8:10] = [[]]
for dtype in STRIDED_SLICE_TYPES:
checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
tensor_type=dtype)
# Check if equal
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Check trivial (1,1) shape tensor
checker[1:2, 1:2] = [[66]]
# shrinks shape changes
checker[1:2, 1] = [66]
checker[1, 1:2] = [66]
checker[1, 1] = 66
# newaxis shape changes
checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]]
# shrink and newaxis
checker[None, None, 0, 0:1] = [[[99]]]
# Non unit strides
checker[::1, ::-2] = [[3, 33], [4, 44]]
# degenerate interval
checker[8:10, 0] = []
checker[8:10, 8:10] = [[]]
# Assign vector to scalar (rank-0) using newaxis
checker2 = StridedSliceAssignChecker(self, 2225)
checker2 = StridedSliceAssignChecker(self, 222)
checker2[()] = 6 # no indices
checker2[...] = 6 # ellipsis
checker2[None] = [6] # new axis
Expand Down

0 comments on commit c0df7c6

Please sign in to comment.