Skip to content

Commit

Permalink
FeatureColumn: Reshape sparse tensors for linear models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 155890727
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed May 12, 2017
1 parent 0489108 commit 1f50f3d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
19 changes: 17 additions & 2 deletions tensorflow/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,13 +1330,28 @@ def _get_sparse_tensors(self,
pass


def _sparse_reshape(inputs, shape):
# Satisfies sparse_reshape assumptions such as dtype int64.
# shape is a list.
return sparse_ops.sparse_reshape(inputs,
math_ops.cast(shape, dtypes.int64))


def _create_categorical_column_weighted_sum(
column, builder, units, sparse_combiner, weight_collections, trainable):
"""Create a weighted sum of a categorical column for linear_model."""
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
trainable=trainable)
id_tensor = _sparse_reshape(sparse_tensors.id_tensor, [
array_ops.shape(sparse_tensors.id_tensor)[0], -1
])
weight_tensor = sparse_tensors.weight_tensor
if weight_tensor is not None:
weight_tensor = _sparse_reshape(weight_tensor,
[array_ops.shape(weight_tensor)[0], -1])

weight = variable_scope.get_variable(
name='weight',
shape=(column._num_buckets, units), # pylint: disable=protected-access
Expand All @@ -1345,8 +1360,8 @@ def _create_categorical_column_weighted_sum(
collections=weight_collections)
return _safe_embedding_lookup_sparse(
weight,
sparse_tensors.id_tensor,
sparse_weights=sparse_tensors.weight_tensor,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
name='weighted_sum')

Expand Down
21 changes: 21 additions & 0 deletions tensorflow/python/feature_column/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,27 @@ def test_dense_multi_dimension(self):
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())

def test_sparse_multi_rank(self):
wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
predictions = fc.linear_model(features, [wire_cast])
wire_cast_var = get_linear_model_column_var(wire_cast)
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
np.zeros((2, 1)),
predictions.eval(feed_dict={wire_tensor: wire_value}))
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
self.assertAllClose(
[[1010.], [11000.]],
predictions.eval(feed_dict={wire_tensor: wire_value}))

def test_sparse_combiner(self):
wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
Expand Down

0 comments on commit 1f50f3d

Please sign in to comment.