From dc71419e8695e81b07c6e1db9e87d3073dc2e6d3 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Fri, 2 Feb 2024 14:25:48 -0800 Subject: [PATCH] Change to call if type_spec not provided Signed-off-by: Cheng Su --- python/ray/data/iterator.py | 52 ++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index e7083cf9f90316..4965419261dfba 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -776,32 +776,45 @@ def to_tf( A ``tf.data.Dataset`` that yields inputs and targets. """ # noqa: E501 - from ray.air._internal.tensorflow_utils import convert_ndarray_to_tf_tensor + from ray.air._internal.tensorflow_utils import ( + convert_ndarray_to_tf_tensor, + get_type_spec, + ) try: import tensorflow as tf except ImportError: raise ValueError("tensorflow must be installed!") + def validate_column(column: str) -> None: + if column not in valid_columns: + raise ValueError( + f"You specified '{column}' in `feature_columns` or " + f"`label_columns`, but there's no column named '{column}' in the " + f"dataset. Valid column names are: {valid_columns}." + ) + + def validate_columns(columns: Union[str, List]) -> None: + if isinstance(columns, list): + for column in columns: + validate_column(column) + else: + validate_column(columns) + def convert_batch_to_tensors( batch: Dict[str, np.ndarray], *, columns: Union[str, List[str]], - type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]] = None, + type_spec: Union[tf.TypeSpec, Dict[str, tf.TypeSpec]], ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]: if isinstance(columns, str): return convert_ndarray_to_tf_tensor(batch[columns], type_spec=type_spec) - else: - tensors = {} - for column in columns: - if type_spec is not None: - column_type_spec = type_spec[column] - else: - column_type_spec = None - tensors[column] = convert_ndarray_to_tf_tensor( - batch[column], type_spec=column_type_spec - ) - return tensors + return { + column: convert_ndarray_to_tf_tensor( + batch[column], type_spec=type_spec[column] + ) + for column in columns + } def generator(): for batch in self.iter_batches( @@ -821,13 +834,16 @@ def generator(): ) yield features, labels - if feature_type_spec is not None and label_type_spec is not None: - output_signature = (feature_type_spec, label_type_spec) - else: - output_signature = None + if feature_type_spec is None or label_type_spec is None: + schema = self.schema() + valid_columns = schema.names + validate_columns(feature_columns) + validate_columns(label_columns) + feature_type_spec = get_type_spec(schema, columns=feature_columns) + label_type_spec = get_type_spec(schema, columns=label_columns) dataset = tf.data.Dataset.from_generator( - generator, output_signature=output_signature + generator, output_signature=(feature_type_spec, label_type_spec) ) options = tf.data.Options()