diff --git a/python/ray/data/tests/test_tf.py b/python/ray/data/tests/test_tf.py index 39b74458611907..43819614ce0fd1 100644 --- a/python/ray/data/tests/test_tf.py +++ b/python/ray/data/tests/test_tf.py @@ -31,6 +31,26 @@ def test_element_spec_type(self): assert isinstance(feature_spec, tf.TypeSpec) assert isinstance(label_spec, tf.TypeSpec) + def test_element_spec_user_provided(self): + ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}]) + + dataset1 = ds.to_tf(feature_columns=["spam", "ham"], label_columns="eggs") + feature_spec, label_spec = dataset1.element_spec + dataset2 = ds.to_tf( + feature_columns=["spam", "ham"], + label_columns="eggs", + feature_type_spec=feature_spec, + label_spec=label_spec, + ) + feature_output_spec, label_output_spec = dataset2.element_spec + assert isinstance(label_output_spec, tf.TypeSpec) + assert isinstance(feature_output_spec, dict) + assert feature_output_spec.keys() == {"spam", "ham"} + assert all( + isinstance(value, tf.TypeSpec) + for value in feature_output_spec.values() + ) + def test_element_spec_type_with_multiple_columns(self): ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])