Skip to content

Commit

Permalink
Make sparse_reshape work well with output of tf.shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 155902266
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed May 12, 2017
1 parent 7d785f1 commit 063bdbe
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
13 changes: 3 additions & 10 deletions tensorflow/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,27 +1330,20 @@ def _get_sparse_tensors(self,
pass


def _sparse_reshape(inputs, shape):
# Satisfies sparse_reshape assumptions such as dtype int64.
# shape is a list.
return sparse_ops.sparse_reshape(inputs,
math_ops.cast(shape, dtypes.int64))


def _create_categorical_column_weighted_sum(
column, builder, units, sparse_combiner, weight_collections, trainable):
"""Create a weighted sum of a categorical column for linear_model."""
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
trainable=trainable)
id_tensor = _sparse_reshape(sparse_tensors.id_tensor, [
id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
array_ops.shape(sparse_tensors.id_tensor)[0], -1
])
weight_tensor = sparse_tensors.weight_tensor
if weight_tensor is not None:
weight_tensor = _sparse_reshape(weight_tensor,
[array_ops.shape(weight_tensor)[0], -1])
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])

weight = variable_scope.get_variable(
name='weight',
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/kernel_tests/sparse_reshape_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ def testFeedSameShape(self):
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)

def testWorksWellWithTfShape(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
input_val = self._SparseTensorValue_5x6()
shape = array_ops.shape(sp_input) # tf.shape generates int32 output
sp_output = sparse_ops.sparse_reshape(sp_input, shape)

output_val = sess.run(sp_output, {sp_input: input_val})
self.assertAllEqual(output_val.indices, input_val.indices)
self.assertAllEqual(output_val.values, input_val.values)
self.assertAllEqual(output_val.dense_shape, input_val.dense_shape)

def testFeedSameShapeWithInferredDim(self):
with self.test_session(use_gpu=False) as sess:
sp_input = self._SparseTensorPlaceholder()
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/ops/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def sparse_reshape(sp_input, shape, name=None):
number of elements than `sp_input`.
"""
sp_input = _convert_to_sparse_tensor(sp_input)
shape = ops.convert_to_tensor(shape, dtype=dtypes.int64)
shape = math_ops.cast(shape, dtype=dtypes.int64)

with ops.name_scope(name, "SparseReshape", [sp_input]) as name:
reshaped_ind, reshaped_shape = gen_sparse_ops._sparse_reshape(
Expand Down

0 comments on commit 063bdbe

Please sign in to comment.