diff --git a/python/ray/data/BUILD b/python/ray/data/BUILD index a390e8e828039..f750667250d1a 100644 --- a/python/ray/data/BUILD +++ b/python/ray/data/BUILD @@ -169,6 +169,14 @@ py_test( deps = ["//:ray_lib", ":conftest"], ) +py_test( + name = "test_file_based_datasource", + size = "small", + srcs = ["tests/test_file_based_datasource.py"], + tags = ["team:data", "exclusive"], + deps = ["//:ray_lib", ":conftest"], +) + py_test( name = "test_image", size = "small", diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 3d92b3be958d1..096617400e110 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3317,6 +3317,7 @@ def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]: try: import pandas as pd + datasource.on_write_start(**write_args) self._write_ds = Dataset( plan, self._epoch, self._lazy, logical_plan ).materialize() @@ -3326,7 +3327,7 @@ def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]: for block in blocks ) write_results = [block["write_result"][0] for block in blocks] - datasource.on_write_complete(write_results) + datasource.on_write_complete(write_results, **write_args) except Exception as e: datasource.on_write_failed([], e) raise diff --git a/python/ray/data/datasource/datasource.py b/python/ray/data/datasource/datasource.py index 2c6e72b24d414..88cefd3dfb19d 100644 --- a/python/ray/data/datasource/datasource.py +++ b/python/ray/data/datasource/datasource.py @@ -51,6 +51,17 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask"]: """Deprecated: Please implement create_reader() instead.""" raise NotImplementedError + def on_write_start(self, **write_args) -> None: + """Callback for when a write job starts. + + Use this method to perform setup for write tasks. For example, creating a + staging bucket in S3. + + Args: + write_args: Additional kwargs to pass to the datasource impl. + """ + pass + def write( self, blocks: Iterable[Block], diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 25a2e78156795..b9af6bc901df5 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -261,6 +261,32 @@ def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args) -> Block "Subclasses of FileBasedDatasource must implement _read_file()." ) + def on_write_start( + self, + path: str, + try_create_dir: bool = True, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + **write_args, + ) -> None: + """Create a directory to write files to. + + If ``try_create_dir`` is ``False``, this method is a no-op. + """ + from pyarrow.fs import FileType + + self.has_created_dir = False + if try_create_dir: + paths, filesystem = _resolve_paths_and_filesystem(path, filesystem) + assert len(paths) == 1, len(paths) + path = paths[0] + + if filesystem.get_file_info(path).type is FileType.NotFound: + # Arrow's S3FileSystem doesn't allow creating buckets by default, so we + # add a query arg enabling bucket creation if an S3 URI is provided. + tmp = _add_creatable_buckets_param_if_s3_uri(path) + filesystem.create_dir(tmp, recursive=True) + self.has_created_dir = True + def write( self, blocks: Iterable[Block], @@ -306,15 +332,6 @@ def write( if block.num_rows() == 0: continue - if block_idx == 0: - # On the first non-empty block, try to create the directory. - if try_create_dir: - # Arrow's S3FileSystem doesn't allow creating buckets by - # default, so we add a query arg enabling bucket creation - # if an S3 URI is provided. - tmp = _add_creatable_buckets_param_if_s3_uri(path) - filesystem.create_dir(tmp, recursive=True) - fs = _unwrap_s3_serialization_workaround(filesystem) if self._WRITE_FILE_PER_ROW: @@ -367,6 +384,23 @@ def write( # succeeds. return "ok" + def on_write_complete( + self, + write_results: List[WriteResult], + path: Optional[str] = None, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + **kwargs, + ) -> None: + if not self.has_created_dir: + return + + paths, filesystem = _resolve_paths_and_filesystem(path, filesystem) + assert len(paths) == 1, len(paths) + path = paths[0] + + if all(write_results == "skip" for write_results in write_results): + filesystem.delete_dir(path) + def _write_block( self, f: "pyarrow.NativeFile", diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py new file mode 100644 index 0000000000000..0f79e2f3daceb --- /dev/null +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -0,0 +1,43 @@ +import os + +import pyarrow +import pytest + +import ray +from ray.data.block import BlockAccessor +from ray.data.datasource import FileBasedDatasource + + +class MockFileBasedDatasource(FileBasedDatasource): + def _write_block( + self, f: "pyarrow.NativeFile", block: BlockAccessor, **writer_args + ): + f.write(b"") + + +@pytest.mark.parametrize("num_rows", [0, 1]) +def test_write_preserves_user_directory(num_rows, tmp_path, ray_start_regular_shared): + ds = ray.data.range(num_rows) + path = os.path.join(tmp_path, "test") + os.mkdir(path) # User-created directory + + ds.write_datasource(MockFileBasedDatasource(), dataset_uuid=ds._uuid, path=path) + + assert os.path.isdir(path) + + +def test_write_creates_dir(tmp_path, ray_start_regular_shared): + ds = ray.data.range(1) + path = os.path.join(tmp_path, "test") + + ds.write_datasource( + MockFileBasedDatasource(), dataset_uuid=ds._uuid, path=path, try_create_dir=True + ) + + assert os.path.isdir(path) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__]))