diff --git a/python/ray/data/datasource/lance_datasource.py b/python/ray/data/datasource/lance_datasource.py index a082b725f9f2e..d28bfbdb747ac 100644 --- a/python/ray/data/datasource/lance_datasource.py +++ b/python/ray/data/datasource/lance_datasource.py @@ -34,17 +34,16 @@ def __init__( self.columns = columns self.filter = filter self.storage_options = storage_options - self.lance_ds = lance.dataset(uri=uri, storage_options=storage_options) - self.fragments = self.lance_ds.get_fragments() def get_read_tasks(self, parallelism: int) -> List[ReadTask]: read_tasks = [] - for fragments in np.array_split(self.fragments, parallelism): + for fragments in np.array_split(self.lance_ds.get_fragments(), parallelism): if len(fragments) <= 0: continue - num_rows = sum([f.count_rows() for f in fragments]) + fragment_ids = [f.metadata.id for f in fragments] + num_rows = sum(f.count_rows() for f in fragments) input_files = [ data_file.path() for f in fragments for data_file in f.data_files() ] @@ -59,9 +58,12 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]: ) columns = self.columns row_filter = self.filter + lance_ds = self.lance_ds read_task = ReadTask( - lambda f=fragments: _read_fragments(f, columns, row_filter), + lambda f=fragment_ids: _read_fragments( + f, lance_ds, columns, row_filter + ), metadata, ) read_tasks.append(read_task) @@ -73,11 +75,17 @@ def estimate_inmemory_data_size(self) -> Optional[int]: return None -def _read_fragments(fragments, columns, row_filter) -> Iterator["pyarrow.Table"]: - """Read Lance fragments in batches.""" +def _read_fragments( + fragment_ids, lance_ds, columns, row_filter +) -> Iterator["pyarrow.Table"]: + """Read Lance fragments in batches. + + NOTE: Use fragment ids, instead of fragments as parameter, because pickling + LanceFragment is expensive. + """ import pyarrow - for fragment in fragments: - batches = fragment.to_batches(columns=columns, filter=row_filter) - for batch in batches: - yield pyarrow.Table.from_batches([batch]) + fragments = [lance_ds.get_fragment(id) for id in fragment_ids] + scanner = lance_ds.scanner(columns, filter=row_filter, fragments=fragments) + for batch in scanner.to_reader(): + yield pyarrow.Table.from_batches([batch]) diff --git a/python/ray/data/tests/test_lance.py b/python/ray/data/tests/test_lance.py index c92646d873999..148c6962000ed 100644 --- a/python/ray/data/tests/test_lance.py +++ b/python/ray/data/tests/test_lance.py @@ -7,6 +7,7 @@ from pytest_lazyfixture import lazy_fixture import ray +from ray._private.test_utils import wait_for_condition from ray._private.utils import _get_pyarrow_version from ray.data.datasource.path_util import _unwrap_protocol @@ -78,11 +79,33 @@ def test_lance_read_basic(fs, data_path): # Test column projection. ds = ray.data.read_lance(path, columns=["one"]) - values = [s["one"] for s in ds.take()] + values = [s["one"] for s in ds.take_all()] assert sorted(values) == [1, 2, 3, 4, 5, 6] assert ds.schema().names == ["one", "two", "three", "four"] +@pytest.mark.parametrize("data_path", [lazy_fixture("local_path")]) +def test_lance_read_many_files(data_path): + # NOTE: Lance only works with PyArrow 12 or above. + pyarrow_version = _get_pyarrow_version() + if pyarrow_version is not None: + pyarrow_version = parse_version(pyarrow_version) + if pyarrow_version is not None and pyarrow_version < parse_version("12.0.0"): + return + + setup_data_path = _unwrap_protocol(data_path) + path = os.path.join(setup_data_path, "test.lance") + num_rows = 1024 + data = pa.table({"id": pa.array(range(num_rows))}) + lance.write_dataset(data, path, max_rows_per_file=1) + + def test_lance(): + ds = ray.data.read_lance(path) + return ds.count() == num_rows + + wait_for_condition(test_lance, timeout=10) + + if __name__ == "__main__": import sys