Skip to content

Commit

Permalink
Improve data processing in TensorForest. Extract name correctly when …
Browse files Browse the repository at this point in the history
…transforming, and support sparse categorical.

PiperOrigin-RevId: 160424501
  • Loading branch information
tensorflower-gardener committed Jun 28, 2017
1 parent ea39a29 commit 54a1e73
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions tensorflow/contrib/tensor_forest/python/ops/data_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def SerializeToString(self):
self.size)


def GetColumnName(column_key, col_num):
if isinstance(column_key, str):
return column_key
else:
return getattr(column_key, 'column_name', str(col_num))


def ParseDataTensorOrDict(data):
"""Return a tensor to use for input data.
Expand All @@ -119,14 +126,13 @@ def ParseDataTensorOrDict(data):
for k in sorted(data.keys()):
is_sparse = isinstance(data[k], sparse_tensor.SparseTensor)
if is_sparse:
# TODO(gilberth): support sparse categorical.
if data[k].dtype == dtypes.string:
logging.info('TensorForest does not support sparse categorical. '
'Transform it into a number with hash buckets.')
# TODO(gilberth): support sparse continuous.
if data[k].dtype == dtypes.float32:
logging.info('TensorForest does not support sparse continuous.')
continue
elif data_spec.sparse.size() == 0:
col_spec = data_spec.sparse.add()
col_spec.original_type = DATA_FLOAT
col_spec.original_type = DATA_CATEGORICAL
col_spec.name = 'all_sparse'
col_spec.size = -1
sparse_features.append(
Expand All @@ -136,7 +142,7 @@ def ParseDataTensorOrDict(data):
col_spec = data_spec.dense.add()

col_spec.original_type = DTYPE_TO_FTYPE[data[k].dtype]
col_spec.name = k
col_spec.name = GetColumnName(k, len(dense_features))
# the second dimension of get_shape should always be known.
shape = data[k].get_shape()
if len(shape) == 1:
Expand Down

0 comments on commit 54a1e73

Please sign in to comment.