Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng Su <[email protected]>
  • Loading branch information
c21 committed Feb 7, 2024
1 parent dc71419 commit e96d8ba
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ def test_element_spec_type(self):
assert isinstance(feature_spec, tf.TypeSpec)
assert isinstance(label_spec, tf.TypeSpec)

def test_element_spec_user_provided(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

dataset1 = ds.to_tf(feature_columns=["spam", "ham"], label_columns="eggs")
feature_spec, label_spec = dataset1.element_spec
dataset2 = ds.to_tf(
feature_columns=["spam", "ham"],
label_columns="eggs",
feature_type_spec=feature_spec,
label_spec=label_spec,
)
feature_output_spec, label_output_spec = dataset2.element_spec
assert isinstance(label_output_spec, tf.TypeSpec)
assert isinstance(feature_output_spec, dict)
assert feature_output_spec.keys() == {"spam", "ham"}
assert all(
isinstance(value, tf.TypeSpec)
for value in feature_output_spec.values()
)

def test_element_spec_type_with_multiple_columns(self):
ds = ray.data.from_items([{"spam": 0, "ham": 0, "eggs": 0}])

Expand Down

0 comments on commit e96d8ba

Please sign in to comment.