Skip to content

Commit

Permalink
fix: replace deprecated ray parallelism arg with override_num_blocks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kukushking committed Jun 27, 2024
1 parent 30276b2 commit 3c38d63
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 20 deletions.
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _read_orc_distributed(
schema: pa.schema | None,
columns: list[str] | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand All @@ -43,7 +43,7 @@ def _read_orc_distributed(
)
ray_dataset = read_datasource(
datasource,
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
to_pandas_kwargs = _data_types.pyarrow2pandas_defaults(
use_threads=use_threads,
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _read_parquet_distributed(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand All @@ -60,7 +60,7 @@ def _read_parquet_distributed(
"dataset_kwargs": dataset_kwargs,
},
),
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
return _to_modin(
dataset=dataset,
Expand Down
4 changes: 2 additions & 2 deletions awswrangler/distributed/ray/modin/s3/_read_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _read_text_distributed(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame:
Expand Down Expand Up @@ -172,6 +172,6 @@ def _read_text_distributed(
meta_provider=FastFileMetadataProvider(),
**configuration,
),
parallelism=parallelism,
override_num_blocks=override_num_blocks,
)
return _to_modin(dataset=ray_dataset, ignore_index=ignore_index)
18 changes: 18 additions & 0 deletions awswrangler/s3/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from awswrangler.catalog._utils import _catalog_id
from awswrangler.distributed.ray import ray_get
from awswrangler.s3._list import _path2list, _prefix_cleanup
from awswrangler.typing import RaySettings

if TYPE_CHECKING:
from mypy_boto3_glue.type_defs import GetTableResponseTypeDef
Expand Down Expand Up @@ -377,3 +378,20 @@ def _get_paths_for_glue_table(
)

return paths, path_root, res


def _get_num_output_blocks(
ray_args: RaySettings | None = None,
) -> int:
ray_args = ray_args or {}
parallelism = ray_args.get("parallelism", -1)
override_num_blocks = ray_args.get("override_num_blocks")
if parallelism != -1:
pass
_logger.warning(
"The argument ``parallelism`` is deprecated and will be removed in the next major release. "
"Please specify ``override_num_blocks`` instead."
)
elif override_num_blocks is not None:
parallelism = override_num_blocks
return parallelism
7 changes: 3 additions & 4 deletions awswrangler/s3/_read_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_apply_partition_filter,
_check_version_id,
_extract_partitions_dtypes_from_table_details,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_get_paths_for_glue_table,
Expand Down Expand Up @@ -137,7 +138,7 @@ def _read_orc(
schema: pa.schema | None,
columns: list[str] | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down Expand Up @@ -283,8 +284,6 @@ def read_orc(
>>> df = wr.s3.read_orc(path, dataset=True, partition_filter=my_filter)
"""
ray_args = ray_args if ray_args else {}

s3_client = _utils.client(service_name="s3", session=boto3_session)
paths: list[str] = _path2list(
path=path,
Expand Down Expand Up @@ -330,7 +329,7 @@ def read_orc(
schema=schema,
columns=columns,
use_threads=use_threads,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
arrow_kwargs=arrow_kwargs,
Expand Down
5 changes: 3 additions & 2 deletions awswrangler/s3/_read_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_apply_partition_filter,
_check_version_id,
_extract_partitions_dtypes_from_table_details,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_get_paths_for_glue_table,
Expand Down Expand Up @@ -285,7 +286,7 @@ def _read_parquet(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down Expand Up @@ -562,7 +563,7 @@ def read_parquet(
columns=columns,
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
use_threads=use_threads,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
s3_client=s3_client,
s3_additional_kwargs=s3_additional_kwargs,
arrow_kwargs=arrow_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_read_parquet.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _read_parquet(
columns: list[str] | None,
coerce_int96_timestamp_unit: str | None,
use_threads: bool | int,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
s3_client: "S3Client" | None,
s3_additional_kwargs: dict[str, Any] | None,
Expand Down
6 changes: 3 additions & 3 deletions awswrangler/s3/_read_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from awswrangler.s3._read import (
_apply_partition_filter,
_check_version_id,
_get_num_output_blocks,
_get_path_ignore_suffix,
_get_path_root,
_union,
Expand Down Expand Up @@ -52,7 +53,7 @@ def _read_text(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame:
Expand Down Expand Up @@ -131,7 +132,6 @@ def _read_text_format(
**args,
)

ray_args = ray_args if ray_args else {}
return _read_text(
read_format,
paths=paths,
Expand All @@ -141,7 +141,7 @@ def _read_text_format(
s3_additional_kwargs=s3_additional_kwargs,
dataset=dataset,
ignore_index=ignore_index,
parallelism=ray_args.get("parallelism", -1),
override_num_blocks=_get_num_output_blocks(ray_args),
version_ids=version_ids,
pandas_kwargs=pandas_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/s3/_read_text.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_text(
s3_additional_kwargs: dict[str, str] | None,
dataset: bool,
ignore_index: bool,
parallelism: int,
override_num_blocks: int,
version_ids: dict[str, str] | None,
pandas_kwargs: dict[str, Any],
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
Expand Down
7 changes: 7 additions & 0 deletions awswrangler/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ class RaySettings(TypedDict):
Parallelism may be limited by the number of files of the dataset.
Auto-detect by default.
"""
override_num_blocks: NotRequired[int]
"""
Override the number of output blocks from all read tasks.
By default, the number of output blocks is dynamically decided based on
input data size and available resources. You shouldn't manually set this
value in most cases.
"""


class RayReadParquetSettings(RaySettings):
Expand Down
2 changes: 1 addition & 1 deletion tests/glue_scripts/ray_read_small_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
import awswrangler as wr

paths = wr.s3.list_objects(f"s3:https://{os.environ['data-gen-bucket']}/parquet/small/partitioned/")
ray.data.read_parquet_bulk(paths=paths, parallelism=1000).to_modin()
ray.data.read_parquet_bulk(paths=paths, override_num_blocks=1000).to_modin()
2 changes: 1 addition & 1 deletion tests/glue_scripts/wrangler_read_small_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

wr.s3.read_parquet(
path=f"s3:https://{os.environ['data-gen-bucket']}/parquet/small/partitioned/",
ray_args={"parallelism": 1000, "bulk_read": True},
ray_args={"override_num_blocks": 1000, "bulk_read": True},
)
2 changes: 1 addition & 1 deletion tests/glue_scripts/wrangler_write_partitioned_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

df = wr.s3.read_parquet(
path=f"s3:https://{os.environ['data-gen-bucket']}/parquet/medium/partitioned/",
ray_args={"parallelism": 1000},
ray_args={"override_num_blocks": 1000},
)

wr.s3.to_parquet(
Expand Down

0 comments on commit 3c38d63

Please sign in to comment.