Skip to content

Commit

Permalink
[air] Pass on KMS-related kwargs for s3fs (ray-project#35938)
Browse files Browse the repository at this point in the history
We currently only parse and pass limited selection of options to s3fs. One recent request was related to passing KMS settings. This PR extends the s3 uri string to allow configuration of signature version, sse, sse key ID, and ACLs in s3 URIs if s3fs is used.

This PR also changes the fs caching logic, which is a requirement for options to be parsed again, e.g. when a key ID is changed in subsequent calls. FS cache keys now include the query string, and cache items are stale after 5 minutes and re-created. As a side-effect, this should fix any problems that come with cached filesystems, e.g. expiring credentials.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke committed Jun 1, 2023
1 parent 9c17d90 commit ba31fdf
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 18 deletions.
84 changes: 66 additions & 18 deletions python/ray/air/_internal/remote_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pathlib
import sys
import time
import urllib.parse
from pathlib import Path
from pkg_resources import packaging
Expand Down Expand Up @@ -29,6 +30,10 @@
from ray import logger


# Re-create fs objects after this amount of seconds
_CACHE_VALIDITY_S = 300


class _ExcludingLocalFilesystem(LocalFileSystem):
"""LocalFileSystem wrapper to exclude files according to patterns.
Expand Down Expand Up @@ -133,8 +138,26 @@ def is_non_local_path_uri(uri: str) -> bool:
return False


# Cache fs objects
_cached_fs = {}
# Cache fs objects. Map from cache_key --> timestamp, fs
_cached_fs: Dict[tuple, Tuple[float, pyarrow.fs.FileSystem]] = {}


def _get_cache(cache_key: tuple) -> Optional[pyarrow.fs.FileSystem]:
ts, fs = _cached_fs.get(cache_key, (0, None))
if not fs:
return None

now = time.monotonic()
if now - ts >= _CACHE_VALIDITY_S:
_cached_fs.pop(cache_key)
return None

return fs


def _put_cache(cache_key: tuple, fs: pyarrow.fs.FileSystem):
now = time.monotonic()
_cached_fs[cache_key] = (now, fs)


def _get_network_mounts() -> List[str]:
Expand Down Expand Up @@ -182,6 +205,18 @@ def _is_local_windows_path(path: str) -> bool:
return False


def _translate_options(
option_map: Dict[str, str], options: Dict[str, List[str]]
) -> Dict[str, str]:
"""Given mapping of old_name -> new_name in option_map, rename keys."""
translated = {}
for opt, target in option_map.items():
if opt in options:
translated[target] = options[opt][0]

return translated


def _translate_s3_options(options: Dict[str, List[str]]) -> Dict[str, Any]:
"""Translate pyarrow s3 query options into s3fs ``storage_kwargs``.
Expand All @@ -199,22 +234,38 @@ def _translate_s3_options(options: Dict[str, List[str]]) -> Dict[str, Any]:
"""
# Map from s3 query keys --> botocore client arguments
# client_kwargs
option_map = {
"endpoint_override": "endpoint_url",
"region": "region_name",
"access_key": "aws_access_key_id",
"secret_key": "aws_secret_access_key",
}
client_kwargs = _translate_options(option_map, options)

client_kwargs = {}
for opt, target in option_map.items():
if opt in options:
client_kwargs[target] = options[opt][0]
# config_kwargs
option_map = {
"signature_version": "signature_version",
}
config_kwargs = _translate_options(option_map, options)

# s3_additional_kwargs
option_map = {
"ServerSideEncryption": "ServerSideEncryption",
"SSEKMSKeyId": "SSEKMSKeyId",
"GrantFullControl": "GrantFullControl",
}
s3_additional_kwargs = _translate_options(option_map, options)

# s3fs directory cache does not work correctly, so we pass
# `use_listings_cache` to disable it. See https://github.com/fsspec/s3fs/issues/657
# We should keep this for s3fs versions <= 2023.4.0.
return {"client_kwargs": client_kwargs, "use_listings_cache": False}
return {
"use_listings_cache": False,
"client_kwargs": client_kwargs,
"config_kwargs": config_kwargs,
"s3_additional_kwargs": s3_additional_kwargs,
}


def _translate_gcs_options(options: Dict[str, List[str]]) -> Dict[str, Any]:
Expand All @@ -234,10 +285,7 @@ def _translate_gcs_options(options: Dict[str, List[str]]) -> Dict[str, Any]:
"endpoint_override": "endpoint_url",
}

storage_kwargs = {}
for opt, target in option_map.items():
if opt in options:
storage_kwargs[target] = options[opt][0]
storage_kwargs = _translate_options(option_map, options)

return storage_kwargs

Expand Down Expand Up @@ -281,9 +329,9 @@ def _get_fsspec_fs_and_path(uri: str) -> Optional["pyarrow.fs.FileSystem"]:
parsed = urllib.parse.urlparse(uri)

storage_kwargs = {}
if parsed.scheme in ["s3", "s3a"] and parsed.query:
if parsed.scheme in ["s3", "s3a"]:
storage_kwargs = _translate_s3_options(urllib.parse.parse_qs(parsed.query))
elif parsed.scheme in ["gs", "gcs"] and parsed.query:
elif parsed.scheme in ["gs", "gcs"]:
if not _has_compatible_gcsfs_version():
# If gcsfs is incompatible, fallback to pyarrow.fs.
return None
Expand Down Expand Up @@ -329,17 +377,17 @@ def get_fs_and_path(
else:
path = parsed.netloc + parsed.path

cache_key = (parsed.scheme, parsed.netloc)
cache_key = (parsed.scheme, parsed.netloc, parsed.query)

if cache_key in _cached_fs:
fs = _cached_fs[cache_key]
fs = _get_cache(cache_key)
if fs:
return fs, path

# Prefer fsspec over native pyarrow.
if fsspec:
fs = _get_fsspec_fs_and_path(uri)
if fs:
_cached_fs[cache_key] = fs
_put_cache(cache_key, fs)
return fs, path

# In case of hdfs filesystem, if uri does not have the netloc part below, it will
Expand All @@ -355,7 +403,7 @@ def get_fs_and_path(
# If no fsspec filesystem was found, use pyarrow native filesystem.
try:
fs, path = pyarrow.fs.FileSystem.from_uri(uri)
_cached_fs[cache_key] = fs
_put_cache(cache_key, fs)
return fs, path
except (pyarrow.lib.ArrowInvalid, pyarrow.lib.ArrowNotImplementedError):
# Raised when URI not recognized
Expand Down
81 changes: 81 additions & 0 deletions python/ray/air/tests/test_remote_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
import pytest
import shutil
import tempfile
import urllib.parse

from ray.air._internal.remote_storage import (
upload_to_uri,
download_from_uri,
get_fs_and_path,
_is_network_mount,
_translate_s3_options,
_CACHE_VALIDITY_S,
)
from ray.tune.utils.file_transfer import _get_recursive_files_and_stats

from freezegun import freeze_time


@pytest.fixture
def temp_data_dirs():
Expand Down Expand Up @@ -235,6 +240,82 @@ def test_is_network_mount(tmp_path, monkeypatch):
assert not _is_network_mount("") # cwd


def test_resolve_aws_kwargs():
def _uri_to_opt(uri: str):
parsed = urllib.parse.urlparse(uri)
return urllib.parse.parse_qs(parsed.query)

# client_kwargs
assert (
_translate_s3_options(_uri_to_opt("s3:https://some/where?endpoint_override=EP"))[
"client_kwargs"
]["endpoint_url"]
== "EP"
)

# config_kwargs
assert (
_translate_s3_options(_uri_to_opt("s3:https://some/where?signature_version=abc"))[
"config_kwargs"
]["signature_version"]
== "abc"
)

# s3_additional_kwargs
assert (
_translate_s3_options(_uri_to_opt("s3:https://some/where?SSEKMSKeyId=abc"))[
"s3_additional_kwargs"
]["SSEKMSKeyId"]
== "abc"
)

# no kwargs
assert (
_translate_s3_options(_uri_to_opt("s3:https://some/where"))["s3_additional_kwargs"]
== {}
)


def test_cache_time_eviction():
"""We use a time-based cache for filesystem objects.
This tests asserts that the cache is evicted after _CACHE_VALIDITY_S
seconds.
"""
with freeze_time() as frozen:
fs, path = get_fs_and_path("s3:https://some/where")
fs2, path = get_fs_and_path("s3:https://some/where")

assert id(fs) == id(fs2)

frozen.tick(_CACHE_VALIDITY_S - 10)

# Cache not expired yet
fs2, path = get_fs_and_path("s3:https://some/where")
assert id(fs) == id(fs2)

frozen.tick(10)

# Cache expired
fs2, path = get_fs_and_path("s3:https://some/where")
assert id(fs) != id(fs2)


def test_cache_uri_query():
"""We cache fs objects, but different query parameters should have different
cached objects."""
fs, path = get_fs_and_path("s3:https://some/where?only=we")
fs2, path = get_fs_and_path("s3:https://some/where?only=we")

# Same query parameters, so same object
assert id(fs) == id(fs2)

fs3, path = get_fs_and_path("s3:https://some/where?we=know")

# Different query parameters, so different object
assert id(fs) != id(fs3)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit ba31fdf

Please sign in to comment.