Skip to content

Commit

Permalink
[Data] Improve filesystem retry coverage (ray-project#46685)
Browse files Browse the repository at this point in the history
See ray-project#43803 (comment).

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani committed Jul 19, 2024
1 parent d1ee314 commit 6b1fc0a
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 19 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,11 @@ py_test(
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_context",
size = "small",
srcs = ["tests/test_context.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def write_blocks_to_path():
call_with_retry(
write_blocks_to_path,
description=f"write '{write_path}'",
match=DataContext.get_current().write_file_retry_on_errors,
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down
25 changes: 17 additions & 8 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_check_pyarrow_version,
_is_local_scheme,
call_with_retry,
iterate_with_retry,
)
from ray.data.block import Block
from ray.data.context import DataContext
Expand Down Expand Up @@ -496,14 +497,22 @@ def _read_fragments(
use_threads = to_batches_kwargs.pop("use_threads", False)
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
for fragment in fragments:
batches = fragment.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
batch_size=batch_size,
**to_batches_kwargs,
)
for batch in batches:

def get_batch_iterable():
return fragment.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
batch_size=batch_size,
**to_batches_kwargs,
)

# S3 can raise transient errors during iteration, and PyArrow doesn't expose a
# way to retry specific batches.
ctx = ray.data.DataContext.get_current()
for batch in iterate_with_retry(
get_batch_iterable, "load batch", match=ctx.retried_io_errors
):
table = pa.Table.from_batches([batch], schema=schema)
if include_paths:
table = table.append_column("path", [[fragment.path]] * len(table))
Expand Down
51 changes: 51 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,57 @@ def call_with_retry(
raise e from None


def iterate_with_retry(
iterable_factory: Callable[[], Iterable],
description: str,
*,
match: Optional[List[str]] = None,
max_attempts: int = 10,
max_backoff_s: int = 32,
) -> Any:
"""Iterate through an iterable with retries.
If the iterable raises an exception, this function recreates and re-iterates
through the iterable, while skipping the items that have already been yielded.
Args:
iterable_factory: A no-argument function that creates the iterable.
match: A list of strings to match in the exception message. If ``None``, any
error is retried.
description: An imperitive description of the function being retried. For
example, "open the file".
max_attempts: The maximum number of attempts to retry.
max_backoff_s: The maximum number of seconds to backoff.
"""
assert max_attempts >= 1, f"`max_attempts` must be positive. Got {max_attempts}."

num_items_yielded = 0
for i in range(max_attempts):
try:
iterable = iterable_factory()
for i, item in enumerate(iterable):
if i < num_items_yielded:
# Skip items that have already been yielded.
continue

num_items_yielded += 1
yield item
return
except Exception as e:
is_retryable = match is None or any(
[pattern in str(e) for pattern in match]
)
if is_retryable and i + 1 < max_attempts:
# Retry with binary expoential backoff with random jitter.
backoff = min((2 ** (i + 1)), max_backoff_s) * random.random()
logger.debug(
f"Retrying {i+1} attempts to {description} after {backoff} seconds."
)
time.sleep(backoff)
else:
raise e from None


def create_dataset_tag(dataset_name: Optional[str], *args):
tag = dataset_name or "dataset"
for arg in args:
Expand Down
25 changes: 25 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import threading
import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -93,6 +94,13 @@
"AWS Error UNKNOWN (HTTP status 503)",
)

DEFAULT_RETRIED_IO_ERRORS = (
"AWS Error INTERNAL_FAILURE",
"AWS Error NETWORK_CONNECTION",
"AWS Error SLOW_DOWN",
"AWS Error UNKNOWN (HTTP status 503)",
)

DEFAULT_WARN_ON_DRIVER_MEMORY_USAGE_BYTES = 2 * 1024 * 1024 * 1024

DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = False
Expand Down Expand Up @@ -231,6 +239,9 @@ class DataContext:
call is made with a S3 URI.
wait_for_min_actors_s: The default time to wait for minimum requested
actors to start before raising a timeout, in seconds.
retried_io_errors: A list of substrings of error messages that should
trigger a retry when reading or writing files. This is useful for handling
transient errors when reading from remote storage systems.
"""

target_max_block_size: int = DEFAULT_TARGET_MAX_BLOCK_SIZE
Expand Down Expand Up @@ -277,6 +288,7 @@ class DataContext:
print_on_execution_start: bool = True
s3_try_create_dir: bool = DEFAULT_S3_TRY_CREATE_DIR
wait_for_min_actors_s: int = DEFAULT_WAIT_FOR_MIN_ACTORS_S
retried_io_errors: List[str] = DEFAULT_RETRIED_IO_ERRORS

def __post_init__(self):
# The additonal ray remote args that should be added to
Expand All @@ -293,6 +305,19 @@ def __post_init__(self):
DEFAULT_MAX_NUM_BLOCKS_IN_STREAMING_GEN_BUFFER
)

def __setattr__(self, name: str, value: Any) -> None:
if (
name == "write_file_retry_on_errors"
and value != DEFAULT_WRITE_FILE_RETRY_ON_ERRORS
):
warnings.warn(
"`write_file_retry_on_errors` is deprecated. Configure "
"`retried_io_errors` instead.",
DeprecationWarning,
)

super().__setattr__(name, value)

@staticmethod
def get_current() -> "DataContext":
"""Get or create a singleton context.
Expand Down
16 changes: 10 additions & 6 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@
# 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
PATHS_PER_FILE_SIZE_FETCH_TASK = 16

# The errors to retry for opening file.
OPEN_FILE_RETRY_ON_ERRORS = ["AWS Error SLOW_DOWN", "AWS Error ACCESS_DENIED"]

# The max retry backoff in seconds for opening file.
OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS = 32

Expand Down Expand Up @@ -300,6 +297,8 @@ def _open_input_source(
import pyarrow as pa
from pyarrow.fs import HadoopFileSystem

ctx = DataContext.get_current()

compression = open_args.get("compression", None)
if compression is None:
try:
Expand All @@ -319,7 +318,6 @@ def _open_input_source(

buffer_size = open_args.pop("buffer_size", None)
if buffer_size is None:
ctx = DataContext.get_current()
buffer_size = ctx.streaming_read_buffer_size

if compression == "snappy":
Expand All @@ -330,7 +328,13 @@ def _open_input_source(
else:
open_args["compression"] = compression

file = filesystem.open_input_stream(path, buffer_size=buffer_size, **open_args)
file = call_with_retry(
lambda: filesystem.open_input_stream(
path, buffer_size=buffer_size, **open_args
),
description=f"open file {path}",
match=ctx.retried_io_errors,
)

if compression == "snappy":
import snappy
Expand Down Expand Up @@ -514,7 +518,7 @@ def _open_file_with_retry(
return call_with_retry(
open_file,
description=f"open file {file_path}",
match=OPEN_FILE_RETRY_ON_ERRORS,
match=DataContext.get_current().retried_io_errors,
max_attempts=OPEN_FILE_MAX_ATTEMPTS,
max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
4 changes: 2 additions & 2 deletions python/ray/data/datasource/file_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def write_row_to_path():
call_with_retry(
write_row_to_path,
description=f"write '{write_path}'",
match=DataContext.get_current().write_file_retry_on_errors,
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down Expand Up @@ -254,7 +254,7 @@ def write_block_to_path():
call_with_retry(
write_block_to_path,
description=f"write '{write_path}'",
match=DataContext.get_current().write_file_retry_on_errors,
match=DataContext.get_current().retried_io_errors,
max_attempts=WRITE_FILE_MAX_ATTEMPTS,
max_backoff_s=WRITE_FILE_RETRY_MAX_BACKOFF_SECONDS,
)
Expand Down
9 changes: 8 additions & 1 deletion python/ray/data/datasource/file_meta_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import numpy as np

import ray
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import call_with_retry
from ray.data.block import BlockMetadata
from ray.data.datasource.partitioning import Partitioning
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -418,7 +420,12 @@ def _get_file_infos(

file_infos = []
try:
file_info = filesystem.get_file_info(path)
ctx = ray.data.DataContext.get_current()
file_info = call_with_retry(
lambda: filesystem.get_file_info(path),
description="get file info",
match=ctx.retried_io_errors,
)
except OSError as e:
_handle_read_os_error(e, path)
if file_info.type == FileType.Directory:
Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

import ray


def test_write_file_retry_on_errors_emits_deprecation_warning(caplog):
ctx = ray.data.DataContext.get_current()
with pytest.warns(DeprecationWarning):
ctx.write_file_retry_on_errors = []


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
36 changes: 35 additions & 1 deletion python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
trace_allocation,
trace_deallocation,
)
from ray.data._internal.util import _check_pyarrow_version, _split_list
from ray.data._internal.util import (
_check_pyarrow_version,
_split_list,
iterate_with_retry,
)
from ray.data.tests.conftest import * # noqa: F401, F403


Expand Down Expand Up @@ -131,6 +135,36 @@ def get_max_concurrency(self):
return self.max_concurrency


def test_iterate_with_retry():
has_raised_error = False

class MockIterable:
"""Iterate over the numbers 0, 1, 2, and raise an error on the first iteration
attempt.
"""

def __init__(self):
self._index = -1

def __iter__(self):
return self

def __next__(self):
self._index += 1

if self._index >= 3:
raise StopIteration

nonlocal has_raised_error
if self._index == 1 and not has_raised_error:
has_raised_error = True
raise RuntimeError("Transient error")

return self._index

assert list(iterate_with_retry(MockIterable, description="get item")) == [0, 1, 2]


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 6b1fc0a

Please sign in to comment.