Skip to content

Commit

Permalink
[workflow] Enhance dataset tests (ray-project#25876)
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark committed Jun 17, 2022
1 parent ce02ac0 commit fea8dd0
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions python/ray/workflow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,35 @@

@ray.remote
def gen_dataset():
# TODO(ekl) seems checkpointing hangs with nested refs of
# LazyBlockList.
return ray.data.range(1000).map(lambda x: x)


@ray.remote
def gen_dataset_1():
return ray.data.range(1000)


@ray.remote
def gen_dataset_2():
return ray.data.range_table(1000)


@ray.remote
def transform_dataset(in_data):
return in_data.map(lambda x: x * 2)


@ray.remote
def transform_dataset_1(in_data):
return in_data.map(lambda r: {"v2": r["value"] * 2})


@ray.remote
def sum_dataset(ds):
return ds.sum()


def test_dataset(workflow_start_regular):
def test_dataset(workflow_start_regular_shared):
ds_ref = gen_dataset.bind()
transformed_ref = transform_dataset.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)
Expand All @@ -32,6 +45,24 @@ def test_dataset(workflow_start_regular):
assert result == 2 * sum(range(1000))


def test_dataset_1(workflow_start_regular_shared):
ds_ref = gen_dataset_1.bind()
transformed_ref = transform_dataset.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)

result = workflow.create(output_ref).run()
assert result == 2 * sum(range(1000))


def test_dataset_2(workflow_start_regular_shared):
ds_ref = gen_dataset_2.bind()
transformed_ref = transform_dataset_1.bind(ds_ref)
output_ref = sum_dataset.bind(transformed_ref)

result = workflow.create(output_ref).run()
assert result == 2 * sum(range(1000))


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit fea8dd0

Please sign in to comment.